Initial checkin
[flowspec-xdp] / genrules.py
1 #!/usr/bin/env python3
2
3 import sys
4 import ipaddress
5 from enum import Enum
6
7 IP_PROTO_ICMP = 1
8 IP_PROTO_ICMPV6 = 58
9 IP_PROTO_TCP = 6
10 IP_PROTO_UDP = 17
11
12 if len(sys.argv) > 2 and sys.argv[2].startswith("parse_ihl"):
13     PARSE_IHL = True
14 else:
15     PARSE_IHL = False
16 if len(sys.argv) > 3 and sys.argv[3].startswith("parse_exthdr"):
17     PARSE_EXTHDR = True
18 else:
19     PARSE_EXTHDR = False
20
21
22 class ASTAction(Enum):
23     OR = 1,
24     AND = 2,
25     NOT = 3,
26     EXPR = 4
27 class ASTNode:
28     def __init__(self, action, left, right=None):
29         self.action = action
30         self.left = left
31         if right is None:
32             assert action == ASTAction.EXPR or action == ASTAction.NOT
33         else:
34             self.right = right
35
36     def write(self, expr_param, expr_param2=None):
37         if self.action == ASTAction.OR:
38             return "(" + self.left.write(expr_param, expr_param2) + ") || (" + self.right.write(expr_param, expr_param2) + ")"
39         if self.action == ASTAction.AND:
40             return "(" + self.left.write(expr_param, expr_param2) + ") && (" + self.right.write(expr_param, expr_param2) + ")"
41         if self.action == ASTAction.NOT:
42             return "!(" + self.left.write(expr_param, expr_param2) + ")"
43         if self.action == ASTAction.EXPR:
44             return self.left.write(expr_param, expr_param2)
45
46 def parse_ast(expr, parse_expr):
47     expr = expr.strip()
48
49     and_split = expr.split("&&", 1)
50     or_split = expr.split("||", 1)
51     if len(and_split) > 1 and not "||" in and_split[0]:
52         return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr))
53     if len(or_split) > 1:
54         assert not "&&" in or_split[0]
55         return ASTNode(ASTAction.OR, parse_ast(or_split[0], parse_expr), parse_ast(or_split[1], parse_expr))
56
57     comma_split = expr.split(",", 1)
58     if len(comma_split) > 1:
59         return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr))
60
61     if expr.startswith("!"):
62         return ASTNode(ASTAction.NOT, parse_ast(expr[1:], parse_expr))
63
64     return parse_expr(expr)
65
66
67 class NumbersAction(Enum):
68     EQ = "=="
69     GT = ">"
70     GTOE = ">="
71     LT = "<"
72     LTOE = "<="
73 class NumbersExpr:
74     def __init__(self, action, val):
75         self.action = action
76         self.val = val
77
78     def write(self, param, param2):
79         if param2 is not None:
80             return "(" + param + self.action.value + self.val + ") || (" + param2 + self.action.value + self.val + ")"
81         return param + self.action.value + self.val
82
83 def parse_numbers_expr(expr):
84     space_split = expr.split(" ")
85     if expr.startswith(">="):
86         assert len(space_split) == 2
87         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, space_split[1]))
88     if expr.startswith(">"):
89         assert len(space_split) == 2
90         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GT, space_split[1]))
91     if expr.startswith("<="):
92         assert len(space_split) == 2
93         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, space_split[1]))
94     if expr.startswith("<"):
95         assert len(space_split) == 2
96         return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LT, space_split[1]))
97     if ".." in expr:
98         rangesplit = expr.split("..")
99         assert len(rangesplit) == 2
100         #XXX: Are ranges really inclusive,inclusive?
101         left = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.GTOE, rangesplit[0]))
102         right = ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.LTOE, rangesplit[1]))
103         return ASTNode(ASTAction.AND, left, right)
104     
105     if expr.startswith("= "):
106         expr = expr[2:]
107     return ASTNode(ASTAction.EXPR, NumbersExpr(NumbersAction.EQ, expr))
108
109 class FragExpr:
110     def __init__(self, val):
111         if val == "is_fragment":
112             self.rule = "(ip->frag_off & BE16(IP_MF|IP_OFFSET)) != 0"
113         elif val == "first_fragment":
114             self.rule = "(ip->frag_off & BE16(IP_MF)) != 0 && (ip->frag_off & BE16(IP_OFFSET)) == 0"
115         elif val == "dont_fragment":
116             self.rule = "(ip->frag_off & BE16(IP_DF)) != 0"
117         elif val == "last_fragment":
118             self.rule = "(ip->frag_off & BE16(IP_MF)) == 0 && (ip->frag_off & BE16(IP_OFFSET)) != 0"
119         else:
120             assert False
121
122     def write(self, _param, _param2):
123         return self.rule
124
125 def parse_frag_expr(expr):
126     return ASTNode(ASTAction.EXPR, FragExpr(expr))
127
128 class BitExpr:
129     def __init__(self, val):
130         s = val.split("/")
131         assert len(s) == 2
132         self.match = s[0]
133         self.mask = s[1]
134
135     def write(self, param, _param2):
136         return f"({param} & {self.mask}) == {self.match}"
137
138 def parse_bit_expr(expr):
139     return ASTNode(ASTAction.EXPR, BitExpr(expr))
140
141
142 def ip_to_rule(proto, inip, ty, offset):
143     if proto == 4:
144         assert offset is None
145         net = ipaddress.IPv4Network(inip.strip())
146         if net.prefixlen == 0:
147             return ""
148         return f"""if ((ip->{ty} & MASK4({net.prefixlen})) != BIGEND32({int(net.network_address)}ULL))
149         break;"""
150     else:
151         net = ipaddress.IPv6Network(inip.strip())
152         if net.prefixlen == 0:
153             return ""
154         u32s = [(int(net.network_address) >> (3*32)) & 0xffffffff,
155                 (int(net.network_address) >> (2*32)) & 0xffffffff,
156                 (int(net.network_address) >> (1*32)) & 0xffffffff,
157                 (int(net.network_address) >> (0*32)) & 0xffffffff]
158         if offset is None:
159             mask = f"MASK6({net.prefixlen})"
160         else:
161             mask = f"MASK6_OFFS({offset}, {net.prefixlen})"
162         return f"""if ((ip6->{ty} & {mask}) != (BIGEND128({u32s[0]}ULL, {u32s[1]}ULL, {u32s[2]}ULL, {u32s[3]}ULL) & {mask}))
163         break;"""
164
165 def fragment_to_rule(ipproto, rules):
166     if ipproto == 6:
167         assert False # XXX: unimplemented
168     ast = parse_ast(rules, parse_frag_expr)
169     return "if (!( " + ast.write(()) + " )) break;"
170
171 def len_to_rule(rules):
172     ast = parse_ast(rules, parse_numbers_expr)
173     return "if (!( " + ast.write("(data_end - pktdata)") + " )) break;"
174  
175 def proto_to_rule(ipproto, proto):
176     ast = parse_ast(proto, parse_numbers_expr)
177
178     if ipproto == 4:
179         return "if (!( " + ast.write("ip->protocol") + " )) break;"
180     else:
181         if PARSE_EXTHDR:
182             assert False # XXX: unimplemented
183         return "if (!( " + ast.write("ip6->nexthdr") + " )) break;"
184
185 def icmp_type_to_rule(proto, ty):
186     ast = parse_ast(ty, parse_numbers_expr)
187     if proto == 4:
188         return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->type") + " )) break;"
189     else:
190         return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_type") + " )) break;"
191
192 def icmp_code_to_rule(proto, code):
193     ast = parse_ast(code, parse_numbers_expr)
194     if proto == 4:
195         return "if (icmp == NULL) break;\nif (!( " + ast.write("icmp->code") + " )) break;"
196     else:
197         return "if (icmpv6 == NULL) break;\nif (!( " + ast.write("icmpv6->icmp6_code") + " )) break;"
198
199 def dscp_to_rule(proto, rules):
200     ast = parse_ast(rules, parse_numbers_expr)
201
202     if proto == 4:
203         return "if (!( " + ast.write("((ip->tos & 0xfc) >> 2)") + " )) break;"
204     else:
205         return "if (!( " + ast.write("((ip6->priority << 4) | ((ip6->flow_lbl[0] & 0xc0) >> 4) >> 2)") + " )) break;"
206
207 def port_to_rule(ty, rules):
208     if ty == "port" :
209         ast = parse_ast(rules, parse_numbers_expr)
210         return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write("sport", "dport") + " )) break;"
211
212     ast = parse_ast(rules, parse_numbers_expr)
213     return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write(ty) + " )) break;"
214
215 def tcp_flags_to_rule(rules):
216     ast = parse_ast(rules, parse_bit_expr)
217
218     return f"""if (tcp == NULL) break;
219 if (!( {ast.write("(ntohs(tcp->flags) & 0xfff)")} )) break;"""
220
221 def flow_label_to_rule(rules):
222     ast = parse_ast(rules, parse_bit_expr)
223
224     return f"""if (ip6 == NULL) break;
225 if (!( {ast.write("((((uint32_t)(ip6->flow_lbl[0] & 0xf)) << 2*8) | (((uint32_t)ip6->flow_lbl[1]) << 1*8) | (uint32_t)ip6->flow_lbl[0])")} )) break;"""
226
227 with open("rules.h", "w") as out:
228     if len(sys.argv) > 1 and sys.argv[1] == "parse_8021q":
229         out.write("#define PARSE_8021Q\n")
230     if len(sys.argv) > 1 and sys.argv[1].startswith("req_8021q="):
231         out.write("#define PARSE_8021Q\n")
232         out.write(f"#define REQ_8021Q {sys.argv[1][10:]}\n")
233
234
235     out.write("#define RULES \\\n")
236
237     def write_rule(r):
238         out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n")
239
240     for line in sys.stdin.readlines():
241         t = line.split("{")
242         if len(t) != 2:
243             continue
244         if t[0].strip() == "flow4":
245             proto = 4
246             out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n")
247             out.write("\tdo {\\\n")
248         elif t[0].strip() == "flow6":
249             proto = 6
250             out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n")
251             out.write("\tdo {\\\n")
252         else:
253             continue
254
255         rule = t[1].split("}")[0].strip()
256         for step in rule.split(";"):
257             if step.strip().startswith("src") or step.strip().startswith("dst"):
258                 nets = step.strip()[3:].strip().split(" ")
259                 if len(nets) > 1:
260                     assert nets[1] == "offset"
261                     offset = nets[2]
262                 else:
263                     offset = None
264                 if step.strip().startswith("src"):
265                     write_rule(ip_to_rule(proto, nets[0], "saddr", offset))
266                 else:
267                     write_rule(ip_to_rule(proto, nets[0], "daddr", offset))
268             elif step.strip().startswith("proto") and proto == 4:
269                 write_rule(proto_to_rule(4, step.strip()[6:]))
270             elif step.strip().startswith("next header") and proto == 6:
271                 write_rule(proto_to_rule(6, step.strip()[12:]))
272             elif step.strip().startswith("icmp type"):
273                 write_rule(icmp_type_to_rule(proto, step.strip()[10:]))
274             elif step.strip().startswith("icmp code"):
275                 write_rule(icmp_code_to_rule(proto, step.strip()[10:]))
276             elif step.strip().startswith("sport") or step.strip().startswith("dport") or step.strip().startswith("port"):
277                 write_rule(port_to_rule(step.strip().split(" ")[0], step.strip().split(" ", 1)[1]))
278             elif step.strip().startswith("length"):
279                 write_rule(len_to_rule(step.strip()[7:]))
280             elif step.strip().startswith("dscp"):
281                 write_rule(dscp_to_rule(proto, step.strip()[5:]))
282             elif step.strip().startswith("tcp flags"):
283                 write_rule(tcp_flags_to_rule(step.strip()[10:]))
284             elif step.strip().startswith("label"):
285                 write_rule(flow_label_to_rule(step.strip()[6:]))
286             elif step.strip().startswith("fragment"):
287                 write_rule(fragment_to_rule(proto, step.strip()[9:]))
288             elif step.strip() == "":
289                 pass
290             else:
291                 assert False
292         out.write("\t\treturn XDP_DROP;\\\n")
293         out.write("\t} while(0);\\\n}\\\n")
294
295     out.write("\n")