Handle packet rate limits, too
[flowspec-xdp] / xdp.c
diff --git a/xdp.c b/xdp.c
index 8003b89e0066b4311fd8eb482a922718914166e4..fc8f2c3bff94b7d24954ccc34dc3fc598a41b148 100644 (file)
--- a/xdp.c
+++ b/xdp.c
@@ -144,28 +144,38 @@ static const int XDP_PASS = 0;
 static const int XDP_DROP = 1;
 
 static long drop_cnt_map[RULECNT + STATIC_RULE_CNT];
-#define INCREMENT_MATCH(reason) drop_cnt_map[reason] += 1;
+#define INCREMENT_MATCH(reason) { drop_cnt_map[reason] += 1; drop_cnt_map[reason] += data_end - pktdata; }
 
 #else
 #include <linux/bpf.h>
 #include <bpf/bpf_helpers.h>
 
-struct bpf_map_def SEC("maps") drop_cnt_map = {
-       .type = BPF_MAP_TYPE_PERCPU_ARRAY,
-       .key_size = sizeof(uint32_t),
-       .value_size = sizeof(long),
-       .max_entries = RULECNT + STATIC_RULE_CNT,
+struct match_counter {
+       uint64_t bytes;
+       uint64_t packets;
 };
+struct {
+       __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
+       __uint(max_entries, RULECNT + STATIC_RULE_CNT);
+       __u32 *key;
+       struct match_counter *value;
+} drop_cnt_map SEC(".maps");
+
 #define INCREMENT_MATCH(reason) { \
-       long *value = bpf_map_lookup_elem(&drop_cnt_map, &reason); \
-       if (value) \
-               *value += 1; \
+       struct match_counter *value = bpf_map_lookup_elem(&drop_cnt_map, &reason); \
+       if (value) { \
+               value->bytes += data_end - pktdata; \
+               value->packets += 1; \
+       } \
 }
 
 #ifdef RATE_CNT
 struct ratelimit {
        struct bpf_spin_lock lock;
-       int64_t sent_bytes;
+       union {
+               int64_t sent_bytes;
+               int64_t sent_packets;
+       } rate;
        int64_t sent_time;
 };
 struct {
@@ -186,30 +196,32 @@ int xdp_drop_prog(struct xdp_md *ctx)
        unsigned short eth_proto;
 
        {
+               // DO_RETURN in CHECK_LEN relies on pktdata being set to calculate packet length.
+               // That said, we don't want to overflow, so just set packet length to 0 here.
+               pktdata = data_end;
                CHECK_LEN((size_t)ctx->data, ethhdr);
                const struct ethhdr *const eth = (void*)(size_t)ctx->data;
+               pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr);
 
 #if PARSE_8021Q == PARSE
                if (likely(eth->h_proto == BE16(ETH_P_8021Q))) {
                        CHECK_LEN((size_t)ctx->data, ethhdr_vlan);
                        const struct ethhdr_vlan *const eth_vlan = (void*)(size_t)ctx->data;
-
+                       pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr_vlan);
 #ifdef REQ_8021Q
                        if (unlikely((eth_vlan->tci & BE16(0xfff)) != BE16(REQ_8021Q)))
                                DO_RETURN(VLAN_DROP, XDP_DROP);
 #endif
-
                        eth_proto = eth_vlan->h_proto;
-                       pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr_vlan);
 #else
                if (unlikely(eth->h_proto == BE16(ETH_P_8021Q))) {
+                       pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr_vlan);
                        DO_RETURN(VLAN_DROP, PARSE_8021Q);
 #endif
                } else {
 #ifdef REQ_8021Q
                        DO_RETURN(VLAN_DROP, XDP_DROP);
 #else
-                       pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr);
                        eth_proto = eth->h_proto;
 #endif
                }