Remove the need for an ipip link to reduce packet size (duh, mtu)
authorMatt Corallo <git@bluematt.me>
Wed, 16 Mar 2022 04:53:27 +0000 (04:53 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 16 Mar 2022 21:20:38 +0000 (21:20 +0000)
main.cpp

index eff9b46a8f14ec4c43aebf23a167620004d51bb7..c5e73e2477a55cdad7c5600e67d8a019f426ec50 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -29,7 +29,8 @@ int getrandom(void* buf, size_t len, unsigned int flags) {
 }
 #endif
 
-#define PACKET_READ_SIZE 1500
+#define PACKET_READ_OFFS (20-16)
+#define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
 
 static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds)
 {
@@ -115,6 +116,20 @@ static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t ex
        return header_size;
 }
 
+#define UDP_PORTS 4242
+static int check_udp_header(const unsigned char* buf, ssize_t buf_len) {
+       if (buf_len < 8) {
+               return -1;
+       }
+       if (((((uint16_t)buf[0]) << 8) | (buf[1])) != UDP_PORTS) {
+               return -1;
+       }
+       if (((((uint16_t)buf[2]) << 8) | (buf[3])) != UDP_PORTS) {
+               return -1;
+       }
+       return 8;
+}
+
 void print_packet(const unsigned char* buf, ssize_t buf_len) {
        for (ssize_t i = 0; i < buf_len; ) {
                for (int j = 0; i < buf_len && j < 20; i++, j++)
@@ -216,7 +231,8 @@ static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int syna
        buf[17] = checksum >> 8;           // Checksum
 }
 
-static void build_ip_header(unsigned char* buf, uint32_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
+// len is data length, not including headers!
+static void build_ip_header(unsigned char* buf, uint16_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
        buf[0 ] = (4 << 4) | 5;    // IPv4 + IHL of 5 (20 bytes)
        buf[1 ] = 0;               // DSCP 0 + ECN 0
        buf[2 ] = (len + 20) >> 8; // Length
@@ -236,6 +252,20 @@ static void build_ip_header(unsigned char* buf, uint32_t len, uint8_t proto, con
        buf[11] = checksum >> 8;
 }
 
+// length is data length, not including headers
+static void build_udp_header(unsigned char* buf, uint16_t len) {
+       buf[0] = UDP_PORTS >> 8;
+       buf[1] = UDP_PORTS & 0xff;
+       buf[2] = UDP_PORTS >> 8;
+       buf[3] = UDP_PORTS & 0xff;
+       buf[4] = (len + 8) >> 8;
+       buf[5] = (len + 8);
+
+       // Checksum 0 for now
+       buf[6] = 0;
+       buf[7] = 0;
+}
+
 
 const signed char p_util_hexdigit[256] =
 { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
@@ -278,14 +308,15 @@ uint32_t hex_to_num(const unsigned char* buf) {
 static int fdr;
 static int fd[TUN_IF_COUNT];
 static struct sockaddr_in dest;
-static in_addr_t src, tun_src, tun_dest, ipip_src, ipip_dest;
+static in_addr_t src, tun_src, tun_dest;
 static uint64_t tcp_init_magic;
 
 #define PENDING_MESSAGES_BUFF_SIZE (0x3000)
-#define PACKET_READ_SIZE 1500
+#define PACKET_READ_OFFS (20-16)
+#define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
 #define THREAD_POLL_SLEEP_MICS 50
 struct MessageQueue {
-       std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
+       std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE + PACKET_READ_OFFS>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
        std::atomic<uint16_t> nextPendingMessage, nextUndefinedMessage;
        MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {}
        MessageQueue(MessageQueue&& q) =delete;
@@ -393,15 +424,13 @@ static void tcp_to_tun_queue_process() {
                                fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
                        }
                } else if (!syn) {
-                       tcp_buf += tcp_header_size - 20;
+                       tcp_buf += tcp_header_size - 20 - 8; // IPv4 header is 20 bytes + UDP header is 8 bytes
 
                        // Replace TCP with IPv4 header
-                       //build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x01, ipip_dest, ipip_src); // ICMP
-                       build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x11, ipip_dest, ipip_src); // UDP
-                       // Add IPIP header
-                       build_ip_header(tcp_buf - 20, nread - tcp_header_size - header_size + 20, 0x04, tun_dest, tun_src);
+                       build_udp_header(tcp_buf + 20, nread - tcp_header_size - header_size);
+                       build_ip_header(tcp_buf, nread - tcp_header_size - header_size + 8, 0x11, tun_dest, tun_src); // UDP
 
-                       write(fd[0], tcp_buf - 20, nread - tcp_header_size - header_size + 40);
+                       write(fd[0], tcp_buf, nread - tcp_header_size - header_size + 20 + 8);
                }
        } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
 }
@@ -430,7 +459,7 @@ static void tun_to_tcp() {
                        continue;
 
                auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage];
-               memcpy(std::get<1>(new_msg).data(), buf, nread);
+               memcpy(std::get<1>(new_msg).data() + PACKET_READ_OFFS, buf, nread);
                std::get<2>(new_msg) = nread;
 
                queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
@@ -453,24 +482,24 @@ static void tun_to_tcp_queue_process() {
                unsigned char* buf = std::get<1>(msg).data();
                ssize_t nread = std::get<2>(msg);
 
-               int header_size = check_ip_header(buf, nread, 0x04); // Only support IPIP
+               int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, 0x11); // Only support UDP
                if (header_size < 0)
                        continue;
 
-               int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x11); // Only support UDP
-               //int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x01); // Only support ICMP
-               if (internal_header_size < 0)
+               if (header_size + 8 > nread) {
+                       fprintf(stderr, "Short UDP packet\n");
                        continue;
+               }
 
-               if (internal_header_size + header_size + 8 > nread) {
-                       fprintf(stderr, "Short UDP-in-IPIP packet\n");
+               if (check_udp_header(buf + PACKET_READ_OFFS + header_size, nread - header_size) != 8) {
+                       fprintf(stderr, "Bad UDP header\n");
                        continue;
                }
 
-               size_t tcp_start_offset = header_size + internal_header_size - 20;
-               build_tcp_header(buf + tcp_start_offset, nread - tcp_start_offset - 20, 0, 0, src, dest.sin_addr.s_addr);
+               unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size + 8 - 20;
+               build_tcp_header(tcp_start, nread - header_size - 8, 0, 0, src, dest.sin_addr.s_addr);
 
-               ssize_t res = sendto(fdr, buf + tcp_start_offset, nread - tcp_start_offset, 0, (struct sockaddr*)&dest, sizeof(dest));
+               ssize_t res = sendto(fdr, tcp_start, nread + 20 - 8 - header_size, 0, (struct sockaddr*)&dest, sizeof(dest));
                if (res < 0) {
                        int err = errno;
                        fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err));
@@ -555,17 +584,15 @@ int main(int argc, char* argv[]) {
        assert(argc > 1 && "Need tun name");
        assert(argc > 2 && "Need tun remote host");
        assert(argc > 3 && "Need tun local host");
-       assert(argc > 4 && "Need ipip remote host");
-       assert(argc > 5 && "Need ipip local host");
-       assert(argc > 6 && "Need server port");
-       assert(argc > 7 && "Need shared secret");
-       assert(argc > 8 && "Need mode (client or server)");
-       assert(argc > 9 && "Need src host");
-       if (std::string(argv[8]) == std::string("client"))
-               assert(argc > 10 && "Need dest host");
-
-       assert(std::string(argv[8]) == std::string("client") || std::string(argv[8]) == std::string("server"));
-       are_server = (std::string(argv[8]) == std::string("server"));
+       assert(argc > 4 && "Need server port");
+       assert(argc > 5 && "Need shared secret");
+       assert(argc > 6 && "Need mode (client or server)");
+       assert(argc > 7 && "Need src host");
+       if (std::string(argv[6]) == std::string("client"))
+               assert(argc > 8 && "Need dest host");
+
+       assert(std::string(argv[6]) == std::string("client") || std::string(argv[6]) == std::string("server"));
+       are_server = (std::string(argv[6]) == std::string("server"));
 
        //
        // Parse args into variables
@@ -579,25 +606,22 @@ int main(int argc, char* argv[]) {
        tun_dest = inet_addr(argv[2]);
        tun_src = inet_addr(argv[3]);
 
-       ipip_dest = inet_addr(argv[4]);
-       ipip_src = inet_addr(argv[5]);
-
        if (are_server) {
-               local_port = atoi(argv[6]);
+               local_port = atoi(argv[4]);
                remote_port = 0;
        } else {
                // Get local port in do_init() so that we pick a new one on reload
-               remote_port = atoi(argv[6]);
+               remote_port = atoi(argv[4]);
        }
 
-       tcp_init_magic = atoll(argv[7]);
+       tcp_init_magic = atoll(argv[5]);
 
-       src = inet_addr(argv[9]);
+       src = inet_addr(argv[7]);
 
        memset(&dest, 0, sizeof(dest));
        if (!are_server) {
                dest.sin_family = AF_INET;
-               dest.sin_addr.s_addr = inet_addr(argv[10]);
+               dest.sin_addr.s_addr = inet_addr(argv[8]);
        }
 
        //