Allow matching on local and remote IPs
[bpfnofrags] / swapper.h
1 #include <stdint.h>
2 #include <unistd.h>
3 #include <endian.h>
4 #include <linux/if_ether.h>
5 #include <linux/ip.h>
6 #include <linux/udp.h>
7 #include <linux/bpf.h>
8 #include <bpf/bpf_helpers.h>
9
10 #if defined(__LITTLE_ENDIAN)
11 #define BE16(a) ((((uint16_t)(a & 0xff00)) >> 8) | (((uint16_t)(a & 0xff)) << 8))
12 #define IP32(a, b, c, d) (((((uint32_t)a) & 0xff) << 0*8) | \
13                           ((((uint32_t)b) & 0xff) << 1*8) | \
14                           ((((uint32_t)c) & 0xff) << 2*8) | \
15                           ((((uint32_t)d) & 0xff) << 3*8))
16 #elif defined(__BIG_ENDIAN)
17 #define BE16(a) ((uint16_t)a)
18 #define IP32(a, b, c, d) (((((uint32_t)a) & 0xff) << 3*8) | \
19                           ((((uint32_t)b) & 0xff) << 2*8) | \
20                           ((((uint32_t)c) & 0xff) << 1*8) | \
21                           ((((uint32_t)d) & 0xff) << 0*8))
22 #else
23 #error "Need endian info"
24 #endif
25
26 #include "ip_filter.h"
27
28 /* IP flags. */
29 #define IP_CE           0x8000          /* Flag: "Congestion"           */
30 #define IP_DF           0x4000          /* Flag: "Don't Fragment"       */
31 #define IP_MF           0x2000          /* Flag: "More Fragments"       */
32 #define IP_OFFSET       0x1FFF          /* "Fragment Offset" part       */
33
34 #define IP_PROTO_UDP 17
35 #define IP_PROTO_FIRST_FRAG 253
36 #define IP_PROTO_SECOND_FRAG 254
37
38 #define unlikely(a) __builtin_expect(a, 0)
39 #define likely(a) __builtin_expect(a, 1)
40
41 struct packet_count {
42         uint64_t packets;
43 };
44 struct {
45         __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
46         __uint(max_entries, 2);
47         __u32 *key;
48         struct packet_count *value;
49 } frag_count_map SEC(".maps");
50
51 #define INC_COUNTER(counter_id) do { \
52         const int reason = counter_id; \
53         struct packet_count *value = bpf_map_lookup_elem(&frag_count_map, &reason); \
54         if (value) { \
55                 value->packets += 1; \
56         } \
57 } while (0)
58
59 static inline void _maybe_swap_egress(struct iphdr *ip) {
60         if (unlikely(ip->ihl != 5)) return;
61
62         CHECK_LOCAL_REMOTE(ip->saddr, ip->daddr);
63
64         if (ip->protocol == IP_PROTO_UDP) {
65                 if (ip->frag_off == BE16(IP_MF)) {
66                         int32_t chk = ~BE16(ip->check) & 0xffff;
67                         chk = chk - IP_MF - IP_PROTO_UDP + IP_PROTO_FIRST_FRAG;
68                         // We're only decreasing the checksum here
69                         if (unlikely(chk < 0)) { chk += 65535; }
70                         ip->check = ~BE16(chk);
71
72                         ip->frag_off = 0;
73                         ip->protocol = IP_PROTO_FIRST_FRAG;
74                         INC_COUNTER(0);
75                 } else if (ip->frag_off == BE16(185)) {
76                         int32_t chk = ~BE16(ip->check) & 0xffff;
77                         chk = chk - 185 - IP_PROTO_UDP + IP_PROTO_SECOND_FRAG;
78                         // We're only increasing the checksum here
79                         if (unlikely(chk > 0xffff)) { chk -= 65535; }
80                         ip->check = ~BE16(chk);
81
82                         ip->frag_off = 0;
83                         ip->protocol = IP_PROTO_SECOND_FRAG;
84                         INC_COUNTER(1);
85                 }
86         }
87 }
88
89 static inline void _maybe_swap_ingress(struct iphdr *ip) {
90         if (unlikely(ip->ihl != 5)) return;
91
92         if (ip->protocol == IP_PROTO_SECOND_FRAG) {
93                 int32_t chk = ~BE16(ip->check) & 0xffff;
94                 chk = chk + 185 + IP_PROTO_UDP - IP_PROTO_SECOND_FRAG;
95                 // We're only decreasing the checksum here
96                 if (unlikely(chk < 0)) { chk += 65535; }
97                 ip->check = ~BE16(chk);
98
99                 ip->frag_off = BE16(185);
100                 ip->protocol = IP_PROTO_UDP;
101                 INC_COUNTER(0);
102         } else if (ip->protocol == IP_PROTO_FIRST_FRAG) {
103                 int32_t chk = ~BE16(ip->check) & 0xffff;
104                 chk = chk + IP_MF + IP_PROTO_UDP - IP_PROTO_FIRST_FRAG;
105                 // We're only increasing the checksum here
106                 if (unlikely(chk > 0xffff)) { chk -= 65535; }
107                 ip->check = ~BE16(chk);
108
109                 ip->frag_off = BE16(IP_MF);
110                 ip->protocol = IP_PROTO_UDP;
111                 INC_COUNTER(1);
112         }
113 }
114
115
116 // Our own ethhdr with optional vlan tags
117 struct _ethhdr_vlan {
118         unsigned char   h_dest[ETH_ALEN];       /* destination eth addr */
119         unsigned char   h_source[ETH_ALEN];     /* source ether addr    */
120         __be16          vlan_magic;             /* 0x8100 */
121         __be16          tci;            /* PCP (3 bits), DEI (1 bit), and VLAN (12 bits) */
122         __be16          h_proto;                /* packet type ID field */
123 } __attribute__((packed));
124
125 #define _CHECK_LEN_RETURN(start, struc) \
126         if (unlikely((void*)(start) + sizeof(struct struc) > data_end)) return;
127
128 #define _CHECK_ETH_TO_HEADER(swap_fn) \
129         void * pktdata; \
130         unsigned short eth_proto; \
131  \
132         _CHECK_LEN_RETURN(data, ethhdr); \
133         struct ethhdr *const eth = (void*)data; \
134         pktdata = (void *)data + sizeof(struct ethhdr); \
135  \
136         if (eth->h_proto == BE16(ETH_P_8021Q)) { \
137                 _CHECK_LEN_RETURN(data, _ethhdr_vlan); \
138                 struct _ethhdr_vlan *const eth_vlan = (void*)data; \
139                 pktdata = (void *)data + sizeof(struct _ethhdr_vlan); \
140                 eth_proto = eth_vlan->h_proto; \
141         } else { \
142                 eth_proto = eth->h_proto; \
143         } \
144 \
145         if (eth_proto == BE16(ETH_P_IP)) { \
146                 _CHECK_LEN_RETURN(pktdata, iphdr); \
147                 struct iphdr *ip = (struct iphdr*) pktdata; \
148                 swap_fn(ip); \
149         } else if (eth_proto == BE16(ETH_P_IPV6)) { \
150                 /* TODO: Support v6? */ \
151         }
152
153 static inline void maybe_swap_egress_eth(void *data, void* data_end) {
154         _CHECK_ETH_TO_HEADER(_maybe_swap_egress);
155 }
156 static inline void maybe_swap_ingress_eth(void *data, void* data_end) {
157         _CHECK_ETH_TO_HEADER(_maybe_swap_ingress);
158 }