Support v6 fragment parsing
[flowspec-xdp] / genrules.py
1 #!/usr/bin/env python3
2
3 import sys
4 import ipaddress
5 from enum import Enum
6
7 IP_PROTO_ICMP = 1
8 IP_PROTO_ICMPV6 = 58
9 IP_PROTO_TCP = 6
10 IP_PROTO_UDP = 17
11
12 if len(sys.argv) > 2 and sys.argv[2].startswith("parse_ihl"):
13     PARSE_IHL = True
14 else:
15     PARSE_IHL = False
16
17 class ASTAction(Enum):
18     OR = 1,
19     AND = 2,
20     NOT = 3,
21     EXPR = 4
22 class ASTNode:
23     def __init__(self, action, left, right=None):
24         self.action = action
25         self.left = left
26         if right is None:
27             assert action == ASTAction.EXPR or action == ASTAction.NOT
28         else:
29             self.right = right
30
31     def write(self, expr_param, expr_param2=None):
32         if self.action == ASTAction.OR:
33             return "(" + self.left.write(expr_param, expr_param2) + ") || (" + self.right.write(expr_param, expr_param2) + ")"
34         if self.action == ASTAction.AND:
35             return "(" + self.left.write(expr_param, expr_param2) + ") && (" + self.right.write(expr_param, expr_param2) + ")"
36         if self.action == ASTAction.NOT:
37             return "!(" + self.left.write(expr_param, expr_param2) + ")"
38         if self.action == ASTAction.EXPR:
39             return self.left.write(expr_param, expr_param2)
40
41 def parse_ast(expr, parse_expr):
42     expr = expr.strip()
43
44     and_split = expr.split("&&", 1)
45     or_split = expr.split("||", 1)
46     if len(and_split) > 1 and not "||" in and_split[0]:
47         return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr))
48     if len(or_split) > 1:
49         assert not "&&" in or_split[0]
50         return ASTNode(ASTAction.OR, parse_ast(or_split[0], parse_expr), parse_ast(or_split[1], parse_expr))
51
52     comma_split = expr.split(",", 1)
53     if len(comma_split) > 1:
54         return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr))
55
56     if expr.startswith("!"):
57         return ASTNode(ASTAction.NOT, parse_ast(expr[1:], parse_expr))
58
59     return parse_expr(expr)
60
61
62 class NumbersAction(Enum):
63     EQ = "=="
64     GT = ">"
65     GTOE = ">="
66     LT = "<"
67     LTOE = "<="
68 class NumbersExpr:
69     def __init__(self, action, val):
70         self.action = action
71         self.val = val
72
73     def write(self, param, param2):
74         if param2 is not None:
75             return "(" + param + self.action.value + self.val + ") || (" + param2 + self.action.value + self.val + ")"
76         return param + self.action.value + self.val
77
78 def parse_numbers_expr(expr):
79     space_split = expr.split(" ")
80     if expr.startswith(">="):
81         assert len(space_split) == 2
82         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, space_split[1]))
83     if expr.startswith(">"):
84         assert len(space_split) == 2
85         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GT, space_split[1]))
86     if expr.startswith("<="):
87         assert len(space_split) == 2
88         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, space_split[1]))
89     if expr.startswith("<"):
90         assert len(space_split) == 2
91         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LT, space_split[1]))
92     if ".." in expr:
93         rangesplit = expr.split("..")
94         assert len(rangesplit) == 2
95         #XXX: Are ranges really inclusive,inclusive?
96         left = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, rangesplit[0]))
97         right = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, rangesplit[1]))
98         return ASTNode(ASTAction.AND, left, right)
99     
100     if expr.startswith("= "):
101         expr = expr[2:]
102     return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.EQ, expr))
103
104 class FragExpr(Enum):
105     IF = 1
106     FF = 2
107     DF = 3
108     LF = 4
109
110     def write(self, ipproto, _param2):
111         if ipproto == 4:
112             if self == FragExpr.IF:
113                 return "(ip->frag_off & BE16(IP_MF|IP_OFFSET)) != 0"
114             elif self == FragExpr.FF:
115                 return "((ip->frag_off & BE16(IP_MF)) != 0 && (ip->frag_off & BE16(IP_OFFSET)) == 0)"
116             elif self == FragExpr.DF:
117                 return "(ip->frag_off & BE16(IP_DF)) != 0"
118             elif self == FragExpr.LF:
119                 return "((ip->frag_off & BE16(IP_MF)) == 0 && (ip->frag_off & BE16(IP_OFFSET)) != 0)"
120             else:
121                 assert False
122         else:
123             if self == FragExpr.IF:
124                 return "frag6 != NULL"
125             elif self == FragExpr.FF:
126                 return "(frag6 != NULL && (frag6->frag_off & BE16(IP6_MF)) != 0 && (frag6->frag_off & BE16(IP6_FRAGOFF)) == 0)"
127             elif self == FragExpr.DF:
128                 assert False # No such thing in v6
129             elif self == FragExpr.LF:
130                 return "(frag6 != NULL && (frag6->frag_off & BE16(IP6_MF)) == 0 && (frag6->frag_off & BE16(IP6_FRAGOFF)) != 0)"
131             else:
132                 assert False
133
134 def parse_frag_expr(expr):
135     if expr == "is_fragment":
136         return ASTNode(ASTAction.EXPR, FragExpr.IF)
137     elif expr == "first_fragment":
138         return ASTNode(ASTAction.EXPR, FragExpr.FF)
139     elif expr == "dont_fragment":
140         return ASTNode(ASTAction.EXPR, FragExpr.DF)
141     elif expr == "last_fragment":
142         return ASTNode(ASTAction.EXPR, FragExpr.LF)
143     else:
144         assert False
145
146 class BitExpr:
147     def __init__(self, val):
148         s = val.split("/")
149         assert len(s) == 2
150         self.match = s[0]
151         self.mask = s[1]
152
153     def write(self, param, _param2):
154         return f"({param} & {self.mask}) == {self.match}"
155
156 def parse_bit_expr(expr):
157     return ASTNode(ASTAction.EXPR, BitExpr(expr))
158
159
160 def ip_to_rule(proto, inip, ty, offset):
161     if proto == 4:
162         assert offset is None
163         net = ipaddress.IPv4Network(inip.strip())
164         if net.prefixlen == 0:
165             return ""
166         return f"""if ((ip->{ty} & MASK4({net.prefixlen})) != BIGEND32({int(net.network_address)}ULL))
167         break;"""
168     else:
169         net = ipaddress.IPv6Network(inip.strip())
170         if net.prefixlen == 0:
171             return ""
172         u32s = [(int(net.network_address) >> (3*32)) & 0xffffffff,
173                 (int(net.network_address) >> (2*32)) & 0xffffffff,
174                 (int(net.network_address) >> (1*32)) & 0xffffffff,
175                 (int(net.network_address) >> (0*32)) & 0xffffffff]
176         if offset is None:
177             mask = f"MASK6({net.prefixlen})"
178         else:
179             mask = f"MASK6_OFFS({offset}, {net.prefixlen})"
180         return f"""if ((ip6->{ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask}))
181         break;"""
182
183 def fragment_to_rule(ipproto, rules):
184     ast = parse_ast(rules, parse_frag_expr)
185     return "if (!( " + ast.write(ipproto) + " )) break;"
186
187 def len_to_rule(rules):
188     ast = parse_ast(rules, parse_numbers_expr)
189     return "if (!( " + ast.write("(data_end - pktdata)") + " )) break;"
190  
191 def proto_to_rule(ipproto, proto):
192     ast = parse_ast(proto, parse_numbers_expr)
193
194     if ipproto == 4:
195         return "if (!( " + ast.write("ip->protocol") + " )) break;"
196     else:
197         return "if (!( " + ast.write("ip6->nexthdr") + " )) break;"
198
199 def icmp_type_to_rule(proto, ty):
200     ast = parse_ast(ty, parse_numbers_expr)
201     if proto == 4:
202         return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->type") + " )) break;"
203     else:
204         return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_type") + " )) break;"
205
206 def icmp_code_to_rule(proto, code):
207     ast = parse_ast(code, parse_numbers_expr)
208     if proto == 4:
209         return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->code") + " )) break;"
210     else:
211         return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_code") + " )) break;"
212
213 def dscp_to_rule(proto, rules):
214     ast = parse_ast(rules, parse_numbers_expr)
215
216     if proto == 4:
217         return "if (!( " + ast.write("((ip->tos & 0xfc) >> 2)") + " )) break;"
218     else:
219         return "if (!( " + ast.write("((ip6->priority << 4) | ((ip6->flow_lbl[0] & 0xc0) >> 4) >> 2)") + " )) break;"
220
221 def port_to_rule(ty, rules):
222     if ty == "port" :
223         ast = parse_ast(rules, parse_numbers_expr)
224         return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write("sport", "dport") + " )) break;"
225
226     ast = parse_ast(rules, parse_numbers_expr)
227     return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write(ty) + " )) break;"
228
229 def tcp_flags_to_rule(rules):
230     ast = parse_ast(rules, parse_bit_expr)
231
232     return f"""if (tcp == NULL) break;
233 if (!( {ast.write("(ntohs(tcp->flags) & 0xfff)")} )) break;"""
234
235 def flow_label_to_rule(rules):
236     ast = parse_ast(rules, parse_bit_expr)
237
238     return f"""if (ip6 == NULL) break;
239 if (!( {ast.write("((((uint32_t)(ip6->flow_lbl[0] & 0xf)) << 2*8) | (((uint32_t)ip6->flow_lbl[1]) << 1*8) | (uint32_t)ip6->flow_lbl[0])")} )) break;"""
240
241 with open("rules.h", "w") as out:
242     if len(sys.argv) > 1 and sys.argv[1] == "parse_8021q":
243         out.write("#define PARSE_8021Q\n")
244     if len(sys.argv) > 1 and sys.argv[1].startswith("req_8021q="):
245         out.write("#define PARSE_8021Q\n")
246         out.write(f"#define REQ_8021Q {sys.argv[1][10:]}\n")
247
248
249     out.write("#define RULES \\\n")
250
251     def write_rule(r):
252         out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n")
253
254     for line in sys.stdin.readlines():
255         t = line.split("{")
256         if len(t) != 2:
257             continue
258         if t[0].strip() == "flow4":
259             proto = 4
260             out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n")
261             out.write("\tdo {\\\n")
262         elif t[0].strip() == "flow6":
263             proto = 6
264             out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n")
265             out.write("\tdo {\\\n")
266         else:
267             continue
268
269         rule = t[1].split("}")[0].strip()
270         for step in rule.split(";"):
271             if step.strip().startswith("src") or step.strip().startswith("dst"):
272                 nets = step.strip()[3:].strip().split(" ")
273                 if len(nets) > 1:
274                     assert nets[1] == "offset"
275                     offset = nets[2]
276                 else:
277                     offset = None
278                 if step.strip().startswith("src"):
279                     write_rule(ip_to_rule(proto, nets[0], "saddr", offset))
280                 else:
281                     write_rule(ip_to_rule(proto, nets[0], "daddr", offset))
282             elif step.strip().startswith("proto") and proto == 4:
283                 write_rule(proto_to_rule(4, step.strip()[6:]))
284             elif step.strip().startswith("next header") and proto == 6:
285                 write_rule(proto_to_rule(6, step.strip()[12:]))
286             elif step.strip().startswith("icmp type"):
287                 write_rule(icmp_type_to_rule(proto, step.strip()[10:]))
288             elif step.strip().startswith("icmp code"):
289                 write_rule(icmp_code_to_rule(proto, step.strip()[10:]))
290             elif step.strip().startswith("sport") or step.strip().startswith("dport") or step.strip().startswith("port"):
291                 write_rule(port_to_rule(step.strip().split(" ")[0], step.strip().split(" ", 1)[1]))
292             elif step.strip().startswith("length"):
293                 write_rule(len_to_rule(step.strip()[7:]))
294             elif step.strip().startswith("dscp"):
295                 write_rule(dscp_to_rule(proto, step.strip()[5:]))
296             elif step.strip().startswith("tcp flags"):
297                 write_rule(tcp_flags_to_rule(step.strip()[10:]))
298             elif step.strip().startswith("label"):
299                 write_rule(flow_label_to_rule(step.strip()[6:]))
300             elif step.strip().startswith("fragment"):
301                 write_rule(fragment_to_rule(proto, step.strip()[9:]))
302             elif step.strip() == "":
303                 pass
304             else:
305                 assert False
306         out.write("\t\treturn XDP_DROP;\\\n")
307         out.write("\t} while(0);\\\n}\\\n")
308
309     out.write("\n")