Partially implement sorting
authorMatt Corallo <git@bluematt.me>
Thu, 9 Dec 2021 22:51:51 +0000 (22:51 +0000)
committerMatt Corallo <git@bluematt.me>
Fri, 10 Dec 2021 01:44:10 +0000 (01:44 +0000)
README.md
genrules.py
test.sh

index 5f6e917d0f9115450f4ee311f1a12b217788dda1..9ed1c187174ece2125df9b5d5cf5c742b261988d 100644 (file)
--- a/README.md
+++ b/README.md
@@ -6,9 +6,9 @@ to an XDP program. It currently supports the entire flowspec match grammar, rate
 action packet match counting (sample bit) and terminal bit, and traffic marking. The redirect
 community is not supported.
 
 action packet match counting (sample bit) and terminal bit, and traffic marking. The redirect
 community is not supported.
 
-Note that correctly sorting rules is *not* implemented as it requires implementing the flowspec
-wire serialization format and it may better be done inside bird/birdc. Thus, be vary careful using
-the terminal bit in the traffict action community.
+Note that correctly sorting rules is *not* fully implemented as it requires implementing the
+flowspec wire serialization format and it may better be done inside bird/birdc. Thus, be vary
+careful using the terminal bit in the traffict action community.
 
 In addition to the communities specified in RFC 8955, two additional communities are supported which
 provide rate-limiting on a per-source basis. When the upper two bytes in an extended community are
 
 In addition to the communities specified in RFC 8955, two additional communities are supported which
 provide rate-limiting on a per-source basis. When the upper two bytes in an extended community are
index 3865b6b69f42edf1cab46786213bc380d8196df8..a9f3d16db50c54ee04830757a4d5482fd8d7f39c 100755 (executable)
@@ -12,6 +12,10 @@ IP_PROTO_ICMPV6 = 58
 IP_PROTO_TCP = 6
 IP_PROTO_UDP = 17
 
 IP_PROTO_TCP = 6
 IP_PROTO_UDP = 17
 
+ORD_LESS = 0
+ORD_GREATER = 1
+ORD_EQUAL = 2
+
 class ASTAction(Enum):
     OR = 1
     AND = 2
 class ASTAction(Enum):
     OR = 1
     AND = 2
@@ -175,28 +179,67 @@ def parse_bit_expr(expr):
     return ASTNode(ASTAction.EXPR, BitExpr(expr))
 
 
     return ASTNode(ASTAction.EXPR, BitExpr(expr))
 
 
+class IpRule:
+    def __init__(self, ty, offset, net, proto):
+        self.ty = ty
+        self.offset = offset
+        if offset is None:
+            self.offset = 0
+        self.net = net
+        self.proto = proto
+
+    def ord(self, other):
+        assert self.ty == other.ty
+        assert self.proto == other.proto
+        if self.offset < other.offset:
+            return ORD_LESS
+        if self.offset > other.offset:
+            return ORD_GREATER
+
+        if self.net.overlaps(other.net):
+            if self.net.prefixlen > other.net.prefixlen:
+                return ORD_LESS
+            elif self.net.prefixlen < other.net.prefixlen:
+                return ORD_GREATER
+        else:
+            if self.net < other.net:
+                return ORD_LESS
+            else:
+                assert self.net > other.net
+                return ORD_GREATER
+
+        return ORD_EQUAL
+
+    def __lt__(self, other):
+        return self.ord(other) == ORD_LESS
+
+    def __eq__(self, other):
+        return type(other) == IpRule and self.ty == other.ty and self.offset == other.offset and self.net == other.net and self.proto == other.proto
+
+    def __str__(self):
+        if self.proto == 4:
+            assert self.offset == 0
+            return f"""if ((ip->{self.ty} & MASK4({self.net.prefixlen})) != BIGEND32({int(self.net.network_address)}ULL))
+       break;"""
+        else:
+            u32s = [(int(self.net.network_address) >> (3*32)) & 0xffffffff,
+                    (int(self.net.network_address) >> (2*32)) & 0xffffffff,
+                    (int(self.net.network_address) >> (1*32)) & 0xffffffff,
+                    (int(self.net.network_address) >> (0*32)) & 0xffffffff]
+            if self.offset == 0:
+                mask = f"MASK6({self.net.prefixlen})"
+            else:
+                mask = f"MASK6_OFFS({self.offset}, {self.net.prefixlen})"
+            return f"""if ((ip6->{self.ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask}))
+       break;"""
 def ip_to_rule(proto, inip, ty, offset):
     if proto == 4:
         assert offset is None
         net = ipaddress.IPv4Network(inip.strip())
 def ip_to_rule(proto, inip, ty, offset):
     if proto == 4:
         assert offset is None
         net = ipaddress.IPv4Network(inip.strip())
-        if net.prefixlen == 0:
-            return ""
-        return f"""if ((ip->{ty} & MASK4({net.prefixlen})) != BIGEND32({int(net.network_address)}ULL))
-       break;"""
+        return IpRule(ty, offset, net, 4)
     else:
         net = ipaddress.IPv6Network(inip.strip())
     else:
         net = ipaddress.IPv6Network(inip.strip())
-        if net.prefixlen == 0:
-            return ""
-        u32s = [(int(net.network_address) >> (3*32)) & 0xffffffff,
-                (int(net.network_address) >> (2*32)) & 0xffffffff,
-                (int(net.network_address) >> (1*32)) & 0xffffffff,
-                (int(net.network_address) >> (0*32)) & 0xffffffff]
-        if offset is None:
-            mask = f"MASK6({net.prefixlen})"
-        else:
-            mask = f"MASK6_OFFS({offset}, {net.prefixlen})"
-        return f"""if ((ip6->{ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask}))
-       break;"""
+        return IpRule(ty, offset, net, 6)
 
 def fragment_to_rule(ipproto, rules):
     ast = parse_ast(rules, parse_frag_expr, False)
 
 def fragment_to_rule(ipproto, rules):
     ast = parse_ast(rules, parse_frag_expr, False)
@@ -278,6 +321,37 @@ class RuleNode:
             assert type(action) == list
             assert type(inner) == RuleNode
 
             assert type(action) == list
             assert type(inner) == RuleNode
 
+    def __lt__(self, other):
+        assert self.ty == RuleAction.CONDITIONS
+        assert other.ty == RuleAction.CONDITIONS
+
+        o = ORD_EQUAL
+
+        # RFC first has us sort by dest, then source, then other conditions. We don't implement the
+        # other conditions because doing so requires re-implementing the Flowspec wire format,
+        # which isn't trivial. However, we do implement the source/dest sorting in the hopes it
+        # allows us to group rules according to source/dest IP and hopefully LLVM optimizes out
+        # later rules.
+
+        selfdest = next(filter(lambda a : type(a) == IpRule and a.ty == "daddr", self.action), None)
+        otherdest = next(filter(lambda a : type(a) == IpRule and a.ty == "daddr", self.action), None)
+        if o == ORD_EQUAL and selfdest is not None and otherdest is not None:
+            o = selfdest.ord(otherdest)
+
+        if o == ORD_EQUAL:
+            selfsrc = next(filter(lambda a : type(a) == IpRule and a.ty == "saddr", self.action), None)
+            othersrc = next(filter(lambda a : type(a) == IpRule and a.ty == "saddr", self.action), None)
+            if selfsrc is not None and othersrc is None:
+                return True
+            elif selfsrc is None and othersrc is not None:
+                return False
+            elif selfsrc is not None and othersrc is not None:
+                o = selfsrc.ord(othersrc)
+
+        if o == ORD_LESS:
+            return True
+        return self.action < other.action
+
     def maybe_join(self, neighbor):
         if self.ty == RuleAction.CONDITIONS and neighbor.ty == RuleAction.CONDITIONS:
             overlapping_conditions = [x for x in self.action if x in neighbor.action]
     def maybe_join(self, neighbor):
         if self.ty == RuleAction.CONDITIONS and neighbor.ty == RuleAction.CONDITIONS:
             overlapping_conditions = [x for x in self.action if x in neighbor.action]
@@ -311,7 +385,7 @@ class RuleNode:
         if self.ty == RuleAction.CONDITIONS:
             out.write(pfx + "do {\\\n")
             for cond in self.action:
         if self.ty == RuleAction.CONDITIONS:
             out.write(pfx + "do {\\\n")
             for cond in self.action:
-                out.write("\t" + pfx + cond.strip().replace("\n", " \\\n\t" + pfx) + " \\\n")
+                out.write("\t" + pfx + str(cond).replace("\n", " \\\n\t" + pfx) + " \\\n")
             self.inner.write(out, pfx)
             out.write(pfx + "} while(0);\\\n")
         elif self.ty == RuleAction.LIST:
             self.inner.write(out, pfx)
             out.write(pfx + "} while(0);\\\n")
         elif self.ty == RuleAction.LIST:
@@ -384,7 +458,7 @@ with open("rules.h", "w") as out:
             conditions = []
             def write_rule(r):
                 global conditions
             conditions = []
             def write_rule(r):
                 global conditions
-                conditions.append(r + "\n")
+                conditions.append(r)
 
             rule = t[1].split("}")[0].strip()
             for step in rule.split(";"):
 
             rule = t[1].split("}")[0].strip()
             for step in rule.split(";"):
@@ -547,12 +621,12 @@ with open("rules.h", "w") as out:
     if ratelimitcnt != 0:
         out.write(f"#define RATE_CNT {ratelimitcnt}\n")
 
     if ratelimitcnt != 0:
         out.write(f"#define RATE_CNT {ratelimitcnt}\n")
 
-    # Here we should probably sort the rules according to flowspec's sorting rules. We don't bother
-    # however, because its annoying.
-
     if len(rules4) != 0:
         out.write("#define NEED_V4_PARSE\n")
         out.write("#define RULES4 {\\\n")
     if len(rules4) != 0:
         out.write("#define NEED_V4_PARSE\n")
         out.write("#define RULES4 {\\\n")
+        # First sort the rules according to the RFC, then make it a single
+        # LIST rule and call flatten() to unify redundant conditions
+        rules4.sort()
         rules4 = RuleNode(RuleAction.LIST, None, rules4)
         rules4.flatten()
         rules4.write(out)
         rules4 = RuleNode(RuleAction.LIST, None, rules4)
         rules4.flatten()
         rules4.write(out)
@@ -561,6 +635,9 @@ with open("rules.h", "w") as out:
     if len(rules6) != 0:
         out.write("#define NEED_V6_PARSE\n")
         out.write("#define RULES6 {\\\n")
     if len(rules6) != 0:
         out.write("#define NEED_V6_PARSE\n")
         out.write("#define RULES6 {\\\n")
+        # First sort the rules according to the RFC, then make it a single
+        # LIST rule and call flatten() to unify redundant conditions
+        rules6.sort()
         rules6 = RuleNode(RuleAction.LIST, None, rules6)
         rules6.flatten()
         rules6.write(out)
         rules6 = RuleNode(RuleAction.LIST, None, rules6)
         rules6.flatten()
         rules6.write(out)
diff --git a/test.sh b/test.sh
index e17ec4bcc9a170ed9f946500f50c6ea91993174e..ad7a111bde187dddee74f9b19a658957e6bedb08 100755 (executable)
--- a/test.sh
+++ b/test.sh
@@ -2,9 +2,15 @@
 
 set -e
 
 
 set -e
 
+# DROP, sample, and change DSCP
 COMMUNITY_DROP="
        Type: static univ
 COMMUNITY_DROP="
        Type: static univ
-       BGP.ext_community: (generic, 0x80060000, 0x0) (generic, 0x80070000, 0xf) (generic, 0x80090000, 0x3f)"
+       BGP.ext_community: (generic, 0x80060000, 0x0) (generic, 0x80070000, 0x3) (generic, 0x80090000, 0x3f)"
+# Sample and stop processing new rules
+COMMUNITY_TERMINAL_ACCEPT="
+       Type: static univ
+       BGP.ext_community: (generic, 0x80070000, 0x2)"
+
 
 DO_TEST() {
        clang -g -std=c99 -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -Wno-unused-function -Wno-tautological-constant-out-of-range-compare -Wno-unused-function -Wno-visibility -O3 -emit-llvm -c xdp.c -o xdp.bc
 
 DO_TEST() {
        clang -g -std=c99 -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -Wno-unused-function -Wno-tautological-constant-out-of-range-compare -Wno-unused-function -Wno-visibility -O3 -emit-llvm -c xdp.c -o xdp.bc
@@ -93,6 +99,16 @@ DO_TEST XDP_DROP
 echo "flow6 { icmp code != 0; };$COMMUNITY_DROP" | ./genrules.py --ihl=drop-options --8021q=drop-vlan --v6frag=drop-frags
 DO_TEST XDP_PASS
 
 echo "flow6 { icmp code != 0; };$COMMUNITY_DROP" | ./genrules.py --ihl=drop-options --8021q=drop-vlan --v6frag=drop-frags
 DO_TEST XDP_PASS
 
+# Test ordering of source addresses. If we hit TERMINAL_ACCEPT first (cause its a more specific
+# prefix), then we'll pass, otherwise we'll drop.
+echo "flow6 { src 2a01:4f8:130:71d2::2/128; };$COMMUNITY_TERMINAL_ACCEPT
+flow6 { src 2a01::/16; }; $COMMUNITY_DROP" | ./genrules.py --ihl=accept-options --8021q=accept-vlan --v6frag=ignore
+DO_TEST XDP_PASS
+
+echo "flow6 { src 2a01::/16; };$COMMUNITY_TERMINAL_ACCEPT
+flow6 { src 2a01:4f8::/32; };$COMMUNITY_DROP" | ./genrules.py --ihl=accept-options --8021q=accept-vlan --v6frag=ignore
+DO_TEST XDP_DROP
+
 TEST_PKT='#define TEST \
 "\xcc\x2d\xe0\xf5\x02\xe1\x00\x0d\xb9\x50\x42\xfe\x81\x00\x00\x03" \
 "\x08\x00\x45\xfc\x00\x54\xda\x85\x40\x00\x40\x01\x67\xc6\x0a\x45" \
 TEST_PKT='#define TEST \
 "\xcc\x2d\xe0\xf5\x02\xe1\x00\x0d\xb9\x50\x42\xfe\x81\x00\x00\x03" \
 "\x08\x00\x45\xfc\x00\x54\xda\x85\x40\x00\x40\x01\x67\xc6\x0a\x45" \