Auto-configure tun interface
[tunudptotcp] / main.cpp
1 #include <fcntl.h>
2 #include <string.h>
3 #include <stdio.h>
4 #include <unistd.h>
5 #include <sys/socket.h>
6 #include <sys/ioctl.h>
7 #include <netinet/in.h>
8 #include <errno.h>
9 #include <arpa/inet.h>
10 #include <stdint.h>
11 #include <stdlib.h>
12 #include <linux/if.h>
13 #include <linux/if_tun.h>
14 #include <assert.h>
15
16 #include <atomic>
17 #include <chrono>
18 #include <thread>
19 #include <string>
20
21 #if __has_include(<sys/random.h>)
22 #include <sys/random.h>
23 #else
24 int getrandom(void* buf, size_t len, unsigned int flags) {
25         FILE* f = fopen("/dev/urandom", "r");
26         int res = fread(buf, len, 1, f);
27         fclose(f);
28         return res;
29 }
30 #endif
31
32 #define PACKET_READ_SIZE 1500
33
34 static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds)
35 {
36         struct ifreq ifr;
37         int fd, err, i;
38         char buf[1024];
39
40         if (!dev)
41                 return -1;
42
43         memset(&ifr, 0, sizeof(ifr));
44
45         /* Flags: IFF_TUN   - TUN device (no Ethernet headers) 
46         *         IFF_TAP   - TAP device  
47         *
48         *         IFF_NO_PI - Do not provide packet information  
49         */ 
50         ifr.ifr_flags = IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE;
51         strncpy(ifr.ifr_name, dev, IFNAMSIZ);
52
53         for (i = 0; i < queues; i++) {
54                 if((fd = open("/dev/net/tun", O_RDWR)) < 0)
55                         goto err;
56                 err = ioctl(fd, TUNSETIFF, (void *) &ifr);
57                 if (err) {
58                         close(fd);
59                         goto err;
60                 }
61                 fds[i] = fd;
62         }
63
64         sprintf(buf, "ip link set %s mtu %d", dev, PACKET_READ_SIZE);
65         err = system(buf);
66         if (err) goto err;
67
68         sprintf(buf, "ip addr add %s/32 dev %s", local_ip, dev);
69         err = system(buf);
70         if (err) goto err;
71
72         sprintf(buf, "ip link set %s up", dev);
73         err = system(buf);
74         if (err) goto err;
75
76         sprintf(buf, "ip route add %s/32 dev %s", remote_ip, dev);
77         err = system(buf);
78         if (err) goto err;
79
80         return 0;
81 err:
82         for (--i; i >= 0; i--)
83                 close(fds[i]);
84         return err;
85 }
86
87 static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t expected_type) {
88         if (buf_len < 20) {
89                 // < size than IPv4?
90                 return -1;
91         }
92
93         if ((buf[0] & 0xf0) != (4 << 4)) {
94                 // Only support IPv4
95                 return -1;
96         }
97
98         uint8_t num_words = buf[0] & 0xf;
99         int header_size = num_words * 4;
100         if (header_size < 20) {
101                 fprintf(stderr, "Invalid IPv4 IHL size (%d)\n", header_size);
102                 return -1;
103         }
104
105         if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
106                 //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len);
107                 return -1;
108         }
109
110         if (buf[9] != expected_type) {
111                 fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type);
112                 return -1;
113         }
114
115         return header_size;
116 }
117
118 void print_packet(const unsigned char* buf, ssize_t buf_len) {
119         for (ssize_t i = 0; i < buf_len; ) {
120                 for (int j = 0; i < buf_len && j < 20; i++, j++)
121                         fprintf(stderr, "%02x", buf[i]);
122                 fprintf(stderr, "\n");
123         }
124 }
125
126 static uint32_t data_checksum(const unsigned char* buf, size_t len) {
127         uint32_t sum = 0;
128
129         while (len > 1)
130         {
131                 sum += *buf++;
132                 sum += (*buf++) << 8;
133                 if (sum & 0x80000000)
134                         sum = (sum & 0xFFFF) + (sum >> 16);
135                 len -= 2;
136         }
137
138         if (len & 1)
139                 sum += *buf;
140
141         return sum;
142 }
143
144 static uint16_t finalize_data_checksum(uint32_t sum) {
145         while (sum >> 16)
146                 sum = (sum & 0xFFFF) + (sum >> 16);
147
148         return (uint16_t)(~sum);
149 }
150
151 static uint16_t tcp_checksum(const unsigned char* buff, size_t len, in_addr_t src_addr, in_addr_t dest_addr)
152 {
153         uint16_t *ip_src = (uint16_t*)&src_addr, *ip_dst = (uint16_t*)&dest_addr;
154         uint32_t sum = data_checksum(buff, len);
155
156         sum += *(ip_src++);
157         sum += *ip_src;
158         sum += *(ip_dst++);
159         sum += *ip_dst;
160         sum += htons(IPPROTO_TCP);
161         sum += htons(len);
162
163         return finalize_data_checksum(sum);
164 }
165
166 static std::atomic<uint32_t> highest_recvd_seq(0), cur_seq(0);
167 static std::atomic<uint16_t> local_port(0);
168 static uint16_t remote_port;
169 static bool are_server;
170
171 static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
172         buf[0 ] = local_port >> 8;         // src port
173         buf[1 ] = local_port;              // src port
174         buf[2 ] = remote_port >> 8;        // dst port
175         buf[3 ] = remote_port;             // dst port
176
177         uint32_t seq = cur_seq.fetch_add((syn || synack) ? 1 : len, std::memory_order_acq_rel);
178         buf[4 ] = seq >> (8 * 3);          // SEQ
179         buf[5 ] = seq >> (8 * 2);          // SEQ
180         buf[6 ] = seq >> (8 * 1);          // SEQ
181         buf[7 ] = seq >> (8 * 0);          // SEQ
182
183         uint32_t their_seq = highest_recvd_seq.load(std::memory_order_relaxed);
184         buf[8 ] = their_seq >> (8 * 3);    // ACK
185         buf[9 ] = their_seq >> (8 * 2);    // ACK
186         buf[10] = their_seq >> (8 * 1);    // ACK
187         buf[11] = their_seq >> (8 * 0);    // ACK
188
189         bool longpkt = syn || synack;
190         buf[12] = (longpkt ? 6 : 5) << 4;  // data offset
191         if (syn)
192                 buf[13] = 1 << 1;              // SYN
193         else if (synack)
194                 buf[13] = (1 << 1) | (1 << 4); // SYN + ACK
195         else if (len == 0)
196                 buf[13] = 1 << 4;              // ACK
197         else
198                 buf[13] = 3 << 3;              // PSH + ACK
199         buf[14] = 0xff;                    // Window Size
200         buf[15] = 0xff;                    // Window Size
201
202         buf[16] = 0x00;                    // Checksum
203         buf[17] = 0x00;                    // Checksum
204         buf[18] = 0x00;                    // URG Pointer
205         buf[19] = 0x00;                    // URG Pointer
206
207         if (longpkt) {
208         buf[20] = 0x01;                    // NOP
209         buf[21] = 0x03;                    // Window Scale
210         buf[22] = 0x03;                    // Window Scale Option Length
211         buf[23] = 0x0e;                    // 1GB Window Size (0xffff << 0x0e)
212         }
213
214         uint16_t checksum = tcp_checksum(buf, len + 20 + (longpkt ? 4 : 0), src_addr, dest_addr);
215         buf[16] = checksum;                // Checksum
216         buf[17] = checksum >> 8;           // Checksum
217 }
218
219 static void build_ip_header(unsigned char* buf, uint32_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
220         buf[0 ] = (4 << 4) | 5;    // IPv4 + IHL of 5 (20 bytes)
221         buf[1 ] = 0;               // DSCP 0 + ECN 0
222         buf[2 ] = (len + 20) >> 8; // Length
223         buf[3 ] = (len + 20);      // Length
224         memset(buf + 4, 0, 4);     // Identification and Fragment 0s
225         buf[6 ] = 1 << 6;          // DF bit
226         buf[8 ] = 255;             // TTL
227         buf[9 ] = proto;           // Protocol Number
228         buf[10] = 0;               // Checksum
229         buf[11] = 0;               // Checksum
230
231         memcpy(buf + 12, &src_addr, 4);
232         memcpy(buf + 16, &dest_addr, 4);
233
234         uint16_t checksum = finalize_data_checksum(data_checksum(buf, 20));
235         buf[10] = checksum;
236         buf[11] = checksum >> 8;
237 }
238
239
240 const signed char p_util_hexdigit[256] =
241 { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
242   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
243   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
244   0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,
245   -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
246   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
247   -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
248   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
249   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
250   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
251   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
252   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
253   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
254   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
255   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
256   -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, };
257
258 uint32_t hex_to_num(const unsigned char* buf) {
259         const unsigned char* pbegin = buf;
260         while (p_util_hexdigit[*buf] != -1)
261                 buf++;
262         buf--;
263         uint32_t res = 0;
264         unsigned char* p1 = (unsigned char*)&res;
265         unsigned char* pend = p1 + 4;
266         while (buf >= pbegin && p1 < pend) {
267                 *p1 = p_util_hexdigit[*buf--];
268                 if (buf >= pbegin) {
269                         *p1 |= ((unsigned char)p_util_hexdigit[*buf--] << 4);
270                         p1++;
271                 }
272         }
273
274         return res;
275 }
276
277 #define TUN_IF_COUNT 4
278 static int fdr;
279 static int fd[TUN_IF_COUNT];
280 static struct sockaddr_in dest;
281 static in_addr_t src, tun_src, tun_dest, ipip_src, ipip_dest;
282 static uint64_t tcp_init_magic;
283
284 #define PENDING_MESSAGES_BUFF_SIZE (0x3000)
285 #define PACKET_READ_SIZE 1500
286 #define THREAD_POLL_SLEEP_MICS 50
287 struct MessageQueue {
288         std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
289         std::atomic<uint16_t> nextPendingMessage, nextUndefinedMessage;
290         MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {}
291         MessageQueue(MessageQueue&& q) =delete;
292         MessageQueue(MessageQueue& q) =delete;
293 };
294
295 static MessageQueue tcp_to_tun_queue;
296 static std::chrono::steady_clock::time_point last_ack_recv;
297
298 static void tcp_to_tun() {
299         unsigned char buf[1500];
300         struct sockaddr_in pkt_src;
301         memset(&pkt_src, 0, sizeof(pkt_src));
302
303         while (1) {
304                 socklen_t hostsz = sizeof(pkt_src);
305                 ssize_t nread = recvfrom(fdr, buf, sizeof(buf), 0, (struct sockaddr*)&pkt_src, &hostsz);
306                 if (nread < 0) {
307                         fprintf (stderr, "Failed to read tcp raw sock\n");
308                         exit(-1);
309                 }
310
311                 if (tcp_to_tun_queue.nextPendingMessage == (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
312                         continue;
313
314                 auto& new_msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextUndefinedMessage];
315                 std::get<0>(new_msg) = pkt_src;
316                 memcpy(std::get<1>(new_msg).data(), buf, nread);
317                 std::get<2>(new_msg) = nread;
318
319                 tcp_to_tun_queue.nextUndefinedMessage = (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
320         }
321 }
322
323 static void tcp_to_tun_queue_process() {
324         do {
325                 while (tcp_to_tun_queue.nextUndefinedMessage == tcp_to_tun_queue.nextPendingMessage) {
326                         std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
327                 }
328
329                 auto& msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextPendingMessage];
330
331                 const sockaddr_in& pkt_src = std::get<0>(msg);
332                 unsigned char* buf = std::get<1>(msg).data();
333                 ssize_t nread = std::get<2>(msg);
334
335                 int header_size = check_ip_header(buf, nread, 0x06); // Only support TCP
336                 if (header_size < 0)
337                         continue;
338
339                 if (nread - header_size < 20) {
340                         fprintf(stderr, "Short TCP packet\n");
341                         continue;
342                 }
343
344                 unsigned char* tcp_buf = buf + header_size;
345
346                 if (((tcp_buf[2] << 8) | tcp_buf[3]) != local_port) continue;
347
348                 bool syn = tcp_buf[13] & (1 << 1);
349                 bool ack = tcp_buf[13] & (1 << 4);
350
351                 if (are_server && syn && !ack) {
352                         // We're a server and just got a client
353                         if (tcp_buf[4 ] != uint8_t(tcp_init_magic >> (7 * 8)) ||
354                             tcp_buf[5 ] != uint8_t(tcp_init_magic >> (6 * 8)) ||
355                             tcp_buf[6 ] != uint8_t(tcp_init_magic >> (5 * 8)) ||
356                             tcp_buf[7 ] != uint8_t(tcp_init_magic >> (4 * 8)) ||
357                             tcp_buf[8 ] != uint8_t(tcp_init_magic >> (3 * 8)) ||
358                             tcp_buf[9 ] != uint8_t(tcp_init_magic >> (2 * 8)) ||
359                             tcp_buf[10] != uint8_t(tcp_init_magic >> (1 * 8)) ||
360                             tcp_buf[11] != uint8_t(tcp_init_magic >> (0 * 8)))
361                                 continue;
362
363                         fprintf(stderr, "Got SYN, sending SYNACK\n");
364                         remote_port = (tcp_buf[0] << 8) | tcp_buf[1];
365                         dest = pkt_src;
366                 }
367
368                 if (((tcp_buf[0] << 8) | tcp_buf[1]) != remote_port) continue;
369                 if (pkt_src.sin_addr.s_addr != dest.sin_addr.s_addr) continue;
370
371                 uint8_t num_words = (tcp_buf[12] & 0xf0) >> 4;
372                 int tcp_header_size = num_words * 4;
373                 if (tcp_header_size < 20) {
374                         fprintf(stderr, "Invalid TCP header size (%d)\n", tcp_header_size);
375                         continue;
376                 }
377
378                 highest_recvd_seq = ((((uint32_t)tcp_buf[4]) << (8 * 3)) |
379                                      (((uint32_t)tcp_buf[5]) << (8 * 2)) |
380                                      (((uint32_t)tcp_buf[6]) << (8 * 1)) |
381                                      (((uint32_t)tcp_buf[7]) << (8 * 0))) +
382                                      (syn ? 1 : nread - header_size - tcp_header_size);
383
384                 if (ack)
385                         last_ack_recv = std::chrono::steady_clock::now();
386
387                 if (are_server && syn && !ack) {
388                         build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
389
390                         ssize_t res = sendto(fdr, tcp_buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
391                         if (res < 0) {
392                                 int err = errno;
393                                 fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
394                         }
395                 } else if (!syn) {
396                         tcp_buf += tcp_header_size - 20;
397
398                         // Replace TCP with IPv4 header
399                         //build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x01, ipip_dest, ipip_src); // ICMP
400                         build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x11, ipip_dest, ipip_src); // UDP
401                         // Add IPIP header
402                         build_ip_header(tcp_buf - 20, nread - tcp_header_size - header_size + 20, 0x04, tun_dest, tun_src);
403
404                         write(fd[0], tcp_buf - 20, nread - tcp_header_size - header_size + 40);
405                 }
406         } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
407 }
408
409 static std::atomic_int tun_if_thread(0), tun_if_process_thread(0);
410 static std::atomic_bool pause_tun_read_reinit_tcp(false);
411 static MessageQueue tun_to_tcp_queue[TUN_IF_COUNT];
412
413 static void tun_to_tcp() {
414         unsigned char buf[PACKET_READ_SIZE];
415
416         int thread = tun_if_thread.fetch_add(1);
417         MessageQueue& queue = tun_to_tcp_queue[thread];
418
419         while (1) {
420                 ssize_t nread = read(fd[thread], buf, sizeof(buf));
421                 if (pause_tun_read_reinit_tcp)
422                         continue;
423
424                 if (nread < 0) {
425                         fprintf (stderr, "Failed to read tun if\n");
426                         continue;
427                 }
428
429                 if (queue.nextPendingMessage == (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
430                         continue;
431
432                 auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage];
433                 memcpy(std::get<1>(new_msg).data(), buf, nread);
434                 std::get<2>(new_msg) = nread;
435
436                 queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
437         }
438 }
439
440 static void tun_to_tcp_queue_process() {
441         int thread = tun_if_process_thread.fetch_add(1);
442         MessageQueue& queue = tun_to_tcp_queue[thread];
443
444         do {
445                 while (queue.nextUndefinedMessage == queue.nextPendingMessage) {
446                         std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
447                 }
448                 if (pause_tun_read_reinit_tcp)
449                         continue;
450
451                 auto& msg = queue.messagesPendingRingBuff[queue.nextPendingMessage];
452
453                 unsigned char* buf = std::get<1>(msg).data();
454                 ssize_t nread = std::get<2>(msg);
455
456                 int header_size = check_ip_header(buf, nread, 0x04); // Only support IPIP
457                 if (header_size < 0)
458                         continue;
459
460                 int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x11); // Only support UDP
461                 //int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x01); // Only support ICMP
462                 if (internal_header_size < 0)
463                         continue;
464
465                 if (internal_header_size + header_size + 8 > nread) {
466                         fprintf(stderr, "Short UDP-in-IPIP packet\n");
467                         continue;
468                 }
469
470                 size_t tcp_start_offset = header_size + internal_header_size - 20;
471                 build_tcp_header(buf + tcp_start_offset, nread - tcp_start_offset - 20, 0, 0, src, dest.sin_addr.s_addr);
472
473                 ssize_t res = sendto(fdr, buf + tcp_start_offset, nread - tcp_start_offset, 0, (struct sockaddr*)&dest, sizeof(dest));
474                 if (res < 0) {
475                         int err = errno;
476                         fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err));
477                 }
478
479         } while ((queue.nextPendingMessage = (queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
480 }
481
482 int do_init() {
483         //
484         // Send SYN and SYN/ACK
485         //
486
487         if (!are_server) {
488                 if (local_port) // Doing a re-init
489                         pause_tun_read_reinit_tcp = true;
490
491                 uint16_t local_port_tmp = 0;
492                 while (local_port_tmp < 1024)
493                         assert(getrandom(&local_port_tmp, sizeof(local_port_tmp), 0) == sizeof(local_port_tmp));
494
495                 local_port = local_port_tmp;
496         }
497
498         uint32_t starting_ack = 0, starting_seq = 0;
499         memcpy(&starting_ack, &tcp_init_magic, 4);
500         memcpy(&starting_seq, ((const unsigned char*)&tcp_init_magic) + 4, 4);
501         highest_recvd_seq = starting_ack;
502         cur_seq = starting_seq;
503
504         if (!pause_tun_read_reinit_tcp) { // Not doing a re-init
505                 std::thread t(&tcp_to_tun);
506                 std::thread t2(&tcp_to_tun_queue_process);
507                 t.detach();
508                 t2.detach();
509         }
510
511         unsigned char buf[1500];
512         if (!are_server) {
513                 build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
514                 ssize_t res = sendto(fdr, buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
515                 if (res < 0) {
516                         int err = errno;
517                         fprintf(stderr, "Failed to send initial SYN with err %d (%s)\n", err, strerror(err));
518                         return -1;
519                 }
520         }
521
522         int i;
523         for (i = 0; i < 1000 && highest_recvd_seq == starting_ack; i++)
524                 std::this_thread::sleep_for(std::chrono::milliseconds(10));
525         if (i == 1000) // Will come back in 10 seconds
526                 return 0;
527
528         if (!are_server) {
529                 fprintf(stderr, "Got SYNACK, sending ACK and starting tun listen\n");
530
531                 build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
532                 ssize_t res = sendto(fdr, buf, 20, 0, (struct sockaddr*)&dest, sizeof(dest));
533                 if (res < 0) {
534                         int err = errno;
535                         fprintf(stderr, "Failed to send initial ACK with err %d (%s)\n", err, strerror(err));
536                         return -1;
537                 }
538         }
539
540         if (pause_tun_read_reinit_tcp) {
541                 pause_tun_read_reinit_tcp = false;
542         } else {
543                 for (int i = 0; i < TUN_IF_COUNT; i++) {
544                         std::thread t3(&tun_to_tcp);
545                         std::thread t4(&tun_to_tcp_queue_process);
546                         t3.detach();
547                         t4.detach();
548                 }
549         }
550
551         return 0;
552 }
553
554 int main(int argc, char* argv[]) {
555         assert(argc > 1 && "Need tun name");
556         assert(argc > 2 && "Need tun remote host");
557         assert(argc > 3 && "Need tun local host");
558         assert(argc > 4 && "Need ipip remote host");
559         assert(argc > 5 && "Need ipip local host");
560         assert(argc > 6 && "Need server port");
561         assert(argc > 7 && "Need shared secret");
562         assert(argc > 8 && "Need mode (client or server)");
563         assert(argc > 9 && "Need src host");
564         if (std::string(argv[8]) == std::string("client"))
565                 assert(argc > 10 && "Need dest host");
566
567         assert(std::string(argv[8]) == std::string("client") || std::string(argv[8]) == std::string("server"));
568         are_server = (std::string(argv[8]) == std::string("server"));
569
570         //
571         // Parse args into variables
572         //
573
574         char tun_name[IFNAMSIZ];
575
576         memset(tun_name, 0, sizeof(tun_name));
577         strcpy(tun_name, argv[1]);
578
579         tun_dest = inet_addr(argv[2]);
580         tun_src = inet_addr(argv[3]);
581
582         ipip_dest = inet_addr(argv[4]);
583         ipip_src = inet_addr(argv[5]);
584
585         if (are_server) {
586                 local_port = atoi(argv[6]);
587                 remote_port = 0;
588         } else {
589                 // Get local port in do_init() so that we pick a new one on reload
590                 remote_port = atoi(argv[6]);
591         }
592
593         tcp_init_magic = atoll(argv[7]);
594
595         src = inet_addr(argv[9]);
596
597         memset(&dest, 0, sizeof(dest));
598         if (!are_server) {
599                 dest.sin_family = AF_INET;
600                 dest.sin_addr.s_addr = inet_addr(argv[10]);
601         }
602
603         //
604         // Create tun and bind to sockets...
605         //
606
607         if (tun_alloc(tun_name, argv[3], argv[2], TUN_IF_COUNT, fd) != 0) {
608                 fprintf(stderr, "Failed to alloc tun if\n");
609                 return -1;
610         }
611
612         fdr = socket(AF_INET, SOCK_RAW, IPPROTO_TCP);
613         if (fdr < 0) {
614                 fprintf(stderr, "Failed to get raw socket\n");
615                 return -1;
616         }
617
618         int res = do_init();
619         if (res)
620                 return res;
621
622         while (true) {
623                 std::this_thread::sleep_for(std::chrono::seconds(15));
624                 if (!are_server && last_ack_recv < std::chrono::steady_clock::now() - std::chrono::seconds(15)) {
625                         res = do_init();
626                         if (res)
627                                 return res;
628                 }
629         }
630 }