X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=genrules.py;h=4468c1e00e52d58fb12c9d4769ae12c95c7d2704;hb=d4d0eb34912fd66bf9eb73d98ed57fd3def3336a;hp=40fc5c1f2d84969b959e5d30fdaad8c3c29cdf20;hpb=66e2db182f749072cac361b873dcc35d6377468e;p=flowspec-xdp diff --git a/genrules.py b/genrules.py index 40fc5c1..4468c1e 100755 --- a/genrules.py +++ b/genrules.py @@ -12,13 +12,18 @@ IP_PROTO_TCP = 6 IP_PROTO_UDP = 17 class ASTAction(Enum): - OR = 1, - AND = 2, - NOT = 3, - EXPR = 4 + OR = 1 + AND = 2 + NOT = 3 + FALSE = 4 + TRUE = 5 + EXPR = 6 class ASTNode: def __init__(self, action, left, right=None): self.action = action + if action == ASTAction.FALSE or action == ASTAction.TRUE: + assert left is None and right is None + return self.left = left if right is None: assert action == ASTAction.EXPR or action == ASTAction.NOT @@ -32,23 +37,32 @@ class ASTNode: return "(" + self.left.write(expr_param, expr_param2) + ") && (" + self.right.write(expr_param, expr_param2) + ")" if self.action == ASTAction.NOT: return "!(" + self.left.write(expr_param, expr_param2) + ")" + if self.action == ASTAction.FALSE: + return "0" + if self.action == ASTAction.TRUE: + return "1" if self.action == ASTAction.EXPR: return self.left.write(expr_param, expr_param2) def parse_ast(expr, parse_expr): expr = expr.strip() - and_split = expr.split("&&", 1) + comma_split = expr.split(",", 1) or_split = expr.split("||", 1) - if len(and_split) > 1 and not "||" in and_split[0]: - return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr)) + if len(comma_split) > 1 and not "||" in comma_split[0]: + return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr)) if len(or_split) > 1: - assert not "&&" in or_split[0] + assert not "," in or_split[0] return ASTNode(ASTAction.OR, parse_ast(or_split[0], parse_expr), parse_ast(or_split[1], parse_expr)) - comma_split = expr.split(",", 1) - if len(comma_split) > 1: - return ASTNode(ASTAction.OR, parse_ast(comma_split[0], parse_expr), parse_ast(comma_split[1], parse_expr)) + and_split = expr.split("&&", 1) + if len(and_split) > 1: + return ASTNode(ASTAction.AND, parse_ast(and_split[0], parse_expr), parse_ast(and_split[1], parse_expr)) + + if expr.strip() == "true": + return ASTNode(ASTAction.TRUE, None) + if expr.strip() == "false": + return ASTNode(ASTAction.FALSE, None) if expr.startswith("!"): return ASTNode(ASTAction.NOT, parse_ast(expr[1:], parse_expr)) @@ -218,10 +232,10 @@ def dscp_to_rule(proto, rules): def port_to_rule(ty, rules): if ty == "port" : ast = parse_ast(rules, parse_numbers_expr) - return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write("sport", "dport") + " )) break;" + return "if (!ports_valid) break;\nif (!( " + ast.write("sport", "dport") + " )) break;" ast = parse_ast(rules, parse_numbers_expr) - return "if (tcp == NULL && udp == NULL) break;\nif (!( " + ast.write(ty) + " )) break;" + return "if (!ports_valid) break;\nif (!( " + ast.write(ty) + " )) break;" def tcp_flags_to_rule(rules): ast = parse_ast(rules, parse_bit_expr) @@ -269,33 +283,31 @@ with open("rules.h", "w") as out: assert False out.write("#define REQ_8021Q " + args.vlan_tag + "\n") - use_v4 = False - use_v6 = False + rules6 = "" + rules4 = "" use_v6_frags = False rulecnt = 0 - out.write("#define RULES \\\n") - - def write_rule(r): - out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n") - for line in sys.stdin.readlines(): t = line.split("{") if len(t) != 2: continue if t[0].strip() == "flow4": proto = 4 - use_v4 = True - out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n") - out.write("\tdo {\\\n") + rules4 += "\tdo {\\\n" elif t[0].strip() == "flow6": proto = 6 - use_v6 = True - out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n") - out.write("\tdo {\\\n") + rules6 += "\tdo {\\\n" else: continue + def write_rule(r): + global rules4, rules6 + if proto == 6: + rules6 += "\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n" + else: + rules4 += "\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n" + rule = t[1].split("}")[0].strip() for step in rule.split(";"): if step.strip().startswith("src") or step.strip().startswith("dst"): @@ -335,17 +347,22 @@ with open("rules.h", "w") as out: pass else: assert False - out.write(f"\t\tconst uint32_t ruleidx = STATIC_RULE_CNT + {rulecnt};\\\n") - out.write("\t\tDO_RETURN(ruleidx, XDP_DROP);\\\n") - out.write("\t} while(0);\\\n}\\\n") + write_rule(f"const uint32_t ruleidx = STATIC_RULE_CNT + {rulecnt};") + write_rule("DO_RETURN(ruleidx, XDP_DROP);") + if proto == 6: + rules6 += "\t} while(0);\\\n" + else: + rules4 += "\t} while(0);\\\n" rulecnt += 1 out.write("\n") out.write(f"#define RULECNT {rulecnt}\n") - if use_v4: + if rules4 != "": out.write("#define NEED_V4_PARSE\n") - if use_v6: + out.write("#define RULES4 {\\\n" + rules4 + "}\n") + if rules6: out.write("#define NEED_V6_PARSE\n") + out.write("#define RULES6 {\\\n" + rules6 + "}\n") if args.v6frag == "ignore-parse-if-rule": if use_v6_frags: out.write("#define PARSE_V6_FRAG PARSE\n")