381b6106220c926e8c737738634ceb0ea1b1ca78
[tunudptotcp] / main.cpp
1 #include <fcntl.h>
2 #include <string.h>
3 #include <stdio.h>
4 #include <unistd.h>
5 #include <sys/socket.h>
6 #include <sys/ioctl.h>
7 #include <netinet/in.h>
8 #include <errno.h>
9 #include <arpa/inet.h>
10 #include <stdint.h>
11 #include <stdlib.h>
12 #include <linux/if.h>
13 #include <linux/if_tun.h>
14 #include <assert.h>
15
16 #include <atomic>
17 #include <chrono>
18 #include <thread>
19 #include <string>
20
21 #if __has_include(<sys/random.h>)
22 #include <sys/random.h>
23 #else
24 int getrandom(void* buf, size_t len, unsigned int flags) {
25         FILE* f = fopen("/dev/urandom", "r");
26         int res = fread(buf, len, 1, f);
27         fclose(f);
28         return res;
29 }
30 #endif
31
32 struct pkt_hdr {
33         // Note that ordering of the first three uint16_ts must match the wire!
34         uint16_t len;
35         uint16_t id;
36         uint16_t frag_flags: 3,
37                  offset    : 13;
38         uint8_t proto;
39         uint8_t padding;
40 };
41 static_assert(sizeof(struct pkt_hdr) == 8);
42
43 // Have to leave room to add a TCP header plus above header (after dropping IP header)
44 #define PACKET_READ_OFFS sizeof(pkt_hdr)
45 #define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
46
47 static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds)
48 {
49         struct ifreq ifr;
50         int fd, err, i;
51         char buf[1024];
52
53         if (!dev)
54                 return -1;
55
56         memset(&ifr, 0, sizeof(ifr));
57
58         /* Flags: IFF_TUN   - TUN device (no Ethernet headers) 
59         *         IFF_TAP   - TAP device  
60         *
61         *         IFF_NO_PI - Do not provide packet information  
62         */ 
63         ifr.ifr_flags = IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE;
64         strncpy(ifr.ifr_name, dev, IFNAMSIZ);
65
66         for (i = 0; i < queues; i++) {
67                 if((fd = open("/dev/net/tun", O_RDWR)) < 0)
68                         goto err;
69                 err = ioctl(fd, TUNSETIFF, (void *) &ifr);
70                 if (err) {
71                         close(fd);
72                         goto err;
73                 }
74                 fds[i] = fd;
75         }
76
77         sprintf(buf, "ip link set dev %s mtu %ld", dev, PACKET_READ_SIZE - 20);
78         err = system(buf);
79         if (err) goto err;
80
81         sprintf(buf, "ip addr add %s/32 dev %s", local_ip, dev);
82         err = system(buf);
83         if (err) goto err;
84
85         sprintf(buf, "ip link set %s up", dev);
86         err = system(buf);
87         if (err) goto err;
88
89         sprintf(buf, "ip route add %s/32 dev %s", remote_ip, dev);
90         err = system(buf);
91         if (err) goto err;
92
93         return 0;
94 err:
95         for (--i; i >= 0; i--)
96                 close(fds[i]);
97         return err;
98 }
99
100 static int check_ip_header(const unsigned char* buf, ssize_t buf_len, int16_t expected_type, struct pkt_hdr *gen_hdr) {
101         if (buf_len < 20) {
102                 // < size than IPv4?
103                 return -1;
104         }
105
106         if ((buf[0] & 0xf0) != (4 << 4)) {
107                 // Only support IPv4
108                 return -1;
109         }
110
111         uint8_t num_words = buf[0] & 0xf;
112         int header_size = num_words * 4;
113         if (header_size < 20) {
114                 fprintf(stderr, "Invalid IPv4 IHL size (%d)\n", header_size);
115                 return -1;
116         }
117
118         /*if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
119                 //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len);
120                 return -1;
121         }*/
122
123         if (expected_type > 0 && buf[9] != expected_type) {
124                 fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type);
125                 return -1;
126         }
127
128         if (gen_hdr) {
129                 memcpy(gen_hdr, buf + 2, 6); // len, id, frag offset
130                 gen_hdr->proto = buf[9];
131                 gen_hdr->padding = 0;
132         }
133
134         return header_size;
135 }
136
137 void print_packet(const unsigned char* buf, ssize_t buf_len) {
138         for (ssize_t i = 0; i < buf_len; ) {
139                 for (int j = 0; i < buf_len && j < 20; i++, j++)
140                         fprintf(stderr, "%02x", buf[i]);
141                 fprintf(stderr, "\n");
142         }
143 }
144
145 static uint32_t data_checksum(const unsigned char* buf, size_t len) {
146         uint32_t sum = 0;
147
148         while (len > 1)
149         {
150                 sum += *buf++;
151                 sum += (*buf++) << 8;
152                 if (sum & 0x80000000)
153                         sum = (sum & 0xFFFF) + (sum >> 16);
154                 len -= 2;
155         }
156
157         if (len & 1)
158                 sum += *buf;
159
160         return sum;
161 }
162
163 static uint16_t finalize_data_checksum(uint32_t sum) {
164         while (sum >> 16)
165                 sum = (sum & 0xFFFF) + (sum >> 16);
166
167         return (uint16_t)(~sum);
168 }
169
170 static uint16_t tcp_checksum(const unsigned char* buff, size_t len, in_addr_t src_addr, in_addr_t dest_addr)
171 {
172         uint16_t *ip_src = (uint16_t*)&src_addr, *ip_dst = (uint16_t*)&dest_addr;
173         uint32_t sum = data_checksum(buff, len);
174
175         sum += *(ip_src++);
176         sum += *ip_src;
177         sum += *(ip_dst++);
178         sum += *ip_dst;
179         sum += htons(IPPROTO_TCP);
180         sum += htons(len);
181
182         return finalize_data_checksum(sum);
183 }
184
185 static std::atomic<uint32_t> highest_recvd_seq(0), cur_seq(0);
186 static std::atomic<uint16_t> local_port(0);
187 static uint16_t remote_port;
188 static bool are_server;
189 static uint64_t timestamps_magic;
190
191 static int build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
192         buf[0 ] = local_port >> 8;         // src port
193         buf[1 ] = local_port;              // src port
194         buf[2 ] = remote_port >> 8;        // dst port
195         buf[3 ] = remote_port;             // dst port
196
197         uint32_t seq = cur_seq.fetch_add((syn || synack) ? 1 : len, std::memory_order_acq_rel);
198         buf[4 ] = seq >> (8 * 3);          // SEQ
199         buf[5 ] = seq >> (8 * 2);          // SEQ
200         buf[6 ] = seq >> (8 * 1);          // SEQ
201         buf[7 ] = seq >> (8 * 0);          // SEQ
202
203         uint32_t their_seq = highest_recvd_seq.load(std::memory_order_relaxed);
204         buf[8 ] = their_seq >> (8 * 3);    // ACK
205         buf[9 ] = their_seq >> (8 * 2);    // ACK
206         buf[10] = their_seq >> (8 * 1);    // ACK
207         buf[11] = their_seq >> (8 * 0);    // ACK
208
209         unsigned char hdrlen = syn ? 36 : (synack ? 24 : 20);
210         buf[12] = (hdrlen/4) << 4;  // data offset
211         if (syn)
212                 buf[13] = 1 << 1;              // SYN
213         else if (synack)
214                 buf[13] = (1 << 1) | (1 << 4); // SYN + ACK
215         else if (len == 0)
216                 buf[13] = 1 << 4;              // ACK
217         else
218                 buf[13] = 3 << 3;              // PSH + ACK
219         buf[14] = 0xff;                    // Window Size
220         buf[15] = 0xff;                    // Window Size
221
222         buf[16] = 0x00;                    // Checksum
223         buf[17] = 0x00;                    // Checksum
224         buf[18] = 0x00;                    // URG Pointer
225         buf[19] = 0x00;                    // URG Pointer
226
227         if (syn || synack) {
228         buf[20] = 0x01;                    // NOP
229         buf[21] = 0x03;                    // Window Scale
230         buf[22] = 0x03;                    // Window Scale Option Length
231         buf[23] = 0x0e;                    // 1GB Window Size (0xffff << 0x0e)
232         }
233         if (syn) {
234         buf[24] = 0x01;                    // NOP
235         buf[25] = 0x01;                    // NOP
236         buf[26] = 8;                       // Timestamp
237         buf[27] = 10;                      // Timestamp Option Length
238         memcpy(buf + 28, &timestamps_magic, 8);
239         }
240
241         uint16_t checksum = tcp_checksum(buf, len + hdrlen, src_addr, dest_addr);
242         buf[16] = checksum;                // Checksum
243         buf[17] = checksum >> 8;           // Checksum
244
245         return hdrlen;
246 }
247
248 static void build_ip_header(unsigned char* buf, struct pkt_hdr hdr, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
249         buf[0 ] = (4 << 4) | 5;    // IPv4 + IHL of 5 (20 bytes)
250         buf[1 ] = 0;               // DSCP 0 + ECN 0
251         memcpy(buf + 2, &hdr, 6);  // length, identification, flags, and offset
252         buf[8 ] = 255;             // TTL
253         buf[9 ] = hdr.proto;       // Protocol Number
254         buf[10] = 0;               // Checksum
255         buf[11] = 0;               // Checksum
256
257         memcpy(buf + 12, &src_addr, 4);
258         memcpy(buf + 16, &dest_addr, 4);
259
260         uint16_t checksum = finalize_data_checksum(data_checksum(buf, 20));
261         buf[10] = checksum;
262         buf[11] = checksum >> 8;
263 }
264
265 const signed char p_util_hexdigit[256] =
266 { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
267   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
268   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
269   0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,
270   -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
271   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
272   -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
273   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
274   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
275   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
276   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
277   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
278   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
279   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
280   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
281   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, };
282
283 uint32_t hex_to_num(const unsigned char* buf) {
284         const unsigned char* pbegin = buf;
285         while (p_util_hexdigit[*buf] != -1)
286                 buf++;
287         buf--;
288         uint32_t res = 0;
289         unsigned char* p1 = (unsigned char*)&res;
290         unsigned char* pend = p1 + 4;
291         while (buf >= pbegin && p1 < pend) {
292                 *p1 = p_util_hexdigit[*buf--];
293                 if (buf >= pbegin) {
294                         *p1 |= ((unsigned char)p_util_hexdigit[*buf--] << 4);
295                         p1++;
296                 }
297         }
298
299         return res;
300 }
301
302 #define TUN_IF_COUNT 4
303 static int fdr;
304 static int fd[TUN_IF_COUNT];
305 static struct sockaddr_in dest;
306 static in_addr_t src, tun_src, tun_dest;
307 static uint32_t starting_ack;
308
309 #define PENDING_MESSAGES_BUFF_SIZE (0x800)
310 #define THREAD_POLL_SLEEP_MICS 250
311 struct MessageQueue {
312         std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE + PACKET_READ_OFFS>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
313         std::atomic<uint16_t> nextPendingMessage, nextUndefinedMessage;
314         MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {}
315         MessageQueue(MessageQueue&& q) =delete;
316         MessageQueue(MessageQueue& q) =delete;
317 };
318
319 static MessageQueue tcp_to_tun_queue;
320 static std::chrono::steady_clock::time_point last_ack_recv;
321
322 static void tcp_to_tun() {
323         unsigned char buf[PACKET_READ_SIZE + PACKET_READ_OFFS];
324         struct sockaddr_in pkt_src;
325         memset(&pkt_src, 0, sizeof(pkt_src));
326
327         while (1) {
328                 socklen_t hostsz = sizeof(pkt_src);
329                 ssize_t nread = recvfrom(fdr, buf, sizeof(buf), 0, (struct sockaddr*)&pkt_src, &hostsz);
330                 if (nread < 0) {
331                         fprintf (stderr, "Failed to read tcp raw sock\n");
332                         exit(-1);
333                 }
334
335                 if (tcp_to_tun_queue.nextPendingMessage == (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
336                         continue;
337
338                 auto& new_msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextUndefinedMessage];
339                 std::get<0>(new_msg) = pkt_src;
340                 memcpy(std::get<1>(new_msg).data(), buf, nread);
341                 std::get<2>(new_msg) = nread;
342
343                 tcp_to_tun_queue.nextUndefinedMessage = (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
344         }
345 }
346
347 static void tcp_to_tun_queue_process() {
348         do {
349                 while (tcp_to_tun_queue.nextUndefinedMessage == tcp_to_tun_queue.nextPendingMessage) {
350                         std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
351                 }
352
353                 auto& msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextPendingMessage];
354
355                 const sockaddr_in& pkt_src = std::get<0>(msg);
356                 unsigned char* buf = std::get<1>(msg).data();
357                 ssize_t nread = std::get<2>(msg);
358
359                 int header_size = check_ip_header(buf, nread, 0x06, NULL); // Only support TCP
360                 if (header_size < 0)
361                         continue;
362
363                 if (nread - header_size < 20) {
364                         fprintf(stderr, "Short TCP packet\n");
365                         continue;
366                 }
367
368                 unsigned char* tcp_buf = buf + header_size;
369
370                 if (((tcp_buf[2] << 8) | tcp_buf[3]) != local_port) continue;
371
372                 bool syn = tcp_buf[13] & (1 << 1);
373                 bool ack = tcp_buf[13] & (1 << 4);
374
375                 if (are_server && syn && !ack) {
376                         uint32_t expected_ack = htobe32(starting_ack);
377                         if (memcmp(tcp_buf + 8, &expected_ack, 4))
378                                 continue;
379                         if (nread < 36 + header_size) { continue; }
380                         // We're a server and just got a client...walk options until we find timestamps
381                         const unsigned char* opt_buf = tcp_buf + 20;
382                         bool found_magic = false;
383                         while (!found_magic && opt_buf < buf + nread) {
384                                 switch (*opt_buf) {
385                                         case 1: opt_buf += 1; break;
386                                         case 2: opt_buf += 4; break;
387                                         case 3: opt_buf += 3; break;
388                                         // SACK should never appear
389                                         case 8:
390                                                 if (opt_buf + 10 <= buf + nread) {
391                                                         if (!memcmp(opt_buf + 2, &timestamps_magic, 8)) {
392                                                                 found_magic = true;
393                                                                 break;
394                                                         }
395                                                 }
396                                                 // Fall through
397                                         default: opt_buf = buf + nread;
398                                 }
399                         }
400                         if (!found_magic) continue;
401
402                         fprintf(stderr, "Got SYN, sending SYNACK\n");
403                         remote_port = (tcp_buf[0] << 8) | tcp_buf[1];
404                         dest = pkt_src;
405                 }
406
407                 if (((tcp_buf[0] << 8) | tcp_buf[1]) != remote_port) continue;
408                 if (pkt_src.sin_addr.s_addr != dest.sin_addr.s_addr) continue;
409
410                 uint8_t num_words = (tcp_buf[12] & 0xf0) >> 4;
411                 int tcp_header_size = num_words * 4;
412                 if (tcp_header_size < 20) {
413                         fprintf(stderr, "Invalid TCP header size (%d)\n", tcp_header_size);
414                         continue;
415                 }
416
417                 highest_recvd_seq = ((((uint32_t)tcp_buf[4]) << (8 * 3)) |
418                                      (((uint32_t)tcp_buf[5]) << (8 * 2)) |
419                                      (((uint32_t)tcp_buf[6]) << (8 * 1)) |
420                                      (((uint32_t)tcp_buf[7]) << (8 * 0))) +
421                                      (syn ? 1 : nread - header_size - tcp_header_size);
422
423                 if (ack)
424                         last_ack_recv = std::chrono::steady_clock::now();
425
426                 if (are_server && syn && !ack) {
427                         int len = build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
428
429                         ssize_t res = sendto(fdr, tcp_buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
430                         if (res < 0) {
431                                 int err = errno;
432                                 fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
433                         }
434                 } else if (!syn) {
435                         if (nread < (ssize_t)(tcp_header_size + header_size + sizeof(struct pkt_hdr))) continue;
436                         struct pkt_hdr hdr;
437                         memcpy(&hdr, tcp_buf + tcp_header_size, sizeof(struct pkt_hdr));
438                         tcp_buf += tcp_header_size + sizeof(struct pkt_hdr) - 20; // IPv4 header is 20 bytes
439
440                         // Replace TCP + pkt_hdr with IPv4 header
441                         build_ip_header(tcp_buf, hdr, tun_dest, tun_src);
442
443                         write(fd[0], tcp_buf, nread - tcp_header_size - header_size + 20 + 8);
444                 }
445         } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
446 }
447
448 static std::atomic_int tun_if_thread(0), tun_if_process_thread(0);
449 static std::atomic_bool pause_tun_read_reinit_tcp(false);
450 static MessageQueue tun_to_tcp_queue[TUN_IF_COUNT];
451
452 static void tun_to_tcp() {
453         unsigned char buf[PACKET_READ_SIZE];
454
455         int thread = tun_if_thread.fetch_add(1);
456         MessageQueue& queue = tun_to_tcp_queue[thread];
457
458         while (1) {
459                 ssize_t nread = read(fd[thread], buf, sizeof(buf));
460                 if (pause_tun_read_reinit_tcp)
461                         continue;
462
463                 if (nread < 0) {
464                         fprintf (stderr, "Failed to read tun if\n");
465                         continue;
466                 }
467
468                 if (queue.nextPendingMessage == (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
469                         continue;
470
471                 auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage];
472                 memcpy(std::get<1>(new_msg).data() + PACKET_READ_OFFS, buf, nread);
473                 std::get<2>(new_msg) = nread;
474
475                 queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
476         }
477 }
478
479 static void tun_to_tcp_queue_process() {
480         int thread = tun_if_process_thread.fetch_add(1);
481         MessageQueue& queue = tun_to_tcp_queue[thread];
482
483         do {
484                 while (queue.nextUndefinedMessage == queue.nextPendingMessage) {
485                         std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
486                 }
487                 if (pause_tun_read_reinit_tcp)
488                         continue;
489
490                 auto& msg = queue.messagesPendingRingBuff[queue.nextPendingMessage];
491
492                 unsigned char* buf = std::get<1>(msg).data();
493                 ssize_t nread = std::get<2>(msg);
494
495                 struct pkt_hdr hdr;
496                 int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, -1, &hdr); // Any paket type is okay
497                 if (header_size < 20)
498                         continue;
499
500                 if (header_size > nread) {
501                         fprintf(stderr, "Short packet\n");
502                         continue;
503                 }
504
505                 if ((size_t)nread > 1500 - 20 - sizeof(struct pkt_hdr)) { // Packets must fit in 1500 bytes with a TCP and packet hdr
506                         fprintf(stderr, "Long packet\n");
507                         continue;
508                 }
509
510                 unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size - 20 - sizeof(struct pkt_hdr);
511                 memcpy(tcp_start + 20, &hdr, sizeof(struct pkt_hdr));
512                 build_tcp_header(tcp_start, nread - header_size + sizeof(struct pkt_hdr), 0, 0, src, dest.sin_addr.s_addr);
513
514                 ssize_t res = sendto(fdr, tcp_start, nread + 20 + sizeof(struct pkt_hdr) - header_size, 0, (struct sockaddr*)&dest, sizeof(dest));
515                 if (res < 0) {
516                         int err = errno;
517                         fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err));
518                 }
519
520         } while ((queue.nextPendingMessage = (queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
521 }
522
523 int do_init() {
524         //
525         // Send SYN and SYN/ACK
526         //
527
528         if (!are_server) {
529                 if (local_port) // Doing a re-init
530                         pause_tun_read_reinit_tcp = true;
531
532                 uint16_t local_port_tmp = 0;
533                 while (local_port_tmp < 1024)
534                         assert(getrandom(&local_port_tmp, sizeof(local_port_tmp), 0) == sizeof(local_port_tmp));
535
536                 local_port = local_port_tmp;
537         }
538
539         highest_recvd_seq = starting_ack;
540         uint32_t starting_seq;
541         assert(getrandom(&starting_seq, sizeof(starting_seq), 0) == sizeof(starting_seq));
542         cur_seq = starting_seq;
543
544         if (!pause_tun_read_reinit_tcp) { // Not doing a re-init
545                 std::thread t(&tcp_to_tun);
546                 std::thread t2(&tcp_to_tun_queue_process);
547                 t.detach();
548                 t2.detach();
549         }
550
551         unsigned char buf[1500];
552         if (!are_server) {
553                 int len = build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
554                 ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
555                 if (res < 0) {
556                         int err = errno;
557                         fprintf(stderr, "Failed to send initial SYN with err %d (%s)\n", err, strerror(err));
558                         return -1;
559                 }
560         }
561
562         int i;
563         for (i = 0; i < 1000 && highest_recvd_seq == starting_ack; i++)
564                 std::this_thread::sleep_for(std::chrono::milliseconds(10));
565         if (i == 1000) // Will come back in 10 seconds
566                 return 0;
567
568         if (!are_server) {
569                 fprintf(stderr, "Got SYNACK, sending ACK and starting tun listen\n");
570
571                 int len = build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
572                 ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
573                 if (res < 0) {
574                         int err = errno;
575                         fprintf(stderr, "Failed to send initial ACK with err %d (%s)\n", err, strerror(err));
576                         return -1;
577                 }
578         }
579
580         if (pause_tun_read_reinit_tcp) {
581                 pause_tun_read_reinit_tcp = false;
582         } else {
583                 for (int i = 0; i < TUN_IF_COUNT; i++) {
584                         std::thread t3(&tun_to_tcp);
585                         std::thread t4(&tun_to_tcp_queue_process);
586                         t3.detach();
587                         t4.detach();
588                 }
589         }
590
591         return 0;
592 }
593
594 int main(int argc, char* argv[]) {
595         assert(argc > 1 && "Need tun name");
596         assert(argc > 2 && "Need tun remote host");
597         assert(argc > 3 && "Need tun local host");
598         assert(argc > 4 && "Need server port");
599         assert(argc > 5 && "Need shared secret");
600         assert(argc > 6 && "Need mode (client or server)");
601         assert(argc > 7 && "Need src host");
602         if (std::string(argv[6]) == std::string("client"))
603                 assert(argc > 8 && "Need dest host");
604
605         assert(std::string(argv[6]) == std::string("client") || std::string(argv[6]) == std::string("server"));
606         are_server = (std::string(argv[6]) == std::string("server"));
607
608         //
609         // Parse args into variables
610         //
611
612         char tun_name[IFNAMSIZ];
613
614         memset(tun_name, 0, sizeof(tun_name));
615         strcpy(tun_name, argv[1]);
616
617         tun_dest = inet_addr(argv[2]);
618         tun_src = inet_addr(argv[3]);
619
620         if (are_server) {
621                 local_port = atoi(argv[4]);
622                 remote_port = 0;
623         } else {
624                 // Get local port in do_init() so that we pick a new one on reload
625                 remote_port = atoi(argv[4]);
626         }
627
628         uint64_t tcp_init_magic = atoll(argv[5]);
629         starting_ack = tcp_init_magic >> 32;
630         timestamps_magic = htobe64(tcp_init_magic);
631
632         src = inet_addr(argv[7]);
633
634         memset(&dest, 0, sizeof(dest));
635         if (!are_server) {
636                 dest.sin_family = AF_INET;
637                 dest.sin_addr.s_addr = inet_addr(argv[8]);
638         }
639
640         //
641         // Create tun and bind to sockets...
642         //
643
644         if (tun_alloc(tun_name, argv[3], argv[2], TUN_IF_COUNT, fd) != 0) {
645                 fprintf(stderr, "Failed to alloc tun if\n");
646                 return -1;
647         }
648
649         fdr = socket(AF_INET, SOCK_RAW, IPPROTO_TCP);
650         if (fdr < 0) {
651                 fprintf(stderr, "Failed to get raw socket\n");
652                 return -1;
653         }
654
655         int res = do_init();
656         if (res)
657                 return res;
658
659         while (true) {
660                 std::this_thread::sleep_for(std::chrono::seconds(15));
661                 if (!are_server && last_ack_recv < std::chrono::steady_clock::now() - std::chrono::seconds(15)) {
662                         res = do_init();
663                         if (res)
664                                 return res;
665                 }
666         }
667 }