Initial checkin
authorMatt Corallo <git@bluematt.me>
Sat, 3 Apr 2021 04:07:27 +0000 (00:07 -0400)
committerMatt Corallo <git@bluematt.me>
Sat, 3 Apr 2021 16:57:11 +0000 (12:57 -0400)
genrules.py [new file with mode: 0755]
test.sh [new file with mode: 0755]
xdp.c [new file with mode: 0644]

diff --git a/genrules.py b/genrules.py
new file mode 100755 (executable)
index 0000000..3f5d9e0
--- /dev/null
@@ -0,0 +1,295 @@
+#!/usr/bin/env python3
+
+import sys
+import ipaddress
+from enum import Enum
+
+IP_PROTO_ICMP = 1
+IP_PROTO_ICMPV6 = 58
+IP_PROTO_TCP = 6
+IP_PROTO_UDP = 17
+
+if len(sys.argv) > 2 and sys.argv[2].startswith("parse_ihl"):
+    PARSE_IHL = True
+else:
+    PARSE_IHL = False
+if len(sys.argv) > 3 and sys.argv[3].startswith("parse_exthdr"):
+    PARSE_EXTHDR = True
+else:
+    PARSE_EXTHDR = False
+
+
+class ASTAction(Enum):
+    OR = 1,
+    AND = 2,
+    NOT = 3,
+    EXPR = 4
+class ASTNode:
+    def __init__(self, action, left, right=None):
+        self.action = action
+        self.left = left
+        if right is None:
+            assert action == ASTAction.EXPR or action == ASTAction.NOT
+        else:
+            self.right = right
+
+    def write(self, expr_param, expr_param2=None):
+        if self.action == ASTAction.OR:
+            return "(" + self.left.write(expr_param, expr_param2) + ") || (" + self.right.write(expr_param, expr_param2) + ")"
+        if self.action == ASTAction.AND:
+            return "(" + self.left.write(expr_param, expr_param2) + ") && (" + self.right.write(expr_param, expr_param2) + ")"
+        if self.action == ASTAction.NOT:
+            return "!(" + self.left.write(expr_param, expr_param2) + ")"
+        if self.action == ASTAction.EXPR:
+            return self.left.write(expr_param, expr_param2)
+
+def parse_ast(expr, parse_expr):
+    expr = expr.strip()
+
+    and_split = expr.split("&&", 1)
+    or_split = expr.split("||", 1)
+    if len(and_split) > 1 and not "||" in and_split[0]:
+        return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr))
+    if len(or_split) > 1:
+        assert not "&&" in or_split[0]
+        return ASTNode(ASTAction.OR, parse_ast(or_split[0], parse_expr), parse_ast(or_split[1], parse_expr))
+
+    comma_split = expr.split(",", 1)
+    if len(comma_split) > 1:
+        return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr))
+
+    if expr.startswith("!"):
+        return ASTNode(ASTAction.NOT, parse_ast(expr[1:], parse_expr))
+
+    return parse_expr(expr)
+
+
+class NumbersAction(Enum):
+    EQ = "=="
+    GT = ">"
+    GTOE = ">="
+    LT = "<"
+    LTOE = "<="
+class NumbersExpr:
+    def __init__(self, action, val):
+        self.action = action
+        self.val = val
+
+    def write(self, param, param2):
+        if param2 is not None:
+            return "(" + param + self.action.value + self.val + ") || (" + param2 + self.action.value + self.val + ")"
+        return param + self.action.value + self.val
+
+def parse_numbers_expr(expr):
+    space_split = expr.split(" ")
+    if expr.startswith(">="):
+        assert len(space_split) == 2
+        return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, space_split[1]))
+    if expr.startswith(">"):
+        assert len(space_split) == 2
+        return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GT, space_split[1]))
+    if expr.startswith("<="):
+        assert len(space_split) == 2
+        return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, space_split[1]))
+    if expr.startswith("<"):
+        assert len(space_split) == 2
+        return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LT, space_split[1]))
+    if ".." in expr:
+        rangesplit = expr.split("..")
+        assert len(rangesplit) == 2
+        #XXX: Are ranges really inclusive,inclusive?
+        left = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, rangesplit[0]))
+        right = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, rangesplit[1]))
+        return ASTNode(ASTAction.AND, left, right)
+    
+    if expr.startswith("= "):
+        expr = expr[2:]
+    return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.EQ, expr))
+
+class FragExpr:
+    def __init__(self, val):
+        if val == "is_fragment":
+            self.rule = "(ip->frag_off & BE16(IP_MF|IP_OFFSET)) != 0"
+        elif val == "first_fragment":
+            self.rule = "(ip->frag_off & BE16(IP_MF)) != 0 && (ip->frag_off & BE16(IP_OFFSET)) == 0"
+        elif val == "dont_fragment":
+            self.rule = "(ip->frag_off & BE16(IP_DF)) != 0"
+        elif val == "last_fragment":
+            self.rule = "(ip->frag_off & BE16(IP_MF)) == 0 && (ip->frag_off & BE16(IP_OFFSET)) != 0"
+        else:
+            assert False
+
+    def write(self, _param, _param2):
+        return self.rule
+
+def parse_frag_expr(expr):
+    return ASTNode(ASTAction.EXPR, FragExpr(expr))
+
+class BitExpr:
+    def __init__(self, val):
+        s = val.split("/")
+        assert len(s) == 2
+        self.match = s[0]
+        self.mask = s[1]
+
+    def write(self, param, _param2):
+        return f"({param} & {self.mask}) == {self.match}"
+
+def parse_bit_expr(expr):
+    return ASTNode(ASTAction.EXPR, BitExpr(expr))
+
+
+def ip_to_rule(proto, inip, ty, offset):
+    if proto == 4:
+        assert offset is None
+        net = ipaddress.IPv4Network(inip.strip())
+        if net.prefixlen == 0:
+            return ""
+        return f"""if ((ip->{ty} & MASK4({net.prefixlen})) != BIGEND32({int(net.network_address)}ULL))
+       break;"""
+    else:
+        net = ipaddress.IPv6Network(inip.strip())
+        if net.prefixlen == 0:
+            return ""
+        u32s = [(int(net.network_address) >> (3*32)) & 0xffffffff,
+                (int(net.network_address) >> (2*32)) & 0xffffffff,
+                (int(net.network_address) >> (1*32)) & 0xffffffff,
+                (int(net.network_address) >> (0*32)) & 0xffffffff]
+        if offset is None:
+            mask = f"MASK6({net.prefixlen})"
+        else:
+            mask = f"MASK6_OFFS({offset}, {net.prefixlen})"
+        return f"""if ((ip6->{ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask}))
+       break;"""
+
+def fragment_to_rule(ipproto, rules):
+    if ipproto == 6:
+        assert False # XXX: unimplemented
+    ast = parse_ast(rules, parse_frag_expr)
+    return "if (!( " + ast.write(()) + " )) break;"
+
+def len_to_rule(rules):
+    ast = parse_ast(rules, parse_numbers_expr)
+    return "if (!( " + ast.write("(data_end - pktdata)") + " )) break;"
+def proto_to_rule(ipproto, proto):
+    ast = parse_ast(proto, parse_numbers_expr)
+
+    if ipproto == 4:
+        return "if (!( " + ast.write("ip->protocol") + " )) break;"
+    else:
+        if PARSE_EXTHDR:
+            assert False # XXX: unimplemented
+        return "if (!( " + ast.write("ip6->nexthdr") + " )) break;"
+
+def icmp_type_to_rule(proto, ty):
+    ast = parse_ast(ty, parse_numbers_expr)
+    if proto == 4:
+        return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->type") + " )) break;"
+    else:
+        return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_type") + " )) break;"
+
+def icmp_code_to_rule(proto, code):
+    ast = parse_ast(code, parse_numbers_expr)
+    if proto == 4:
+        return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->code") + " )) break;"
+    else:
+        return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_code") + " )) break;"
+
+def dscp_to_rule(proto, rules):
+    ast = parse_ast(rules, parse_numbers_expr)
+
+    if proto == 4:
+        return "if (!( " + ast.write("((ip->tos & 0xfc) >> 2)") + " )) break;"
+    else:
+        return "if (!( " + ast.write("((ip6->priority << 4) | ((ip6->flow_lbl[0] & 0xc0) >> 4) >> 2)") + " )) break;"
+
+def port_to_rule(ty, rules):
+    if ty == "port" :
+        ast = parse_ast(rules, parse_numbers_expr)
+        return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write("sport", "dport") + " )) break;"
+
+    ast = parse_ast(rules, parse_numbers_expr)
+    return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write(ty) + " )) break;"
+
+def tcp_flags_to_rule(rules):
+    ast = parse_ast(rules, parse_bit_expr)
+
+    return f"""if (tcp == NULL) break;
+if (!( {ast.write("(ntohs(tcp->flags) & 0xfff)")} )) break;"""
+
+def flow_label_to_rule(rules):
+    ast = parse_ast(rules, parse_bit_expr)
+
+    return f"""if (ip6 == NULL) break;
+if (!( {ast.write("((((uint32_t)(ip6->flow_lbl[0] & 0xf)) << 2*8) | (((uint32_t)ip6->flow_lbl[1]) << 1*8) | (uint32_t)ip6->flow_lbl[0])")} )) break;"""
+
+with open("rules.h", "w") as out:
+    if len(sys.argv) > 1 and sys.argv[1] == "parse_8021q":
+        out.write("#define PARSE_8021Q\n")
+    if len(sys.argv) > 1 and sys.argv[1].startswith("req_8021q="):
+        out.write("#define PARSE_8021Q\n")
+        out.write(f"#define REQ_8021Q {sys.argv[1][10:]}\n")
+
+
+    out.write("#define RULES \\\n")
+
+    def write_rule(r):
+        out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n")
+
+    for line in sys.stdin.readlines():
+        t = line.split("{")
+        if len(t) != 2:
+            continue
+        if t[0].strip() == "flow4":
+            proto = 4
+            out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n")
+            out.write("\tdo {\\\n")
+        elif t[0].strip() == "flow6":
+            proto = 6
+            out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n")
+            out.write("\tdo {\\\n")
+        else:
+            continue
+
+        rule = t[1].split("}")[0].strip()
+        for step in rule.split(";"):
+            if step.strip().startswith("src") or step.strip().startswith("dst"):
+                nets = step.strip()[3:].strip().split(" ")
+                if len(nets) > 1:
+                    assert nets[1] == "offset"
+                    offset = nets[2]
+                else:
+                    offset = None
+                if step.strip().startswith("src"):
+                    write_rule(ip_to_rule(proto, nets[0], "saddr", offset))
+                else:
+                    write_rule(ip_to_rule(proto, nets[0], "daddr", offset))
+            elif step.strip().startswith("proto") and proto == 4:
+                write_rule(proto_to_rule(4, step.strip()[6:]))
+            elif step.strip().startswith("next header") and proto == 6:
+                write_rule(proto_to_rule(6, step.strip()[12:]))
+            elif step.strip().startswith("icmp type"):
+                write_rule(icmp_type_to_rule(proto, step.strip()[10:]))
+            elif step.strip().startswith("icmp code"):
+                write_rule(icmp_code_to_rule(proto, step.strip()[10:]))
+            elif step.strip().startswith("sport") or step.strip().startswith("dport") or step.strip().startswith("port"):
+                write_rule(port_to_rule(step.strip().split(" ")[0], step.strip().split(" ", 1)[1]))
+            elif step.strip().startswith("length"):
+                write_rule(len_to_rule(step.strip()[7:]))
+            elif step.strip().startswith("dscp"):
+                write_rule(dscp_to_rule(proto, step.strip()[5:]))
+            elif step.strip().startswith("tcp flags"):
+                write_rule(tcp_flags_to_rule(step.strip()[10:]))
+            elif step.strip().startswith("label"):
+                write_rule(flow_label_to_rule(step.strip()[6:]))
+            elif step.strip().startswith("fragment"):
+                write_rule(fragment_to_rule(proto, step.strip()[9:]))
+            elif step.strip() == "":
+                pass
+            else:
+                assert False
+        out.write("\t\treturn XDP_DROP;\\\n")
+        out.write("\t} while(0);\\\n}\\\n")
+
+    out.write("\n")
diff --git a/test.sh b/test.sh
new file mode 100755 (executable)
index 0000000..03bc656
--- /dev/null
+++ b/test.sh
@@ -0,0 +1,149 @@
+#!/bin/bash
+
+set -e
+
+TEST_PKT='#define TEST \
+"\x00\x17\x10\x95\xe8\x96\x00\x0d\xb9\x50\x11\x4c\x08\x00\x45\x00" \
+"\x00\x8c\x7d\x0f\x00\x00\x40\x11\x3a\x31\x48\xe5\x68\xce\x67\x63" \
+"\xaa\x0a\xdd\x9d\x10\x92\x00\x78\xc3\xaa\x04\x00\x00\x00\x47\x89" \
+"\x49\xb1\x1f\x0e\x00\x00\x00\x00\x00\x00\xa7\xee\xab\xa4\xc6\x09" \
+"\xe7\x0f\x41\xfc\xd5\x75\x1d\xc4\x97\xfa\xd7\x96\x8c\x1f\x19\x54" \
+"\xa7\x74\x08\x5c\x28\xfe\xd9\x32\x4b\xe0\x62\x55\xeb\xb4\x1e\x36" \
+"\x5f\xf5\x38\x48\x18\x75\x57\x9a\x05\x7e\x3d\xb1\x55\x79\x0f\xd0" \
+"\x8c\x79\x72\x90\xb7\x16\x12\x18\xa1\x97\x53\xf1\x49\x0a\x35\x40" \
+"\xc2\x8b\x72\x7a\x38\x22\x04\x96\x01\xd3\x7e\x47\x5d\xaa\x03\xb0" \
+"\xb5\xc3\xa9\xa6\x21\x14\xc7\xd9\x71\x07"'
+
+# Test all the things...
+echo "flow4 { src 72.229.104.206/32; dst 103.99.170.10/32; proto = 17; sport = 56733; dport = 4242; length = 140; dscp 0/0xff; fragment !dont_fragment && !is_fragment && !first_fragment && !last_fragment };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_DROP" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { port = 4242; icmp code = 0; };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+# Some port tests...
+echo "flow4 { port = 4242 && = 56733; };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_DROP" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { port = 4242 || 1; sport = 56733 };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_DROP" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { port = 4242 && 1 };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { icmp code != 0; };" | ./genrules.py parse_8021q
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+TEST_PKT='#define TEST \
+"\x00\x0d\xb9\x50\x11\x4c\x00\x17\x10\x95\xe8\x96\x86\xdd\x60\x00" \
+"\x00\x00\x00\x20\x06\x37\x2a\x01\x04\xf8\x01\x30\x71\xd2\x00\x00" \
+"\x00\x00\x00\x00\x00\x02\x26\x20\x00\x6e\xa0\x00\x20\x01\x00\x00" \
+"\x00\x00\x00\x00\x00\x06\x20\x8d\xc2\x72\xff\x5f\x50\xa7\x1a\xfb" \
+"\x41\xed\x80\x10\x06\xef\x87\x8c\x00\x00\x01\x01\x08\x0a\x98\x3d" \
+"\x75\xde\xeb\x22\xd6\x80"'
+
+# Some v6 TCP tests...
+echo "flow6 { src 2a01:4f8:130:71d2::2/128; dst 2620:6e:a000:2001::6/128; next header 6; port 8333 && 49778; tcp flags 0x010/0xfff;};" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_DROP" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow6 { src 0:4f8:130:71d2::2/128 offset 16; dst 0:0:a000:2001::/64 offset 32; next header 6; port 8333 && 49778; tcp flags 0x010/0xfff;};" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_DROP" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow6 { icmp code != 0; };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+TEST_PKT='#define TEST \
+"\xcc\x2d\xe0\xf5\x02\xe1\x00\x0d\xb9\x50\x42\xfe\x81\x00\x00\x03" \
+"\x08\x00\x45\x00\x00\x54\xda\x85\x40\x00\x40\x01\x67\xc6\x0a\x45" \
+"\x1e\x51\xd1\xfa\xfd\xcc\x08\x00\x18\x82\x7e\xda\x00\x02\xc8\xc4" \
+"\x67\x60\x00\x00\x00\x00\x69\xa9\x08\x00\x00\x00\x00\x00\x10\x11" \
+"\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21" \
+"\x22\x23\x24\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31" \
+"\x32\x33\x34\x35\x36\x37"'
+
+# ICMP and VLAN tests
+echo "flow4 { src 10.0.0.0/8; dst 209.250.0.0/16; proto = 1; icmp type 8; icmp code >= 0; length < 100; fragment dont_fragment; };" | ./genrules.py parse_8021q
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_DROP" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { icmp type 8; icmp code > 0; };" | ./genrules.py parse_8021q
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { icmp type 9; };" | ./genrules.py parse_8021q
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { src 10.0.0.0/8; dst 209.250.0.0/16; proto = 1; icmp type 8; icmp code >= 0; length < 100; fragment dont_fragment; };" | ./genrules.py req_8021q=3
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_DROP" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { src 0.0.0.0/32; };" | ./genrules.py req_8021q=4
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_DROP" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { src 0.0.0.0/32; };" | ./genrules.py req_8021q=3
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow4 { port 42; };" | ./genrules.py parse_8021q
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+TEST_PKT='#define TEST \
+"\x00\x0d\xb9\x50\x11\x4c\x00\x17\x10\x95\xe8\x96\x86\xdd\x60\x0a" \
+"\xb8\x00\x00\x40\x3a\x3e\x20\x01\x04\x70\x00\x00\x02\xc8\x00\x00" \
+"\x00\x00\x00\x00\x00\x02\x26\x20\x00\x6e\xa0\x00\x00\x01\x00\x00" \
+"\x00\x00\x00\x00\xca\xfe\x81\x00\x40\x94\x85\x14\x00\x13\x00\x00" \
+"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" \
+"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" \
+"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" \
+"\x00\x00\x00\x00\x00\x00"'
+
+# ICMPv6 tests
+echo "flow6 { icmp type 129; icmp code 0; };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_DROP" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow6 { icmp code != 0; };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow6 { tcp flags 0x0/0x0; };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+echo "flow6 { port 42; };" | ./genrules.py
+echo "$TEST_PKT" >> rules.h
+echo "#define TEST_EXP XDP_PASS" >> rules.h
+clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp
+
+
diff --git a/xdp.c b/xdp.c
new file mode 100644 (file)
index 0000000..87b7eb5
--- /dev/null
+++ b/xdp.c
@@ -0,0 +1,243 @@
+#include <stdint.h>
+#include <endian.h>
+#include <linux/if_ether.h>
+#include <linux/ip.h>
+#include <linux/udp.h>
+#include <linux/icmp.h>
+#include <linux/icmpv6.h>
+#include <arpa/inet.h>
+
+#define NULL (void*)0
+
+/* IP flags. */
+#define IP_CE          0x8000          /* Flag: "Congestion"           */
+#define IP_DF          0x4000          /* Flag: "Don't Fragment"       */
+#define IP_MF          0x2000          /* Flag: "More Fragments"       */
+#define IP_OFFSET      0x1FFF          /* "Fragment Offset" part       */
+
+#define IP_PROTO_TCP 6
+#define IP_PROTO_UDP 17
+#define IP_PROTO_ICMP 1
+#define IP_PROTO_ICMPV6 58
+
+typedef __uint128_t uint128_t;
+
+// Our own ipv6hdr that uses uint128_t
+struct ip6hdr {
+#if defined(__LITTLE_ENDIAN_BITFIELD)
+       __u8    priority:4,
+               version:4;
+#elif defined(__BIG_ENDIAN_BITFIELD)
+       __u8    version:4,
+               priority:4;
+#else
+#error "Please fix <asm/byteorder.h>"
+#endif
+       __u8    flow_lbl[3];
+
+       __be16  payload_len;
+       __u8            nexthdr;
+       __u8            hop_limit;
+
+       uint128_t       saddr;
+       uint128_t       daddr;
+} __attribute__((packed));
+
+// Our own ethhdr with optional vlan tags
+struct ethhdr_vlan {
+       unsigned char   h_dest[ETH_ALEN];       /* destination eth addr */
+       unsigned char   h_source[ETH_ALEN];     /* source ether addr    */
+       __be16          vlan_magic;             /* 0x8100 */
+       __be16          tci;            /* PCP (3 bits), DEI (1 bit), and VLAN (12 bits) */
+       __be16          h_proto;                /* packet type ID field */
+} __attribute__((packed));
+
+// Our own tcphdr without the flags blown up
+struct tcphdr {
+       __be16  source;
+       __be16  dest;
+       __be32  seq;
+       __be32  ack_seq;
+       __u16   flags;
+       __be16  window;
+       __sum16 check;
+       __be16  urg_ptr;
+} __attribute__((packed));
+
+// Note that all operations on uint128s *stay* in Network byte order!
+
+#if defined(__LITTLE_ENDIAN)
+#define BIGEND32(v) ((v >> 3*8) | ((v >> 8) & 0xff00) | ((v << 8) & 0xff0000) | (v << 3*8) & 0xff000000)
+#elif defined(__BIG_ENDIAN)
+#define BIGEND32(v) (v)
+#else
+#error "Need endian info"
+#endif
+
+#if defined(__LITTLE_ENDIAN)
+#define BIGEND128(a, b, c, d) ( \
+               (((uint128_t)BIGEND32(d)) << 3*32) | \
+               (((uint128_t)BIGEND32(c)) << 2*32) | \
+               (((uint128_t)BIGEND32(b)) << 1*32) | \
+               (((uint128_t)BIGEND32(a)) << 0*32))
+#define HTON128(a) BIGEND128(a >> 3*32, a >> 2*32, a >> 1*32, a>> 0*32)
+// Yes, somehow macro'ing this changes LLVM's view of htons...
+#define BE16(a) ((((uint16_t)(a & 0xff00)) >> 8) | (((uint16_t)(a & 0xff)) << 8))
+#elif defined(__BIG_ENDIAN)
+#define BIGEND128(a, b, c, d) ((((uint128_t)a) << 3*32) | (((uint128_t)b) << 2*32) | (((uint128_t)c) << 1*32) | (((uint128_t)d) << 0*32))
+#define HTON128(a) (a)
+#else
+#error "Need endian info"
+#endif
+
+#define MASK4(pfxlen) BIGEND32(~((((uint32_t)1) << (32 - pfxlen)) - 1))
+#define MASK6(pfxlen) HTON128(~((((uint128_t)1) << (128 - pfxlen)) - 1))
+#define MASK6_OFFS(offs, pfxlen) HTON128((~((((uint128_t)1) << (128 - pfxlen)) - 1)) & ((((uint128_t)1) << (128 - offs)) - 1))
+
+// Note rules.h may also define PARSE_8021Q and REQ_8021Q
+// Note rules.h may also define PARSE_IHL
+#include "rules.h"
+
+#ifdef TEST
+// 64 bit version of xdp_md for testing
+struct xdp_md {
+       __u64 data;
+       __u64 data_end;
+       __u64 data_meta;
+       /* Below access go through struct xdp_rxq_info */
+       __u64 ingress_ifindex; /* rxq->dev->ifindex */
+       __u64 rx_queue_index;  /* rxq->queue_index  */
+
+       __u64 egress_ifindex;  /* txq->dev->ifindex */
+};
+static const int XDP_PASS = 0;
+static const int XDP_DROP = 1;
+#else
+#include <linux/bpf.h>
+#include <bpf/bpf_helpers.h>
+
+SEC("xdp_drop")
+#endif
+int xdp_drop_prog(struct xdp_md *ctx)
+{
+       const void *const data_end = (void *)(size_t)ctx->data_end;
+
+       const void * pktdata;
+       unsigned short eth_proto;
+
+       {
+               if ((void*)(size_t)ctx->data + sizeof(struct ethhdr) > data_end)
+                       return XDP_DROP;
+               const struct ethhdr *const eth = (void*)(size_t)ctx->data;
+
+#ifdef PARSE_8021Q
+               if (eth->h_proto == BE16(ETH_P_8021Q)) {
+                       if ((void*)(size_t)ctx->data + sizeof(struct ethhdr_vlan) > data_end)
+                               return XDP_DROP;
+                       const struct ethhdr_vlan *const eth_vlan = (void*)(size_t)ctx->data;
+
+#ifdef REQ_8021Q
+                       if ((eth_vlan->tci & BE16(0xfff)) != BE16(REQ_8021Q))
+                               return XDP_DROP;
+#endif
+
+                       eth_proto = eth_vlan->h_proto;
+                       pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr_vlan);
+               } else {
+#ifdef REQ_8021Q
+                       return XDP_DROP;
+#else
+                       pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr);
+                       eth_proto = eth->h_proto;
+#endif
+               }
+#else
+               pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr);
+               eth_proto = eth->h_proto;
+#endif
+       }
+
+       const struct tcphdr *tcp = NULL;
+       const struct udphdr *udp = NULL;
+       const struct icmphdr *icmp = NULL;
+       const struct icmp6hdr *icmpv6 = NULL;
+       const struct iphdr *ip = NULL;
+       const struct ip6hdr *ip6 = NULL;
+       const void *l4hdr = NULL;
+       if (eth_proto == BE16(ETH_P_IP)) {
+               if (pktdata + sizeof(struct iphdr) > data_end)
+                       return XDP_DROP;
+               ip = (struct iphdr*) pktdata;
+
+#ifdef PARSE_IHL
+               if (ip->ihl < 5) return XDP_DROP;
+               l4hdr = pktdata + ip->ihl * 4;
+#else
+               if (ip->ihl != 5) return XDP_DROP;
+               l4hdr = pktdata + 5*4;
+#endif
+               if (ip->protocol == IP_PROTO_TCP) {
+                       if (l4hdr + sizeof(struct tcphdr) > data_end)
+                               return XDP_DROP;
+                       tcp = (struct tcphdr*) l4hdr;
+               } else if (ip->protocol == IP_PROTO_UDP) {
+                       if (l4hdr + sizeof(struct udphdr) > data_end)
+                               return XDP_DROP;
+                       udp = (struct udphdr*) l4hdr;
+               } else if (ip->protocol == IP_PROTO_ICMP) {
+                       if (l4hdr + sizeof(struct icmphdr) > data_end)
+                               return XDP_DROP;
+                       icmp = (struct icmphdr*) l4hdr;
+               }
+       } else if (eth_proto == BE16(ETH_P_IPV6)) {
+               if (pktdata + sizeof(struct ip6hdr) > data_end)
+                       return XDP_DROP;
+               ip6 = (struct ip6hdr*) pktdata;
+
+               l4hdr = pktdata + 40;
+               if (ip6->nexthdr == IP_PROTO_TCP) {
+                       if (l4hdr + sizeof(struct tcphdr) > data_end)
+                               return XDP_DROP;
+                       tcp = (struct tcphdr*) l4hdr;
+               } else if (ip6->nexthdr == IP_PROTO_UDP) {
+                       if (l4hdr + sizeof(struct udphdr) > data_end)
+                               return XDP_DROP;
+                       udp = (struct udphdr*) l4hdr;
+               } else if (ip6->nexthdr == IP_PROTO_ICMPV6) {
+                       if (l4hdr + sizeof(struct icmp6hdr) > data_end)
+                               return XDP_DROP;
+                       icmpv6 = (struct icmp6hdr*) l4hdr;
+               }
+               // TODO: Handle some options?
+       } else {
+               return XDP_PASS;
+       }
+
+       uint16_t sport, dport; // Host Endian! Only valid with tcp || udp
+       if (tcp != NULL) {
+               sport = BE16(tcp->source);
+               dport = BE16(tcp->dest);
+       } else if (udp != NULL) {
+               sport = BE16(udp->source);
+               dport = BE16(udp->dest);
+       }
+
+       RULES
+
+       return XDP_PASS;
+}
+
+#ifdef TEST
+#include <assert.h>
+#include <string.h>
+
+const char d[] = TEST;
+int main() {
+       struct xdp_md test = {
+               .data = (uint64_t)d,
+               // -1 because sizeof includes a trailing null in the "string"
+               .data_end = (uint64_t)(d + sizeof(d) - 1),
+       };
+       assert(xdp_drop_prog(&test) == TEST_EXP);
+}
+#endif