]> git.bitcoin.ninja Git - flowspec-xdp/commitdiff
Don't fail to print dropcount when zero rules are installed main
authorMatt Corallo <git@bluematt.me>
Tue, 10 Sep 2024 01:02:38 +0000 (01:02 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 10 Sep 2024 01:06:52 +0000 (01:06 +0000)
collision_prob.py
dropcount.sh
genrules.py
install.sh
siphash.h
xdp.c

index 36df0c94c3e1911cdb8854cecb7f921abaa082ee..9d3ee7dfdcf3c35b81324852ace189e62334e402 100644 (file)
@@ -40,36 +40,40 @@ def print_entry(e, t, b):
 #print_entry(300, 20000, 1)
 #print_entry(300, 20000*2, 2)
 
-print("Table and bucket sizes mapped to rough element count which has a 1% bucket-overflow probability")
-print("Note that we currently have a hard-coded bucket size of 16 elements")
+print("For each table size below, we list the number of sources before a bucket storing 16 entries")
+print(" has a 1% chance of overflowing, cause it to spuriously accept packets from all 16 sources.")
+#print("Note that we currently have a hard-coded bucket size of 16 elements")
+print()
+print("Entry sizes are generally 16 bytes for IPv4 or IPv6 matching source /64s or less.")
+print(" When matching IPv6 sources longer than /64s entries are 32 bytes.")
 print()
 
 print("128K table * 16 bytes = %dMiB." % (128*16/1024))
-print_entry(4000, 128*1024, 4)
-print_entry(15000, 128*1024, 8)
+#print_entry(4000, 128*1024, 4)
+#print_entry(15000, 128*1024, 8)
 print_entry(33000, 128*1024, 16)
-print_entry(53000, 128*1024, 32)
+#print_entry(53000, 128*1024, 32)
 
 print("256K table * 16 bytes = %dMiB." % (256*16/1024))
-print_entry(7000, 256*1024, 4)
-print_entry(28000, 256*1024, 8)
+#print_entry(7000, 256*1024, 4)
+#print_entry(28000, 256*1024, 8)
 print_entry(63000, 256*1024, 16)
-print_entry(104000, 256*1024, 32)
+#print_entry(104000, 256*1024, 32)
 
 print("512K table * 16 bytes = %dMiB." % (512*16/1024))
-print_entry(13000, 512*1024, 4)
-print_entry(52000, 512*1024, 8)
+#print_entry(13000, 512*1024, 4)
+#print_entry(52000, 512*1024, 8)
 print_entry(119000, 512*1024, 16)
-print_entry(200000, 512*1024, 32)
+#print_entry(200000, 512*1024, 32)
 
 print("1M table * 16 bytes = %dMiB." % (1024*16/1024))
-print_entry(23000, 1024*1024, 4)
-print_entry(95000, 1024*1024, 8)
+#print_entry(23000, 1024*1024, 4)
+#print_entry(95000, 1024*1024, 8)
 print_entry(227000, 1024*1024, 16)
-print_entry(387000, 1024*1024, 32)
+#print_entry(387000, 1024*1024, 32)
 
 print("2M table * 16 bytes = %dMiB." % (2*1024*16/1024))
-print_entry(40000, 2*1024*1024, 4)
-print_entry(175000, 2*1024*1024, 8)
+#print_entry(40000, 2*1024*1024, 4)
+#print_entry(175000, 2*1024*1024, 8)
 print_entry(431000, 2*1024*1024, 16)
-print_entry(749000, 2*1024*1024, 32)
+#print_entry(749000, 2*1024*1024, 32)
index 6278955bb1e967fb60eeb846647c7632a0cfea8f..18d96a231bd2b546c67443a5b598aed988f3a12e 100755 (executable)
@@ -34,8 +34,10 @@ echo "$MAP_CONTENTS" | {
        echo -e "${PACKETS[2]}\t$(( ${BYTES[2]} / 1000 ))\tInvalid/rejected IHL IPv4 field"
        echo -e "${PACKETS[3]}\t$(( ${BYTES[3]} / 1000 ))\tRejected IPv6 fragments"
        C=4
-       while read LINE; do
-               echo -e "${PACKETS["$C"]}\t$(( ${BYTES["$C"]} / 1000 ))\t$LINE"
-               C=$(( $C + 1 ))
-       done < "$(dirname ${BASH_SOURCE[0]})/installed-rules.txt"
+       if [ "$(cat "$(dirname ${BASH_SOURCE[0]})/installed-rules.txt" | wc -c)" -gt "1" ]; then
+               while read LINE; do
+                       echo -e "${PACKETS["$C"]}\t$(( ${BYTES["$C"]} / 1000 ))\t$LINE"
+                       C=$(( $C + 1 ))
+               done < "$(dirname ${BASH_SOURCE[0]})/installed-rules.txt"
+       fi
 }
index a9f3d16db50c54ee04830757a4d5482fd8d7f39c..cdbeb5b89f61f5a0c8cdcb6f64d445ea9dce3e89 100755 (executable)
@@ -348,9 +348,9 @@ class RuleNode:
             elif selfsrc is not None and othersrc is not None:
                 o = selfsrc.ord(othersrc)
 
-        if o == ORD_LESS:
-            return True
-        return self.action < other.action
+        if o == ORD_EQUAL:
+            return [a for a in self.action if type(a) != IpRule] < [b for b in other.action if type(b) != IpRule]
+        return o == ORD_LESS
 
     def maybe_join(self, neighbor):
         if self.ty == RuleAction.CONDITIONS and neighbor.ty == RuleAction.CONDITIONS:
@@ -433,9 +433,8 @@ with open("rules.h", "w") as out:
     use_v6_frags = False
     stats_rulecnt = 0
     ratelimitcnt = 0
-    v4persrcratelimits = []
-    v5persrcratelimits = []
-    v6persrcratelimits = []
+    persrcratelimits64 = []
+    persrcratelimits128 = []
 
     lastrule = None
     for line in sys.stdin.readlines():
@@ -543,7 +542,7 @@ with open("rules.h", "w") as out:
                         value *= 2**(exp-127)
 
                         first_action =   "int64_t time_masked = bpf_ktime_get_ns() & RATE_TIME_MASK;\n"
-                        first_action += f"int64_t per_pkt_ns = (1000000000LL << RATE_BUCKET_INTEGER_BITS) / {math.floor(value)};\n"
+                        first_action += f"int64_t per_pkt_ns = 1000000000LL / {math.floor(value)};\n"
                         if ty == "0x8006" or ty == "0x8306":
                             first_action += "uint64_t amt = data_end - pktdata;\n"
                         else:
@@ -559,22 +558,22 @@ with open("rules.h", "w") as out:
                             if proto == 4:
                                 if mid_byte > 32:
                                     continue
-                                first_action += f"const uint32_t srcip = ip->saddr & MASK4({mid_byte});\n"
-                                first_action += f"void *rate_map = &v4_src_rate_{len(v4persrcratelimits)};\n"
-                                first_action += f"int matched = check_v4_persrc_ratelimit(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
-                                v4persrcratelimits.append((high_byte + 1) * 4096)
+                                first_action += f"const uint64_t srcip = ip->saddr & MASK4({mid_byte});\n"
+                                first_action += f"void *rate_map = &src_rate_64_{len(persrcratelimits64)};\n"
+                                first_action += f"int matched = check_persrc_ratelimit_64(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
+                                persrcratelimits64.append((high_byte + 1) * 4096)
                             elif mid_byte <= 64:
                                 first_action += f"const uint64_t srcip = BE128BEHIGH64(ip6->saddr & MASK6({mid_byte}));\n"
-                                first_action += f"void *rate_map = &v5_src_rate_{len(v5persrcratelimits)};\n"
-                                first_action += f"int matched = check_v5_persrc_ratelimit(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
-                                v5persrcratelimits.append((high_byte + 1) * 4096)
+                                first_action += f"void *rate_map = &src_rate_64_{len(persrcratelimits64)};\n"
+                                first_action += f"int matched = check_persrc_ratelimit_64(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
+                                persrcratelimits64.append((high_byte + 1) * 4096)
                             else:
                                 if mid_byte > 128:
                                     continue
                                 first_action += f"const uint128_t srcip = ip6->saddr & MASK6({mid_byte});\n"
-                                first_action += f"void *rate_map = &v6_src_rate_{len(v6persrcratelimits)};\n"
-                                first_action += f"int matched = check_v6_persrc_ratelimit(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
-                                v6persrcratelimits.append((high_byte + 1) * 4096)
+                                first_action += f"void *rate_map = &src_rate_128_{len(persrcratelimits128)};\n"
+                                first_action += f"int matched = check_persrc_ratelimit_128(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n"
+                                persrcratelimits128.append((high_byte + 1) * 4096)
                         first_action +=  "if (matched) {\n"
                         first_action +=  "\t{stats_replace}\n"
                         first_action +=  "\treturn XDP_DROP;\n"
@@ -647,9 +646,7 @@ with open("rules.h", "w") as out:
         if use_v6_frags:
             out.write("#define PARSE_V6_FRAG PARSE\n")
     with open("maps.h", "w") as out:
-        for idx, limit in enumerate(v4persrcratelimits):
-            out.write(f"SRC_RATE_DEFINE(4, {idx}, {limit})\n")
-        for idx, limit in enumerate(v5persrcratelimits):
-            out.write(f"SRC_RATE_DEFINE(5, {idx}, {limit})\n")
-        for idx, limit in enumerate(v6persrcratelimits):
-            out.write(f"SRC_RATE_DEFINE(6, {idx}, {limit})\n")
+        for idx, limit in enumerate(persrcratelimits64):
+            out.write(f"SRC_RATE_DEFINE(64, {idx}, {limit})\n")
+        for idx, limit in enumerate(persrcratelimits128):
+            out.write(f"SRC_RATE_DEFINE(128, {idx}, {limit})\n")
index 2a47a1dd1ffe57cddb1e8473ef633ae6d8771cc9..197155c722d7b7bf24aaea92a62c8036da4ed7a3 100755 (executable)
@@ -17,18 +17,19 @@ $(birdc show route table flowspec6 primary all)"
 
 echo "const uint8_t COMPILE_TIME_RAND[] = { $(dd if=/dev/urandom of=/dev/stdout bs=1 count=8 2>/dev/null | hexdump -e '4/1 "0x%02x, "') };" > rand.h
 
+[ "$CLANG" != "" ] || CLANG="clang"
+[ "$LLC" != "" ] || LLC="llc"
+[ "$LLVM_LINK" != "" ] || LLVM_LINK="llvm-link"
+
 STATS_RULES="$(echo "$RULES" | ./genrules.py --8021q=drop-vlan --v6frag=ignore-parse-if-rule --ihl=parse-options)"
-clang $CLANG_ARGS -g -std=c99 -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -Wno-unused-function -O3 -emit-llvm -c xdp.c -o xdp.bc
+$CLANG $CLANG_ARGS -g -std=c99 -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -Wno-unused-function -Oz -emit-llvm -c xdp.c -o xdp.bc
 if [ "$2" != "" ]; then
-       clang $4 -g -std=c99 -pedantic -Wall -Wextra -Wno-pointer-arith -O3 -emit-llvm -c "$2" -o wrapper.bc
-       llvm-link xdp.bc wrapper.bc | llc -O3 -march=bpf -mcpu=probe -filetype=obj -o xdp
+       $CLANG $4 -g -std=c99 -pedantic -Wall -Wextra -Wno-pointer-arith -Oz -emit-llvm -c "$2" -o wrapper.bc
+       $LLVM_LINK xdp.bc wrapper.bc | $LLC -O3 -march=bpf -mcpu=probe -filetype=obj -o xdp
 else
-       cat xdp.bc | llc -O3 -march=bpf -mcpu=probe -filetype=obj -o xdp
+       cat xdp.bc | $LLC -O3 -march=bpf -mcpu=probe -filetype=obj -o xdp
 fi
 
-echo "Before unload drop count was:"
-./dropcount.sh || echo "Not loaded"
-
 # Note that sometimes the automated fallback does not work properly so we have to || generic here
 ip -force link set "$1" xdpoffload obj xdp sec $XDP_SECTION || (
        echo "Failed to install in NIC, testing in driver..." && ip -force link set "$1" xdpdrv obj xdp sec $XDP_SECTION || (
index 298a3a2dd626470915f21c7fe39fa0f0d56e827d..cfcaad96824dc056a501ff5a18910fe7b04bd455 100644 (file)
--- a/siphash.h
+++ b/siphash.h
@@ -18,6 +18,7 @@
 
 #include <stddef.h>
 #include <stdint.h>
+#include <string.h>
 
 /* default: SipHash-2-4 */
 #ifndef cROUNDS
@@ -73,8 +74,7 @@
 #endif
 
 __attribute__((always_inline))
-static inline uint64_t siphash(const void *in, const size_t inlen, const uint8_t k[16]) {
-    const unsigned char *ni = (const unsigned char *)in;
+static inline uint64_t siphash(const uint64_t *in, const size_t inwords, const uint8_t k[16]) {
     const unsigned char *kk = (const unsigned char *)k;
 
     uint64_t v0 = UINT64_C(0x736f6d6570736575);
@@ -85,16 +85,16 @@ static inline uint64_t siphash(const void *in, const size_t inlen, const uint8_t
     uint64_t k1 = U8TO64_LE(kk + 8);
     uint64_t m;
     int i;
-    const unsigned char *end = ni + inlen - (inlen % sizeof(uint64_t));
-    const int left = inlen & 7;
-    uint64_t b = ((uint64_t)inlen) << 56;
+    size_t j;
+    uint64_t b = ((uint64_t)inwords) << (56 + 3);
     v3 ^= k1;
     v2 ^= k0;
     v1 ^= k1;
     v0 ^= k0;
 
-    for (; ni != end; ni += 8) {
-        m = U8TO64_LE(ni);
+    for (j = 0; j < inwords; ++j) {
+        m = *in;
+        in += 1;
         v3 ^= m;
 
         TRACE;
@@ -104,31 +104,23 @@ static inline uint64_t siphash(const void *in, const size_t inlen, const uint8_t
         v0 ^= m;
     }
 
-    switch (left) {
-    case 7:
-        b |= ((uint64_t)ni[6]) << 48;
-    case 6:
-        b |= ((uint64_t)ni[5]) << 40;
-    case 5:
-        b |= ((uint64_t)ni[4]) << 32;
-    case 4:
-        b |= ((uint64_t)ni[3]) << 24;
-    case 3:
-        b |= ((uint64_t)ni[2]) << 16;
-    case 2:
-        b |= ((uint64_t)ni[1]) << 8;
-    case 1:
-        b |= ((uint64_t)ni[0]);
-        break;
-    case 0:
-        break;
-    }
-
-    v3 ^= b;
+    // Generally, here siphash writes any extra bytes that weren't an even
+    // multiple of eight as well as the length (in the form of `b`). Then,
+    // because we've written fresh attacker-controlled data into our state, we
+    // do an extra `cROUNDS` `SIPROUND`s. This ensures we have
+    // `cROUNDS` + `dROUNDS` `SIPROUND`s between any attacker-controlled data
+    // and the output, which for SipHash 1-3 means the four rounds required for
+    // good mixing.
+    //
+    // However, in our use-case the input is always a multiple of eight bytes
+    // and the attacker doesn't control the length. Thus, we skip the extra
+    // round here, giving us a very slightly tweaked SipHash 1-2 which is
+    // equivalent to SipHash 1-3 with a fixed input of N*8+7 bytes.
+    /*v3 ^= b;
 
     TRACE;
     for (i = 0; i < cROUNDS; ++i)
-        SIPROUND;
+        SIPROUND;*/
 
     v0 ^= b;
     v2 ^= 0xff;
@@ -143,12 +135,10 @@ static inline uint64_t siphash(const void *in, const size_t inlen, const uint8_t
 
 #include "rand.h"
 static uint64_t siphash_uint64_t(const uint64_t in) {
-       return siphash(&in, sizeof(uint64_t), COMPILE_TIME_RAND);
-}
-__attribute__((always_inline))
-static inline uint64_t siphash_uint32_t(const uint32_t in) {
-       return siphash_uint64_t(in);
+       return siphash(&in, 1, COMPILE_TIME_RAND);
 }
 static uint64_t siphash_uint128_t(const __uint128_t in) {
-       return siphash(&in, sizeof(__uint128_t), COMPILE_TIME_RAND);
+       uint64_t words[2];
+       memcpy(words, &in, sizeof(__uint128_t));
+       return siphash(words, 2, COMPILE_TIME_RAND);
 }
diff --git a/xdp.c b/xdp.c
index 187c220b0261fdaad0acc5ec7d48c6d36a4b86e5..414c98320179445740dc19d13c152c35b5d1a7ca 100644 (file)
--- a/xdp.c
+++ b/xdp.c
@@ -178,6 +178,11 @@ struct {
 }
 
 // Rate limits are done in a static-sized leaky bucket with a decimal counter
+//
+// They are stored in a single uint64_t with the top RATE_BUCKET_BITS holding
+// the packet count/size and the remaining low bits holding the the time (as a
+// fixed-point decimal).
+//
 // Bucket size is always exactly (1 << RATE_BUCKET_INTEGER_BITS)
 #define RATE_BUCKET_DECIMAL_BITS 8
 #define RATE_BUCKET_INTEGER_BITS 4
@@ -193,7 +198,7 @@ struct {
 #ifdef RATE_CNT
 struct ratelimit {
        struct bpf_spin_lock lock;
-       uint64_t sent_time;
+       uint64_t pkts_and_time;
 };
 struct {
        __uint(type, BPF_MAP_TYPE_ARRAY);
@@ -229,19 +234,31 @@ struct {
 #define DO_RATE_LIMIT(do_lock, rate, time_masked, amt_in_pkt, limit_ns_per_pkt, matchbool) do { \
 if (rate) { \
        do_lock; \
-       int64_t bucket_pkts = (rate->sent_time & (~RATE_TIME_MASK)) >> (64 - RATE_BUCKET_BITS); \
-       /* We mask the top 12 bits, so date overflows every 52 days, handled below */ \
-       int64_t time_diff = time_masked - ((int64_t)(rate->sent_time & RATE_TIME_MASK)); \
-       if (unlikely(time_diff < -1000000000 || time_diff > 16000000000)) { \
+       int64_t bucket_pkts = (rate->pkts_and_time & (~RATE_TIME_MASK)) >> (64 - RATE_BUCKET_BITS); \
+       /* We mask the top 12 bits, so date overflows every 52 days, resetting the counter */ \
+       int64_t time_diff = time_masked - ((int64_t)(rate->pkts_and_time & RATE_TIME_MASK)); \
+       if (unlikely(time_diff < RATE_MIN_TIME_OFFSET || time_diff > RATE_MAX_TIME_OFFSET)) { \
                bucket_pkts = 0; \
        } else { \
                if (unlikely(time_diff < 0)) { time_diff = 0; } \
-               int64_t pkts_since_last = (time_diff << RATE_BUCKET_BITS) * ((uint64_t)amt_in_pkt) / ((uint64_t)limit_ns_per_pkt); \
-               bucket_pkts -= pkts_since_last; \
+               /* To avoid storing too many bits, we make a simplifying assumption that all packets */ \
+               /* hit by a rule are the same size. Thus, when a rule is denominated in bytes rather */ \
+               /* than packets, we can keep counting packets and simply adjust the ratelimit by the*/ \
+               /* size of the packet we're looking at. */ \
+               /* Thus, here, we simply reduce our packet counter by the */ \
+               /* time difference / (our ns/packet limit * the size of the current packet). */ \
+               /* We shift by RATE_BUCKET_DECIMAL_BITS first since we're calculating whole packets. */ \
+               int64_t pkts_allowed_since_last_update = \
+                       (time_diff << RATE_BUCKET_DECIMAL_BITS) / (((uint64_t)amt_in_pkt) * ((uint64_t)limit_ns_per_pkt)); \
+               bucket_pkts -= pkts_allowed_since_last_update; \
        } \
-       if (bucket_pkts < (((1 << RATE_BUCKET_INTEGER_BITS) - 1) << RATE_BUCKET_DECIMAL_BITS)) { \
+       /* Accept as long as we can add one to our bucket without overflow */ \
+       const int64_t MAX_PACKETS = (1 << RATE_BUCKET_INTEGER_BITS) - 2; \
+       if (bucket_pkts <= (MAX_PACKETS << RATE_BUCKET_DECIMAL_BITS)) { \
                if (unlikely(bucket_pkts < 0)) bucket_pkts = 0; \
-               rate->sent_time = time_masked | ((bucket_pkts + (1 << RATE_BUCKET_DECIMAL_BITS)) << (64 - RATE_BUCKET_BITS)); \
+               int64_t new_packet_count = bucket_pkts + (1 << RATE_BUCKET_DECIMAL_BITS); \
+               if (new_packet_count < 0) { new_packet_count = 0; } \
+               rate->pkts_and_time = time_masked | (new_packet_count << (64 - RATE_BUCKET_BITS)); \
                matchbool = 0; \
        } else { \
                matchbool = 1; \
@@ -249,51 +266,53 @@ if (rate) { \
 } \
 } while(0);
 
-#define CREATE_PERSRC_LOOKUP(IPV, IP_TYPE) \
-struct persrc_rate##IPV##_entry { \
-       uint64_t sent_time; \
+#define CREATE_PERSRC_LOOKUP(LEN, IP_TYPE) \
+struct persrc_rate_##LEN##_entry { \
+       uint64_t pkts_and_time; \
        IP_TYPE srcip; \
 }; \
  \
-struct persrc_rate##IPV##_bucket { \
+struct persrc_rate_##LEN##_bucket { \
        struct bpf_spin_lock lock; \
-       struct persrc_rate##IPV##_entry entries[]; \
+       struct persrc_rate_##LEN##_entry entries[]; \
 }; \
  \
-static int check_v##IPV##_persrc_ratelimit(IP_TYPE key, void *map, size_t map_limit, int64_t cur_time_masked, uint64_t amt, uint64_t limit_ns_per_pkt) { \
+static int check_persrc_ratelimit_##LEN(IP_TYPE key, void *map, size_t map_limit, int64_t cur_time_masked, uint64_t amt, uint64_t limit_ns_per_pkt) { \
        uint64_t hash = siphash_##IP_TYPE(key); \
  \
        const uint32_t map_key = hash % SRC_HASH_MAX_PARALLELISM; \
-       struct persrc_rate##IPV##_bucket *buckets = bpf_map_lookup_elem(map, &map_key); \
+       struct persrc_rate_##LEN##_bucket *buckets = bpf_map_lookup_elem(map, &map_key); \
        if (!buckets) return 0; \
  \
        hash >>= SRC_HASH_MAX_PARALLELISM_POW; \
        map_limit >>= SRC_HASH_MAX_PARALLELISM_POW; \
  \
-       struct persrc_rate##IPV##_entry *first_bucket = &buckets->entries[(hash % map_limit) & (~(SRC_HASH_BUCKET_COUNT - 1))]; \
+       struct persrc_rate_##LEN##_entry *first_bucket = &buckets->entries[(hash % map_limit) & (~(SRC_HASH_BUCKET_COUNT - 1))]; \
        bpf_spin_lock(&buckets->lock); \
  \
-       uint64_t min_sent_idx = 0; /* Must be uint64_t or BPF verifier gets lost and thinks it can be any value */ \
-       uint64_t min_sent_time = UINT64_MAX; \
+       uint64_t bucket_idx = SRC_HASH_BUCKET_COUNT; \
+       uint64_t min_sent_idx = 0; \
+       uint64_t min_time = UINT64_MAX; \
        for (uint64_t i = 0; i < SRC_HASH_BUCKET_COUNT; i++) { \
                if (first_bucket[i].srcip == key) { \
-                       min_sent_idx = i; \
+                       bucket_idx = i; \
                        break; \
                } \
-               int64_t time_offset = ((int64_t)cur_time_masked) - (first_bucket[i].sent_time & RATE_TIME_MASK); \
+               int64_t time_offset = ((int64_t)cur_time_masked) - (first_bucket[i].pkts_and_time & RATE_TIME_MASK); \
                if (time_offset < RATE_MIN_TIME_OFFSET || time_offset > RATE_MAX_TIME_OFFSET) { \
+                       min_time = 0; \
                        min_sent_idx = i; \
-                       break; \
                } \
-               if ((first_bucket[i].sent_time & RATE_TIME_MASK) < min_sent_time) { \
-                       min_sent_time = first_bucket[i].sent_time & RATE_TIME_MASK; \
+               if ((first_bucket[i].pkts_and_time & RATE_TIME_MASK) < min_time) { \
+                       min_time = first_bucket[i].pkts_and_time & RATE_TIME_MASK; \
                        min_sent_idx = i; \
                } \
        } \
-       struct persrc_rate##IPV##_entry *entry = &first_bucket[min_sent_idx]; \
+       if (bucket_idx >= SRC_HASH_BUCKET_COUNT) bucket_idx = min_sent_idx; \
+       struct persrc_rate_##LEN##_entry *entry = &first_bucket[bucket_idx]; \
        if (entry->srcip != key) { \
                entry->srcip = key; \
-               entry->sent_time = 0; \
+               entry->pkts_and_time = 0; \
        } \
        int matched = 0; \
        DO_RATE_LIMIT(, entry, cur_time_masked, amt, limit_ns_per_pkt, matched); \
@@ -301,21 +320,20 @@ static int check_v##IPV##_persrc_ratelimit(IP_TYPE key, void *map, size_t map_li
        return matched; \
 }
 
-CREATE_PERSRC_LOOKUP(6, uint128_t)
-CREATE_PERSRC_LOOKUP(5, uint64_t) // IPv6 matching no more than a /64
-CREATE_PERSRC_LOOKUP(4, uint32_t)
+CREATE_PERSRC_LOOKUP(128, uint128_t)
+CREATE_PERSRC_LOOKUP(64, uint64_t) // IPv6 matching no more than a /64 and IPv4
 
-#define SRC_RATE_DEFINE(IPV, n, limit) \
-struct persrc_rate##IPV##_bucket_##n { \
+#define SRC_RATE_DEFINE(LEN, n, limit) \
+struct persrc_rate_##LEN##_bucket_##n { \
        struct bpf_spin_lock lock; \
-       struct persrc_rate##IPV##_entry entries[limit / SRC_HASH_MAX_PARALLELISM]; \
+       struct persrc_rate_##LEN##_entry entries[limit / SRC_HASH_MAX_PARALLELISM]; \
 }; \
 struct { \
        __uint(type, BPF_MAP_TYPE_ARRAY); \
        __uint(max_entries, SRC_HASH_MAX_PARALLELISM); \
        uint32_t *key; \
-       struct persrc_rate##IPV##_bucket_##n *value; \
-} v##IPV##_src_rate_##n SEC(".maps");
+       struct persrc_rate_##LEN##_bucket_##n *value; \
+} src_rate_##LEN##_##n SEC(".maps");
 
 #include "maps.h"