Help the BPF verifier somewhat by splitting v4 and v6 rules
authorMatt Corallo <git@bluematt.me>
Sun, 4 Apr 2021 20:31:34 +0000 (16:31 -0400)
committerMatt Corallo <git@bluematt.me>
Sun, 4 Apr 2021 20:35:27 +0000 (16:35 -0400)
genrules.py
xdp.c

index 6de93829c18529a10e2feb5c0dad285d4efa5c0b..edef6175e264f03f6edec910069c55ea92ae1c41 100755 (executable)
@@ -283,33 +283,31 @@ with open("rules.h", "w") as out:
             assert False
         out.write("#define REQ_8021Q " + args.vlan_tag + "\n")
 
             assert False
         out.write("#define REQ_8021Q " + args.vlan_tag + "\n")
 
-    use_v4 = False
-    use_v6 = False
+    rules6 = ""
+    rules4 = ""
     use_v6_frags = False
     rulecnt = 0
 
     use_v6_frags = False
     rulecnt = 0
 
-    out.write("#define RULES \\\n")
-
-    def write_rule(r):
-        out.write("\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n")
-
     for line in sys.stdin.readlines():
         t = line.split("{")
         if len(t) != 2:
             continue
         if t[0].strip() == "flow4":
             proto = 4
     for line in sys.stdin.readlines():
         t = line.split("{")
         if len(t) != 2:
             continue
         if t[0].strip() == "flow4":
             proto = 4
-            use_v4 = True
-            out.write("if (eth_proto == htons(ETH_P_IP)) { \\\n")
-            out.write("\tdo {\\\n")
+            rules4 += "\tdo {\\\n"
         elif t[0].strip() == "flow6":
             proto = 6
         elif t[0].strip() == "flow6":
             proto = 6
-            use_v6 = True
-            out.write("if (eth_proto == htons(ETH_P_IPV6)) { \\\n")
-            out.write("\tdo {\\\n")
+            rules6 += "\tdo {\\\n"
         else:
             continue
 
         else:
             continue
 
+        def write_rule(r):
+            global rules4, rules6
+            if proto == 6:
+                rules6 += "\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n"
+            else:
+                rules4 += "\t\t" + r.replace("\n", " \\\n\t\t") + " \\\n"
+
         rule = t[1].split("}")[0].strip()
         for step in rule.split(";"):
             if step.strip().startswith("src") or step.strip().startswith("dst"):
         rule = t[1].split("}")[0].strip()
         for step in rule.split(";"):
             if step.strip().startswith("src") or step.strip().startswith("dst"):
@@ -349,17 +347,22 @@ with open("rules.h", "w") as out:
                 pass
             else:
                 assert False
                 pass
             else:
                 assert False
-        out.write(f"\t\tconst uint32_t ruleidx = STATIC_RULE_CNT + {rulecnt};\\\n")
-        out.write("\t\tDO_RETURN(ruleidx, XDP_DROP);\\\n")
-        out.write("\t} while(0);\\\n}\\\n")
+        write_rule(f"const uint32_t ruleidx = STATIC_RULE_CNT + {rulecnt};")
+        write_rule("DO_RETURN(ruleidx, XDP_DROP);")
+        if proto == 6:
+            rules6 += "\t} while(0);\\\n"
+        else:
+            rules4 += "\t} while(0);\\\n"
         rulecnt += 1
 
     out.write("\n")
     out.write(f"#define RULECNT {rulecnt}\n")
         rulecnt += 1
 
     out.write("\n")
     out.write(f"#define RULECNT {rulecnt}\n")
-    if use_v4:
+    if rules4 != "":
         out.write("#define NEED_V4_PARSE\n")
         out.write("#define NEED_V4_PARSE\n")
-    if use_v6:
+        out.write("#define RULES4 {\\\n" + rules4 + "}\n")
+    if rules6:
         out.write("#define NEED_V6_PARSE\n")
         out.write("#define NEED_V6_PARSE\n")
+        out.write("#define RULES6 {\\\n" + rules6 + "}\n")
     if args.v6frag == "ignore-parse-if-rule":
         if use_v6_frags:
             out.write("#define PARSE_V6_FRAG PARSE\n")
     if args.v6frag == "ignore-parse-if-rule":
         if use_v6_frags:
             out.write("#define PARSE_V6_FRAG PARSE\n")
diff --git a/xdp.c b/xdp.c
index fc2198d8480dcfc0ab81f7252e70a794cd634c86..01feac657723f2e290a91f5d9721752626b5ee14 100644 (file)
--- a/xdp.c
+++ b/xdp.c
@@ -202,24 +202,15 @@ int xdp_drop_prog(struct xdp_md *ctx)
                }
        }
 
                }
        }
 
-#ifdef NEED_V4_PARSE
-       const struct iphdr *ip = NULL;
-       const struct icmphdr *icmp = NULL;
-#endif
-#ifdef NEED_V6_PARSE
-       const struct ip6hdr *ip6 = NULL;
-       const struct icmp6hdr *icmpv6 = NULL;
-       const struct ip6_fraghdr *frag6 = NULL;
-#endif
-
        const void *l4hdr = NULL;
        const struct tcphdr *tcp = NULL;
        const struct udphdr *udp = NULL;
        const void *l4hdr = NULL;
        const struct tcphdr *tcp = NULL;
        const struct udphdr *udp = NULL;
+       uint16_t sport, dport; // Host Endian! Only valid with tcp || udp
 
 #ifdef NEED_V4_PARSE
        if (eth_proto == BE16(ETH_P_IP)) {
                CHECK_LEN(pktdata, iphdr);
 
 #ifdef NEED_V4_PARSE
        if (eth_proto == BE16(ETH_P_IP)) {
                CHECK_LEN(pktdata, iphdr);
-               ip = (struct iphdr*) pktdata;
+               const struct iphdr *ip = (struct iphdr*) pktdata;
 
 #if PARSE_IHL == PARSE
                if (unlikely(ip->ihl < 5)) DO_RETURN(IHL_DROP, XDP_DROP);
 
 #if PARSE_IHL == PARSE
                if (unlikely(ip->ihl < 5)) DO_RETURN(IHL_DROP, XDP_DROP);
@@ -229,28 +220,36 @@ int xdp_drop_prog(struct xdp_md *ctx)
                l4hdr = pktdata + 5*4;
 #endif
 
                l4hdr = pktdata + 5*4;
 #endif
 
+               const struct icmphdr *icmp = NULL;
                if ((ip->frag_off & BE16(IP_OFFSET)) == 0) {
                        if (ip->protocol == IP_PROTO_TCP) {
                                CHECK_LEN(l4hdr, tcphdr);
                                tcp = (struct tcphdr*) l4hdr;
                if ((ip->frag_off & BE16(IP_OFFSET)) == 0) {
                        if (ip->protocol == IP_PROTO_TCP) {
                                CHECK_LEN(l4hdr, tcphdr);
                                tcp = (struct tcphdr*) l4hdr;
+                               sport = BE16(tcp->source);
+                               dport = BE16(tcp->dest);
                        } else if (ip->protocol == IP_PROTO_UDP) {
                                CHECK_LEN(l4hdr, udphdr);
                                udp = (struct udphdr*) l4hdr;
                        } else if (ip->protocol == IP_PROTO_UDP) {
                                CHECK_LEN(l4hdr, udphdr);
                                udp = (struct udphdr*) l4hdr;
+                               sport = BE16(udp->source);
+                               dport = BE16(udp->dest);
                        } else if (ip->protocol == IP_PROTO_ICMP) {
                                CHECK_LEN(l4hdr, icmphdr);
                                icmp = (struct icmphdr*) l4hdr;
                        }
                }
                        } else if (ip->protocol == IP_PROTO_ICMP) {
                                CHECK_LEN(l4hdr, icmphdr);
                                icmp = (struct icmphdr*) l4hdr;
                        }
                }
+
+               RULES4
        }
 #endif
 #ifdef NEED_V6_PARSE
        if (eth_proto == BE16(ETH_P_IPV6)) {
                CHECK_LEN(pktdata, ip6hdr);
        }
 #endif
 #ifdef NEED_V6_PARSE
        if (eth_proto == BE16(ETH_P_IPV6)) {
                CHECK_LEN(pktdata, ip6hdr);
-               ip6 = (struct ip6hdr*) pktdata;
+               const struct ip6hdr *ip6 = (struct ip6hdr*) pktdata;
 
                l4hdr = pktdata + 40;
 
                uint8_t v6nexthdr = ip6->nexthdr;
 
                l4hdr = pktdata + 40;
 
                uint8_t v6nexthdr = ip6->nexthdr;
+               const struct ip6_fraghdr *frag6 = NULL;
 #ifdef PARSE_V6_FRAG
 #if PARSE_V6_FRAG == PARSE
                if (ip6->nexthdr == IP6_PROTO_FRAG) {
 #ifdef PARSE_V6_FRAG
 #if PARSE_V6_FRAG == PARSE
                if (ip6->nexthdr == IP6_PROTO_FRAG) {
@@ -266,31 +265,27 @@ int xdp_drop_prog(struct xdp_md *ctx)
 #endif
                // TODO: Handle more options?
 
 #endif
                // TODO: Handle more options?
 
+               const struct icmp6hdr *icmpv6 = NULL;
                if (frag6 == NULL || (frag6->frag_off & BE16(IP6_FRAGOFF)) == 0) {
                        if (v6nexthdr == IP_PROTO_TCP) {
                                CHECK_LEN(l4hdr, tcphdr);
                                tcp = (struct tcphdr*) l4hdr;
                if (frag6 == NULL || (frag6->frag_off & BE16(IP6_FRAGOFF)) == 0) {
                        if (v6nexthdr == IP_PROTO_TCP) {
                                CHECK_LEN(l4hdr, tcphdr);
                                tcp = (struct tcphdr*) l4hdr;
+                               sport = BE16(tcp->source);
+                               dport = BE16(tcp->dest);
                        } else if (v6nexthdr == IP_PROTO_UDP) {
                                CHECK_LEN(l4hdr, udphdr);
                                udp = (struct udphdr*) l4hdr;
                        } else if (v6nexthdr == IP_PROTO_UDP) {
                                CHECK_LEN(l4hdr, udphdr);
                                udp = (struct udphdr*) l4hdr;
+                               sport = BE16(udp->source);
+                               dport = BE16(udp->dest);
                        } else if (v6nexthdr == IP6_PROTO_ICMPV6) {
                                CHECK_LEN(l4hdr, icmp6hdr);
                                icmpv6 = (struct icmp6hdr*) l4hdr;
                        }
                }
                        } else if (v6nexthdr == IP6_PROTO_ICMPV6) {
                                CHECK_LEN(l4hdr, icmp6hdr);
                                icmpv6 = (struct icmp6hdr*) l4hdr;
                        }
                }
-       }
-#endif
 
 
-       uint16_t sport, dport; // Host Endian! Only valid with tcp || udp
-       if (tcp != NULL) {
-               sport = BE16(tcp->source);
-               dport = BE16(tcp->dest);
-       } else if (udp != NULL) {
-               sport = BE16(udp->source);
-               dport = BE16(udp->dest);
+               RULES6
        }
        }
-
-       RULES
+#endif
 
        return XDP_PASS;
 }
 
        return XDP_PASS;
 }