Add missing <array> inclusion, which is needed for modern libstd++
[tunudptotcp] / main.cpp
index 9f0173f3bb5e125ddfe040fa5537957ec46afd0f..7c7cbed15fc503faf126d2f1cccafe06ca9acfb8 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -1,6 +1,5 @@
 #include <fcntl.h>
 #include <string.h>
-#include <stropts.h>
 #include <stdio.h>
 #include <unistd.h>
 #include <sys/socket.h>
@@ -9,13 +8,16 @@
 #include <errno.h>
 #include <arpa/inet.h>
 #include <stdint.h>
+#include <stdlib.h>
 #include <linux/if.h>
 #include <linux/if_tun.h>
 #include <assert.h>
 
+#include <array>
 #include <atomic>
 #include <chrono>
 #include <thread>
+#include <string>
 
 #if __has_include(<sys/random.h>)
 #include <sys/random.h>
@@ -28,11 +30,26 @@ int getrandom(void* buf, size_t len, unsigned int flags) {
 }
 #endif
 
+struct pkt_hdr {
+       // Note that ordering of the first three uint16_ts must match the wire!
+       uint16_t len;
+       uint16_t id;
+       uint16_t frag_flags: 3,
+                offset    : 13;
+       uint8_t proto;
+       uint8_t padding;
+};
+static_assert(sizeof(struct pkt_hdr) == 8);
+
+// Have to leave room to add a TCP header plus above header (after dropping IP header)
+#define PACKET_READ_OFFS sizeof(pkt_hdr)
+#define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
 
-static int tun_alloc(char *dev, int queues, int *fds)
+static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds)
 {
        struct ifreq ifr;
        int fd, err, i;
+       char buf[1024];
 
        if (!dev)
                return -1;
@@ -57,6 +74,23 @@ static int tun_alloc(char *dev, int queues, int *fds)
                }
                fds[i] = fd;
        }
+
+       sprintf(buf, "ip link set dev %s mtu %ld", dev, PACKET_READ_SIZE - 20);
+       err = system(buf);
+       if (err) goto err;
+
+       sprintf(buf, "ip addr add %s/32 dev %s", local_ip, dev);
+       err = system(buf);
+       if (err) goto err;
+
+       sprintf(buf, "ip link set %s up", dev);
+       err = system(buf);
+       if (err) goto err;
+
+       sprintf(buf, "ip route add %s/32 dev %s", remote_ip, dev);
+       err = system(buf);
+       if (err) goto err;
+
        return 0;
 err:
        for (--i; i >= 0; i--)
@@ -64,7 +98,7 @@ err:
        return err;
 }
 
-static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t expected_type) {
+static int check_ip_header(const unsigned char* buf, ssize_t buf_len, int16_t expected_type, struct pkt_hdr *gen_hdr) {
        if (buf_len < 20) {
                // < size than IPv4?
                return -1;
@@ -82,16 +116,22 @@ static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t ex
                return -1;
        }
 
-       if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
+       /*if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
                //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len);
                return -1;
-       }
+       }*/
 
-       if (buf[9] != expected_type) {
+       if (expected_type > 0 && buf[9] != expected_type) {
                fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type);
                return -1;
        }
 
+       if (gen_hdr) {
+               memcpy(gen_hdr, buf + 2, 6); // len, id, frag offset
+               gen_hdr->proto = buf[9];
+               gen_hdr->padding = 0;
+       }
+
        return header_size;
 }
 
@@ -147,8 +187,9 @@ static std::atomic<uint32_t> highest_recvd_seq(0), cur_seq(0);
 static std::atomic<uint16_t> local_port(0);
 static uint16_t remote_port;
 static bool are_server;
+static uint64_t timestamps_magic;
 
-static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
+static int build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
        buf[0 ] = local_port >> 8;         // src port
        buf[1 ] = local_port;              // src port
        buf[2 ] = remote_port >> 8;        // dst port
@@ -166,8 +207,8 @@ static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int syna
        buf[10] = their_seq >> (8 * 1);    // ACK
        buf[11] = their_seq >> (8 * 0);    // ACK
 
-       bool longpkt = syn || synack;
-       buf[12] = (longpkt ? 6 : 5) << 4;  // data offset
+       unsigned char hdrlen = syn ? 36 : (synack ? 24 : 20);
+       buf[12] = (hdrlen/4) << 4;  // data offset
        if (syn)
                buf[13] = 1 << 1;              // SYN
        else if (synack)
@@ -184,27 +225,33 @@ static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int syna
        buf[18] = 0x00;                    // URG Pointer
        buf[19] = 0x00;                    // URG Pointer
 
-       if (longpkt) {
+       if (syn || synack) {
        buf[20] = 0x01;                    // NOP
        buf[21] = 0x03;                    // Window Scale
        buf[22] = 0x03;                    // Window Scale Option Length
        buf[23] = 0x0e;                    // 1GB Window Size (0xffff << 0x0e)
        }
+       if (syn) {
+       buf[24] = 0x01;                    // NOP
+       buf[25] = 0x01;                    // NOP
+       buf[26] = 8;                       // Timestamp
+       buf[27] = 10;                      // Timestamp Option Length
+       memcpy(buf + 28, &timestamps_magic, 8);
+       }
 
-       uint16_t checksum = tcp_checksum(buf, len + 20 + (longpkt ? 4 : 0), src_addr, dest_addr);
+       uint16_t checksum = tcp_checksum(buf, len + hdrlen, src_addr, dest_addr);
        buf[16] = checksum;                // Checksum
        buf[17] = checksum >> 8;           // Checksum
+
+       return hdrlen;
 }
 
-static void build_ip_header(unsigned char* buf, uint32_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
+static void build_ip_header(unsigned char* buf, struct pkt_hdr hdr, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
        buf[0 ] = (4 << 4) | 5;    // IPv4 + IHL of 5 (20 bytes)
        buf[1 ] = 0;               // DSCP 0 + ECN 0
-       buf[2 ] = (len + 20) >> 8; // Length
-       buf[3 ] = (len + 20);      // Length
-       memset(buf + 4, 0, 4);     // Identification and Fragment 0s
-       buf[6 ] = 1 << 6;          // DF bit
+       memcpy(buf + 2, &hdr, 6);  // length, identification, flags, and offset
        buf[8 ] = 255;             // TTL
-       buf[9 ] = proto;           // Protocol Number
+       buf[9 ] = hdr.proto;       // Protocol Number
        buf[10] = 0;               // Checksum
        buf[11] = 0;               // Checksum
 
@@ -216,7 +263,6 @@ static void build_ip_header(unsigned char* buf, uint32_t len, uint8_t proto, con
        buf[11] = checksum >> 8;
 }
 
-
 const signed char p_util_hexdigit[256] =
 { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
@@ -235,8 +281,8 @@ const signed char p_util_hexdigit[256] =
   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, };
 
-uint32_t hex_to_num(const char* buf) {
-       const char* pbegin = buf;
+uint32_t hex_to_num(const unsigned char* buf) {
+       const unsigned char* pbegin = buf;
        while (p_util_hexdigit[*buf] != -1)
                buf++;
        buf--;
@@ -258,14 +304,13 @@ uint32_t hex_to_num(const char* buf) {
 static int fdr;
 static int fd[TUN_IF_COUNT];
 static struct sockaddr_in dest;
-static in_addr_t src, tun_src, tun_dest, ipip_src, ipip_dest;
-static const uint64_t tcp_init_magic = 0x1badcafedeadbeefULL;
+static in_addr_t src, tun_src, tun_dest;
+static uint32_t starting_ack;
 
-#define PENDING_MESSAGES_BUFF_SIZE (0x3000)
-#define PACKET_READ_SIZE 1500
-#define THREAD_POLL_SLEEP_MICS 50
+#define PENDING_MESSAGES_BUFF_SIZE (0x800)
+#define THREAD_POLL_SLEEP_MICS 250
 struct MessageQueue {
-       std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
+       std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE + PACKET_READ_OFFS>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
        std::atomic<uint16_t> nextPendingMessage, nextUndefinedMessage;
        MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {}
        MessageQueue(MessageQueue&& q) =delete;
@@ -276,7 +321,7 @@ static MessageQueue tcp_to_tun_queue;
 static std::chrono::steady_clock::time_point last_ack_recv;
 
 static void tcp_to_tun() {
-       unsigned char buf[1500];
+       unsigned char buf[PACKET_READ_SIZE + PACKET_READ_OFFS];
        struct sockaddr_in pkt_src;
        memset(&pkt_src, 0, sizeof(pkt_src));
 
@@ -312,7 +357,7 @@ static void tcp_to_tun_queue_process() {
                unsigned char* buf = std::get<1>(msg).data();
                ssize_t nread = std::get<2>(msg);
 
-               int header_size = check_ip_header(buf, nread, 0x06); // Only support TCP
+               int header_size = check_ip_header(buf, nread, 0x06, NULL); // Only support TCP
                if (header_size < 0)
                        continue;
 
@@ -329,16 +374,31 @@ static void tcp_to_tun_queue_process() {
                bool ack = tcp_buf[13] & (1 << 4);
 
                if (are_server && syn && !ack) {
-                       // We're a server and just got a client
-                       if (tcp_buf[4 ] != uint8_t(tcp_init_magic >> (7 * 8)) ||
-                           tcp_buf[5 ] != uint8_t(tcp_init_magic >> (6 * 8)) ||
-                           tcp_buf[6 ] != uint8_t(tcp_init_magic >> (5 * 8)) ||
-                           tcp_buf[7 ] != uint8_t(tcp_init_magic >> (4 * 8)) ||
-                           tcp_buf[8 ] != uint8_t(tcp_init_magic >> (3 * 8)) ||
-                           tcp_buf[9 ] != uint8_t(tcp_init_magic >> (2 * 8)) ||
-                           tcp_buf[10] != uint8_t(tcp_init_magic >> (1 * 8)) ||
-                           tcp_buf[11] != uint8_t(tcp_init_magic >> (0 * 8)))
+                       uint32_t expected_ack = htobe32(starting_ack);
+                       if (memcmp(tcp_buf + 8, &expected_ack, 4))
                                continue;
+                       if (nread < 36 + header_size) { continue; }
+                       // We're a server and just got a client...walk options until we find timestamps
+                       const unsigned char* opt_buf = tcp_buf + 20;
+                       bool found_magic = false;
+                       while (!found_magic && opt_buf < buf + nread) {
+                               switch (*opt_buf) {
+                                       case 1: opt_buf += 1; break;
+                                       case 2: opt_buf += 4; break;
+                                       case 3: opt_buf += 3; break;
+                                       // SACK should never appear
+                                       case 8:
+                                               if (opt_buf + 10 <= buf + nread) {
+                                                       if (!memcmp(opt_buf + 2, &timestamps_magic, 8)) {
+                                                               found_magic = true;
+                                                               break;
+                                                       }
+                                               }
+                                               // Fall through
+                                       default: opt_buf = buf + nread;
+                               }
+                       }
+                       if (!found_magic) continue;
 
                        fprintf(stderr, "Got SYN, sending SYNACK\n");
                        remote_port = (tcp_buf[0] << 8) | tcp_buf[1];
@@ -365,23 +425,23 @@ static void tcp_to_tun_queue_process() {
                        last_ack_recv = std::chrono::steady_clock::now();
 
                if (are_server && syn && !ack) {
-                       build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
+                       int len = build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
 
-                       ssize_t res = sendto(fdr, tcp_buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
+                       ssize_t res = sendto(fdr, tcp_buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
                        if (res < 0) {
                                int err = errno;
                                fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
                        }
                } else if (!syn) {
-                       tcp_buf += tcp_header_size - 20;
+                       if (nread < (ssize_t)(tcp_header_size + header_size + sizeof(struct pkt_hdr))) continue;
+                       struct pkt_hdr hdr;
+                       memcpy(&hdr, tcp_buf + tcp_header_size, sizeof(struct pkt_hdr));
+                       tcp_buf += tcp_header_size + sizeof(struct pkt_hdr) - 20; // IPv4 header is 20 bytes
 
-                       // Replace TCP with IPv4 header
-                       //build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x01, ipip_dest, ipip_src); // ICMP
-                       build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x11, ipip_dest, ipip_src); // UDP
-                       // Add IPIP header
-                       build_ip_header(tcp_buf - 20, nread - tcp_header_size - header_size + 20, 0x04, tun_dest, tun_src);
+                       // Replace TCP + pkt_hdr with IPv4 header
+                       build_ip_header(tcp_buf, hdr, tun_dest, tun_src);
 
-                       write(fd[0], tcp_buf - 20, nread - tcp_header_size - header_size + 40);
+                       write(fd[0], tcp_buf, nread - tcp_header_size - header_size + 20 + 8);
                }
        } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
 }
@@ -410,7 +470,7 @@ static void tun_to_tcp() {
                        continue;
 
                auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage];
-               memcpy(std::get<1>(new_msg).data(), buf, nread);
+               memcpy(std::get<1>(new_msg).data() + PACKET_READ_OFFS, buf, nread);
                std::get<2>(new_msg) = nread;
 
                queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
@@ -433,24 +493,26 @@ static void tun_to_tcp_queue_process() {
                unsigned char* buf = std::get<1>(msg).data();
                ssize_t nread = std::get<2>(msg);
 
-               int header_size = check_ip_header(buf, nread, 0x04); // Only support IPIP
-               if (header_size < 0)
+               struct pkt_hdr hdr;
+               int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, -1, &hdr); // Any paket type is okay
+               if (header_size < 20)
                        continue;
 
-               int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x11); // Only support UDP
-               //int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x01); // Only support ICMP
-               if (internal_header_size < 0)
+               if (header_size > nread) {
+                       fprintf(stderr, "Short packet\n");
                        continue;
+               }
 
-               if (internal_header_size + header_size + 8 > nread) {
-                       fprintf(stderr, "Short UDP-in-IPIP packet\n");
+               if ((size_t)nread > 1500 - 20 - sizeof(struct pkt_hdr)) { // Packets must fit in 1500 bytes with a TCP and packet hdr
+                       fprintf(stderr, "Long packet\n");
                        continue;
                }
 
-               size_t tcp_start_offset = header_size + internal_header_size - 20;
-               build_tcp_header(buf + tcp_start_offset, nread - tcp_start_offset - 20, 0, 0, src, dest.sin_addr.s_addr);
+               unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size - 20 - sizeof(struct pkt_hdr);
+               memcpy(tcp_start + 20, &hdr, sizeof(struct pkt_hdr));
+               build_tcp_header(tcp_start, nread - header_size + sizeof(struct pkt_hdr), 0, 0, src, dest.sin_addr.s_addr);
 
-               ssize_t res = sendto(fdr, buf + tcp_start_offset, nread - tcp_start_offset, 0, (struct sockaddr*)&dest, sizeof(dest));
+               ssize_t res = sendto(fdr, tcp_start, nread + 20 + sizeof(struct pkt_hdr) - header_size, 0, (struct sockaddr*)&dest, sizeof(dest));
                if (res < 0) {
                        int err = errno;
                        fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err));
@@ -475,10 +537,9 @@ int do_init() {
                local_port = local_port_tmp;
        }
 
-       uint32_t starting_ack = 0, starting_seq = 0;
-       memcpy(&starting_ack, &tcp_init_magic, 4);
-       memcpy(&starting_seq, ((const unsigned char*)&tcp_init_magic) + 4, 4);
        highest_recvd_seq = starting_ack;
+       uint32_t starting_seq;
+       assert(getrandom(&starting_seq, sizeof(starting_seq), 0) == sizeof(starting_seq));
        cur_seq = starting_seq;
 
        if (!pause_tun_read_reinit_tcp) { // Not doing a re-init
@@ -490,8 +551,8 @@ int do_init() {
 
        unsigned char buf[1500];
        if (!are_server) {
-               build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
-               ssize_t res = sendto(fdr, buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
+               int len = build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
+               ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
                if (res < 0) {
                        int err = errno;
                        fprintf(stderr, "Failed to send initial SYN with err %d (%s)\n", err, strerror(err));
@@ -508,8 +569,8 @@ int do_init() {
        if (!are_server) {
                fprintf(stderr, "Got SYNACK, sending ACK and starting tun listen\n");
 
-               build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
-               ssize_t res = sendto(fdr, buf, 20, 0, (struct sockaddr*)&dest, sizeof(dest));
+               int len = build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
+               ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
                if (res < 0) {
                        int err = errno;
                        fprintf(stderr, "Failed to send initial ACK with err %d (%s)\n", err, strerror(err));
@@ -535,16 +596,15 @@ int main(int argc, char* argv[]) {
        assert(argc > 1 && "Need tun name");
        assert(argc > 2 && "Need tun remote host");
        assert(argc > 3 && "Need tun local host");
-       assert(argc > 4 && "Need ipip remote host");
-       assert(argc > 5 && "Need ipip local host");
-       assert(argc > 6 && "Need server port");
-       assert(argc > 7 && "Need mode (client or server)");
-       assert(argc > 8 && "Need src host");
-       if (std::string(argv[7]) == std::string("client"))
-               assert(argc > 9 && "Need dest host");
+       assert(argc > 4 && "Need server port");
+       assert(argc > 5 && "Need shared secret");
+       assert(argc > 6 && "Need mode (client or server)");
+       assert(argc > 7 && "Need src host");
+       if (std::string(argv[6]) == std::string("client"))
+               assert(argc > 8 && "Need dest host");
 
-       assert(std::string(argv[7]) == std::string("client") || std::string(argv[7]) == std::string("server"));
-       are_server = (std::string(argv[7]) == std::string("server"));
+       assert(std::string(argv[6]) == std::string("client") || std::string(argv[6]) == std::string("server"));
+       are_server = (std::string(argv[6]) == std::string("server"));
 
        //
        // Parse args into variables
@@ -558,30 +618,31 @@ int main(int argc, char* argv[]) {
        tun_dest = inet_addr(argv[2]);
        tun_src = inet_addr(argv[3]);
 
-       ipip_dest = inet_addr(argv[4]);
-       ipip_src = inet_addr(argv[5]);
-
        if (are_server) {
-               local_port = atoi(argv[6]);
+               local_port = atoi(argv[4]);
                remote_port = 0;
        } else {
                // Get local port in do_init() so that we pick a new one on reload
-               remote_port = atoi(argv[6]);
+               remote_port = atoi(argv[4]);
        }
 
-       src = inet_addr(argv[8]);
+       uint64_t tcp_init_magic = atoll(argv[5]);
+       starting_ack = tcp_init_magic >> 32;
+       timestamps_magic = htobe64(tcp_init_magic);
+
+       src = inet_addr(argv[7]);
 
        memset(&dest, 0, sizeof(dest));
        if (!are_server) {
                dest.sin_family = AF_INET;
-               dest.sin_addr.s_addr = inet_addr(argv[9]);
+               dest.sin_addr.s_addr = inet_addr(argv[8]);
        }
 
        //
        // Create tun and bind to sockets...
        //
 
-       if (tun_alloc(tun_name, TUN_IF_COUNT, fd) != 0) {
+       if (tun_alloc(tun_name, argv[3], argv[2], TUN_IF_COUNT, fd) != 0) {
                fprintf(stderr, "Failed to alloc tun if\n");
                return -1;
        }