12 if len(sys.argv) > 2 and sys.argv[2].startswith("parse_ihl"):
17 class ASTAction(Enum):
23 def __init__(self, action, left, right=None):
27 assert action == ASTAction.EXPR or action == ASTAction.NOT
31 def write(self, expr_param, expr_param2=None):
32 if self.action == ASTAction.OR:
33 return "(" + self.left.write(expr_param, expr_param2) + ") || (" + self.right.write(expr_param, expr_param2) + ")"
34 if self.action == ASTAction.AND:
35 return "(" + self.left.write(expr_param, expr_param2) + ") && (" + self.right.write(expr_param, expr_param2) + ")"
36 if self.action == ASTAction.NOT:
37 return "!(" + self.left.write(expr_param, expr_param2) + ")"
38 if self.action == ASTAction.EXPR:
39 return self.left.write(expr_param, expr_param2)
41 def parse_ast(expr, parse_expr):
44 and_split = expr.split("&&", 1)
45 or_split = expr.split("||", 1)
46 if len(and_split) > 1 and not "||" in and_split[0]:
47 return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr))
49 assert not "&&" in or_split[0]
50 return ASTNode(ASTAction.OR, parse_ast(or_split[0], parse_expr), parse_ast(or_split[1], parse_expr))
52 comma_split = expr.split(",", 1)
53 if len(comma_split) > 1:
54 return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr))
56 if expr.startswith("!"):
57 return ASTNode(ASTAction.NOT, parse_ast(expr[1:], parse_expr))
59 return parse_expr(expr)
62 class NumbersAction(Enum):
69 def __init__(self, action, val):
73 def write(self, param, param2):
74 if param2 is not None:
75 return "(" + param + self.action.value + self.val + ") || (" + param2 + self.action.value + self.val + ")"
76 return param + self.action.value + self.val
78 def parse_numbers_expr(expr):
79 space_split = expr.split(" ")
80 if expr.startswith(">="):
81 assert len(space_split) == 2
82 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, space_split[1]))
83 if expr.startswith(">"):
84 assert len(space_split) == 2
85 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GT, space_split[1]))
86 if expr.startswith("<="):
87 assert len(space_split) == 2
88 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, space_split[1]))
89 if expr.startswith("<"):
90 assert len(space_split) == 2
91 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LT, space_split[1]))
93 rangesplit = expr.split("..")
94 assert len(rangesplit) == 2
95 #XXX: Are ranges really inclusive,inclusive?
96 left = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, rangesplit[0]))
97 right = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, rangesplit[1]))
98 return ASTNode(ASTAction.AND, left, right)
100 if expr.startswith("= "):
102 return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.EQ, expr))
104 class FragExpr(Enum):
110 def write(self, ipproto, _param2):
112 if self == FragExpr.IF:
113 return "(ip->frag_off & BE16(IP_MF|IP_OFFSET)) != 0"
114 elif self == FragExpr.FF:
115 return "((ip->frag_off & BE16(IP_MF)) != 0 && (ip->frag_off & BE16(IP_OFFSET)) == 0)"
116 elif self == FragExpr.DF:
117 return "(ip->frag_off & BE16(IP_DF)) != 0"
118 elif self == FragExpr.LF:
119 return "((ip->frag_off & BE16(IP_MF)) == 0 && (ip->frag_off & BE16(IP_OFFSET)) != 0)"
123 if self == FragExpr.IF:
124 return "frag6 != NULL"
125 elif self == FragExpr.FF:
126 return "(frag6 != NULL && (frag6->frag_off & BE16(IP6_MF)) != 0 && (frag6->frag_off & BE16(IP6_FRAGOFF)) == 0)"
127 elif self == FragExpr.DF:
128 assert False # No such thing in v6
129 elif self == FragExpr.LF:
130 return "(frag6 != NULL && (frag6->frag_off & BE16(IP6_MF)) == 0 && (frag6->frag_off & BE16(IP6_FRAGOFF)) != 0)"
134 def parse_frag_expr(expr):
135 if expr == "is_fragment":
136 return ASTNode(ASTAction.EXPR, FragExpr.IF)
137 elif expr == "first_fragment":
138 return ASTNode(ASTAction.EXPR, FragExpr.FF)
139 elif expr == "dont_fragment":
140 return ASTNode(ASTAction.EXPR, FragExpr.DF)
141 elif expr == "last_fragment":
142 return ASTNode(ASTAction.EXPR, FragExpr.LF)
147 def __init__(self, val):
153 def write(self, param, _param2):
154 return f"({param} & {self.mask}) == {self.match}"
156 def parse_bit_expr(expr):
157 return ASTNode(ASTAction.EXPR, BitExpr(expr))
160 def ip_to_rule(proto, inip, ty, offset):
162 assert offset is None
163 net = ipaddress.IPv4Network(inip.strip())
164 if net.prefixlen == 0:
166 return f"""if ((ip->{ty} & MASK4({net.prefixlen})) != BIGEND32({int(net.network_address)}ULL))
169 net = ipaddress.IPv6Network(inip.strip())
170 if net.prefixlen == 0:
172 u32s = [(int(net.network_address) >> (3*32)) & 0xffffffff,
173 (int(net.network_address) >> (2*32)) & 0xffffffff,
174 (int(net.network_address) >> (1*32)) & 0xffffffff,
175 (int(net.network_address) >> (0*32)) & 0xffffffff]
177 mask = f"MASK6({net.prefixlen})"
179 mask = f"MASK6_OFFS({offset}, {net.prefixlen})"
180 return f"""if ((ip6->{ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask}))
183 def fragment_to_rule(ipproto, rules):
184 ast = parse_ast(rules, parse_frag_expr)
185 return "if (!( " + ast.write(ipproto) + " )) break;"
187 def len_to_rule(rules):
188 ast = parse_ast(rules, parse_numbers_expr)
189 return "if (!( " + ast.write("(data_end - pktdata)") + " )) break;"
191 def proto_to_rule(ipproto, proto):
192 ast = parse_ast(proto, parse_numbers_expr)
195 return "if (!( " + ast.write("ip->protocol") + " )) break;"
197 return "if (!( " + ast.write("ip6->nexthdr") + " )) break;"
199 def icmp_type_to_rule(proto, ty):
200 ast = parse_ast(ty, parse_numbers_expr)
202 return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->type") + " )) break;"
204 return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_type") + " )) break;"
206 def icmp_code_to_rule(proto, code):
207 ast = parse_ast(code, parse_numbers_expr)
209 return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->code") + " )) break;"
211 return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_code") + " )) break;"
213 def dscp_to_rule(proto, rules):
214 ast = parse_ast(rules, parse_numbers_expr)
217 return "if (!( " + ast.write("((ip->tos & 0xfc) >> 2)") + " )) break;"
219 return "if (!( " + ast.write("((ip6->priority << 4) | ((ip6->flow_lbl[0] & 0xc0) >> 4) >> 2)") + " )) break;"
221 def port_to_rule(ty, rules):
223 ast = parse_ast(rules, parse_numbers_expr)
224 return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write("sport", "dport") + " )) break;"
226 ast = parse_ast(rules, parse_numbers_expr)
227 return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write(ty) + " )) break;"
229 def tcp_flags_to_rule(rules):
230 ast = parse_ast(rules, parse_bit_expr)
232 return f"""if (tcp == NULL) break;
233 if (!( {ast.write("(ntohs(tcp->flags) & 0xfff)")} )) break;"""
235 def flow_label_to_rule(rules):
236 ast = parse_ast(rules, parse_bit_expr)
238 return f"""if (ip6 == NULL) break;
239 if (!( {ast.write("((((uint32_t)(ip6->flow_lbl[0] & 0xf)) << 2*8) | (((uint32_t)ip6->flow_lbl[1]) << 1*8) | (uint32_t)ip6->flow_lbl[0])")} )) break;"""
241 with open("rules.h", "w") as out:
242 if len(sys.argv) > 1 and sys.argv[1] == "parse_8021q":
243 out.write("#define PARSE_8021Q\n")
244 if len(sys.argv) > 1 and sys.argv[1].startswith("req_8021q="):
245 out.write("#define PARSE_8021Q\n")
246 out.write(f"#define REQ_8021Q {sys.argv[1][10:]}\n")
249 out.write("#define RULES \\\n")
252 out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n")
254 for line in sys.stdin.readlines():
258 if t[0].strip() == "flow4":
260 out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n")
261 out.write("\tdo {\\\n")
262 elif t[0].strip() == "flow6":
264 out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n")
265 out.write("\tdo {\\\n")
269 rule = t[1].split("}")[0].strip()
270 for step in rule.split(";"):
271 if step.strip().startswith("src") or step.strip().startswith("dst"):
272 nets = step.strip()[3:].strip().split(" ")
274 assert nets[1] == "offset"
278 if step.strip().startswith("src"):
279 write_rule(ip_to_rule(proto, nets[0], "saddr", offset))
281 write_rule(ip_to_rule(proto, nets[0], "daddr", offset))
282 elif step.strip().startswith("proto") and proto == 4:
283 write_rule(proto_to_rule(4, step.strip()[6:]))
284 elif step.strip().startswith("next header") and proto == 6:
285 write_rule(proto_to_rule(6, step.strip()[12:]))
286 elif step.strip().startswith("icmp type"):
287 write_rule(icmp_type_to_rule(proto, step.strip()[10:]))
288 elif step.strip().startswith("icmp code"):
289 write_rule(icmp_code_to_rule(proto, step.strip()[10:]))
290 elif step.strip().startswith("sport") or step.strip().startswith("dport") or step.strip().startswith("port"):
291 write_rule(port_to_rule(step.strip().split(" ")[0], step.strip().split(" ", 1)[1]))
292 elif step.strip().startswith("length"):
293 write_rule(len_to_rule(step.strip()[7:]))
294 elif step.strip().startswith("dscp"):
295 write_rule(dscp_to_rule(proto, step.strip()[5:]))
296 elif step.strip().startswith("tcp flags"):
297 write_rule(tcp_flags_to_rule(step.strip()[10:]))
298 elif step.strip().startswith("label"):
299 write_rule(flow_label_to_rule(step.strip()[6:]))
300 elif step.strip().startswith("fragment"):
301 write_rule(fragment_to_rule(proto, step.strip()[9:]))
302 elif step.strip() == "":
306 out.write("\t\treturn XDP_DROP;\\\n")
307 out.write("\t} while(0);\\\n}\\\n")