Remove the need for an ipip link to reduce packet size (duh, mtu)
[tunudptotcp] / main.cpp
1 #include <fcntl.h>
2 #include <string.h>
3 #include <stdio.h>
4 #include <unistd.h>
5 #include <sys/socket.h>
6 #include <sys/ioctl.h>
7 #include <netinet/in.h>
8 #include <errno.h>
9 #include <arpa/inet.h>
10 #include <stdint.h>
11 #include <stdlib.h>
12 #include <linux/if.h>
13 #include <linux/if_tun.h>
14 #include <assert.h>
15
16 #include <atomic>
17 #include <chrono>
18 #include <thread>
19 #include <string>
20
21 #if __has_include(<sys/random.h>)
22 #include <sys/random.h>
23 #else
24 int getrandom(void* buf, size_t len, unsigned int flags) {
25         FILE* f = fopen("/dev/urandom", "r");
26         int res = fread(buf, len, 1, f);
27         fclose(f);
28         return res;
29 }
30 #endif
31
32 #define PACKET_READ_OFFS (20-16)
33 #define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
34
35 static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds)
36 {
37         struct ifreq ifr;
38         int fd, err, i;
39         char buf[1024];
40
41         if (!dev)
42                 return -1;
43
44         memset(&ifr, 0, sizeof(ifr));
45
46         /* Flags: IFF_TUN   - TUN device (no Ethernet headers) 
47         *         IFF_TAP   - TAP device  
48         *
49         *         IFF_NO_PI - Do not provide packet information  
50         */ 
51         ifr.ifr_flags = IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE;
52         strncpy(ifr.ifr_name, dev, IFNAMSIZ);
53
54         for (i = 0; i < queues; i++) {
55                 if((fd = open("/dev/net/tun", O_RDWR)) < 0)
56                         goto err;
57                 err = ioctl(fd, TUNSETIFF, (void *) &ifr);
58                 if (err) {
59                         close(fd);
60                         goto err;
61                 }
62                 fds[i] = fd;
63         }
64
65         sprintf(buf, "ip link set %s mtu %d", dev, PACKET_READ_SIZE);
66         err = system(buf);
67         if (err) goto err;
68
69         sprintf(buf, "ip addr add %s/32 dev %s", local_ip, dev);
70         err = system(buf);
71         if (err) goto err;
72
73         sprintf(buf, "ip link set %s up", dev);
74         err = system(buf);
75         if (err) goto err;
76
77         sprintf(buf, "ip route add %s/32 dev %s", remote_ip, dev);
78         err = system(buf);
79         if (err) goto err;
80
81         return 0;
82 err:
83         for (--i; i >= 0; i--)
84                 close(fds[i]);
85         return err;
86 }
87
88 static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t expected_type) {
89         if (buf_len < 20) {
90                 // < size than IPv4?
91                 return -1;
92         }
93
94         if ((buf[0] & 0xf0) != (4 << 4)) {
95                 // Only support IPv4
96                 return -1;
97         }
98
99         uint8_t num_words = buf[0] & 0xf;
100         int header_size = num_words * 4;
101         if (header_size < 20) {
102                 fprintf(stderr, "Invalid IPv4 IHL size (%d)\n", header_size);
103                 return -1;
104         }
105
106         if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
107                 //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len);
108                 return -1;
109         }
110
111         if (buf[9] != expected_type) {
112                 fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type);
113                 return -1;
114         }
115
116         return header_size;
117 }
118
119 #define UDP_PORTS 4242
120 static int check_udp_header(const unsigned char* buf, ssize_t buf_len) {
121         if (buf_len < 8) {
122                 return -1;
123         }
124         if (((((uint16_t)buf[0]) << 8) | (buf[1])) != UDP_PORTS) {
125                 return -1;
126         }
127         if (((((uint16_t)buf[2]) << 8) | (buf[3])) != UDP_PORTS) {
128                 return -1;
129         }
130         return 8;
131 }
132
133 void print_packet(const unsigned char* buf, ssize_t buf_len) {
134         for (ssize_t i = 0; i < buf_len; ) {
135                 for (int j = 0; i < buf_len && j < 20; i++, j++)
136                         fprintf(stderr, "%02x", buf[i]);
137                 fprintf(stderr, "\n");
138         }
139 }
140
141 static uint32_t data_checksum(const unsigned char* buf, size_t len) {
142         uint32_t sum = 0;
143
144         while (len > 1)
145         {
146                 sum += *buf++;
147                 sum += (*buf++) << 8;
148                 if (sum & 0x80000000)
149                         sum = (sum & 0xFFFF) + (sum >> 16);
150                 len -= 2;
151         }
152
153         if (len & 1)
154                 sum += *buf;
155
156         return sum;
157 }
158
159 static uint16_t finalize_data_checksum(uint32_t sum) {
160         while (sum >> 16)
161                 sum = (sum & 0xFFFF) + (sum >> 16);
162
163         return (uint16_t)(~sum);
164 }
165
166 static uint16_t tcp_checksum(const unsigned char* buff, size_t len, in_addr_t src_addr, in_addr_t dest_addr)
167 {
168         uint16_t *ip_src = (uint16_t*)&src_addr, *ip_dst = (uint16_t*)&dest_addr;
169         uint32_t sum = data_checksum(buff, len);
170
171         sum += *(ip_src++);
172         sum += *ip_src;
173         sum += *(ip_dst++);
174         sum += *ip_dst;
175         sum += htons(IPPROTO_TCP);
176         sum += htons(len);
177
178         return finalize_data_checksum(sum);
179 }
180
181 static std::atomic<uint32_t> highest_recvd_seq(0), cur_seq(0);
182 static std::atomic<uint16_t> local_port(0);
183 static uint16_t remote_port;
184 static bool are_server;
185
186 static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
187         buf[0 ] = local_port >> 8;         // src port
188         buf[1 ] = local_port;              // src port
189         buf[2 ] = remote_port >> 8;        // dst port
190         buf[3 ] = remote_port;             // dst port
191
192         uint32_t seq = cur_seq.fetch_add((syn || synack) ? 1 : len, std::memory_order_acq_rel);
193         buf[4 ] = seq >> (8 * 3);          // SEQ
194         buf[5 ] = seq >> (8 * 2);          // SEQ
195         buf[6 ] = seq >> (8 * 1);          // SEQ
196         buf[7 ] = seq >> (8 * 0);          // SEQ
197
198         uint32_t their_seq = highest_recvd_seq.load(std::memory_order_relaxed);
199         buf[8 ] = their_seq >> (8 * 3);    // ACK
200         buf[9 ] = their_seq >> (8 * 2);    // ACK
201         buf[10] = their_seq >> (8 * 1);    // ACK
202         buf[11] = their_seq >> (8 * 0);    // ACK
203
204         bool longpkt = syn || synack;
205         buf[12] = (longpkt ? 6 : 5) << 4;  // data offset
206         if (syn)
207                 buf[13] = 1 << 1;              // SYN
208         else if (synack)
209                 buf[13] = (1 << 1) | (1 << 4); // SYN + ACK
210         else if (len == 0)
211                 buf[13] = 1 << 4;              // ACK
212         else
213                 buf[13] = 3 << 3;              // PSH + ACK
214         buf[14] = 0xff;                    // Window Size
215         buf[15] = 0xff;                    // Window Size
216
217         buf[16] = 0x00;                    // Checksum
218         buf[17] = 0x00;                    // Checksum
219         buf[18] = 0x00;                    // URG Pointer
220         buf[19] = 0x00;                    // URG Pointer
221
222         if (longpkt) {
223         buf[20] = 0x01;                    // NOP
224         buf[21] = 0x03;                    // Window Scale
225         buf[22] = 0x03;                    // Window Scale Option Length
226         buf[23] = 0x0e;                    // 1GB Window Size (0xffff << 0x0e)
227         }
228
229         uint16_t checksum = tcp_checksum(buf, len + 20 + (longpkt ? 4 : 0), src_addr, dest_addr);
230         buf[16] = checksum;                // Checksum
231         buf[17] = checksum >> 8;           // Checksum
232 }
233
234 // len is data length, not including headers!
235 static void build_ip_header(unsigned char* buf, uint16_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
236         buf[0 ] = (4 << 4) | 5;    // IPv4 + IHL of 5 (20 bytes)
237         buf[1 ] = 0;               // DSCP 0 + ECN 0
238         buf[2 ] = (len + 20) >> 8; // Length
239         buf[3 ] = (len + 20);      // Length
240         memset(buf + 4, 0, 4);     // Identification and Fragment 0s
241         buf[6 ] = 1 << 6;          // DF bit
242         buf[8 ] = 255;             // TTL
243         buf[9 ] = proto;           // Protocol Number
244         buf[10] = 0;               // Checksum
245         buf[11] = 0;               // Checksum
246
247         memcpy(buf + 12, &src_addr, 4);
248         memcpy(buf + 16, &dest_addr, 4);
249
250         uint16_t checksum = finalize_data_checksum(data_checksum(buf, 20));
251         buf[10] = checksum;
252         buf[11] = checksum >> 8;
253 }
254
255 // length is data length, not including headers
256 static void build_udp_header(unsigned char* buf, uint16_t len) {
257         buf[0] = UDP_PORTS >> 8;
258         buf[1] = UDP_PORTS & 0xff;
259         buf[2] = UDP_PORTS >> 8;
260         buf[3] = UDP_PORTS & 0xff;
261         buf[4] = (len + 8) >> 8;
262         buf[5] = (len + 8);
263
264         // Checksum 0 for now
265         buf[6] = 0;
266         buf[7] = 0;
267 }
268
269
270 const signed char p_util_hexdigit[256] =
271 { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
272   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
273   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
274   0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,
275   -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
276   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
277   -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
278   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
279   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
280   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
281   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
282   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
283   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
284   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
285   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
286   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, };
287
288 uint32_t hex_to_num(const unsigned char* buf) {
289         const unsigned char* pbegin = buf;
290         while (p_util_hexdigit[*buf] != -1)
291                 buf++;
292         buf--;
293         uint32_t res = 0;
294         unsigned char* p1 = (unsigned char*)&res;
295         unsigned char* pend = p1 + 4;
296         while (buf >= pbegin && p1 < pend) {
297                 *p1 = p_util_hexdigit[*buf--];
298                 if (buf >= pbegin) {
299                         *p1 |= ((unsigned char)p_util_hexdigit[*buf--] << 4);
300                         p1++;
301                 }
302         }
303
304         return res;
305 }
306
307 #define TUN_IF_COUNT 4
308 static int fdr;
309 static int fd[TUN_IF_COUNT];
310 static struct sockaddr_in dest;
311 static in_addr_t src, tun_src, tun_dest;
312 static uint64_t tcp_init_magic;
313
314 #define PENDING_MESSAGES_BUFF_SIZE (0x3000)
315 #define PACKET_READ_OFFS (20-16)
316 #define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
317 #define THREAD_POLL_SLEEP_MICS 50
318 struct MessageQueue {
319         std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE + PACKET_READ_OFFS>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
320         std::atomic<uint16_t> nextPendingMessage, nextUndefinedMessage;
321         MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {}
322         MessageQueue(MessageQueue&& q) =delete;
323         MessageQueue(MessageQueue& q) =delete;
324 };
325
326 static MessageQueue tcp_to_tun_queue;
327 static std::chrono::steady_clock::time_point last_ack_recv;
328
329 static void tcp_to_tun() {
330         unsigned char buf[1500];
331         struct sockaddr_in pkt_src;
332         memset(&pkt_src, 0, sizeof(pkt_src));
333
334         while (1) {
335                 socklen_t hostsz = sizeof(pkt_src);
336                 ssize_t nread = recvfrom(fdr, buf, sizeof(buf), 0, (struct sockaddr*)&pkt_src, &hostsz);
337                 if (nread < 0) {
338                         fprintf (stderr, "Failed to read tcp raw sock\n");
339                         exit(-1);
340                 }
341
342                 if (tcp_to_tun_queue.nextPendingMessage == (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
343                         continue;
344
345                 auto& new_msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextUndefinedMessage];
346                 std::get<0>(new_msg) = pkt_src;
347                 memcpy(std::get<1>(new_msg).data(), buf, nread);
348                 std::get<2>(new_msg) = nread;
349
350                 tcp_to_tun_queue.nextUndefinedMessage = (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
351         }
352 }
353
354 static void tcp_to_tun_queue_process() {
355         do {
356                 while (tcp_to_tun_queue.nextUndefinedMessage == tcp_to_tun_queue.nextPendingMessage) {
357                         std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
358                 }
359
360                 auto& msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextPendingMessage];
361
362                 const sockaddr_in& pkt_src = std::get<0>(msg);
363                 unsigned char* buf = std::get<1>(msg).data();
364                 ssize_t nread = std::get<2>(msg);
365
366                 int header_size = check_ip_header(buf, nread, 0x06); // Only support TCP
367                 if (header_size < 0)
368                         continue;
369
370                 if (nread - header_size < 20) {
371                         fprintf(stderr, "Short TCP packet\n");
372                         continue;
373                 }
374
375                 unsigned char* tcp_buf = buf + header_size;
376
377                 if (((tcp_buf[2] << 8) | tcp_buf[3]) != local_port) continue;
378
379                 bool syn = tcp_buf[13] & (1 << 1);
380                 bool ack = tcp_buf[13] & (1 << 4);
381
382                 if (are_server && syn && !ack) {
383                         // We're a server and just got a client
384                         if (tcp_buf[4 ] != uint8_t(tcp_init_magic >> (7 * 8)) ||
385                             tcp_buf[5 ] != uint8_t(tcp_init_magic >> (6 * 8)) ||
386                             tcp_buf[6 ] != uint8_t(tcp_init_magic >> (5 * 8)) ||
387                             tcp_buf[7 ] != uint8_t(tcp_init_magic >> (4 * 8)) ||
388                             tcp_buf[8 ] != uint8_t(tcp_init_magic >> (3 * 8)) ||
389                             tcp_buf[9 ] != uint8_t(tcp_init_magic >> (2 * 8)) ||
390                             tcp_buf[10] != uint8_t(tcp_init_magic >> (1 * 8)) ||
391                             tcp_buf[11] != uint8_t(tcp_init_magic >> (0 * 8)))
392                                 continue;
393
394                         fprintf(stderr, "Got SYN, sending SYNACK\n");
395                         remote_port = (tcp_buf[0] << 8) | tcp_buf[1];
396                         dest = pkt_src;
397                 }
398
399                 if (((tcp_buf[0] << 8) | tcp_buf[1]) != remote_port) continue;
400                 if (pkt_src.sin_addr.s_addr != dest.sin_addr.s_addr) continue;
401
402                 uint8_t num_words = (tcp_buf[12] & 0xf0) >> 4;
403                 int tcp_header_size = num_words * 4;
404                 if (tcp_header_size < 20) {
405                         fprintf(stderr, "Invalid TCP header size (%d)\n", tcp_header_size);
406                         continue;
407                 }
408
409                 highest_recvd_seq = ((((uint32_t)tcp_buf[4]) << (8 * 3)) |
410                                      (((uint32_t)tcp_buf[5]) << (8 * 2)) |
411                                      (((uint32_t)tcp_buf[6]) << (8 * 1)) |
412                                      (((uint32_t)tcp_buf[7]) << (8 * 0))) +
413                                      (syn ? 1 : nread - header_size - tcp_header_size);
414
415                 if (ack)
416                         last_ack_recv = std::chrono::steady_clock::now();
417
418                 if (are_server && syn && !ack) {
419                         build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
420
421                         ssize_t res = sendto(fdr, tcp_buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
422                         if (res < 0) {
423                                 int err = errno;
424                                 fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
425                         }
426                 } else if (!syn) {
427                         tcp_buf += tcp_header_size - 20 - 8; // IPv4 header is 20 bytes + UDP header is 8 bytes
428
429                         // Replace TCP with IPv4 header
430                         build_udp_header(tcp_buf + 20, nread - tcp_header_size - header_size);
431                         build_ip_header(tcp_buf, nread - tcp_header_size - header_size + 8, 0x11, tun_dest, tun_src); // UDP
432
433                         write(fd[0], tcp_buf, nread - tcp_header_size - header_size + 20 + 8);
434                 }
435         } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
436 }
437
438 static std::atomic_int tun_if_thread(0), tun_if_process_thread(0);
439 static std::atomic_bool pause_tun_read_reinit_tcp(false);
440 static MessageQueue tun_to_tcp_queue[TUN_IF_COUNT];
441
442 static void tun_to_tcp() {
443         unsigned char buf[PACKET_READ_SIZE];
444
445         int thread = tun_if_thread.fetch_add(1);
446         MessageQueue& queue = tun_to_tcp_queue[thread];
447
448         while (1) {
449                 ssize_t nread = read(fd[thread], buf, sizeof(buf));
450                 if (pause_tun_read_reinit_tcp)
451                         continue;
452
453                 if (nread < 0) {
454                         fprintf (stderr, "Failed to read tun if\n");
455                         continue;
456                 }
457
458                 if (queue.nextPendingMessage == (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
459                         continue;
460
461                 auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage];
462                 memcpy(std::get<1>(new_msg).data() + PACKET_READ_OFFS, buf, nread);
463                 std::get<2>(new_msg) = nread;
464
465                 queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
466         }
467 }
468
469 static void tun_to_tcp_queue_process() {
470         int thread = tun_if_process_thread.fetch_add(1);
471         MessageQueue& queue = tun_to_tcp_queue[thread];
472
473         do {
474                 while (queue.nextUndefinedMessage == queue.nextPendingMessage) {
475                         std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
476                 }
477                 if (pause_tun_read_reinit_tcp)
478                         continue;
479
480                 auto& msg = queue.messagesPendingRingBuff[queue.nextPendingMessage];
481
482                 unsigned char* buf = std::get<1>(msg).data();
483                 ssize_t nread = std::get<2>(msg);
484
485                 int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, 0x11); // Only support UDP
486                 if (header_size < 0)
487                         continue;
488
489                 if (header_size + 8 > nread) {
490                         fprintf(stderr, "Short UDP packet\n");
491                         continue;
492                 }
493
494                 if (check_udp_header(buf + PACKET_READ_OFFS + header_size, nread - header_size) != 8) {
495                         fprintf(stderr, "Bad UDP header\n");
496                         continue;
497                 }
498
499                 unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size + 8 - 20;
500                 build_tcp_header(tcp_start, nread - header_size - 8, 0, 0, src, dest.sin_addr.s_addr);
501
502                 ssize_t res = sendto(fdr, tcp_start, nread + 20 - 8 - header_size, 0, (struct sockaddr*)&dest, sizeof(dest));
503                 if (res < 0) {
504                         int err = errno;
505                         fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err));
506                 }
507
508         } while ((queue.nextPendingMessage = (queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
509 }
510
511 int do_init() {
512         //
513         // Send SYN and SYN/ACK
514         //
515
516         if (!are_server) {
517                 if (local_port) // Doing a re-init
518                         pause_tun_read_reinit_tcp = true;
519
520                 uint16_t local_port_tmp = 0;
521                 while (local_port_tmp < 1024)
522                         assert(getrandom(&local_port_tmp, sizeof(local_port_tmp), 0) == sizeof(local_port_tmp));
523
524                 local_port = local_port_tmp;
525         }
526
527         uint32_t starting_ack = 0, starting_seq = 0;
528         memcpy(&starting_ack, &tcp_init_magic, 4);
529         memcpy(&starting_seq, ((const unsigned char*)&tcp_init_magic) + 4, 4);
530         highest_recvd_seq = starting_ack;
531         cur_seq = starting_seq;
532
533         if (!pause_tun_read_reinit_tcp) { // Not doing a re-init
534                 std::thread t(&tcp_to_tun);
535                 std::thread t2(&tcp_to_tun_queue_process);
536                 t.detach();
537                 t2.detach();
538         }
539
540         unsigned char buf[1500];
541         if (!are_server) {
542                 build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
543                 ssize_t res = sendto(fdr, buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
544                 if (res < 0) {
545                         int err = errno;
546                         fprintf(stderr, "Failed to send initial SYN with err %d (%s)\n", err, strerror(err));
547                         return -1;
548                 }
549         }
550
551         int i;
552         for (i = 0; i < 1000 && highest_recvd_seq == starting_ack; i++)
553                 std::this_thread::sleep_for(std::chrono::milliseconds(10));
554         if (i == 1000) // Will come back in 10 seconds
555                 return 0;
556
557         if (!are_server) {
558                 fprintf(stderr, "Got SYNACK, sending ACK and starting tun listen\n");
559
560                 build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
561                 ssize_t res = sendto(fdr, buf, 20, 0, (struct sockaddr*)&dest, sizeof(dest));
562                 if (res < 0) {
563                         int err = errno;
564                         fprintf(stderr, "Failed to send initial ACK with err %d (%s)\n", err, strerror(err));
565                         return -1;
566                 }
567         }
568
569         if (pause_tun_read_reinit_tcp) {
570                 pause_tun_read_reinit_tcp = false;
571         } else {
572                 for (int i = 0; i < TUN_IF_COUNT; i++) {
573                         std::thread t3(&tun_to_tcp);
574                         std::thread t4(&tun_to_tcp_queue_process);
575                         t3.detach();
576                         t4.detach();
577                 }
578         }
579
580         return 0;
581 }
582
583 int main(int argc, char* argv[]) {
584         assert(argc > 1 && "Need tun name");
585         assert(argc > 2 && "Need tun remote host");
586         assert(argc > 3 && "Need tun local host");
587         assert(argc > 4 && "Need server port");
588         assert(argc > 5 && "Need shared secret");
589         assert(argc > 6 && "Need mode (client or server)");
590         assert(argc > 7 && "Need src host");
591         if (std::string(argv[6]) == std::string("client"))
592                 assert(argc > 8 && "Need dest host");
593
594         assert(std::string(argv[6]) == std::string("client") || std::string(argv[6]) == std::string("server"));
595         are_server = (std::string(argv[6]) == std::string("server"));
596
597         //
598         // Parse args into variables
599         //
600
601         char tun_name[IFNAMSIZ];
602
603         memset(tun_name, 0, sizeof(tun_name));
604         strcpy(tun_name, argv[1]);
605
606         tun_dest = inet_addr(argv[2]);
607         tun_src = inet_addr(argv[3]);
608
609         if (are_server) {
610                 local_port = atoi(argv[4]);
611                 remote_port = 0;
612         } else {
613                 // Get local port in do_init() so that we pick a new one on reload
614                 remote_port = atoi(argv[4]);
615         }
616
617         tcp_init_magic = atoll(argv[5]);
618
619         src = inet_addr(argv[7]);
620
621         memset(&dest, 0, sizeof(dest));
622         if (!are_server) {
623                 dest.sin_family = AF_INET;
624                 dest.sin_addr.s_addr = inet_addr(argv[8]);
625         }
626
627         //
628         // Create tun and bind to sockets...
629         //
630
631         if (tun_alloc(tun_name, argv[3], argv[2], TUN_IF_COUNT, fd) != 0) {
632                 fprintf(stderr, "Failed to alloc tun if\n");
633                 return -1;
634         }
635
636         fdr = socket(AF_INET, SOCK_RAW, IPPROTO_TCP);
637         if (fdr < 0) {
638                 fprintf(stderr, "Failed to get raw socket\n");
639                 return -1;
640         }
641
642         int res = do_init();
643         if (res)
644                 return res;
645
646         while (true) {
647                 std::this_thread::sleep_for(std::chrono::seconds(15));
648                 if (!are_server && last_ack_recv < std::chrono::steady_clock::now() - std::chrono::seconds(15)) {
649                         res = do_init();
650                         if (res)
651                                 return res;
652                 }
653         }
654 }