From a0ba2a08e5f98bc4f731b384ac13e11f610ffdda Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Sat, 3 Apr 2021 00:07:27 -0400 Subject: [PATCH] Initial checkin --- genrules.py | 295 ++++++++++++++++++++++++++++++++++++++++++++++++++++ test.sh | 149 ++++++++++++++++++++++++++ xdp.c | 243 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 687 insertions(+) create mode 100755 genrules.py create mode 100755 test.sh create mode 100644 xdp.c diff --git a/genrules.py b/genrules.py new file mode 100755 index 0000000..3f5d9e0 --- /dev/null +++ b/genrules.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3 + +import sys +import ipaddress +from enum import Enum + +IP_PROTO_ICMP = 1 +IP_PROTO_ICMPV6 = 58 +IP_PROTO_TCP = 6 +IP_PROTO_UDP = 17 + +if len(sys.argv) > 2 and sys.argv[2].startswith("parse_ihl"): + PARSE_IHL = True +else: + PARSE_IHL = False +if len(sys.argv) > 3 and sys.argv[3].startswith("parse_exthdr"): + PARSE_EXTHDR = True +else: + PARSE_EXTHDR = False + + +class ASTAction(Enum): + OR = 1, + AND = 2, + NOT = 3, + EXPR = 4 +class ASTNode: + def __init__(self, action, left, right=None): + self.action = action + self.left = left + if right is None: + assert action == ASTAction.EXPR or action == ASTAction.NOT + else: + self.right = right + + def write(self, expr_param, expr_param2=None): + if self.action == ASTAction.OR: + return "(" + self.left.write(expr_param, expr_param2) + ") || (" + self.right.write(expr_param, expr_param2) + ")" + if self.action == ASTAction.AND: + return "(" + self.left.write(expr_param, expr_param2) + ") && (" + self.right.write(expr_param, expr_param2) + ")" + if self.action == ASTAction.NOT: + return "!(" + self.left.write(expr_param, expr_param2) + ")" + if self.action == ASTAction.EXPR: + return self.left.write(expr_param, expr_param2) + +def parse_ast(expr, parse_expr): + expr = expr.strip() + + and_split = expr.split("&&", 1) + or_split = expr.split("||", 1) + if len(and_split) > 1 and not "||" in and_split[0]: + return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr)) + if len(or_split) > 1: + assert not "&&" in or_split[0] + return ASTNode(ASTAction.OR, parse_ast(or_split[0], parse_expr), parse_ast(or_split[1], parse_expr)) + + comma_split = expr.split(",", 1) + if len(comma_split) > 1: + return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr)) + + if expr.startswith("!"): + return ASTNode(ASTAction.NOT, parse_ast(expr[1:], parse_expr)) + + return parse_expr(expr) + + +class NumbersAction(Enum): + EQ = "==" + GT = ">" + GTOE = ">=" + LT = "<" + LTOE = "<=" +class NumbersExpr: + def __init__(self, action, val): + self.action = action + self.val = val + + def write(self, param, param2): + if param2 is not None: + return "(" + param + self.action.value + self.val + ") || (" + param2 + self.action.value + self.val + ")" + return param + self.action.value + self.val + +def parse_numbers_expr(expr): + space_split = expr.split(" ") + if expr.startswith(">="): + assert len(space_split) == 2 + return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, space_split[1])) + if expr.startswith(">"): + assert len(space_split) == 2 + return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GT, space_split[1])) + if expr.startswith("<="): + assert len(space_split) == 2 + return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, space_split[1])) + if expr.startswith("<"): + assert len(space_split) == 2 + return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LT, space_split[1])) + if ".." in expr: + rangesplit = expr.split("..") + assert len(rangesplit) == 2 + #XXX: Are ranges really inclusive,inclusive? + left = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, rangesplit[0])) + right = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, rangesplit[1])) + return ASTNode(ASTAction.AND, left, right) + + if expr.startswith("= "): + expr = expr[2:] + return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.EQ, expr)) + +class FragExpr: + def __init__(self, val): + if val == "is_fragment": + self.rule = "(ip->frag_off & BE16(IP_MF|IP_OFFSET)) != 0" + elif val == "first_fragment": + self.rule = "(ip->frag_off & BE16(IP_MF)) != 0 && (ip->frag_off & BE16(IP_OFFSET)) == 0" + elif val == "dont_fragment": + self.rule = "(ip->frag_off & BE16(IP_DF)) != 0" + elif val == "last_fragment": + self.rule = "(ip->frag_off & BE16(IP_MF)) == 0 && (ip->frag_off & BE16(IP_OFFSET)) != 0" + else: + assert False + + def write(self, _param, _param2): + return self.rule + +def parse_frag_expr(expr): + return ASTNode(ASTAction.EXPR, FragExpr(expr)) + +class BitExpr: + def __init__(self, val): + s = val.split("/") + assert len(s) == 2 + self.match = s[0] + self.mask = s[1] + + def write(self, param, _param2): + return f"({param} & {self.mask}) == {self.match}" + +def parse_bit_expr(expr): + return ASTNode(ASTAction.EXPR, BitExpr(expr)) + + +def ip_to_rule(proto, inip, ty, offset): + if proto == 4: + assert offset is None + net = ipaddress.IPv4Network(inip.strip()) + if net.prefixlen == 0: + return "" + return f"""if ((ip->{ty} & MASK4({net.prefixlen})) != BIGEND32({int(net.network_address)}ULL)) + break;""" + else: + net = ipaddress.IPv6Network(inip.strip()) + if net.prefixlen == 0: + return "" + u32s = [(int(net.network_address) >> (3*32)) & 0xffffffff, + (int(net.network_address) >> (2*32)) & 0xffffffff, + (int(net.network_address) >> (1*32)) & 0xffffffff, + (int(net.network_address) >> (0*32)) & 0xffffffff] + if offset is None: + mask = f"MASK6({net.prefixlen})" + else: + mask = f"MASK6_OFFS({offset}, {net.prefixlen})" + return f"""if ((ip6->{ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask})) + break;""" + +def fragment_to_rule(ipproto, rules): + if ipproto == 6: + assert False # XXX: unimplemented + ast = parse_ast(rules, parse_frag_expr) + return "if (!( " + ast.write(()) + " )) break;" + +def len_to_rule(rules): + ast = parse_ast(rules, parse_numbers_expr) + return "if (!( " + ast.write("(data_end - pktdata)") + " )) break;" + +def proto_to_rule(ipproto, proto): + ast = parse_ast(proto, parse_numbers_expr) + + if ipproto == 4: + return "if (!( " + ast.write("ip->protocol") + " )) break;" + else: + if PARSE_EXTHDR: + assert False # XXX: unimplemented + return "if (!( " + ast.write("ip6->nexthdr") + " )) break;" + +def icmp_type_to_rule(proto, ty): + ast = parse_ast(ty, parse_numbers_expr) + if proto == 4: + return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->type") + " )) break;" + else: + return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_type") + " )) break;" + +def icmp_code_to_rule(proto, code): + ast = parse_ast(code, parse_numbers_expr) + if proto == 4: + return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->code") + " )) break;" + else: + return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_code") + " )) break;" + +def dscp_to_rule(proto, rules): + ast = parse_ast(rules, parse_numbers_expr) + + if proto == 4: + return "if (!( " + ast.write("((ip->tos & 0xfc) >> 2)") + " )) break;" + else: + return "if (!( " + ast.write("((ip6->priority << 4) | ((ip6->flow_lbl[0] & 0xc0) >> 4) >> 2)") + " )) break;" + +def port_to_rule(ty, rules): + if ty == "port" : + ast = parse_ast(rules, parse_numbers_expr) + return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write("sport", "dport") + " )) break;" + + ast = parse_ast(rules, parse_numbers_expr) + return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write(ty) + " )) break;" + +def tcp_flags_to_rule(rules): + ast = parse_ast(rules, parse_bit_expr) + + return f"""if (tcp == NULL) break; +if (!( {ast.write("(ntohs(tcp->flags) & 0xfff)")} )) break;""" + +def flow_label_to_rule(rules): + ast = parse_ast(rules, parse_bit_expr) + + return f"""if (ip6 == NULL) break; +if (!( {ast.write("((((uint32_t)(ip6->flow_lbl[0] & 0xf)) << 2*8) | (((uint32_t)ip6->flow_lbl[1]) << 1*8) | (uint32_t)ip6->flow_lbl[0])")} )) break;""" + +with open("rules.h", "w") as out: + if len(sys.argv) > 1 and sys.argv[1] == "parse_8021q": + out.write("#define PARSE_8021Q\n") + if len(sys.argv) > 1 and sys.argv[1].startswith("req_8021q="): + out.write("#define PARSE_8021Q\n") + out.write(f"#define REQ_8021Q {sys.argv[1][10:]}\n") + + + out.write("#define RULES \\\n") + + def write_rule(r): + out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n") + + for line in sys.stdin.readlines(): + t = line.split("{") + if len(t) != 2: + continue + if t[0].strip() == "flow4": + proto = 4 + out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n") + out.write("\tdo {\\\n") + elif t[0].strip() == "flow6": + proto = 6 + out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n") + out.write("\tdo {\\\n") + else: + continue + + rule = t[1].split("}")[0].strip() + for step in rule.split(";"): + if step.strip().startswith("src") or step.strip().startswith("dst"): + nets = step.strip()[3:].strip().split(" ") + if len(nets) > 1: + assert nets[1] == "offset" + offset = nets[2] + else: + offset = None + if step.strip().startswith("src"): + write_rule(ip_to_rule(proto, nets[0], "saddr", offset)) + else: + write_rule(ip_to_rule(proto, nets[0], "daddr", offset)) + elif step.strip().startswith("proto") and proto == 4: + write_rule(proto_to_rule(4, step.strip()[6:])) + elif step.strip().startswith("next header") and proto == 6: + write_rule(proto_to_rule(6, step.strip()[12:])) + elif step.strip().startswith("icmp type"): + write_rule(icmp_type_to_rule(proto, step.strip()[10:])) + elif step.strip().startswith("icmp code"): + write_rule(icmp_code_to_rule(proto, step.strip()[10:])) + elif step.strip().startswith("sport") or step.strip().startswith("dport") or step.strip().startswith("port"): + write_rule(port_to_rule(step.strip().split(" ")[0], step.strip().split(" ", 1)[1])) + elif step.strip().startswith("length"): + write_rule(len_to_rule(step.strip()[7:])) + elif step.strip().startswith("dscp"): + write_rule(dscp_to_rule(proto, step.strip()[5:])) + elif step.strip().startswith("tcp flags"): + write_rule(tcp_flags_to_rule(step.strip()[10:])) + elif step.strip().startswith("label"): + write_rule(flow_label_to_rule(step.strip()[6:])) + elif step.strip().startswith("fragment"): + write_rule(fragment_to_rule(proto, step.strip()[9:])) + elif step.strip() == "": + pass + else: + assert False + out.write("\t\treturn XDP_DROP;\\\n") + out.write("\t} while(0);\\\n}\\\n") + + out.write("\n") diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..03bc656 --- /dev/null +++ b/test.sh @@ -0,0 +1,149 @@ +#!/bin/bash + +set -e + +TEST_PKT='#define TEST \ +"\x00\x17\x10\x95\xe8\x96\x00\x0d\xb9\x50\x11\x4c\x08\x00\x45\x00" \ +"\x00\x8c\x7d\x0f\x00\x00\x40\x11\x3a\x31\x48\xe5\x68\xce\x67\x63" \ +"\xaa\x0a\xdd\x9d\x10\x92\x00\x78\xc3\xaa\x04\x00\x00\x00\x47\x89" \ +"\x49\xb1\x1f\x0e\x00\x00\x00\x00\x00\x00\xa7\xee\xab\xa4\xc6\x09" \ +"\xe7\x0f\x41\xfc\xd5\x75\x1d\xc4\x97\xfa\xd7\x96\x8c\x1f\x19\x54" \ +"\xa7\x74\x08\x5c\x28\xfe\xd9\x32\x4b\xe0\x62\x55\xeb\xb4\x1e\x36" \ +"\x5f\xf5\x38\x48\x18\x75\x57\x9a\x05\x7e\x3d\xb1\x55\x79\x0f\xd0" \ +"\x8c\x79\x72\x90\xb7\x16\x12\x18\xa1\x97\x53\xf1\x49\x0a\x35\x40" \ +"\xc2\x8b\x72\x7a\x38\x22\x04\x96\x01\xd3\x7e\x47\x5d\xaa\x03\xb0" \ +"\xb5\xc3\xa9\xa6\x21\x14\xc7\xd9\x71\x07"' + +# Test all the things... +echo "flow4 { src 72.229.104.206/32; dst 103.99.170.10/32; proto = 17; sport = 56733; dport = 4242; length = 140; dscp 0/0xff; fragment !dont_fragment && !is_fragment && !first_fragment && !last_fragment };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_DROP" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { port = 4242; icmp code = 0; };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +# Some port tests... +echo "flow4 { port = 4242 && = 56733; };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_DROP" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { port = 4242 || 1; sport = 56733 };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_DROP" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { port = 4242 && 1 };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { icmp code != 0; };" | ./genrules.py parse_8021q +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +TEST_PKT='#define TEST \ +"\x00\x0d\xb9\x50\x11\x4c\x00\x17\x10\x95\xe8\x96\x86\xdd\x60\x00" \ +"\x00\x00\x00\x20\x06\x37\x2a\x01\x04\xf8\x01\x30\x71\xd2\x00\x00" \ +"\x00\x00\x00\x00\x00\x02\x26\x20\x00\x6e\xa0\x00\x20\x01\x00\x00" \ +"\x00\x00\x00\x00\x00\x06\x20\x8d\xc2\x72\xff\x5f\x50\xa7\x1a\xfb" \ +"\x41\xed\x80\x10\x06\xef\x87\x8c\x00\x00\x01\x01\x08\x0a\x98\x3d" \ +"\x75\xde\xeb\x22\xd6\x80"' + +# Some v6 TCP tests... +echo "flow6 { src 2a01:4f8:130:71d2::2/128; dst 2620:6e:a000:2001::6/128; next header 6; port 8333 && 49778; tcp flags 0x010/0xfff;};" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_DROP" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow6 { src 0:4f8:130:71d2::2/128 offset 16; dst 0:0:a000:2001::/64 offset 32; next header 6; port 8333 && 49778; tcp flags 0x010/0xfff;};" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_DROP" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow6 { icmp code != 0; };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +TEST_PKT='#define TEST \ +"\xcc\x2d\xe0\xf5\x02\xe1\x00\x0d\xb9\x50\x42\xfe\x81\x00\x00\x03" \ +"\x08\x00\x45\x00\x00\x54\xda\x85\x40\x00\x40\x01\x67\xc6\x0a\x45" \ +"\x1e\x51\xd1\xfa\xfd\xcc\x08\x00\x18\x82\x7e\xda\x00\x02\xc8\xc4" \ +"\x67\x60\x00\x00\x00\x00\x69\xa9\x08\x00\x00\x00\x00\x00\x10\x11" \ +"\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21" \ +"\x22\x23\x24\x25\x26\x27\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31" \ +"\x32\x33\x34\x35\x36\x37"' + +# ICMP and VLAN tests +echo "flow4 { src 10.0.0.0/8; dst 209.250.0.0/16; proto = 1; icmp type 8; icmp code >= 0; length < 100; fragment dont_fragment; };" | ./genrules.py parse_8021q +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_DROP" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { icmp type 8; icmp code > 0; };" | ./genrules.py parse_8021q +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { icmp type 9; };" | ./genrules.py parse_8021q +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { src 10.0.0.0/8; dst 209.250.0.0/16; proto = 1; icmp type 8; icmp code >= 0; length < 100; fragment dont_fragment; };" | ./genrules.py req_8021q=3 +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_DROP" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { src 0.0.0.0/32; };" | ./genrules.py req_8021q=4 +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_DROP" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { src 0.0.0.0/32; };" | ./genrules.py req_8021q=3 +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow4 { port 42; };" | ./genrules.py parse_8021q +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +TEST_PKT='#define TEST \ +"\x00\x0d\xb9\x50\x11\x4c\x00\x17\x10\x95\xe8\x96\x86\xdd\x60\x0a" \ +"\xb8\x00\x00\x40\x3a\x3e\x20\x01\x04\x70\x00\x00\x02\xc8\x00\x00" \ +"\x00\x00\x00\x00\x00\x02\x26\x20\x00\x6e\xa0\x00\x00\x01\x00\x00" \ +"\x00\x00\x00\x00\xca\xfe\x81\x00\x40\x94\x85\x14\x00\x13\x00\x00" \ +"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" \ +"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" \ +"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" \ +"\x00\x00\x00\x00\x00\x00"' + +# ICMPv6 tests +echo "flow6 { icmp type 129; icmp code 0; };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_DROP" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow6 { icmp code != 0; };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow6 { tcp flags 0x0/0x0; };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + +echo "flow6 { port 42; };" | ./genrules.py +echo "$TEST_PKT" >> rules.h +echo "#define TEST_EXP XDP_PASS" >> rules.h +clang -std=c99 -fsanitize=address -pedantic -Wall -Wextra -Wno-pointer-arith -Wno-unused-variable -O0 -g xdp.c -o xdp && ./xdp + + diff --git a/xdp.c b/xdp.c new file mode 100644 index 0000000..87b7eb5 --- /dev/null +++ b/xdp.c @@ -0,0 +1,243 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#define NULL (void*)0 + +/* IP flags. */ +#define IP_CE 0x8000 /* Flag: "Congestion" */ +#define IP_DF 0x4000 /* Flag: "Don't Fragment" */ +#define IP_MF 0x2000 /* Flag: "More Fragments" */ +#define IP_OFFSET 0x1FFF /* "Fragment Offset" part */ + +#define IP_PROTO_TCP 6 +#define IP_PROTO_UDP 17 +#define IP_PROTO_ICMP 1 +#define IP_PROTO_ICMPV6 58 + +typedef __uint128_t uint128_t; + +// Our own ipv6hdr that uses uint128_t +struct ip6hdr { +#if defined(__LITTLE_ENDIAN_BITFIELD) + __u8 priority:4, + version:4; +#elif defined(__BIG_ENDIAN_BITFIELD) + __u8 version:4, + priority:4; +#else +#error "Please fix " +#endif + __u8 flow_lbl[3]; + + __be16 payload_len; + __u8 nexthdr; + __u8 hop_limit; + + uint128_t saddr; + uint128_t daddr; +} __attribute__((packed)); + +// Our own ethhdr with optional vlan tags +struct ethhdr_vlan { + unsigned char h_dest[ETH_ALEN]; /* destination eth addr */ + unsigned char h_source[ETH_ALEN]; /* source ether addr */ + __be16 vlan_magic; /* 0x8100 */ + __be16 tci; /* PCP (3 bits), DEI (1 bit), and VLAN (12 bits) */ + __be16 h_proto; /* packet type ID field */ +} __attribute__((packed)); + +// Our own tcphdr without the flags blown up +struct tcphdr { + __be16 source; + __be16 dest; + __be32 seq; + __be32 ack_seq; + __u16 flags; + __be16 window; + __sum16 check; + __be16 urg_ptr; +} __attribute__((packed)); + +// Note that all operations on uint128s *stay* in Network byte order! + +#if defined(__LITTLE_ENDIAN) +#define BIGEND32(v) ((v >> 3*8) | ((v >> 8) & 0xff00) | ((v << 8) & 0xff0000) | (v << 3*8) & 0xff000000) +#elif defined(__BIG_ENDIAN) +#define BIGEND32(v) (v) +#else +#error "Need endian info" +#endif + +#if defined(__LITTLE_ENDIAN) +#define BIGEND128(a, b, c, d) ( \ + (((uint128_t)BIGEND32(d)) << 3*32) | \ + (((uint128_t)BIGEND32(c)) << 2*32) | \ + (((uint128_t)BIGEND32(b)) << 1*32) | \ + (((uint128_t)BIGEND32(a)) << 0*32)) +#define HTON128(a) BIGEND128(a >> 3*32, a >> 2*32, a >> 1*32, a>> 0*32) +// Yes, somehow macro'ing this changes LLVM's view of htons... +#define BE16(a) ((((uint16_t)(a & 0xff00)) >> 8) | (((uint16_t)(a & 0xff)) << 8)) +#elif defined(__BIG_ENDIAN) +#define BIGEND128(a, b, c, d) ((((uint128_t)a) << 3*32) | (((uint128_t)b) << 2*32) | (((uint128_t)c) << 1*32) | (((uint128_t)d) << 0*32)) +#define HTON128(a) (a) +#else +#error "Need endian info" +#endif + +#define MASK4(pfxlen) BIGEND32(~((((uint32_t)1) << (32 - pfxlen)) - 1)) +#define MASK6(pfxlen) HTON128(~((((uint128_t)1) << (128 - pfxlen)) - 1)) +#define MASK6_OFFS(offs, pfxlen) HTON128((~((((uint128_t)1) << (128 - pfxlen)) - 1)) & ((((uint128_t)1) << (128 - offs)) - 1)) + +// Note rules.h may also define PARSE_8021Q and REQ_8021Q +// Note rules.h may also define PARSE_IHL +#include "rules.h" + +#ifdef TEST +// 64 bit version of xdp_md for testing +struct xdp_md { + __u64 data; + __u64 data_end; + __u64 data_meta; + /* Below access go through struct xdp_rxq_info */ + __u64 ingress_ifindex; /* rxq->dev->ifindex */ + __u64 rx_queue_index; /* rxq->queue_index */ + + __u64 egress_ifindex; /* txq->dev->ifindex */ +}; +static const int XDP_PASS = 0; +static const int XDP_DROP = 1; +#else +#include +#include + +SEC("xdp_drop") +#endif +int xdp_drop_prog(struct xdp_md *ctx) +{ + const void *const data_end = (void *)(size_t)ctx->data_end; + + const void * pktdata; + unsigned short eth_proto; + + { + if ((void*)(size_t)ctx->data + sizeof(struct ethhdr) > data_end) + return XDP_DROP; + const struct ethhdr *const eth = (void*)(size_t)ctx->data; + +#ifdef PARSE_8021Q + if (eth->h_proto == BE16(ETH_P_8021Q)) { + if ((void*)(size_t)ctx->data + sizeof(struct ethhdr_vlan) > data_end) + return XDP_DROP; + const struct ethhdr_vlan *const eth_vlan = (void*)(size_t)ctx->data; + +#ifdef REQ_8021Q + if ((eth_vlan->tci & BE16(0xfff)) != BE16(REQ_8021Q)) + return XDP_DROP; +#endif + + eth_proto = eth_vlan->h_proto; + pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr_vlan); + } else { +#ifdef REQ_8021Q + return XDP_DROP; +#else + pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr); + eth_proto = eth->h_proto; +#endif + } +#else + pktdata = (const void *)(long)ctx->data + sizeof(struct ethhdr); + eth_proto = eth->h_proto; +#endif + } + + const struct tcphdr *tcp = NULL; + const struct udphdr *udp = NULL; + const struct icmphdr *icmp = NULL; + const struct icmp6hdr *icmpv6 = NULL; + const struct iphdr *ip = NULL; + const struct ip6hdr *ip6 = NULL; + const void *l4hdr = NULL; + if (eth_proto == BE16(ETH_P_IP)) { + if (pktdata + sizeof(struct iphdr) > data_end) + return XDP_DROP; + ip = (struct iphdr*) pktdata; + +#ifdef PARSE_IHL + if (ip->ihl < 5) return XDP_DROP; + l4hdr = pktdata + ip->ihl * 4; +#else + if (ip->ihl != 5) return XDP_DROP; + l4hdr = pktdata + 5*4; +#endif + if (ip->protocol == IP_PROTO_TCP) { + if (l4hdr + sizeof(struct tcphdr) > data_end) + return XDP_DROP; + tcp = (struct tcphdr*) l4hdr; + } else if (ip->protocol == IP_PROTO_UDP) { + if (l4hdr + sizeof(struct udphdr) > data_end) + return XDP_DROP; + udp = (struct udphdr*) l4hdr; + } else if (ip->protocol == IP_PROTO_ICMP) { + if (l4hdr + sizeof(struct icmphdr) > data_end) + return XDP_DROP; + icmp = (struct icmphdr*) l4hdr; + } + } else if (eth_proto == BE16(ETH_P_IPV6)) { + if (pktdata + sizeof(struct ip6hdr) > data_end) + return XDP_DROP; + ip6 = (struct ip6hdr*) pktdata; + + l4hdr = pktdata + 40; + if (ip6->nexthdr == IP_PROTO_TCP) { + if (l4hdr + sizeof(struct tcphdr) > data_end) + return XDP_DROP; + tcp = (struct tcphdr*) l4hdr; + } else if (ip6->nexthdr == IP_PROTO_UDP) { + if (l4hdr + sizeof(struct udphdr) > data_end) + return XDP_DROP; + udp = (struct udphdr*) l4hdr; + } else if (ip6->nexthdr == IP_PROTO_ICMPV6) { + if (l4hdr + sizeof(struct icmp6hdr) > data_end) + return XDP_DROP; + icmpv6 = (struct icmp6hdr*) l4hdr; + } + // TODO: Handle some options? + } else { + return XDP_PASS; + } + + uint16_t sport, dport; // Host Endian! Only valid with tcp || udp + if (tcp != NULL) { + sport = BE16(tcp->source); + dport = BE16(tcp->dest); + } else if (udp != NULL) { + sport = BE16(udp->source); + dport = BE16(udp->dest); + } + + RULES + + return XDP_PASS; +} + +#ifdef TEST +#include +#include + +const char d[] = TEST; +int main() { + struct xdp_md test = { + .data = (uint64_t)d, + // -1 because sizeof includes a trailing null in the "string" + .data_end = (uint64_t)(d + sizeof(d) - 1), + }; + assert(xdp_drop_prog(&test) == TEST_EXP); +} +#endif -- 2.30.2