Handle initial seq endianness correctly
authorMatt Corallo <git@bluematt.me>
Wed, 16 Mar 2022 06:31:02 +0000 (06:31 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 16 Mar 2022 22:07:53 +0000 (22:07 +0000)
main.cpp

index c116e8994a27b53e6d0d3ec10a147b2704b26b73..3fa3522a89f17f06ebb90042897e4fc6546b4b4d 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -187,7 +187,7 @@ static std::atomic<uint16_t> local_port(0);
 static uint16_t remote_port;
 static bool are_server;
 
-static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
+static int build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
        buf[0 ] = local_port >> 8;         // src port
        buf[1 ] = local_port;              // src port
        buf[2 ] = remote_port >> 8;        // dst port
@@ -233,6 +233,8 @@ static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int syna
        uint16_t checksum = tcp_checksum(buf, len + 20 + (longpkt ? 4 : 0), src_addr, dest_addr);
        buf[16] = checksum;                // Checksum
        buf[17] = checksum >> 8;           // Checksum
+
+       return 20 + (longpkt ? 4 : 0);
 }
 
 static void build_ip_header(unsigned char* buf, struct pkt_hdr hdr, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
@@ -294,7 +296,7 @@ static int fdr;
 static int fd[TUN_IF_COUNT];
 static struct sockaddr_in dest;
 static in_addr_t src, tun_src, tun_dest;
-static uint64_t tcp_init_magic;
+static uint32_t starting_seq, starting_ack;
 
 #define PENDING_MESSAGES_BUFF_SIZE (0x800)
 #define THREAD_POLL_SLEEP_MICS 250
@@ -363,15 +365,7 @@ static void tcp_to_tun_queue_process() {
                bool ack = tcp_buf[13] & (1 << 4);
 
                if (are_server && syn && !ack) {
-                       // We're a server and just got a client
-                       if (tcp_buf[4 ] != uint8_t(tcp_init_magic >> (7 * 8)) ||
-                           tcp_buf[5 ] != uint8_t(tcp_init_magic >> (6 * 8)) ||
-                           tcp_buf[6 ] != uint8_t(tcp_init_magic >> (5 * 8)) ||
-                           tcp_buf[7 ] != uint8_t(tcp_init_magic >> (4 * 8)) ||
-                           tcp_buf[8 ] != uint8_t(tcp_init_magic >> (3 * 8)) ||
-                           tcp_buf[9 ] != uint8_t(tcp_init_magic >> (2 * 8)) ||
-                           tcp_buf[10] != uint8_t(tcp_init_magic >> (1 * 8)) ||
-                           tcp_buf[11] != uint8_t(tcp_init_magic >> (0 * 8)))
+                       if (memcmp(tcp_buf + 4, &starting_seq, 4) || memcmp(tcp_buf + 8, &starting_ack, 4))
                                continue;
 
                        fprintf(stderr, "Got SYN, sending SYNACK\n");
@@ -399,9 +393,9 @@ static void tcp_to_tun_queue_process() {
                        last_ack_recv = std::chrono::steady_clock::now();
 
                if (are_server && syn && !ack) {
-                       build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
+                       int len = build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
 
-                       ssize_t res = sendto(fdr, tcp_buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
+                       ssize_t res = sendto(fdr, tcp_buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
                        if (res < 0) {
                                int err = errno;
                                fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
@@ -511,9 +505,6 @@ int do_init() {
                local_port = local_port_tmp;
        }
 
-       uint32_t starting_ack = 0, starting_seq = 0;
-       memcpy(&starting_ack, &tcp_init_magic, 4);
-       memcpy(&starting_seq, ((const unsigned char*)&tcp_init_magic) + 4, 4);
        highest_recvd_seq = starting_ack;
        cur_seq = starting_seq;
 
@@ -526,8 +517,8 @@ int do_init() {
 
        unsigned char buf[1500];
        if (!are_server) {
-               build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
-               ssize_t res = sendto(fdr, buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
+               int len = build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
+               ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
                if (res < 0) {
                        int err = errno;
                        fprintf(stderr, "Failed to send initial SYN with err %d (%s)\n", err, strerror(err));
@@ -544,8 +535,8 @@ int do_init() {
        if (!are_server) {
                fprintf(stderr, "Got SYNACK, sending ACK and starting tun listen\n");
 
-               build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
-               ssize_t res = sendto(fdr, buf, 20, 0, (struct sockaddr*)&dest, sizeof(dest));
+               int len = build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
+               ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
                if (res < 0) {
                        int err = errno;
                        fprintf(stderr, "Failed to send initial ACK with err %d (%s)\n", err, strerror(err));
@@ -601,7 +592,9 @@ int main(int argc, char* argv[]) {
                remote_port = atoi(argv[4]);
        }
 
-       tcp_init_magic = atoll(argv[5]);
+       uint64_t tcp_init_magic = atoll(argv[5]);
+       starting_seq = htobe32(tcp_init_magic);
+       starting_ack = htobe32(tcp_init_magic >> 32);
 
        src = inet_addr(argv[7]);