+#!/usr/bin/env python3
+
+import sys
+import ipaddress
+from enum import Enum
+
+IP_PROTO_ICMP = 1
+IP_PROTO_ICMPV6 = 58
+IP_PROTO_TCP = 6
+IP_PROTO_UDP = 17
+
+if len(sys.argv) > 2 and sys.argv[2].startswith("parse_ihl"):
+ PARSE_IHL = True
+else:
+ PARSE_IHL = False
+if len(sys.argv) > 3 and sys.argv[3].startswith("parse_exthdr"):
+ PARSE_EXTHDR = True
+else:
+ PARSE_EXTHDR = False
+
+
+class ASTAction(Enum):
+ OR = 1,
+ AND = 2,
+ NOT = 3,
+ EXPR = 4
+class ASTNode:
+ def __init__(self, action, left, right=None):
+ self.action = action
+ self.left = left
+ if right is None:
+ assert action == ASTAction.EXPR or action == ASTAction.NOT
+ else:
+ self.right = right
+
+ def write(self, expr_param, expr_param2=None):
+ if self.action == ASTAction.OR:
+ return "(" + self.left.write(expr_param, expr_param2) + ") || (" + self.right.write(expr_param, expr_param2) + ")"
+ if self.action == ASTAction.AND:
+ return "(" + self.left.write(expr_param, expr_param2) + ") && (" + self.right.write(expr_param, expr_param2) + ")"
+ if self.action == ASTAction.NOT:
+ return "!(" + self.left.write(expr_param, expr_param2) + ")"
+ if self.action == ASTAction.EXPR:
+ return self.left.write(expr_param, expr_param2)
+
+def parse_ast(expr, parse_expr):
+ expr = expr.strip()
+
+ and_split = expr.split("&&", 1)
+ or_split = expr.split("||", 1)
+ if len(and_split) > 1 and not "||" in and_split[0]:
+ return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr))
+ if len(or_split) > 1:
+ assert not "&&" in or_split[0]
+ return ASTNode(ASTAction.OR, parse_ast(or_split[0], parse_expr), parse_ast(or_split[1], parse_expr))
+
+ comma_split = expr.split(",", 1)
+ if len(comma_split) > 1:
+ return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr))
+
+ if expr.startswith("!"):
+ return ASTNode(ASTAction.NOT, parse_ast(expr[1:], parse_expr))
+
+ return parse_expr(expr)
+
+
+class NumbersAction(Enum):
+ EQ = "=="
+ GT = ">"
+ GTOE = ">="
+ LT = "<"
+ LTOE = "<="
+class NumbersExpr:
+ def __init__(self, action, val):
+ self.action = action
+ self.val = val
+
+ def write(self, param, param2):
+ if param2 is not None:
+ return "(" + param + self.action.value + self.val + ") || (" + param2 + self.action.value + self.val + ")"
+ return param + self.action.value + self.val
+
+def parse_numbers_expr(expr):
+ space_split = expr.split(" ")
+ if expr.startswith(">="):
+ assert len(space_split) == 2
+ return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, space_split[1]))
+ if expr.startswith(">"):
+ assert len(space_split) == 2
+ return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GT, space_split[1]))
+ if expr.startswith("<="):
+ assert len(space_split) == 2
+ return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, space_split[1]))
+ if expr.startswith("<"):
+ assert len(space_split) == 2
+ return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LT, space_split[1]))
+ if ".." in expr:
+ rangesplit = expr.split("..")
+ assert len(rangesplit) == 2
+ #XXX: Are ranges really inclusive,inclusive?
+ left = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, rangesplit[0]))
+ right = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, rangesplit[1]))
+ return ASTNode(ASTAction.AND, left, right)
+
+ if expr.startswith("= "):
+ expr = expr[2:]
+ return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.EQ, expr))
+
+class FragExpr:
+ def __init__(self, val):
+ if val == "is_fragment":
+ self.rule = "(ip->frag_off & BE16(IP_MF|IP_OFFSET)) != 0"
+ elif val == "first_fragment":
+ self.rule = "(ip->frag_off & BE16(IP_MF)) != 0 && (ip->frag_off & BE16(IP_OFFSET)) == 0"
+ elif val == "dont_fragment":
+ self.rule = "(ip->frag_off & BE16(IP_DF)) != 0"
+ elif val == "last_fragment":
+ self.rule = "(ip->frag_off & BE16(IP_MF)) == 0 && (ip->frag_off & BE16(IP_OFFSET)) != 0"
+ else:
+ assert False
+
+ def write(self, _param, _param2):
+ return self.rule
+
+def parse_frag_expr(expr):
+ return ASTNode(ASTAction.EXPR, FragExpr(expr))
+
+class BitExpr:
+ def __init__(self, val):
+ s = val.split("/")
+ assert len(s) == 2
+ self.match = s[0]
+ self.mask = s[1]
+
+ def write(self, param, _param2):
+ return f"({param} & {self.mask}) == {self.match}"
+
+def parse_bit_expr(expr):
+ return ASTNode(ASTAction.EXPR, BitExpr(expr))
+
+
+def ip_to_rule(proto, inip, ty, offset):
+ if proto == 4:
+ assert offset is None
+ net = ipaddress.IPv4Network(inip.strip())
+ if net.prefixlen == 0:
+ return ""
+ return f"""if ((ip->{ty} & MASK4({net.prefixlen})) != BIGEND32({int(net.network_address)}ULL))
+ break;"""
+ else:
+ net = ipaddress.IPv6Network(inip.strip())
+ if net.prefixlen == 0:
+ return ""
+ u32s = [(int(net.network_address) >> (3*32)) & 0xffffffff,
+ (int(net.network_address) >> (2*32)) & 0xffffffff,
+ (int(net.network_address) >> (1*32)) & 0xffffffff,
+ (int(net.network_address) >> (0*32)) & 0xffffffff]
+ if offset is None:
+ mask = f"MASK6({net.prefixlen})"
+ else:
+ mask = f"MASK6_OFFS({offset}, {net.prefixlen})"
+ return f"""if ((ip6->{ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask}))
+ break;"""
+
+def fragment_to_rule(ipproto, rules):
+ if ipproto == 6:
+ assert False # XXX: unimplemented
+ ast = parse_ast(rules, parse_frag_expr)
+ return "if (!( " + ast.write(()) + " )) break;"
+
+def len_to_rule(rules):
+ ast = parse_ast(rules, parse_numbers_expr)
+ return "if (!( " + ast.write("(data_end - pktdata)") + " )) break;"
+
+def proto_to_rule(ipproto, proto):
+ ast = parse_ast(proto, parse_numbers_expr)
+
+ if ipproto == 4:
+ return "if (!( " + ast.write("ip->protocol") + " )) break;"
+ else:
+ if PARSE_EXTHDR:
+ assert False # XXX: unimplemented
+ return "if (!( " + ast.write("ip6->nexthdr") + " )) break;"
+
+def icmp_type_to_rule(proto, ty):
+ ast = parse_ast(ty, parse_numbers_expr)
+ if proto == 4:
+ return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->type") + " )) break;"
+ else:
+ return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_type") + " )) break;"
+
+def icmp_code_to_rule(proto, code):
+ ast = parse_ast(code, parse_numbers_expr)
+ if proto == 4:
+ return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->code") + " )) break;"
+ else:
+ return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_code") + " )) break;"
+
+def dscp_to_rule(proto, rules):
+ ast = parse_ast(rules, parse_numbers_expr)
+
+ if proto == 4:
+ return "if (!( " + ast.write("((ip->tos & 0xfc) >> 2)") + " )) break;"
+ else:
+ return "if (!( " + ast.write("((ip6->priority << 4) | ((ip6->flow_lbl[0] & 0xc0) >> 4) >> 2)") + " )) break;"
+
+def port_to_rule(ty, rules):
+ if ty == "port" :
+ ast = parse_ast(rules, parse_numbers_expr)
+ return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write("sport", "dport") + " )) break;"
+
+ ast = parse_ast(rules, parse_numbers_expr)
+ return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write(ty) + " )) break;"
+
+def tcp_flags_to_rule(rules):
+ ast = parse_ast(rules, parse_bit_expr)
+
+ return f"""if (tcp == NULL) break;
+if (!( {ast.write("(ntohs(tcp->flags) & 0xfff)")} )) break;"""
+
+def flow_label_to_rule(rules):
+ ast = parse_ast(rules, parse_bit_expr)
+
+ return f"""if (ip6 == NULL) break;
+if (!( {ast.write("((((uint32_t)(ip6->flow_lbl[0] & 0xf)) << 2*8) | (((uint32_t)ip6->flow_lbl[1]) << 1*8) | (uint32_t)ip6->flow_lbl[0])")} )) break;"""
+
+with open("rules.h", "w") as out:
+ if len(sys.argv) > 1 and sys.argv[1] == "parse_8021q":
+ out.write("#define PARSE_8021Q\n")
+ if len(sys.argv) > 1 and sys.argv[1].startswith("req_8021q="):
+ out.write("#define PARSE_8021Q\n")
+ out.write(f"#define REQ_8021Q {sys.argv[1][10:]}\n")
+
+
+ out.write("#define RULES \\\n")
+
+ def write_rule(r):
+ out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n")
+
+ for line in sys.stdin.readlines():
+ t = line.split("{")
+ if len(t) != 2:
+ continue
+ if t[0].strip() == "flow4":
+ proto = 4
+ out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n")
+ out.write("\tdo {\\\n")
+ elif t[0].strip() == "flow6":
+ proto = 6
+ out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n")
+ out.write("\tdo {\\\n")
+ else:
+ continue
+
+ rule = t[1].split("}")[0].strip()
+ for step in rule.split(";"):
+ if step.strip().startswith("src") or step.strip().startswith("dst"):
+ nets = step.strip()[3:].strip().split(" ")
+ if len(nets) > 1:
+ assert nets[1] == "offset"
+ offset = nets[2]
+ else:
+ offset = None
+ if step.strip().startswith("src"):
+ write_rule(ip_to_rule(proto, nets[0], "saddr", offset))
+ else:
+ write_rule(ip_to_rule(proto, nets[0], "daddr", offset))
+ elif step.strip().startswith("proto") and proto == 4:
+ write_rule(proto_to_rule(4, step.strip()[6:]))
+ elif step.strip().startswith("next header") and proto == 6:
+ write_rule(proto_to_rule(6, step.strip()[12:]))
+ elif step.strip().startswith("icmp type"):
+ write_rule(icmp_type_to_rule(proto, step.strip()[10:]))
+ elif step.strip().startswith("icmp code"):
+ write_rule(icmp_code_to_rule(proto, step.strip()[10:]))
+ elif step.strip().startswith("sport") or step.strip().startswith("dport") or step.strip().startswith("port"):
+ write_rule(port_to_rule(step.strip().split(" ")[0], step.strip().split(" ", 1)[1]))
+ elif step.strip().startswith("length"):
+ write_rule(len_to_rule(step.strip()[7:]))
+ elif step.strip().startswith("dscp"):
+ write_rule(dscp_to_rule(proto, step.strip()[5:]))
+ elif step.strip().startswith("tcp flags"):
+ write_rule(tcp_flags_to_rule(step.strip()[10:]))
+ elif step.strip().startswith("label"):
+ write_rule(flow_label_to_rule(step.strip()[6:]))
+ elif step.strip().startswith("fragment"):
+ write_rule(fragment_to_rule(proto, step.strip()[9:]))
+ elif step.strip() == "":
+ pass
+ else:
+ assert False
+ out.write("\t\treturn XDP_DROP;\\\n")
+ out.write("\t} while(0);\\\n}\\\n")
+
+ out.write("\n")