From 69928d8ec232115827d08726240847e81cc54bea Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 17 Jun 2024 06:27:11 +0000 Subject: [PATCH] Use common per-source lookup between v4 and v6/64 Since v4 has 32 bits of slack in the struct anyway, we might as well just make it a u64 and unify the codepaths. --- genrules.py | 35 ++++++++++++++++------------------- siphash.h | 4 ---- xdp.c | 31 +++++++++++++++---------------- 3 files changed, 31 insertions(+), 39 deletions(-) diff --git a/genrules.py b/genrules.py index 647152a..cdbeb5b 100755 --- a/genrules.py +++ b/genrules.py @@ -433,9 +433,8 @@ with open("rules.h", "w") as out: use_v6_frags = False stats_rulecnt = 0 ratelimitcnt = 0 - v4persrcratelimits = [] - v5persrcratelimits = [] - v6persrcratelimits = [] + persrcratelimits64 = [] + persrcratelimits128 = [] lastrule = None for line in sys.stdin.readlines(): @@ -559,22 +558,22 @@ with open("rules.h", "w") as out: if proto == 4: if mid_byte > 32: continue - first_action += f"const uint32_t srcip = ip->saddr & MASK4({mid_byte});\n" - first_action += f"void *rate_map = &v4_src_rate_{len(v4persrcratelimits)};\n" - first_action += f"int matched = check_v4_persrc_ratelimit(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n" - v4persrcratelimits.append((high_byte + 1) * 4096) + first_action += f"const uint64_t srcip = ip->saddr & MASK4({mid_byte});\n" + first_action += f"void *rate_map = &src_rate_64_{len(persrcratelimits64)};\n" + first_action += f"int matched = check_persrc_ratelimit_64(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n" + persrcratelimits64.append((high_byte + 1) * 4096) elif mid_byte <= 64: first_action += f"const uint64_t srcip = BE128BEHIGH64(ip6->saddr & MASK6({mid_byte}));\n" - first_action += f"void *rate_map = &v5_src_rate_{len(v5persrcratelimits)};\n" - first_action += f"int matched = check_v5_persrc_ratelimit(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n" - v5persrcratelimits.append((high_byte + 1) * 4096) + first_action += f"void *rate_map = &src_rate_64_{len(persrcratelimits64)};\n" + first_action += f"int matched = check_persrc_ratelimit_64(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n" + persrcratelimits64.append((high_byte + 1) * 4096) else: if mid_byte > 128: continue first_action += f"const uint128_t srcip = ip6->saddr & MASK6({mid_byte});\n" - first_action += f"void *rate_map = &v6_src_rate_{len(v6persrcratelimits)};\n" - first_action += f"int matched = check_v6_persrc_ratelimit(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n" - v6persrcratelimits.append((high_byte + 1) * 4096) + first_action += f"void *rate_map = &src_rate_128_{len(persrcratelimits128)};\n" + first_action += f"int matched = check_persrc_ratelimit_128(srcip, rate_map, {(high_byte + 1) * 4096}, time_masked, amt, per_pkt_ns);\n" + persrcratelimits128.append((high_byte + 1) * 4096) first_action += "if (matched) {\n" first_action += "\t{stats_replace}\n" first_action += "\treturn XDP_DROP;\n" @@ -647,9 +646,7 @@ with open("rules.h", "w") as out: if use_v6_frags: out.write("#define PARSE_V6_FRAG PARSE\n") with open("maps.h", "w") as out: - for idx, limit in enumerate(v4persrcratelimits): - out.write(f"SRC_RATE_DEFINE(4, {idx}, {limit})\n") - for idx, limit in enumerate(v5persrcratelimits): - out.write(f"SRC_RATE_DEFINE(5, {idx}, {limit})\n") - for idx, limit in enumerate(v6persrcratelimits): - out.write(f"SRC_RATE_DEFINE(6, {idx}, {limit})\n") + for idx, limit in enumerate(persrcratelimits64): + out.write(f"SRC_RATE_DEFINE(64, {idx}, {limit})\n") + for idx, limit in enumerate(persrcratelimits128): + out.write(f"SRC_RATE_DEFINE(128, {idx}, {limit})\n") diff --git a/siphash.h b/siphash.h index 298a3a2..1d748b9 100644 --- a/siphash.h +++ b/siphash.h @@ -145,10 +145,6 @@ static inline uint64_t siphash(const void *in, const size_t inlen, const uint8_t static uint64_t siphash_uint64_t(const uint64_t in) { return siphash(&in, sizeof(uint64_t), COMPILE_TIME_RAND); } -__attribute__((always_inline)) -static inline uint64_t siphash_uint32_t(const uint32_t in) { - return siphash_uint64_t(in); -} static uint64_t siphash_uint128_t(const __uint128_t in) { return siphash(&in, sizeof(__uint128_t), COMPILE_TIME_RAND); } diff --git a/xdp.c b/xdp.c index d5bd0b0..414c983 100644 --- a/xdp.c +++ b/xdp.c @@ -266,28 +266,28 @@ if (rate) { \ } \ } while(0); -#define CREATE_PERSRC_LOOKUP(IPV, IP_TYPE) \ -struct persrc_rate##IPV##_entry { \ +#define CREATE_PERSRC_LOOKUP(LEN, IP_TYPE) \ +struct persrc_rate_##LEN##_entry { \ uint64_t pkts_and_time; \ IP_TYPE srcip; \ }; \ \ -struct persrc_rate##IPV##_bucket { \ +struct persrc_rate_##LEN##_bucket { \ struct bpf_spin_lock lock; \ - struct persrc_rate##IPV##_entry entries[]; \ + struct persrc_rate_##LEN##_entry entries[]; \ }; \ \ -static int check_v##IPV##_persrc_ratelimit(IP_TYPE key, void *map, size_t map_limit, int64_t cur_time_masked, uint64_t amt, uint64_t limit_ns_per_pkt) { \ +static int check_persrc_ratelimit_##LEN(IP_TYPE key, void *map, size_t map_limit, int64_t cur_time_masked, uint64_t amt, uint64_t limit_ns_per_pkt) { \ uint64_t hash = siphash_##IP_TYPE(key); \ \ const uint32_t map_key = hash % SRC_HASH_MAX_PARALLELISM; \ - struct persrc_rate##IPV##_bucket *buckets = bpf_map_lookup_elem(map, &map_key); \ + struct persrc_rate_##LEN##_bucket *buckets = bpf_map_lookup_elem(map, &map_key); \ if (!buckets) return 0; \ \ hash >>= SRC_HASH_MAX_PARALLELISM_POW; \ map_limit >>= SRC_HASH_MAX_PARALLELISM_POW; \ \ - struct persrc_rate##IPV##_entry *first_bucket = &buckets->entries[(hash % map_limit) & (~(SRC_HASH_BUCKET_COUNT - 1))]; \ + struct persrc_rate_##LEN##_entry *first_bucket = &buckets->entries[(hash % map_limit) & (~(SRC_HASH_BUCKET_COUNT - 1))]; \ bpf_spin_lock(&buckets->lock); \ \ uint64_t bucket_idx = SRC_HASH_BUCKET_COUNT; \ @@ -309,7 +309,7 @@ static int check_v##IPV##_persrc_ratelimit(IP_TYPE key, void *map, size_t map_li } \ } \ if (bucket_idx >= SRC_HASH_BUCKET_COUNT) bucket_idx = min_sent_idx; \ - struct persrc_rate##IPV##_entry *entry = &first_bucket[bucket_idx]; \ + struct persrc_rate_##LEN##_entry *entry = &first_bucket[bucket_idx]; \ if (entry->srcip != key) { \ entry->srcip = key; \ entry->pkts_and_time = 0; \ @@ -320,21 +320,20 @@ static int check_v##IPV##_persrc_ratelimit(IP_TYPE key, void *map, size_t map_li return matched; \ } -CREATE_PERSRC_LOOKUP(6, uint128_t) -CREATE_PERSRC_LOOKUP(5, uint64_t) // IPv6 matching no more than a /64 -CREATE_PERSRC_LOOKUP(4, uint32_t) +CREATE_PERSRC_LOOKUP(128, uint128_t) +CREATE_PERSRC_LOOKUP(64, uint64_t) // IPv6 matching no more than a /64 and IPv4 -#define SRC_RATE_DEFINE(IPV, n, limit) \ -struct persrc_rate##IPV##_bucket_##n { \ +#define SRC_RATE_DEFINE(LEN, n, limit) \ +struct persrc_rate_##LEN##_bucket_##n { \ struct bpf_spin_lock lock; \ - struct persrc_rate##IPV##_entry entries[limit / SRC_HASH_MAX_PARALLELISM]; \ + struct persrc_rate_##LEN##_entry entries[limit / SRC_HASH_MAX_PARALLELISM]; \ }; \ struct { \ __uint(type, BPF_MAP_TYPE_ARRAY); \ __uint(max_entries, SRC_HASH_MAX_PARALLELISM); \ uint32_t *key; \ - struct persrc_rate##IPV##_bucket_##n *value; \ -} v##IPV##_src_rate_##n SEC(".maps"); + struct persrc_rate_##LEN##_bucket_##n *value; \ +} src_rate_##LEN##_##n SEC(".maps"); #include "maps.h" -- 2.39.5