diff --git a/bpf/kprobe_pwru.c b/bpf/kprobe_pwru.c index 092f53a4..73c8d950 100644 --- a/bpf/kprobe_pwru.c +++ b/bpf/kprobe_pwru.c @@ -28,6 +28,7 @@ const static bool TRUE = true; +const static u32 ZERO = 0; volatile const static __u64 BPF_PROG_ADDR = 0; @@ -65,9 +66,6 @@ struct tuple { u8 pad; } __attribute__((packed)); -u64 print_skb_id = 0; -u64 print_shinfo_id = 0; - enum event_type { EVENT_TYPE_KPROBE = 0, EVENT_TYPE_TC = 1, @@ -81,8 +79,8 @@ struct event_t { u64 caller_addr; u64 skb_head; u64 ts; - typeof(print_skb_id) print_skb_id; - typeof(print_shinfo_id) print_shinfo_id; + u64 print_skb_id; + u64 print_shinfo_id; struct skb_meta meta; struct tuple tuple; s64 print_stack_id; @@ -164,15 +162,27 @@ struct print_shinfo_value { char str[PRINT_SHINFO_STR_SIZE]; }; struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 256); + __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); + __uint(max_entries, 1); __type(key, u32); + __type(value, u32); +} print_skb_id_map SEC(".maps"); +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 256); + __type(key, u64); __type(value, struct print_skb_value); } print_skb_map SEC(".maps"); struct { - __uint(type, BPF_MAP_TYPE_ARRAY); - __uint(max_entries, 256); + __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY); + __uint(max_entries, 1); __type(key, u32); + __type(value, u32); +} print_shinfo_id_map SEC(".maps"); +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, 256); + __type(key, u64); __type(value, struct print_shinfo_value); } print_shinfo_map SEC(".maps"); #endif @@ -312,45 +322,42 @@ set_tuple(struct sk_buff *skb, struct tuple *tpl) { __set_tuple(tpl, skb_head, l3_off, is_ipv4); } +static __always_inline u64 +sync_fetch_and_add(void *id_map) { + u32 *id = bpf_map_lookup_elem(id_map, &ZERO); + if (id) + return ((*id)++) | ((u64)bpf_get_smp_processor_id() << 32); + return 0; +} static __always_inline void -set_skb_btf(struct sk_buff *skb, typeof(print_skb_id) *event_id) { +set_skb_btf(struct sk_buff *skb, u64 *event_id) { #ifdef OUTPUT_SKB static struct btf_ptr p = {}; - struct print_skb_value *v; - typeof(print_skb_id) id; - long n; + static struct print_skb_value v = {}; + u64 id; p.type_id = bpf_core_type_id_kernel(struct sk_buff); p.ptr = skb; - id = __sync_fetch_and_add(&print_skb_id, 1) % 256; + *event_id = sync_fetch_and_add(&print_skb_id_map); - v = bpf_map_lookup_elem(&print_skb_map, (u32 *) &id); - if (!v) { + v.len = bpf_snprintf_btf(v.str, PRINT_SKB_STR_SIZE, &p, sizeof(p), 0); + if (v.len < 0) { return; } - n = bpf_snprintf_btf(v->str, PRINT_SKB_STR_SIZE, &p, sizeof(p), 0); - if (n < 0) { - return; - } - - v->len = n; - - *event_id = id; + bpf_map_update_elem(&print_skb_map, event_id, &v, BPF_ANY); #endif } static __always_inline void -set_shinfo_btf(struct sk_buff *skb, typeof(print_shinfo_id) *event_id) { +set_shinfo_btf(struct sk_buff *skb, u64 *event_id) { #ifdef OUTPUT_SKB struct skb_shared_info *shinfo; static struct btf_ptr p = {}; - struct print_shinfo_value *v; - typeof(print_shinfo_id) id; + static struct print_shinfo_value v = {}; unsigned char *head; unsigned int end; - long n; /* skb_shared_info is located at the end of skb data. * When CONFIG_NET_SKBUFF_DATA_USES_OFFSET is enabled, skb->end @@ -366,21 +373,14 @@ set_shinfo_btf(struct sk_buff *skb, typeof(print_shinfo_id) *event_id) { p.type_id = bpf_core_type_id_kernel(struct skb_shared_info); p.ptr = shinfo; - id = __sync_fetch_and_add(&print_shinfo_id, 1) % 256; + *event_id = sync_fetch_and_add(&print_shinfo_id_map); - v = bpf_map_lookup_elem(&print_shinfo_map, (u32 *) &id); - if (!v) { + v.len = bpf_snprintf_btf(v.str, PRINT_SHINFO_STR_SIZE, &p, sizeof(p), 0); + if (v.len < 0) { return; } - n = bpf_snprintf_btf(v->str, PRINT_SHINFO_STR_SIZE, &p, sizeof(p), 0); - if (n < 0) { - return; - } - - v->len = n; - - *event_id = id; + bpf_map_update_elem(&print_shinfo_map, event_id, &v, BPF_ANY); #endif } diff --git a/internal/pwru/output.go b/internal/pwru/output.go index b284b5a8..189f7371 100644 --- a/internal/pwru/output.go +++ b/internal/pwru/output.go @@ -5,7 +5,6 @@ package pwru import ( - "encoding/binary" "encoding/json" "errors" "fmt" @@ -287,39 +286,37 @@ func getStackData(event *Event, o *output) (stackData string) { } func getSkbData(event *Event, o *output) (skbData string) { - id := uint32(event.PrintSkbId) + id := uint64(event.PrintSkbId) b, err := o.printSkbMap.LookupBytes(&id) if err != nil { return "" } - length := binary.NativeEndian.Uint32(b[:4]) + defer o.printSkbMap.Delete(&id) - // Bounds check - if int(length+4) > len(b) { + if len(b) < 4 { return "" } - return "\n" + string(b[4:4+length]) + return "\n" + string(b[4:]) } func getShinfoData(event *Event, o *output) (shinfoData string) { - id := uint32(event.PrintShinfoId) + id := uint64(event.PrintShinfoId) b, err := o.printShinfoMap.LookupBytes(&id) if err != nil { return "" } - length := binary.NativeEndian.Uint32(b[:4]) + defer o.printShinfoMap.Delete(&id) - // Bounds check - if int(length+4) > len(b) { + if len(b) < 4 { return "" } - return "\n" + string(b[4:4+length]) + return "\n" + string(b[4:]) } func getMetaData(event *Event, o *output) (metaData string) {