]> git.bitcoin.ninja Git - flowspec-xdp/commitdiff
Use common per-source lookup between v4 and v6/64
authorMatt Corallo <git@bluematt.me>
Mon, 17 Jun 2024 06:27:11 +0000 (06:27 +0000)
committerMatt Corallo <git@bluematt.me>
Mon, 17 Jun 2024 06:29:59 +0000 (06:29 +0000)
Since v4 has 32 bits of slack in the struct anyway, we might as
well just make it a u64 and unify the codepaths.

genrules.py
siphash.h
xdp.c

index 647152aaa847576d3be781ae36cb8cdb409edb02..cdbeb5b89f61f5a0c8cdcb6f64d445ea9dce3e89 100755 (executable)
@@ -433,9 +433,8 @@ with open("rules.h", "w") as out:
     use_v6_frags = False
     stats_rulecnt = 0
     ratelimitcnt = 0
-    v4persrcratelimits = []
-    v5persrcratelimits = []
-    v6persrcratelimits = []
+    persrcratelimits64 = []
+    persrcratelimits128 = []
 
     lastrule = None
     for line in sys.stdin.readlines():
@@ -559,22 +558,22 @@ with open("rules.h", "w") as out:
                             if proto == 4:
                                 if mid_byte > 32:
                                     continue
-                                first_action += f"const uint32_t srcip = ip->saddr & MASK4({mid_byte});\n"
-                                first_action += f"void *rate_map = &v4_src_rate_{len(v4persrcratelimits)};\n"
-                                first_action += f"int matched = check_v4_persrc_ratelimit(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
-                                v4persrcratelimits.append((high_byte + 1) * 4096)
+                                first_action += f"const uint64_t srcip = ip->saddr & MASK4({mid_byte});\n"
+                                first_action += f"void *rate_map = &src_rate_64_{len(persrcratelimits64)};\n"
+                                first_action += f"int matched = check_persrc_ratelimit_64(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
+                                persrcratelimits64.append((high_byte + 1) * 4096)
                             elif mid_byte <= 64:
                                 first_action += f"const uint64_t srcip = BE128BEHIGH64(ip6->saddr & MASK6({mid_byte}));\n"
-                                first_action += f"void *rate_map = &v5_src_rate_{len(v5persrcratelimits)};\n"
-                                first_action += f"int matched = check_v5_persrc_ratelimit(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
-                                v5persrcratelimits.append((high_byte + 1) * 4096)
+                                first_action += f"void *rate_map = &src_rate_64_{len(persrcratelimits64)};\n"
+                                first_action += f"int matched = check_persrc_ratelimit_64(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
+                                persrcratelimits64.append((high_byte + 1) * 4096)
                             else:
                                 if mid_byte > 128:
                                     continue
                                 first_action += f"const uint128_t srcip = ip6->saddr & MASK6({mid_byte});\n"
-                                first_action += f"void *rate_map = &v6_src_rate_{len(v6persrcratelimits)};\n"
-                                first_action += f"int matched = check_v6_persrc_ratelimit(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
-                                v6persrcratelimits.append((high_byte + 1) * 4096)
+                                first_action += f"void *rate_map = &src_rate_128_{len(persrcratelimits128)};\n"
+                                first_action += f"int matched = check_persrc_ratelimit_128(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
+                                persrcratelimits128.append((high_byte + 1) * 4096)
                         first_action +=  "if (matched) {\n"
                         first_action +=  "\t{stats_replace}\n"
                         first_action +=  "\treturn XDP_DROP;\n"
@@ -647,9 +646,7 @@ with open("rules.h", "w") as out:
         if use_v6_frags:
             out.write("#define PARSE_V6_FRAG PARSE\n")
     with open("maps.h", "w") as out:
-        for idx, limit in enumerate(v4persrcratelimits):
-            out.write(f"SRC_RATE_DEFINE(4, {idx}, {limit})\n")
-        for idx, limit in enumerate(v5persrcratelimits):
-            out.write(f"SRC_RATE_DEFINE(5, {idx}, {limit})\n")
-        for idx, limit in enumerate(v6persrcratelimits):
-            out.write(f"SRC_RATE_DEFINE(6, {idx}, {limit})\n")
+        for idx, limit in enumerate(persrcratelimits64):
+            out.write(f"SRC_RATE_DEFINE(64, {idx}, {limit})\n")
+        for idx, limit in enumerate(persrcratelimits128):
+            out.write(f"SRC_RATE_DEFINE(128, {idx}, {limit})\n")
index 298a3a2dd626470915f21c7fe39fa0f0d56e827d..1d748b9beab3298622b9d7b566f276825d151322 100644 (file)
--- a/siphash.h
+++ b/siphash.h
@@ -145,10 +145,6 @@ static inline uint64_t siphash(const void *in, const size_t inlen, const uint8_t
 static uint64_t siphash_uint64_t(const uint64_t in) {
        return siphash(&in, sizeof(uint64_t), COMPILE_TIME_RAND);
 }
-__attribute__((always_inline))
-static inline uint64_t siphash_uint32_t(const uint32_t in) {
-       return siphash_uint64_t(in);
-}
 static uint64_t siphash_uint128_t(const __uint128_t in) {
        return siphash(&in, sizeof(__uint128_t), COMPILE_TIME_RAND);
 }
diff --git a/xdp.c b/xdp.c
index d5bd0b05ccd40b5331ebfcd7566cb3075cec3d39..414c98320179445740dc19d13c152c35b5d1a7ca 100644 (file)
--- a/xdp.c
+++ b/xdp.c
@@ -266,28 +266,28 @@ if (rate) { \
 } \
 } while(0);
 
-#define CREATE_PERSRC_LOOKUP(IPV, IP_TYPE) \
-struct persrc_rate##IPV##_entry { \
+#define CREATE_PERSRC_LOOKUP(LEN, IP_TYPE) \
+struct persrc_rate_##LEN##_entry { \
        uint64_t pkts_and_time; \
        IP_TYPE srcip; \
 }; \
  \
-struct persrc_rate##IPV##_bucket { \
+struct persrc_rate_##LEN##_bucket { \
        struct bpf_spin_lock lock; \
-       struct persrc_rate##IPV##_entry entries[]; \
+       struct persrc_rate_##LEN##_entry entries[]; \
 }; \
  \
-static int check_v##IPV##_persrc_ratelimit(IP_TYPE key, void *map, size_t map_limit, int64_t cur_time_masked, uint64_t amt, uint64_t limit_ns_per_pkt) { \
+static int check_persrc_ratelimit_##LEN(IP_TYPE key, void *map, size_t map_limit, int64_t cur_time_masked, uint64_t amt, uint64_t limit_ns_per_pkt) { \
        uint64_t hash = siphash_##IP_TYPE(key); \
  \
        const uint32_t map_key = hash % SRC_HASH_MAX_PARALLELISM; \
-       struct persrc_rate##IPV##_bucket *buckets = bpf_map_lookup_elem(map, &map_key); \
+       struct persrc_rate_##LEN##_bucket *buckets = bpf_map_lookup_elem(map, &map_key); \
        if (!buckets) return 0; \
  \
        hash >>= SRC_HASH_MAX_PARALLELISM_POW; \
        map_limit >>= SRC_HASH_MAX_PARALLELISM_POW; \
  \
-       struct persrc_rate##IPV##_entry *first_bucket = &buckets->entries[(hash % map_limit) & (~(SRC_HASH_BUCKET_COUNT - 1))]; \
+       struct persrc_rate_##LEN##_entry *first_bucket = &buckets->entries[(hash % map_limit) & (~(SRC_HASH_BUCKET_COUNT - 1))]; \
        bpf_spin_lock(&buckets->lock); \
  \
        uint64_t bucket_idx = SRC_HASH_BUCKET_COUNT; \
@@ -309,7 +309,7 @@ static int check_v##IPV##_persrc_ratelimit(IP_TYPE key, void *map, size_t map_li
                } \
        } \
        if (bucket_idx >= SRC_HASH_BUCKET_COUNT) bucket_idx = min_sent_idx; \
-       struct persrc_rate##IPV##_entry *entry = &first_bucket[bucket_idx]; \
+       struct persrc_rate_##LEN##_entry *entry = &first_bucket[bucket_idx]; \
        if (entry->srcip != key) { \
                entry->srcip = key; \
                entry->pkts_and_time = 0; \
@@ -320,21 +320,20 @@ static int check_v##IPV##_persrc_ratelimit(IP_TYPE key, void *map, size_t map_li
        return matched; \
 }
 
-CREATE_PERSRC_LOOKUP(6, uint128_t)
-CREATE_PERSRC_LOOKUP(5, uint64_t) // IPv6 matching no more than a /64
-CREATE_PERSRC_LOOKUP(4, uint32_t)
+CREATE_PERSRC_LOOKUP(128, uint128_t)
+CREATE_PERSRC_LOOKUP(64, uint64_t) // IPv6 matching no more than a /64 and IPv4
 
-#define SRC_RATE_DEFINE(IPV, n, limit) \
-struct persrc_rate##IPV##_bucket_##n { \
+#define SRC_RATE_DEFINE(LEN, n, limit) \
+struct persrc_rate_##LEN##_bucket_##n { \
        struct bpf_spin_lock lock; \
-       struct persrc_rate##IPV##_entry entries[limit / SRC_HASH_MAX_PARALLELISM]; \
+       struct persrc_rate_##LEN##_entry entries[limit / SRC_HASH_MAX_PARALLELISM]; \
 }; \
 struct { \
        __uint(type, BPF_MAP_TYPE_ARRAY); \
        __uint(max_entries, SRC_HASH_MAX_PARALLELISM); \
        uint32_t *key; \
-       struct persrc_rate##IPV##_bucket_##n *value; \
-} v##IPV##_src_rate_##n SEC(".maps");
+       struct persrc_rate_##LEN##_bucket_##n *value; \
+} src_rate_##LEN##_##n SEC(".maps");
 
 #include "maps.h"