Drop ports_valid flag, it just wastes a register
[flowspec-xdp] / genrules.py
index a8d23a779b95f7b9b237786cfc84fac5b3d6e8b9..5d90e8c773a414b218a7db7be03f2b5eb3e085f9 100755 (executable)
@@ -199,7 +199,7 @@ def ip_to_rule(proto, inip, ty, offset):
        break;"""
 
 def fragment_to_rule(ipproto, rules):
-    ast = parse_ast(rules, parse_frag_expr, True)
+    ast = parse_ast(rules, parse_frag_expr, False)
     return "if (!( " + ast.write(ipproto) + " )) break;"
 
 def len_to_rule(rules):
@@ -239,10 +239,10 @@ def dscp_to_rule(proto, rules):
 def port_to_rule(ty, rules):
     if ty == "port" :
         ast = parse_ast(rules, parse_numbers_expr, True)
-        return "if (!ports_valid) break;\nif (!( " + ast.write("sport", "dport") + " )) break;"
+        return "if (sport == -1 || dport == -1) break;\nif (!( " + ast.write("sport", "dport") + " )) break;"
 
     ast = parse_ast(rules, parse_numbers_expr, True)
-    return "if (!ports_valid) break;\nif (!( " + ast.write(ty) + " )) break;"
+    return "if (" + ty + " == -1) break;\nif (!( " + ast.write(ty) + " )) break;"
 
 def tcp_flags_to_rule(rules):
     ast = parse_ast(rules, parse_bit_expr, False)
@@ -296,8 +296,8 @@ with open("rules.h", "w") as out:
     use_v6_frags = False
     rulecnt = 0
     ratelimitcnt = 0
-    v4persrcratelimitcnt = 0
-    v6persrcratelimitcnt = 0
+    v4persrcratelimits = []
+    v6persrcratelimits = []
 
     lastrule = None
     for line in sys.stdin.readlines():
@@ -378,7 +378,8 @@ with open("rules.h", "w") as out:
                 if len(blocks[1].strip()) != 10: # Should be 0x12345678
                     continue
                 ty = blocks[1].strip()[:6]
-                high_byte = int(blocks[1].strip()[8:], 16)
+                high_byte = int(blocks[1].strip()[6:8], 16)
+                mid_byte = int(blocks[1].strip()[8:], 16)
                 low_bytes = int(blocks[2].strip(") \n"), 16)
                 if ty == "0x8006" or ty == "0x800c" or ty == "0x8306" or ty == "0x830c":
                     if first_action is not None:
@@ -418,17 +419,17 @@ with open("rules.h", "w") as out:
                             spin_lock = "/* No locking as we're per-CPU */"
                             spin_unlock = "/* No locking as we're per-CPU */"
                             if proto == 4:
-                                if high_byte > 32:
+                                if mid_byte > 32:
                                     continue
-                                first_action += f"const uint32_t srcip = ip->saddr & MASK4({high_byte});\n"
-                                first_action += f"void *rate_map = &v4_src_rate_{v4persrcratelimitcnt};\n"
-                                v4persrcratelimitcnt += 1
+                                first_action += f"const uint32_t srcip = ip->saddr & MASK4({mid_byte});\n"
+                                first_action += f"void *rate_map = &v4_src_rate_{len(v4persrcratelimits)};\n"
+                                v4persrcratelimits.append((high_byte + 1) * 1024)
                             else:
-                                if high_byte > 128:
+                                if mid_byte > 128:
                                     continue
-                                first_action += f"const uint128_t srcip = ip6->saddr & MASK6({high_byte});\n"
-                                first_action += f"void *rate_map = &v6_src_rate_{v6persrcratelimitcnt};\n"
-                                v6persrcratelimitcnt += 1
+                                first_action += f"const uint128_t srcip = ip6->saddr & MASK6({mid_byte});\n"
+                                first_action += f"void *rate_map = &v6_src_rate_{len(v6persrcratelimits)};\n"
+                                v6persrcratelimits.append((high_byte + 1) * 1024)
                             first_action += f"struct percpu_ratelimit *rate = bpf_map_lookup_elem(rate_map, &srcip);\n"
                         first_action +=  "if (rate) {\n"
                         first_action += f"\t{spin_lock}\n"
@@ -501,10 +502,6 @@ with open("rules.h", "w") as out:
     out.write(f"#define RULECNT {rulecnt}\n")
     if ratelimitcnt != 0:
         out.write(f"#define RATE_CNT {ratelimitcnt}\n")
-    if v4persrcratelimitcnt != 0:
-        out.write(f"#define V4_SRC_RATE_CNT {v4persrcratelimitcnt}\n")
-    if v6persrcratelimitcnt != 0:
-        out.write(f"#define V6_SRC_RATE_CNT {v6persrcratelimitcnt}\n")
     if rules4 != "":
         out.write("#define NEED_V4_PARSE\n")
         out.write("#define RULES4 {\\\n" + rules4 + "}\n")
@@ -514,3 +511,8 @@ with open("rules.h", "w") as out:
     if args.v6frag == "ignore-parse-if-rule":
         if use_v6_frags:
             out.write("#define PARSE_V6_FRAG PARSE\n")
+    with open("maps.h", "w") as out:
+        for idx, limit in enumerate(v4persrcratelimits):
+            out.write(f"V4_SRC_RATE_DEFINE({idx}, {limit})\n")
+        for idx, limit in enumerate(v6persrcratelimits):
+            out.write(f"V6_SRC_RATE_DEFINE({idx}, {limit})\n")