Add missing <array> inclusion, which is needed for modern libstd++
[tunudptotcp] / main.cpp
1 #include <fcntl.h>
2 #include <string.h>
3 #include <stdio.h>
4 #include <unistd.h>
5 #include <sys/socket.h>
6 #include <sys/ioctl.h>
7 #include <netinet/in.h>
8 #include <errno.h>
9 #include <arpa/inet.h>
10 #include <stdint.h>
11 #include <stdlib.h>
12 #include <linux/if.h>
13 #include <linux/if_tun.h>
14 #include <assert.h>
15
16 #include <array>
17 #include <atomic>
18 #include <chrono>
19 #include <thread>
20 #include <string>
21
22 #if __has_include(<sys/random.h>)
23 #include <sys/random.h>
24 #else
25 int getrandom(void* buf, size_t len, unsigned int flags) {
26         FILE* f = fopen("/dev/urandom", "r");
27         int res = fread(buf, len, 1, f);
28         fclose(f);
29         return res;
30 }
31 #endif
32
33 struct pkt_hdr {
34         // Note that ordering of the first three uint16_ts must match the wire!
35         uint16_t len;
36         uint16_t id;
37         uint16_t frag_flags: 3,
38                  offset    : 13;
39         uint8_t proto;
40         uint8_t padding;
41 };
42 static_assert(sizeof(struct pkt_hdr) == 8);
43
44 // Have to leave room to add a TCP header plus above header (after dropping IP header)
45 #define PACKET_READ_OFFS sizeof(pkt_hdr)
46 #define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
47
48 static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds)
49 {
50         struct ifreq ifr;
51         int fd, err, i;
52         char buf[1024];
53
54         if (!dev)
55                 return -1;
56
57         memset(&ifr, 0, sizeof(ifr));
58
59         /* Flags: IFF_TUN   - TUN device (no Ethernet headers) 
60         *         IFF_TAP   - TAP device  
61         *
62         *         IFF_NO_PI - Do not provide packet information  
63         */ 
64         ifr.ifr_flags = IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE;
65         strncpy(ifr.ifr_name, dev, IFNAMSIZ);
66
67         for (i = 0; i < queues; i++) {
68                 if((fd = open("/dev/net/tun", O_RDWR)) < 0)
69                         goto err;
70                 err = ioctl(fd, TUNSETIFF, (void *) &ifr);
71                 if (err) {
72                         close(fd);
73                         goto err;
74                 }
75                 fds[i] = fd;
76         }
77
78         sprintf(buf, "ip link set dev %s mtu %ld", dev, PACKET_READ_SIZE - 20);
79         err = system(buf);
80         if (err) goto err;
81
82         sprintf(buf, "ip addr add %s/32 dev %s", local_ip, dev);
83         err = system(buf);
84         if (err) goto err;
85
86         sprintf(buf, "ip link set %s up", dev);
87         err = system(buf);
88         if (err) goto err;
89
90         sprintf(buf, "ip route add %s/32 dev %s", remote_ip, dev);
91         err = system(buf);
92         if (err) goto err;
93
94         return 0;
95 err:
96         for (--i; i >= 0; i--)
97                 close(fds[i]);
98         return err;
99 }
100
101 static int check_ip_header(const unsigned char* buf, ssize_t buf_len, int16_t expected_type, struct pkt_hdr *gen_hdr) {
102         if (buf_len < 20) {
103                 // < size than IPv4?
104                 return -1;
105         }
106
107         if ((buf[0] & 0xf0) != (4 << 4)) {
108                 // Only support IPv4
109                 return -1;
110         }
111
112         uint8_t num_words = buf[0] & 0xf;
113         int header_size = num_words * 4;
114         if (header_size < 20) {
115                 fprintf(stderr, "Invalid IPv4 IHL size (%d)\n", header_size);
116                 return -1;
117         }
118
119         /*if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
120                 //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len);
121                 return -1;
122         }*/
123
124         if (expected_type > 0 && buf[9] != expected_type) {
125                 fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type);
126                 return -1;
127         }
128
129         if (gen_hdr) {
130                 memcpy(gen_hdr, buf + 2, 6); // len, id, frag offset
131                 gen_hdr->proto = buf[9];
132                 gen_hdr->padding = 0;
133         }
134
135         return header_size;
136 }
137
138 void print_packet(const unsigned char* buf, ssize_t buf_len) {
139         for (ssize_t i = 0; i < buf_len; ) {
140                 for (int j = 0; i < buf_len && j < 20; i++, j++)
141                         fprintf(stderr, "%02x", buf[i]);
142                 fprintf(stderr, "\n");
143         }
144 }
145
146 static uint32_t data_checksum(const unsigned char* buf, size_t len) {
147         uint32_t sum = 0;
148
149         while (len > 1)
150         {
151                 sum += *buf++;
152                 sum += (*buf++) << 8;
153                 if (sum & 0x80000000)
154                         sum = (sum & 0xFFFF) + (sum >> 16);
155                 len -= 2;
156         }
157
158         if (len & 1)
159                 sum += *buf;
160
161         return sum;
162 }
163
164 static uint16_t finalize_data_checksum(uint32_t sum) {
165         while (sum >> 16)
166                 sum = (sum & 0xFFFF) + (sum >> 16);
167
168         return (uint16_t)(~sum);
169 }
170
171 static uint16_t tcp_checksum(const unsigned char* buff, size_t len, in_addr_t src_addr, in_addr_t dest_addr)
172 {
173         uint16_t *ip_src = (uint16_t*)&src_addr, *ip_dst = (uint16_t*)&dest_addr;
174         uint32_t sum = data_checksum(buff, len);
175
176         sum += *(ip_src++);
177         sum += *ip_src;
178         sum += *(ip_dst++);
179         sum += *ip_dst;
180         sum += htons(IPPROTO_TCP);
181         sum += htons(len);
182
183         return finalize_data_checksum(sum);
184 }
185
186 static std::atomic<uint32_t> highest_recvd_seq(0), cur_seq(0);
187 static std::atomic<uint16_t> local_port(0);
188 static uint16_t remote_port;
189 static bool are_server;
190 static uint64_t timestamps_magic;
191
192 static int build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
193         buf[0 ] = local_port >> 8;         // src port
194         buf[1 ] = local_port;              // src port
195         buf[2 ] = remote_port >> 8;        // dst port
196         buf[3 ] = remote_port;             // dst port
197
198         uint32_t seq = cur_seq.fetch_add((syn || synack) ? 1 : len, std::memory_order_acq_rel);
199         buf[4 ] = seq >> (8 * 3);          // SEQ
200         buf[5 ] = seq >> (8 * 2);          // SEQ
201         buf[6 ] = seq >> (8 * 1);          // SEQ
202         buf[7 ] = seq >> (8 * 0);          // SEQ
203
204         uint32_t their_seq = highest_recvd_seq.load(std::memory_order_relaxed);
205         buf[8 ] = their_seq >> (8 * 3);    // ACK
206         buf[9 ] = their_seq >> (8 * 2);    // ACK
207         buf[10] = their_seq >> (8 * 1);    // ACK
208         buf[11] = their_seq >> (8 * 0);    // ACK
209
210         unsigned char hdrlen = syn ? 36 : (synack ? 24 : 20);
211         buf[12] = (hdrlen/4) << 4;  // data offset
212         if (syn)
213                 buf[13] = 1 << 1;              // SYN
214         else if (synack)
215                 buf[13] = (1 << 1) | (1 << 4); // SYN + ACK
216         else if (len == 0)
217                 buf[13] = 1 << 4;              // ACK
218         else
219                 buf[13] = 3 << 3;              // PSH + ACK
220         buf[14] = 0xff;                    // Window Size
221         buf[15] = 0xff;                    // Window Size
222
223         buf[16] = 0x00;                    // Checksum
224         buf[17] = 0x00;                    // Checksum
225         buf[18] = 0x00;                    // URG Pointer
226         buf[19] = 0x00;                    // URG Pointer
227
228         if (syn || synack) {
229         buf[20] = 0x01;                    // NOP
230         buf[21] = 0x03;                    // Window Scale
231         buf[22] = 0x03;                    // Window Scale Option Length
232         buf[23] = 0x0e;                    // 1GB Window Size (0xffff << 0x0e)
233         }
234         if (syn) {
235         buf[24] = 0x01;                    // NOP
236         buf[25] = 0x01;                    // NOP
237         buf[26] = 8;                       // Timestamp
238         buf[27] = 10;                      // Timestamp Option Length
239         memcpy(buf + 28, &timestamps_magic, 8);
240         }
241
242         uint16_t checksum = tcp_checksum(buf, len + hdrlen, src_addr, dest_addr);
243         buf[16] = checksum;                // Checksum
244         buf[17] = checksum >> 8;           // Checksum
245
246         return hdrlen;
247 }
248
249 static void build_ip_header(unsigned char* buf, struct pkt_hdr hdr, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
250         buf[0 ] = (4 << 4) | 5;    // IPv4 + IHL of 5 (20 bytes)
251         buf[1 ] = 0;               // DSCP 0 + ECN 0
252         memcpy(buf + 2, &hdr, 6);  // length, identification, flags, and offset
253         buf[8 ] = 255;             // TTL
254         buf[9 ] = hdr.proto;       // Protocol Number
255         buf[10] = 0;               // Checksum
256         buf[11] = 0;               // Checksum
257
258         memcpy(buf + 12, &src_addr, 4);
259         memcpy(buf + 16, &dest_addr, 4);
260
261         uint16_t checksum = finalize_data_checksum(data_checksum(buf, 20));
262         buf[10] = checksum;
263         buf[11] = checksum >> 8;
264 }
265
266 const signed char p_util_hexdigit[256] =
267 { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
268   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
269   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
270   0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,
271   -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
272   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
273   -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
274   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
275   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
276   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
277   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
278   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
279   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
280   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
281   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
282   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, };
283
284 uint32_t hex_to_num(const unsigned char* buf) {
285         const unsigned char* pbegin = buf;
286         while (p_util_hexdigit[*buf] != -1)
287                 buf++;
288         buf--;
289         uint32_t res = 0;
290         unsigned char* p1 = (unsigned char*)&res;
291         unsigned char* pend = p1 + 4;
292         while (buf >= pbegin && p1 < pend) {
293                 *p1 = p_util_hexdigit[*buf--];
294                 if (buf >= pbegin) {
295                         *p1 |= ((unsigned char)p_util_hexdigit[*buf--] << 4);
296                         p1++;
297                 }
298         }
299
300         return res;
301 }
302
303 #define TUN_IF_COUNT 4
304 static int fdr;
305 static int fd[TUN_IF_COUNT];
306 static struct sockaddr_in dest;
307 static in_addr_t src, tun_src, tun_dest;
308 static uint32_t starting_ack;
309
310 #define PENDING_MESSAGES_BUFF_SIZE (0x800)
311 #define THREAD_POLL_SLEEP_MICS 250
312 struct MessageQueue {
313         std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE + PACKET_READ_OFFS>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
314         std::atomic<uint16_t> nextPendingMessage, nextUndefinedMessage;
315         MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {}
316         MessageQueue(MessageQueue&& q) =delete;
317         MessageQueue(MessageQueue& q) =delete;
318 };
319
320 static MessageQueue tcp_to_tun_queue;
321 static std::chrono::steady_clock::time_point last_ack_recv;
322
323 static void tcp_to_tun() {
324         unsigned char buf[PACKET_READ_SIZE + PACKET_READ_OFFS];
325         struct sockaddr_in pkt_src;
326         memset(&pkt_src, 0, sizeof(pkt_src));
327
328         while (1) {
329                 socklen_t hostsz = sizeof(pkt_src);
330                 ssize_t nread = recvfrom(fdr, buf, sizeof(buf), 0, (struct sockaddr*)&pkt_src, &hostsz);
331                 if (nread < 0) {
332                         fprintf (stderr, "Failed to read tcp raw sock\n");
333                         exit(-1);
334                 }
335
336                 if (tcp_to_tun_queue.nextPendingMessage == (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
337                         continue;
338
339                 auto& new_msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextUndefinedMessage];
340                 std::get<0>(new_msg) = pkt_src;
341                 memcpy(std::get<1>(new_msg).data(), buf, nread);
342                 std::get<2>(new_msg) = nread;
343
344                 tcp_to_tun_queue.nextUndefinedMessage = (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
345         }
346 }
347
348 static void tcp_to_tun_queue_process() {
349         do {
350                 while (tcp_to_tun_queue.nextUndefinedMessage == tcp_to_tun_queue.nextPendingMessage) {
351                         std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
352                 }
353
354                 auto& msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextPendingMessage];
355
356                 const sockaddr_in& pkt_src = std::get<0>(msg);
357                 unsigned char* buf = std::get<1>(msg).data();
358                 ssize_t nread = std::get<2>(msg);
359
360                 int header_size = check_ip_header(buf, nread, 0x06, NULL); // Only support TCP
361                 if (header_size < 0)
362                         continue;
363
364                 if (nread - header_size < 20) {
365                         fprintf(stderr, "Short TCP packet\n");
366                         continue;
367                 }
368
369                 unsigned char* tcp_buf = buf + header_size;
370
371                 if (((tcp_buf[2] << 8) | tcp_buf[3]) != local_port) continue;
372
373                 bool syn = tcp_buf[13] & (1 << 1);
374                 bool ack = tcp_buf[13] & (1 << 4);
375
376                 if (are_server && syn && !ack) {
377                         uint32_t expected_ack = htobe32(starting_ack);
378                         if (memcmp(tcp_buf + 8, &expected_ack, 4))
379                                 continue;
380                         if (nread < 36 + header_size) { continue; }
381                         // We're a server and just got a client...walk options until we find timestamps
382                         const unsigned char* opt_buf = tcp_buf + 20;
383                         bool found_magic = false;
384                         while (!found_magic && opt_buf < buf + nread) {
385                                 switch (*opt_buf) {
386                                         case 1: opt_buf += 1; break;
387                                         case 2: opt_buf += 4; break;
388                                         case 3: opt_buf += 3; break;
389                                         // SACK should never appear
390                                         case 8:
391                                                 if (opt_buf + 10 <= buf + nread) {
392                                                         if (!memcmp(opt_buf + 2, &timestamps_magic, 8)) {
393                                                                 found_magic = true;
394                                                                 break;
395                                                         }
396                                                 }
397                                                 // Fall through
398                                         default: opt_buf = buf + nread;
399                                 }
400                         }
401                         if (!found_magic) continue;
402
403                         fprintf(stderr, "Got SYN, sending SYNACK\n");
404                         remote_port = (tcp_buf[0] << 8) | tcp_buf[1];
405                         dest = pkt_src;
406                 }
407
408                 if (((tcp_buf[0] << 8) | tcp_buf[1]) != remote_port) continue;
409                 if (pkt_src.sin_addr.s_addr != dest.sin_addr.s_addr) continue;
410
411                 uint8_t num_words = (tcp_buf[12] & 0xf0) >> 4;
412                 int tcp_header_size = num_words * 4;
413                 if (tcp_header_size < 20) {
414                         fprintf(stderr, "Invalid TCP header size (%d)\n", tcp_header_size);
415                         continue;
416                 }
417
418                 highest_recvd_seq = ((((uint32_t)tcp_buf[4]) << (8 * 3)) |
419                                      (((uint32_t)tcp_buf[5]) << (8 * 2)) |
420                                      (((uint32_t)tcp_buf[6]) << (8 * 1)) |
421                                      (((uint32_t)tcp_buf[7]) << (8 * 0))) +
422                                      (syn ? 1 : nread - header_size - tcp_header_size);
423
424                 if (ack)
425                         last_ack_recv = std::chrono::steady_clock::now();
426
427                 if (are_server && syn && !ack) {
428                         int len = build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
429
430                         ssize_t res = sendto(fdr, tcp_buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
431                         if (res < 0) {
432                                 int err = errno;
433                                 fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
434                         }
435                 } else if (!syn) {
436                         if (nread < (ssize_t)(tcp_header_size + header_size + sizeof(struct pkt_hdr))) continue;
437                         struct pkt_hdr hdr;
438                         memcpy(&hdr, tcp_buf + tcp_header_size, sizeof(struct pkt_hdr));
439                         tcp_buf += tcp_header_size + sizeof(struct pkt_hdr) - 20; // IPv4 header is 20 bytes
440
441                         // Replace TCP + pkt_hdr with IPv4 header
442                         build_ip_header(tcp_buf, hdr, tun_dest, tun_src);
443
444                         write(fd[0], tcp_buf, nread - tcp_header_size - header_size + 20 + 8);
445                 }
446         } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
447 }
448
449 static std::atomic_int tun_if_thread(0), tun_if_process_thread(0);
450 static std::atomic_bool pause_tun_read_reinit_tcp(false);
451 static MessageQueue tun_to_tcp_queue[TUN_IF_COUNT];
452
453 static void tun_to_tcp() {
454         unsigned char buf[PACKET_READ_SIZE];
455
456         int thread = tun_if_thread.fetch_add(1);
457         MessageQueue& queue = tun_to_tcp_queue[thread];
458
459         while (1) {
460                 ssize_t nread = read(fd[thread], buf, sizeof(buf));
461                 if (pause_tun_read_reinit_tcp)
462                         continue;
463
464                 if (nread < 0) {
465                         fprintf (stderr, "Failed to read tun if\n");
466                         continue;
467                 }
468
469                 if (queue.nextPendingMessage == (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
470                         continue;
471
472                 auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage];
473                 memcpy(std::get<1>(new_msg).data() + PACKET_READ_OFFS, buf, nread);
474                 std::get<2>(new_msg) = nread;
475
476                 queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
477         }
478 }
479
480 static void tun_to_tcp_queue_process() {
481         int thread = tun_if_process_thread.fetch_add(1);
482         MessageQueue& queue = tun_to_tcp_queue[thread];
483
484         do {
485                 while (queue.nextUndefinedMessage == queue.nextPendingMessage) {
486                         std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
487                 }
488                 if (pause_tun_read_reinit_tcp)
489                         continue;
490
491                 auto& msg = queue.messagesPendingRingBuff[queue.nextPendingMessage];
492
493                 unsigned char* buf = std::get<1>(msg).data();
494                 ssize_t nread = std::get<2>(msg);
495
496                 struct pkt_hdr hdr;
497                 int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, -1, &hdr); // Any paket type is okay
498                 if (header_size < 20)
499                         continue;
500
501                 if (header_size > nread) {
502                         fprintf(stderr, "Short packet\n");
503                         continue;
504                 }
505
506                 if ((size_t)nread > 1500 - 20 - sizeof(struct pkt_hdr)) { // Packets must fit in 1500 bytes with a TCP and packet hdr
507                         fprintf(stderr, "Long packet\n");
508                         continue;
509                 }
510
511                 unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size - 20 - sizeof(struct pkt_hdr);
512                 memcpy(tcp_start + 20, &hdr, sizeof(struct pkt_hdr));
513                 build_tcp_header(tcp_start, nread - header_size + sizeof(struct pkt_hdr), 0, 0, src, dest.sin_addr.s_addr);
514
515                 ssize_t res = sendto(fdr, tcp_start, nread + 20 + sizeof(struct pkt_hdr) - header_size, 0, (struct sockaddr*)&dest, sizeof(dest));
516                 if (res < 0) {
517                         int err = errno;
518                         fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err));
519                 }
520
521         } while ((queue.nextPendingMessage = (queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
522 }
523
524 int do_init() {
525         //
526         // Send SYN and SYN/ACK
527         //
528
529         if (!are_server) {
530                 if (local_port) // Doing a re-init
531                         pause_tun_read_reinit_tcp = true;
532
533                 uint16_t local_port_tmp = 0;
534                 while (local_port_tmp < 1024)
535                         assert(getrandom(&local_port_tmp, sizeof(local_port_tmp), 0) == sizeof(local_port_tmp));
536
537                 local_port = local_port_tmp;
538         }
539
540         highest_recvd_seq = starting_ack;
541         uint32_t starting_seq;
542         assert(getrandom(&starting_seq, sizeof(starting_seq), 0) == sizeof(starting_seq));
543         cur_seq = starting_seq;
544
545         if (!pause_tun_read_reinit_tcp) { // Not doing a re-init
546                 std::thread t(&tcp_to_tun);
547                 std::thread t2(&tcp_to_tun_queue_process);
548                 t.detach();
549                 t2.detach();
550         }
551
552         unsigned char buf[1500];
553         if (!are_server) {
554                 int len = build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
555                 ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
556                 if (res < 0) {
557                         int err = errno;
558                         fprintf(stderr, "Failed to send initial SYN with err %d (%s)\n", err, strerror(err));
559                         return -1;
560                 }
561         }
562
563         int i;
564         for (i = 0; i < 1000 && highest_recvd_seq == starting_ack; i++)
565                 std::this_thread::sleep_for(std::chrono::milliseconds(10));
566         if (i == 1000) // Will come back in 10 seconds
567                 return 0;
568
569         if (!are_server) {
570                 fprintf(stderr, "Got SYNACK, sending ACK and starting tun listen\n");
571
572                 int len = build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
573                 ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
574                 if (res < 0) {
575                         int err = errno;
576                         fprintf(stderr, "Failed to send initial ACK with err %d (%s)\n", err, strerror(err));
577                         return -1;
578                 }
579         }
580
581         if (pause_tun_read_reinit_tcp) {
582                 pause_tun_read_reinit_tcp = false;
583         } else {
584                 for (int i = 0; i < TUN_IF_COUNT; i++) {
585                         std::thread t3(&tun_to_tcp);
586                         std::thread t4(&tun_to_tcp_queue_process);
587                         t3.detach();
588                         t4.detach();
589                 }
590         }
591
592         return 0;
593 }
594
595 int main(int argc, char* argv[]) {
596         assert(argc > 1 && "Need tun name");
597         assert(argc > 2 && "Need tun remote host");
598         assert(argc > 3 && "Need tun local host");
599         assert(argc > 4 && "Need server port");
600         assert(argc > 5 && "Need shared secret");
601         assert(argc > 6 && "Need mode (client or server)");
602         assert(argc > 7 && "Need src host");
603         if (std::string(argv[6]) == std::string("client"))
604                 assert(argc > 8 && "Need dest host");
605
606         assert(std::string(argv[6]) == std::string("client") || std::string(argv[6]) == std::string("server"));
607         are_server = (std::string(argv[6]) == std::string("server"));
608
609         //
610         // Parse args into variables
611         //
612
613         char tun_name[IFNAMSIZ];
614
615         memset(tun_name, 0, sizeof(tun_name));
616         strcpy(tun_name, argv[1]);
617
618         tun_dest = inet_addr(argv[2]);
619         tun_src = inet_addr(argv[3]);
620
621         if (are_server) {
622                 local_port = atoi(argv[4]);
623                 remote_port = 0;
624         } else {
625                 // Get local port in do_init() so that we pick a new one on reload
626                 remote_port = atoi(argv[4]);
627         }
628
629         uint64_t tcp_init_magic = atoll(argv[5]);
630         starting_ack = tcp_init_magic >> 32;
631         timestamps_magic = htobe64(tcp_init_magic);
632
633         src = inet_addr(argv[7]);
634
635         memset(&dest, 0, sizeof(dest));
636         if (!are_server) {
637                 dest.sin_family = AF_INET;
638                 dest.sin_addr.s_addr = inet_addr(argv[8]);
639         }
640
641         //
642         // Create tun and bind to sockets...
643         //
644
645         if (tun_alloc(tun_name, argv[3], argv[2], TUN_IF_COUNT, fd) != 0) {
646                 fprintf(stderr, "Failed to alloc tun if\n");
647                 return -1;
648         }
649
650         fdr = socket(AF_INET, SOCK_RAW, IPPROTO_TCP);
651         if (fdr < 0) {
652                 fprintf(stderr, "Failed to get raw socket\n");
653                 return -1;
654         }
655
656         int res = do_init();
657         if (res)
658                 return res;
659
660         while (true) {
661                 std::this_thread::sleep_for(std::chrono::seconds(15));
662                 if (!are_server && last_ack_recv < std::chrono::steady_clock::now() - std::chrono::seconds(15)) {
663                         res = do_init();
664                         if (res)
665                                 return res;
666                 }
667         }
668 }