Use a custom header to support all packets and, importantly, frags
authorMatt Corallo <git@bluematt.me>
Wed, 16 Mar 2022 18:03:58 +0000 (18:03 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 16 Mar 2022 22:07:50 +0000 (22:07 +0000)
main.cpp

index e070706ae7546eb945f95543592e271c00245c08..c116e8994a27b53e6d0d3ec10a147b2704b26b73 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -29,7 +29,19 @@ int getrandom(void* buf, size_t len, unsigned int flags) {
 }
 #endif
 
-#define PACKET_READ_OFFS (20-16)
+struct pkt_hdr {
+       // Note that ordering of the first three uint16_ts must match the wire!
+       uint16_t len;
+       uint16_t id;
+       uint16_t frag_flags: 3,
+                offset    : 13;
+       uint8_t proto;
+       uint8_t padding;
+};
+static_assert(sizeof(struct pkt_hdr) == 8);
+
+// Have to leave room to add a TCP header plus above header (after dropping IP header)
+#define PACKET_READ_OFFS sizeof(pkt_hdr)
 #define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
 
 static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds)
@@ -62,7 +74,7 @@ static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int
                fds[i] = fd;
        }
 
-       sprintf(buf, "ip link set %s mtu %d", dev, PACKET_READ_SIZE);
+       sprintf(buf, "ip link set dev %s mtu %ld", dev, PACKET_READ_SIZE - 20);
        err = system(buf);
        if (err) goto err;
 
@@ -85,7 +97,7 @@ err:
        return err;
 }
 
-static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t expected_type) {
+static int check_ip_header(const unsigned char* buf, ssize_t buf_len, int16_t expected_type, struct pkt_hdr *gen_hdr) {
        if (buf_len < 20) {
                // < size than IPv4?
                return -1;
@@ -103,31 +115,23 @@ static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t ex
                return -1;
        }
 
-       if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
+       /*if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
                //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len);
                return -1;
-       }
+       }*/
 
-       if (buf[9] != expected_type) {
+       if (expected_type > 0 && buf[9] != expected_type) {
                fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type);
                return -1;
        }
 
-       return header_size;
-}
-
-#define UDP_PORTS 4242
-static int check_udp_header(const unsigned char* buf, ssize_t buf_len) {
-       if (buf_len < 8) {
-               return -1;
+       if (gen_hdr) {
+               memcpy(gen_hdr, buf + 2, 6); // len, id, frag offset
+               gen_hdr->proto = buf[9];
+               gen_hdr->padding = 0;
        }
-       if (((((uint16_t)buf[0]) << 8) | (buf[1])) != UDP_PORTS) {
-               return -1;
-       }
-       if (((((uint16_t)buf[2]) << 8) | (buf[3])) != UDP_PORTS) {
-               return -1;
-       }
-       return 8;
+
+       return header_size;
 }
 
 void print_packet(const unsigned char* buf, ssize_t buf_len) {
@@ -231,16 +235,12 @@ static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int syna
        buf[17] = checksum >> 8;           // Checksum
 }
 
-// len is data length, not including headers!
-static void build_ip_header(unsigned char* buf, uint16_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
+static void build_ip_header(unsigned char* buf, struct pkt_hdr hdr, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
        buf[0 ] = (4 << 4) | 5;    // IPv4 + IHL of 5 (20 bytes)
        buf[1 ] = 0;               // DSCP 0 + ECN 0
-       buf[2 ] = (len + 20) >> 8; // Length
-       buf[3 ] = (len + 20);      // Length
-       memset(buf + 4, 0, 4);     // Identification and Fragment 0s
-       buf[6 ] = 1 << 6;          // DF bit
+       memcpy(buf + 2, &hdr, 6);  // length, identification, flags, and offset
        buf[8 ] = 255;             // TTL
-       buf[9 ] = proto;           // Protocol Number
+       buf[9 ] = hdr.proto;       // Protocol Number
        buf[10] = 0;               // Checksum
        buf[11] = 0;               // Checksum
 
@@ -252,21 +252,6 @@ static void build_ip_header(unsigned char* buf, uint16_t len, uint8_t proto, con
        buf[11] = checksum >> 8;
 }
 
-// length is data length, not including headers
-static void build_udp_header(unsigned char* buf, uint16_t len) {
-       buf[0] = UDP_PORTS >> 8;
-       buf[1] = UDP_PORTS & 0xff;
-       buf[2] = UDP_PORTS >> 8;
-       buf[3] = UDP_PORTS & 0xff;
-       buf[4] = (len + 8) >> 8;
-       buf[5] = (len + 8);
-
-       // Checksum 0 for now
-       buf[6] = 0;
-       buf[7] = 0;
-}
-
-
 const signed char p_util_hexdigit[256] =
 { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
@@ -311,9 +296,7 @@ static struct sockaddr_in dest;
 static in_addr_t src, tun_src, tun_dest;
 static uint64_t tcp_init_magic;
 
-#define PENDING_MESSAGES_BUFF_SIZE (0x3000)
-#define PACKET_READ_OFFS (20-16)
-#define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
+#define PENDING_MESSAGES_BUFF_SIZE (0x800)
 #define THREAD_POLL_SLEEP_MICS 250
 struct MessageQueue {
        std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE + PACKET_READ_OFFS>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
@@ -327,7 +310,7 @@ static MessageQueue tcp_to_tun_queue;
 static std::chrono::steady_clock::time_point last_ack_recv;
 
 static void tcp_to_tun() {
-       unsigned char buf[1500];
+       unsigned char buf[PACKET_READ_SIZE + PACKET_READ_OFFS];
        struct sockaddr_in pkt_src;
        memset(&pkt_src, 0, sizeof(pkt_src));
 
@@ -363,7 +346,7 @@ static void tcp_to_tun_queue_process() {
                unsigned char* buf = std::get<1>(msg).data();
                ssize_t nread = std::get<2>(msg);
 
-               int header_size = check_ip_header(buf, nread, 0x06); // Only support TCP
+               int header_size = check_ip_header(buf, nread, 0x06, NULL); // Only support TCP
                if (header_size < 0)
                        continue;
 
@@ -424,11 +407,13 @@ static void tcp_to_tun_queue_process() {
                                fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
                        }
                } else if (!syn) {
-                       tcp_buf += tcp_header_size - 20 - 8; // IPv4 header is 20 bytes + UDP header is 8 bytes
+                       if (nread < (ssize_t)(tcp_header_size + header_size + sizeof(struct pkt_hdr))) continue;
+                       struct pkt_hdr hdr;
+                       memcpy(&hdr, tcp_buf + tcp_header_size, sizeof(struct pkt_hdr));
+                       tcp_buf += tcp_header_size + sizeof(struct pkt_hdr) - 20; // IPv4 header is 20 bytes
 
-                       // Replace TCP with IPv4 header
-                       build_udp_header(tcp_buf + 20, nread - tcp_header_size - header_size);
-                       build_ip_header(tcp_buf, nread - tcp_header_size - header_size + 8, 0x11, tun_dest, tun_src); // UDP
+                       // Replace TCP + pkt_hdr with IPv4 header
+                       build_ip_header(tcp_buf, hdr, tun_dest, tun_src);
 
                        write(fd[0], tcp_buf, nread - tcp_header_size - header_size + 20 + 8);
                }
@@ -482,24 +467,26 @@ static void tun_to_tcp_queue_process() {
                unsigned char* buf = std::get<1>(msg).data();
                ssize_t nread = std::get<2>(msg);
 
-               int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, 0x11); // Only support UDP
-               if (header_size < 0)
+               struct pkt_hdr hdr;
+               int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, -1, &hdr); // Any paket type is okay
+               if (header_size < 20)
                        continue;
 
-               if (header_size + 8 > nread) {
-                       fprintf(stderr, "Short UDP packet\n");
+               if (header_size > nread) {
+                       fprintf(stderr, "Short packet\n");
                        continue;
                }
 
-               if (check_udp_header(buf + PACKET_READ_OFFS + header_size, nread - header_size) != 8) {
-                       fprintf(stderr, "Bad UDP header\n");
+               if ((size_t)nread > 1500 - 20 - sizeof(struct pkt_hdr)) { // Packets must fit in 1500 bytes with a TCP and packet hdr
+                       fprintf(stderr, "Long packet\n");
                        continue;
                }
 
-               unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size + 8 - 20;
-               build_tcp_header(tcp_start, nread - header_size - 8, 0, 0, src, dest.sin_addr.s_addr);
+               unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size - 20 - sizeof(struct pkt_hdr);
+               memcpy(tcp_start + 20, &hdr, sizeof(struct pkt_hdr));
+               build_tcp_header(tcp_start, nread - header_size + sizeof(struct pkt_hdr), 0, 0, src, dest.sin_addr.s_addr);
 
-               ssize_t res = sendto(fdr, tcp_start, nread + 20 - 8 - header_size, 0, (struct sockaddr*)&dest, sizeof(dest));
+               ssize_t res = sendto(fdr, tcp_start, nread + 20 + sizeof(struct pkt_hdr) - header_size, 0, (struct sockaddr*)&dest, sizeof(dest));
                if (res < 0) {
                        int err = errno;
                        fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err));