Only parse v4/v6 if we have relevant rules for them
[flowspec-xdp] / genrules.py
1 #!/usr/bin/env python3
2
3 import sys
4 import ipaddress
5 from enum import Enum
6 import argparse
7
8
9 IP_PROTO_ICMP = 1
10 IP_PROTO_ICMPV6 = 58
11 IP_PROTO_TCP = 6
12 IP_PROTO_UDP = 17
13
14 class ASTAction(Enum):
15     OR = 1,
16     AND = 2,
17     NOT = 3,
18     EXPR = 4
19 class ASTNode:
20     def __init__(self, action, left, right=None):
21         self.action = action
22         self.left = left
23         if right is None:
24             assert action == ASTAction.EXPR or action == ASTAction.NOT
25         else:
26             self.right = right
27
28     def write(self, expr_param, expr_param2=None):
29         if self.action == ASTAction.OR:
30             return "(" + self.left.write(expr_param, expr_param2) + ") || (" + self.right.write(expr_param, expr_param2) + ")"
31         if self.action == ASTAction.AND:
32             return "(" + self.left.write(expr_param, expr_param2) + ") && (" + self.right.write(expr_param, expr_param2) + ")"
33         if self.action == ASTAction.NOT:
34             return "!(" + self.left.write(expr_param, expr_param2) + ")"
35         if self.action == ASTAction.EXPR:
36             return self.left.write(expr_param, expr_param2)
37
38 def parse_ast(expr, parse_expr):
39     expr = expr.strip()
40
41     and_split = expr.split("&&", 1)
42     or_split = expr.split("||", 1)
43     if len(and_split) > 1 and not "||" in and_split[0]:
44         return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr))
45     if len(or_split) > 1:
46         assert not "&&" in or_split[0]
47         return ASTNode(ASTAction.OR, parse_ast(or_split[0], parse_expr), parse_ast(or_split[1], parse_expr))
48
49     comma_split = expr.split(",", 1)
50     if len(comma_split) > 1:
51         return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr))
52
53     if expr.startswith("!"):
54         return ASTNode(ASTAction.NOT, parse_ast(expr[1:], parse_expr))
55
56     return parse_expr(expr)
57
58
59 class NumbersAction(Enum):
60     EQ = "=="
61     GT = ">"
62     GTOE = ">="
63     LT = "<"
64     LTOE = "<="
65 class NumbersExpr:
66     def __init__(self, action, val):
67         self.action = action
68         self.val = val
69
70     def write(self, param, param2):
71         if param2 is not None:
72             return "(" + param + self.action.value + self.val + ") || (" + param2 + self.action.value + self.val + ")"
73         return param + self.action.value + self.val
74
75 def parse_numbers_expr(expr):
76     space_split = expr.split(" ")
77     if expr.startswith(">="):
78         assert len(space_split) == 2
79         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, space_split[1]))
80     if expr.startswith(">"):
81         assert len(space_split) == 2
82         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GT, space_split[1]))
83     if expr.startswith("<="):
84         assert len(space_split) == 2
85         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, space_split[1]))
86     if expr.startswith("<"):
87         assert len(space_split) == 2
88         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LT, space_split[1]))
89     if ".." in expr:
90         rangesplit = expr.split("..")
91         assert len(rangesplit) == 2
92         #XXX: Are ranges really inclusive,inclusive?
93         left = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, rangesplit[0]))
94         right = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, rangesplit[1]))
95         return ASTNode(ASTAction.AND, left, right)
96     
97     if expr.startswith("= "):
98         expr = expr[2:]
99     return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.EQ, expr))
100
101 class FragExpr(Enum):
102     IF = 1
103     FF = 2
104     DF = 3
105     LF = 4
106
107     def write(self, ipproto, _param2):
108         if ipproto == 4:
109             if self == FragExpr.IF:
110                 return "(ip->frag_off & BE16(IP_MF|IP_OFFSET)) != 0"
111             elif self == FragExpr.FF:
112                 return "((ip->frag_off & BE16(IP_MF)) != 0 && (ip->frag_off & BE16(IP_OFFSET)) == 0)"
113             elif self == FragExpr.DF:
114                 return "(ip->frag_off & BE16(IP_DF)) != 0"
115             elif self == FragExpr.LF:
116                 return "((ip->frag_off & BE16(IP_MF)) == 0 && (ip->frag_off & BE16(IP_OFFSET)) != 0)"
117             else:
118                 assert False
119         else:
120             if self == FragExpr.IF:
121                 return "frag6 != NULL"
122             elif self == FragExpr.FF:
123                 return "(frag6 != NULL && (frag6->frag_off & BE16(IP6_MF)) != 0 && (frag6->frag_off & BE16(IP6_FRAGOFF)) == 0)"
124             elif self == FragExpr.DF:
125                 assert False # No such thing in v6
126             elif self == FragExpr.LF:
127                 return "(frag6 != NULL && (frag6->frag_off & BE16(IP6_MF)) == 0 && (frag6->frag_off & BE16(IP6_FRAGOFF)) != 0)"
128             else:
129                 assert False
130
131 def parse_frag_expr(expr):
132     if expr == "is_fragment":
133         return ASTNode(ASTAction.EXPR, FragExpr.IF)
134     elif expr == "first_fragment":
135         return ASTNode(ASTAction.EXPR, FragExpr.FF)
136     elif expr == "dont_fragment":
137         return ASTNode(ASTAction.EXPR, FragExpr.DF)
138     elif expr == "last_fragment":
139         return ASTNode(ASTAction.EXPR, FragExpr.LF)
140     else:
141         assert False
142
143 class BitExpr:
144     def __init__(self, val):
145         s = val.split("/")
146         assert len(s) == 2
147         self.match = s[0]
148         self.mask = s[1]
149
150     def write(self, param, _param2):
151         return f"({param} & {self.mask}) == {self.match}"
152
153 def parse_bit_expr(expr):
154     return ASTNode(ASTAction.EXPR, BitExpr(expr))
155
156
157 def ip_to_rule(proto, inip, ty, offset):
158     if proto == 4:
159         assert offset is None
160         net = ipaddress.IPv4Network(inip.strip())
161         if net.prefixlen == 0:
162             return ""
163         return f"""if ((ip->{ty} & MASK4({net.prefixlen})) != BIGEND32({int(net.network_address)}ULL))
164         break;"""
165     else:
166         net = ipaddress.IPv6Network(inip.strip())
167         if net.prefixlen == 0:
168             return ""
169         u32s = [(int(net.network_address) >> (3*32)) & 0xffffffff,
170                 (int(net.network_address) >> (2*32)) & 0xffffffff,
171                 (int(net.network_address) >> (1*32)) & 0xffffffff,
172                 (int(net.network_address) >> (0*32)) & 0xffffffff]
173         if offset is None:
174             mask = f"MASK6({net.prefixlen})"
175         else:
176             mask = f"MASK6_OFFS({offset}, {net.prefixlen})"
177         return f"""if ((ip6->{ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask}))
178         break;"""
179
180 def fragment_to_rule(ipproto, rules):
181     ast = parse_ast(rules, parse_frag_expr)
182     return "if (!( " + ast.write(ipproto) + " )) break;"
183
184 def len_to_rule(rules):
185     ast = parse_ast(rules, parse_numbers_expr)
186     return "if (!( " + ast.write("(data_end - pktdata)") + " )) break;"
187  
188 def proto_to_rule(ipproto, proto):
189     ast = parse_ast(proto, parse_numbers_expr)
190
191     if ipproto == 4:
192         return "if (!( " + ast.write("ip->protocol") + " )) break;"
193     else:
194         return "if (!( " + ast.write("ip6->nexthdr") + " )) break;"
195
196 def icmp_type_to_rule(proto, ty):
197     ast = parse_ast(ty, parse_numbers_expr)
198     if proto == 4:
199         return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->type") + " )) break;"
200     else:
201         return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_type") + " )) break;"
202
203 def icmp_code_to_rule(proto, code):
204     ast = parse_ast(code, parse_numbers_expr)
205     if proto == 4:
206         return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->code") + " )) break;"
207     else:
208         return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_code") + " )) break;"
209
210 def dscp_to_rule(proto, rules):
211     ast = parse_ast(rules, parse_numbers_expr)
212
213     if proto == 4:
214         return "if (!( " + ast.write("((ip->tos & 0xfc) >> 2)") + " )) break;"
215     else:
216         return "if (!( " + ast.write("((ip6->priority << 4) | ((ip6->flow_lbl[0] & 0xc0) >> 4) >> 2)") + " )) break;"
217
218 def port_to_rule(ty, rules):
219     if ty == "port" :
220         ast = parse_ast(rules, parse_numbers_expr)
221         return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write("sport", "dport") + " )) break;"
222
223     ast = parse_ast(rules, parse_numbers_expr)
224     return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write(ty) + " )) break;"
225
226 def tcp_flags_to_rule(rules):
227     ast = parse_ast(rules, parse_bit_expr)
228
229     return f"""if (tcp == NULL) break;
230 if (!( {ast.write("(ntohs(tcp->flags) & 0xfff)")} )) break;"""
231
232 def flow_label_to_rule(rules):
233     ast = parse_ast(rules, parse_bit_expr)
234
235     return f"""if (ip6 == NULL) break;
236 if (!( {ast.write("((((uint32_t)(ip6->flow_lbl[0] & 0xf)) << 2*8) | (((uint32_t)ip6->flow_lbl[1]) << 1*8) | (uint32_t)ip6->flow_lbl[0])")} )) break;"""
237
238 with open("rules.h", "w") as out:
239     parse = argparse.ArgumentParser()
240     parse.add_argument("--ihl", dest="ihl", required=True, choices=["drop-options","accept-options","parse-options"])
241     parse.add_argument("--8021q", dest="vlan", required=True, choices=["drop-vlan","accept-vlan","parse-vlan"])
242     parse.add_argument("--require-8021q", dest="vlan_tag")
243     args = parse.parse_args(sys.argv[1:])
244
245     if args.ihl == "drop-options":
246         out.write("#define PARSE_IHL XDP_DROP\n")
247     elif args.ihl == "accept-options":
248         out.write("#define PARSE_IHL XDP_PASS\n")
249     elif args.ihl == "parse-options":
250         out.write("#define PARSE_IHL PARSE\n")
251
252     if args.vlan == "drop-vlan":
253         out.write("#define PARSE_8021Q XDP_DROP\n")
254     elif args.vlan == "accept-vlan":
255         out.write("#define PARSE_8021Q XDP_PASS\n")
256     elif args.vlan == "parse-vlan":
257         out.write("#define PARSE_8021Q PARSE\n")
258
259     if args.vlan_tag is not None:
260         if args.vlan != "parse-vlan":
261             assert False
262         out.write("#define REQ_8021Q " + args.vlan_tag + "\n")
263
264     use_v4 = False
265     use_v6 = False
266
267     out.write("#define RULES \\\n")
268
269     def write_rule(r):
270         out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n")
271
272     for line in sys.stdin.readlines():
273         t = line.split("{")
274         if len(t) != 2:
275             continue
276         if t[0].strip() == "flow4":
277             proto = 4
278             use_v4 = True
279             out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n")
280             out.write("\tdo {\\\n")
281         elif t[0].strip() == "flow6":
282             proto = 6
283             use_v6 = True
284             out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n")
285             out.write("\tdo {\\\n")
286         else:
287             continue
288
289         rule = t[1].split("}")[0].strip()
290         for step in rule.split(";"):
291             if step.strip().startswith("src") or step.strip().startswith("dst"):
292                 nets = step.strip()[3:].strip().split(" ")
293                 if len(nets) > 1:
294                     assert nets[1] == "offset"
295                     offset = nets[2]
296                 else:
297                     offset = None
298                 if step.strip().startswith("src"):
299                     write_rule(ip_to_rule(proto, nets[0], "saddr", offset))
300                 else:
301                     write_rule(ip_to_rule(proto, nets[0], "daddr", offset))
302             elif step.strip().startswith("proto") and proto == 4:
303                 write_rule(proto_to_rule(4, step.strip()[6:]))
304             elif step.strip().startswith("next header") and proto == 6:
305                 write_rule(proto_to_rule(6, step.strip()[12:]))
306             elif step.strip().startswith("icmp type"):
307                 write_rule(icmp_type_to_rule(proto, step.strip()[10:]))
308             elif step.strip().startswith("icmp code"):
309                 write_rule(icmp_code_to_rule(proto, step.strip()[10:]))
310             elif step.strip().startswith("sport") or step.strip().startswith("dport") or step.strip().startswith("port"):
311                 write_rule(port_to_rule(step.strip().split(" ")[0], step.strip().split(" ", 1)[1]))
312             elif step.strip().startswith("length"):
313                 write_rule(len_to_rule(step.strip()[7:]))
314             elif step.strip().startswith("dscp"):
315                 write_rule(dscp_to_rule(proto, step.strip()[5:]))
316             elif step.strip().startswith("tcp flags"):
317                 write_rule(tcp_flags_to_rule(step.strip()[10:]))
318             elif step.strip().startswith("label"):
319                 write_rule(flow_label_to_rule(step.strip()[6:]))
320             elif step.strip().startswith("fragment"):
321                 write_rule(fragment_to_rule(proto, step.strip()[9:]))
322             elif step.strip() == "":
323                 pass
324             else:
325                 assert False
326         out.write("\t\treturn XDP_DROP;\\\n")
327         out.write("\t} while(0);\\\n}\\\n")
328
329     out.write("\n")
330     if use_v4:
331         out.write("#define NEED_V4_PARSE\n")
332     if use_v6:
333         out.write("#define NEED_V6_PARSE\n")