Randomize initial seq, use timestamps for magic
authorMatt Corallo <git@bluematt.me>
Wed, 16 Mar 2022 21:08:38 +0000 (21:08 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 16 Mar 2022 22:07:53 +0000 (22:07 +0000)
main.cpp

index 3fa3522a89f17f06ebb90042897e4fc6546b4b4d..381b6106220c926e8c737738634ceb0ea1b1ca78 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -186,6 +186,7 @@ static std::atomic<uint32_t> highest_recvd_seq(0), cur_seq(0);
 static std::atomic<uint16_t> local_port(0);
 static uint16_t remote_port;
 static bool are_server;
+static uint64_t timestamps_magic;
 
 static int build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
        buf[0 ] = local_port >> 8;         // src port
@@ -205,8 +206,8 @@ static int build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synac
        buf[10] = their_seq >> (8 * 1);    // ACK
        buf[11] = their_seq >> (8 * 0);    // ACK
 
-       bool longpkt = syn || synack;
-       buf[12] = (longpkt ? 6 : 5) << 4;  // data offset
+       unsigned char hdrlen = syn ? 36 : (synack ? 24 : 20);
+       buf[12] = (hdrlen/4) << 4;  // data offset
        if (syn)
                buf[13] = 1 << 1;              // SYN
        else if (synack)
@@ -223,18 +224,25 @@ static int build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synac
        buf[18] = 0x00;                    // URG Pointer
        buf[19] = 0x00;                    // URG Pointer
 
-       if (longpkt) {
+       if (syn || synack) {
        buf[20] = 0x01;                    // NOP
        buf[21] = 0x03;                    // Window Scale
        buf[22] = 0x03;                    // Window Scale Option Length
        buf[23] = 0x0e;                    // 1GB Window Size (0xffff << 0x0e)
        }
+       if (syn) {
+       buf[24] = 0x01;                    // NOP
+       buf[25] = 0x01;                    // NOP
+       buf[26] = 8;                       // Timestamp
+       buf[27] = 10;                      // Timestamp Option Length
+       memcpy(buf + 28, &timestamps_magic, 8);
+       }
 
-       uint16_t checksum = tcp_checksum(buf, len + 20 + (longpkt ? 4 : 0), src_addr, dest_addr);
+       uint16_t checksum = tcp_checksum(buf, len + hdrlen, src_addr, dest_addr);
        buf[16] = checksum;                // Checksum
        buf[17] = checksum >> 8;           // Checksum
 
-       return 20 + (longpkt ? 4 : 0);
+       return hdrlen;
 }
 
 static void build_ip_header(unsigned char* buf, struct pkt_hdr hdr, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
@@ -296,7 +304,7 @@ static int fdr;
 static int fd[TUN_IF_COUNT];
 static struct sockaddr_in dest;
 static in_addr_t src, tun_src, tun_dest;
-static uint32_t starting_seq, starting_ack;
+static uint32_t starting_ack;
 
 #define PENDING_MESSAGES_BUFF_SIZE (0x800)
 #define THREAD_POLL_SLEEP_MICS 250
@@ -365,8 +373,31 @@ static void tcp_to_tun_queue_process() {
                bool ack = tcp_buf[13] & (1 << 4);
 
                if (are_server && syn && !ack) {
-                       if (memcmp(tcp_buf + 4, &starting_seq, 4) || memcmp(tcp_buf + 8, &starting_ack, 4))
+                       uint32_t expected_ack = htobe32(starting_ack);
+                       if (memcmp(tcp_buf + 8, &expected_ack, 4))
                                continue;
+                       if (nread < 36 + header_size) { continue; }
+                       // We're a server and just got a client...walk options until we find timestamps
+                       const unsigned char* opt_buf = tcp_buf + 20;
+                       bool found_magic = false;
+                       while (!found_magic && opt_buf < buf + nread) {
+                               switch (*opt_buf) {
+                                       case 1: opt_buf += 1; break;
+                                       case 2: opt_buf += 4; break;
+                                       case 3: opt_buf += 3; break;
+                                       // SACK should never appear
+                                       case 8:
+                                               if (opt_buf + 10 <= buf + nread) {
+                                                       if (!memcmp(opt_buf + 2, &timestamps_magic, 8)) {
+                                                               found_magic = true;
+                                                               break;
+                                                       }
+                                               }
+                                               // Fall through
+                                       default: opt_buf = buf + nread;
+                               }
+                       }
+                       if (!found_magic) continue;
 
                        fprintf(stderr, "Got SYN, sending SYNACK\n");
                        remote_port = (tcp_buf[0] << 8) | tcp_buf[1];
@@ -506,6 +537,8 @@ int do_init() {
        }
 
        highest_recvd_seq = starting_ack;
+       uint32_t starting_seq;
+       assert(getrandom(&starting_seq, sizeof(starting_seq), 0) == sizeof(starting_seq));
        cur_seq = starting_seq;
 
        if (!pause_tun_read_reinit_tcp) { // Not doing a re-init
@@ -593,8 +626,8 @@ int main(int argc, char* argv[]) {
        }
 
        uint64_t tcp_init_magic = atoll(argv[5]);
-       starting_seq = htobe32(tcp_init_magic);
-       starting_ack = htobe32(tcp_init_magic >> 32);
+       starting_ack = tcp_init_magic >> 32;
+       timestamps_magic = htobe64(tcp_init_magic);
 
        src = inet_addr(argv[7]);