14 class ASTAction(Enum):
20 def __init__(self, action, left, right=None):
24 assert action == ASTAction.EXPR or action == ASTAction.NOT
28 def write(self, expr_param, expr_param2=None):
29 if self.action == ASTAction.OR:
30 return "(" + self.left.write(expr_param, expr_param2) + ") || (" + self.right.write(expr_param, expr_param2) + ")"
31 if self.action == ASTAction.AND:
32 return "(" + self.left.write(expr_param, expr_param2) + ") && (" + self.right.write(expr_param, expr_param2) + ")"
33 if self.action == ASTAction.NOT:
34 return "!(" + self.left.write(expr_param, expr_param2) + ")"
35 if self.action == ASTAction.EXPR:
36 return self.left.write(expr_param, expr_param2)
38 def parse_ast(expr, parse_expr):
41 and_split = expr.split("&&", 1)
42 or_split = expr.split("||", 1)
43 if len(and_split) > 1 and not "||" in and_split[0]:
44 return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr))
46 assert not "&&" in or_split[0]
47 return ASTNode(ASTAction.OR, parse_ast(or_split[0], parse_expr), parse_ast(or_split[1], parse_expr))
49 comma_split = expr.split(",", 1)
50 if len(comma_split) > 1:
51 return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr))
53 if expr.startswith("!"):
54 return ASTNode(ASTAction.NOT, parse_ast(expr[1:], parse_expr))
56 return parse_expr(expr)
59 class NumbersAction(Enum):
66 def __init__(self, action, val):
70 def write(self, param, param2):
71 if param2 is not None:
72 return "(" + param + self.action.value + self.val + ") || (" + param2 + self.action.value + self.val + ")"
73 return param + self.action.value + self.val
75 def parse_numbers_expr(expr):
76 space_split = expr.split(" ")
77 if expr.startswith(">="):
78 assert len(space_split) == 2
79 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, space_split[1]))
80 if expr.startswith(">"):
81 assert len(space_split) == 2
82 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GT, space_split[1]))
83 if expr.startswith("<="):
84 assert len(space_split) == 2
85 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, space_split[1]))
86 if expr.startswith("<"):
87 assert len(space_split) == 2
88 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LT, space_split[1]))
90 rangesplit = expr.split("..")
91 assert len(rangesplit) == 2
92 #XXX: Are ranges really inclusive,inclusive?
93 left = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, rangesplit[0]))
94 right = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, rangesplit[1]))
95 return ASTNode(ASTAction.AND, left, right)
97 if expr.startswith("= "):
99 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.EQ, expr))
101 class FragExpr(Enum):
107 def write(self, ipproto, _param2):
109 if self == FragExpr.IF:
110 return "(ip->frag_off & BE16(IP_MF|IP_OFFSET)) != 0"
111 elif self == FragExpr.FF:
112 return "((ip->frag_off & BE16(IP_MF)) != 0 && (ip->frag_off & BE16(IP_OFFSET)) == 0)"
113 elif self == FragExpr.DF:
114 return "(ip->frag_off & BE16(IP_DF)) != 0"
115 elif self == FragExpr.LF:
116 return "((ip->frag_off & BE16(IP_MF)) == 0 && (ip->frag_off & BE16(IP_OFFSET)) != 0)"
120 if self == FragExpr.IF:
121 return "frag6 != NULL"
122 elif self == FragExpr.FF:
123 return "(frag6 != NULL && (frag6->frag_off & BE16(IP6_MF)) != 0 && (frag6->frag_off & BE16(IP6_FRAGOFF)) == 0)"
124 elif self == FragExpr.DF:
125 assert False # No such thing in v6
126 elif self == FragExpr.LF:
127 return "(frag6 != NULL && (frag6->frag_off & BE16(IP6_MF)) == 0 && (frag6->frag_off & BE16(IP6_FRAGOFF)) != 0)"
131 def parse_frag_expr(expr):
132 if expr == "is_fragment":
133 return ASTNode(ASTAction.EXPR, FragExpr.IF)
134 elif expr == "first_fragment":
135 return ASTNode(ASTAction.EXPR, FragExpr.FF)
136 elif expr == "dont_fragment":
137 return ASTNode(ASTAction.EXPR, FragExpr.DF)
138 elif expr == "last_fragment":
139 return ASTNode(ASTAction.EXPR, FragExpr.LF)
144 def __init__(self, val):
150 def write(self, param, _param2):
151 return f"({param} & {self.mask}) == {self.match}"
153 def parse_bit_expr(expr):
154 return ASTNode(ASTAction.EXPR, BitExpr(expr))
157 def ip_to_rule(proto, inip, ty, offset):
159 assert offset is None
160 net = ipaddress.IPv4Network(inip.strip())
161 if net.prefixlen == 0:
163 return f"""if ((ip->{ty} & MASK4({net.prefixlen})) != BIGEND32({int(net.network_address)}ULL))
166 net = ipaddress.IPv6Network(inip.strip())
167 if net.prefixlen == 0:
169 u32s = [(int(net.network_address) >> (3*32)) & 0xffffffff,
170 (int(net.network_address) >> (2*32)) & 0xffffffff,
171 (int(net.network_address) >> (1*32)) & 0xffffffff,
172 (int(net.network_address) >> (0*32)) & 0xffffffff]
174 mask = f"MASK6({net.prefixlen})"
176 mask = f"MASK6_OFFS({offset}, {net.prefixlen})"
177 return f"""if ((ip6->{ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask}))
180 def fragment_to_rule(ipproto, rules):
181 ast = parse_ast(rules, parse_frag_expr)
182 return "if (!( " + ast.write(ipproto) + " )) break;"
184 def len_to_rule(rules):
185 ast = parse_ast(rules, parse_numbers_expr)
186 return "if (!( " + ast.write("(data_end - pktdata)") + " )) break;"
188 def proto_to_rule(ipproto, proto):
189 ast = parse_ast(proto, parse_numbers_expr)
192 return "if (!( " + ast.write("ip->protocol") + " )) break;"
194 return "if (!( " + ast.write("ip6->nexthdr") + " )) break;"
196 def icmp_type_to_rule(proto, ty):
197 ast = parse_ast(ty, parse_numbers_expr)
199 return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->type") + " )) break;"
201 return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_type") + " )) break;"
203 def icmp_code_to_rule(proto, code):
204 ast = parse_ast(code, parse_numbers_expr)
206 return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->code") + " )) break;"
208 return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_code") + " )) break;"
210 def dscp_to_rule(proto, rules):
211 ast = parse_ast(rules, parse_numbers_expr)
214 return "if (!( " + ast.write("((ip->tos & 0xfc) >> 2)") + " )) break;"
216 return "if (!( " + ast.write("((ip6->priority << 4) | ((ip6->flow_lbl[0] & 0xc0) >> 4) >> 2)") + " )) break;"
218 def port_to_rule(ty, rules):
220 ast = parse_ast(rules, parse_numbers_expr)
221 return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write("sport", "dport") + " )) break;"
223 ast = parse_ast(rules, parse_numbers_expr)
224 return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write(ty) + " )) break;"
226 def tcp_flags_to_rule(rules):
227 ast = parse_ast(rules, parse_bit_expr)
229 return f"""if (tcp == NULL) break;
230 if (!( {ast.write("(ntohs(tcp->flags) & 0xfff)")} )) break;"""
232 def flow_label_to_rule(rules):
233 ast = parse_ast(rules, parse_bit_expr)
235 return f"""if (ip6 == NULL) break;
236 if (!( {ast.write("((((uint32_t)(ip6->flow_lbl[0] & 0xf)) << 2*8) | (((uint32_t)ip6->flow_lbl[1]) << 1*8) | (uint32_t)ip6->flow_lbl[0])")} )) break;"""
238 with open("rules.h", "w") as out:
239 parse = argparse.ArgumentParser()
240 parse.add_argument("--ihl", dest="ihl", required=True, choices=["drop-options","accept-options","parse-options"])
241 parse.add_argument("--8021q", dest="vlan", required=True, choices=["drop-vlan","accept-vlan","parse-vlan"])
242 parse.add_argument("--require-8021q", dest="vlan_tag")
243 args = parse.parse_args(sys.argv[1:])
245 if args.ihl == "drop-options":
246 out.write("#define PARSE_IHL XDP_DROP\n")
247 elif args.ihl == "accept-options":
248 out.write("#define PARSE_IHL XDP_PASS\n")
249 elif args.ihl == "parse-options":
250 out.write("#define PARSE_IHL PARSE\n")
252 if args.vlan == "drop-vlan":
253 out.write("#define PARSE_8021Q XDP_DROP\n")
254 elif args.vlan == "accept-vlan":
255 out.write("#define PARSE_8021Q XDP_PASS\n")
256 elif args.vlan == "parse-vlan":
257 out.write("#define PARSE_8021Q PARSE\n")
259 if args.vlan_tag is not None:
260 if args.vlan != "parse-vlan":
262 out.write("#define REQ_8021Q " + args.vlan_tag + "\n")
267 out.write("#define RULES \\\n")
270 out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n")
272 for line in sys.stdin.readlines():
276 if t[0].strip() == "flow4":
279 out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n")
280 out.write("\tdo {\\\n")
281 elif t[0].strip() == "flow6":
284 out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n")
285 out.write("\tdo {\\\n")
289 rule = t[1].split("}")[0].strip()
290 for step in rule.split(";"):
291 if step.strip().startswith("src") or step.strip().startswith("dst"):
292 nets = step.strip()[3:].strip().split(" ")
294 assert nets[1] == "offset"
298 if step.strip().startswith("src"):
299 write_rule(ip_to_rule(proto, nets[0], "saddr", offset))
301 write_rule(ip_to_rule(proto, nets[0], "daddr", offset))
302 elif step.strip().startswith("proto") and proto == 4:
303 write_rule(proto_to_rule(4, step.strip()[6:]))
304 elif step.strip().startswith("next header") and proto == 6:
305 write_rule(proto_to_rule(6, step.strip()[12:]))
306 elif step.strip().startswith("icmp type"):
307 write_rule(icmp_type_to_rule(proto, step.strip()[10:]))
308 elif step.strip().startswith("icmp code"):
309 write_rule(icmp_code_to_rule(proto, step.strip()[10:]))
310 elif step.strip().startswith("sport") or step.strip().startswith("dport") or step.strip().startswith("port"):
311 write_rule(port_to_rule(step.strip().split(" ")[0], step.strip().split(" ", 1)[1]))
312 elif step.strip().startswith("length"):
313 write_rule(len_to_rule(step.strip()[7:]))
314 elif step.strip().startswith("dscp"):
315 write_rule(dscp_to_rule(proto, step.strip()[5:]))
316 elif step.strip().startswith("tcp flags"):
317 write_rule(tcp_flags_to_rule(step.strip()[10:]))
318 elif step.strip().startswith("label"):
319 write_rule(flow_label_to_rule(step.strip()[6:]))
320 elif step.strip().startswith("fragment"):
321 write_rule(fragment_to_rule(proto, step.strip()[9:]))
322 elif step.strip() == "":
326 out.write("\t\treturn XDP_DROP;\\\n")
327 out.write("\t} while(0);\\\n}\\\n")
331 out.write("#define NEED_V4_PARSE\n")
333 out.write("#define NEED_V6_PARSE\n")