5 #include <sys/socket.h>
7 #include <netinet/in.h>
13 #include <linux/if_tun.h>
21 #if __has_include(<sys/random.h>)
22 #include <sys/random.h>
24 int getrandom(void* buf, size_t len, unsigned int flags) {
25 FILE* f = fopen("/dev/urandom", "r");
26 int res = fread(buf, len, 1, f);
32 #define PACKET_READ_SIZE 1500
34 static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds)
43 memset(&ifr, 0, sizeof(ifr));
45 /* Flags: IFF_TUN - TUN device (no Ethernet headers)
46 * IFF_TAP - TAP device
48 * IFF_NO_PI - Do not provide packet information
50 ifr.ifr_flags = IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE;
51 strncpy(ifr.ifr_name, dev, IFNAMSIZ);
53 for (i = 0; i < queues; i++) {
54 if((fd = open("/dev/net/tun", O_RDWR)) < 0)
56 err = ioctl(fd, TUNSETIFF, (void *) &ifr);
64 sprintf(buf, "ip link set %s mtu %d", dev, PACKET_READ_SIZE);
68 sprintf(buf, "ip addr add %s/32 dev %s", local_ip, dev);
72 sprintf(buf, "ip link set %s up", dev);
76 sprintf(buf, "ip route add %s/32 dev %s", remote_ip, dev);
82 for (--i; i >= 0; i--)
87 static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t expected_type) {
93 if ((buf[0] & 0xf0) != (4 << 4)) {
98 uint8_t num_words = buf[0] & 0xf;
99 int header_size = num_words * 4;
100 if (header_size < 20) {
101 fprintf(stderr, "Invalid IPv4 IHL size (%d)\n", header_size);
105 if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
106 //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len);
110 if (buf[9] != expected_type) {
111 fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type);
118 void print_packet(const unsigned char* buf, ssize_t buf_len) {
119 for (ssize_t i = 0; i < buf_len; ) {
120 for (int j = 0; i < buf_len && j < 20; i++, j++)
121 fprintf(stderr, "%02x", buf[i]);
122 fprintf(stderr, "\n");
126 static uint32_t data_checksum(const unsigned char* buf, size_t len) {
132 sum += (*buf++) << 8;
133 if (sum & 0x80000000)
134 sum = (sum & 0xFFFF) + (sum >> 16);
144 static uint16_t finalize_data_checksum(uint32_t sum) {
146 sum = (sum & 0xFFFF) + (sum >> 16);
148 return (uint16_t)(~sum);
151 static uint16_t tcp_checksum(const unsigned char* buff, size_t len, in_addr_t src_addr, in_addr_t dest_addr)
153 uint16_t *ip_src = (uint16_t*)&src_addr, *ip_dst = (uint16_t*)&dest_addr;
154 uint32_t sum = data_checksum(buff, len);
160 sum += htons(IPPROTO_TCP);
163 return finalize_data_checksum(sum);
166 static std::atomic<uint32_t> highest_recvd_seq(0), cur_seq(0);
167 static std::atomic<uint16_t> local_port(0);
168 static uint16_t remote_port;
169 static bool are_server;
171 static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
172 buf[0 ] = local_port >> 8; // src port
173 buf[1 ] = local_port; // src port
174 buf[2 ] = remote_port >> 8; // dst port
175 buf[3 ] = remote_port; // dst port
177 uint32_t seq = cur_seq.fetch_add((syn || synack) ? 1 : len, std::memory_order_acq_rel);
178 buf[4 ] = seq >> (8 * 3); // SEQ
179 buf[5 ] = seq >> (8 * 2); // SEQ
180 buf[6 ] = seq >> (8 * 1); // SEQ
181 buf[7 ] = seq >> (8 * 0); // SEQ
183 uint32_t their_seq = highest_recvd_seq.load(std::memory_order_relaxed);
184 buf[8 ] = their_seq >> (8 * 3); // ACK
185 buf[9 ] = their_seq >> (8 * 2); // ACK
186 buf[10] = their_seq >> (8 * 1); // ACK
187 buf[11] = their_seq >> (8 * 0); // ACK
189 bool longpkt = syn || synack;
190 buf[12] = (longpkt ? 6 : 5) << 4; // data offset
192 buf[13] = 1 << 1; // SYN
194 buf[13] = (1 << 1) | (1 << 4); // SYN + ACK
196 buf[13] = 1 << 4; // ACK
198 buf[13] = 3 << 3; // PSH + ACK
199 buf[14] = 0xff; // Window Size
200 buf[15] = 0xff; // Window Size
202 buf[16] = 0x00; // Checksum
203 buf[17] = 0x00; // Checksum
204 buf[18] = 0x00; // URG Pointer
205 buf[19] = 0x00; // URG Pointer
208 buf[20] = 0x01; // NOP
209 buf[21] = 0x03; // Window Scale
210 buf[22] = 0x03; // Window Scale Option Length
211 buf[23] = 0x0e; // 1GB Window Size (0xffff << 0x0e)
214 uint16_t checksum = tcp_checksum(buf, len + 20 + (longpkt ? 4 : 0), src_addr, dest_addr);
215 buf[16] = checksum; // Checksum
216 buf[17] = checksum >> 8; // Checksum
219 static void build_ip_header(unsigned char* buf, uint32_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
220 buf[0 ] = (4 << 4) | 5; // IPv4 + IHL of 5 (20 bytes)
221 buf[1 ] = 0; // DSCP 0 + ECN 0
222 buf[2 ] = (len + 20) >> 8; // Length
223 buf[3 ] = (len + 20); // Length
224 memset(buf + 4, 0, 4); // Identification and Fragment 0s
225 buf[6 ] = 1 << 6; // DF bit
226 buf[8 ] = 255; // TTL
227 buf[9 ] = proto; // Protocol Number
228 buf[10] = 0; // Checksum
229 buf[11] = 0; // Checksum
231 memcpy(buf + 12, &src_addr, 4);
232 memcpy(buf + 16, &dest_addr, 4);
234 uint16_t checksum = finalize_data_checksum(data_checksum(buf, 20));
236 buf[11] = checksum >> 8;
240 const signed char p_util_hexdigit[256] =
241 { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
242 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
243 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
244 0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,
245 -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
246 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
247 -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
248 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
249 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
250 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
251 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
252 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
253 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
254 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
255 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
256 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, };
258 uint32_t hex_to_num(const unsigned char* buf) {
259 const unsigned char* pbegin = buf;
260 while (p_util_hexdigit[*buf] != -1)
264 unsigned char* p1 = (unsigned char*)&res;
265 unsigned char* pend = p1 + 4;
266 while (buf >= pbegin && p1 < pend) {
267 *p1 = p_util_hexdigit[*buf--];
269 *p1 |= ((unsigned char)p_util_hexdigit[*buf--] << 4);
277 #define TUN_IF_COUNT 4
279 static int fd[TUN_IF_COUNT];
280 static struct sockaddr_in dest;
281 static in_addr_t src, tun_src, tun_dest, ipip_src, ipip_dest;
282 static uint64_t tcp_init_magic;
284 #define PENDING_MESSAGES_BUFF_SIZE (0x3000)
285 #define PACKET_READ_SIZE 1500
286 #define THREAD_POLL_SLEEP_MICS 50
287 struct MessageQueue {
288 std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
289 std::atomic<uint16_t> nextPendingMessage, nextUndefinedMessage;
290 MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {}
291 MessageQueue(MessageQueue&& q) =delete;
292 MessageQueue(MessageQueue& q) =delete;
295 static MessageQueue tcp_to_tun_queue;
296 static std::chrono::steady_clock::time_point last_ack_recv;
298 static void tcp_to_tun() {
299 unsigned char buf[1500];
300 struct sockaddr_in pkt_src;
301 memset(&pkt_src, 0, sizeof(pkt_src));
304 socklen_t hostsz = sizeof(pkt_src);
305 ssize_t nread = recvfrom(fdr, buf, sizeof(buf), 0, (struct sockaddr*)&pkt_src, &hostsz);
307 fprintf (stderr, "Failed to read tcp raw sock\n");
311 if (tcp_to_tun_queue.nextPendingMessage == (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
314 auto& new_msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextUndefinedMessage];
315 std::get<0>(new_msg) = pkt_src;
316 memcpy(std::get<1>(new_msg).data(), buf, nread);
317 std::get<2>(new_msg) = nread;
319 tcp_to_tun_queue.nextUndefinedMessage = (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
323 static void tcp_to_tun_queue_process() {
325 while (tcp_to_tun_queue.nextUndefinedMessage == tcp_to_tun_queue.nextPendingMessage) {
326 std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
329 auto& msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextPendingMessage];
331 const sockaddr_in& pkt_src = std::get<0>(msg);
332 unsigned char* buf = std::get<1>(msg).data();
333 ssize_t nread = std::get<2>(msg);
335 int header_size = check_ip_header(buf, nread, 0x06); // Only support TCP
339 if (nread - header_size < 20) {
340 fprintf(stderr, "Short TCP packet\n");
344 unsigned char* tcp_buf = buf + header_size;
346 if (((tcp_buf[2] << 8) | tcp_buf[3]) != local_port) continue;
348 bool syn = tcp_buf[13] & (1 << 1);
349 bool ack = tcp_buf[13] & (1 << 4);
351 if (are_server && syn && !ack) {
352 // We're a server and just got a client
353 if (tcp_buf[4 ] != uint8_t(tcp_init_magic >> (7 * 8)) ||
354 tcp_buf[5 ] != uint8_t(tcp_init_magic >> (6 * 8)) ||
355 tcp_buf[6 ] != uint8_t(tcp_init_magic >> (5 * 8)) ||
356 tcp_buf[7 ] != uint8_t(tcp_init_magic >> (4 * 8)) ||
357 tcp_buf[8 ] != uint8_t(tcp_init_magic >> (3 * 8)) ||
358 tcp_buf[9 ] != uint8_t(tcp_init_magic >> (2 * 8)) ||
359 tcp_buf[10] != uint8_t(tcp_init_magic >> (1 * 8)) ||
360 tcp_buf[11] != uint8_t(tcp_init_magic >> (0 * 8)))
363 fprintf(stderr, "Got SYN, sending SYNACK\n");
364 remote_port = (tcp_buf[0] << 8) | tcp_buf[1];
368 if (((tcp_buf[0] << 8) | tcp_buf[1]) != remote_port) continue;
369 if (pkt_src.sin_addr.s_addr != dest.sin_addr.s_addr) continue;
371 uint8_t num_words = (tcp_buf[12] & 0xf0) >> 4;
372 int tcp_header_size = num_words * 4;
373 if (tcp_header_size < 20) {
374 fprintf(stderr, "Invalid TCP header size (%d)\n", tcp_header_size);
378 highest_recvd_seq = ((((uint32_t)tcp_buf[4]) << (8 * 3)) |
379 (((uint32_t)tcp_buf[5]) << (8 * 2)) |
380 (((uint32_t)tcp_buf[6]) << (8 * 1)) |
381 (((uint32_t)tcp_buf[7]) << (8 * 0))) +
382 (syn ? 1 : nread - header_size - tcp_header_size);
385 last_ack_recv = std::chrono::steady_clock::now();
387 if (are_server && syn && !ack) {
388 build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
390 ssize_t res = sendto(fdr, tcp_buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
393 fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
396 tcp_buf += tcp_header_size - 20;
398 // Replace TCP with IPv4 header
399 //build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x01, ipip_dest, ipip_src); // ICMP
400 build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x11, ipip_dest, ipip_src); // UDP
402 build_ip_header(tcp_buf - 20, nread - tcp_header_size - header_size + 20, 0x04, tun_dest, tun_src);
404 write(fd[0], tcp_buf - 20, nread - tcp_header_size - header_size + 40);
406 } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
409 static std::atomic_int tun_if_thread(0), tun_if_process_thread(0);
410 static std::atomic_bool pause_tun_read_reinit_tcp(false);
411 static MessageQueue tun_to_tcp_queue[TUN_IF_COUNT];
413 static void tun_to_tcp() {
414 unsigned char buf[PACKET_READ_SIZE];
416 int thread = tun_if_thread.fetch_add(1);
417 MessageQueue& queue = tun_to_tcp_queue[thread];
420 ssize_t nread = read(fd[thread], buf, sizeof(buf));
421 if (pause_tun_read_reinit_tcp)
425 fprintf (stderr, "Failed to read tun if\n");
429 if (queue.nextPendingMessage == (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
432 auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage];
433 memcpy(std::get<1>(new_msg).data(), buf, nread);
434 std::get<2>(new_msg) = nread;
436 queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
440 static void tun_to_tcp_queue_process() {
441 int thread = tun_if_process_thread.fetch_add(1);
442 MessageQueue& queue = tun_to_tcp_queue[thread];
445 while (queue.nextUndefinedMessage == queue.nextPendingMessage) {
446 std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
448 if (pause_tun_read_reinit_tcp)
451 auto& msg = queue.messagesPendingRingBuff[queue.nextPendingMessage];
453 unsigned char* buf = std::get<1>(msg).data();
454 ssize_t nread = std::get<2>(msg);
456 int header_size = check_ip_header(buf, nread, 0x04); // Only support IPIP
460 int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x11); // Only support UDP
461 //int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x01); // Only support ICMP
462 if (internal_header_size < 0)
465 if (internal_header_size + header_size + 8 > nread) {
466 fprintf(stderr, "Short UDP-in-IPIP packet\n");
470 size_t tcp_start_offset = header_size + internal_header_size - 20;
471 build_tcp_header(buf + tcp_start_offset, nread - tcp_start_offset - 20, 0, 0, src, dest.sin_addr.s_addr);
473 ssize_t res = sendto(fdr, buf + tcp_start_offset, nread - tcp_start_offset, 0, (struct sockaddr*)&dest, sizeof(dest));
476 fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err));
479 } while ((queue.nextPendingMessage = (queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
484 // Send SYN and SYN/ACK
488 if (local_port) // Doing a re-init
489 pause_tun_read_reinit_tcp = true;
491 uint16_t local_port_tmp = 0;
492 while (local_port_tmp < 1024)
493 assert(getrandom(&local_port_tmp, sizeof(local_port_tmp), 0) == sizeof(local_port_tmp));
495 local_port = local_port_tmp;
498 uint32_t starting_ack = 0, starting_seq = 0;
499 memcpy(&starting_ack, &tcp_init_magic, 4);
500 memcpy(&starting_seq, ((const unsigned char*)&tcp_init_magic) + 4, 4);
501 highest_recvd_seq = starting_ack;
502 cur_seq = starting_seq;
504 if (!pause_tun_read_reinit_tcp) { // Not doing a re-init
505 std::thread t(&tcp_to_tun);
506 std::thread t2(&tcp_to_tun_queue_process);
511 unsigned char buf[1500];
513 build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
514 ssize_t res = sendto(fdr, buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest));
517 fprintf(stderr, "Failed to send initial SYN with err %d (%s)\n", err, strerror(err));
523 for (i = 0; i < 1000 && highest_recvd_seq == starting_ack; i++)
524 std::this_thread::sleep_for(std::chrono::milliseconds(10));
525 if (i == 1000) // Will come back in 10 seconds
529 fprintf(stderr, "Got SYNACK, sending ACK and starting tun listen\n");
531 build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
532 ssize_t res = sendto(fdr, buf, 20, 0, (struct sockaddr*)&dest, sizeof(dest));
535 fprintf(stderr, "Failed to send initial ACK with err %d (%s)\n", err, strerror(err));
540 if (pause_tun_read_reinit_tcp) {
541 pause_tun_read_reinit_tcp = false;
543 for (int i = 0; i < TUN_IF_COUNT; i++) {
544 std::thread t3(&tun_to_tcp);
545 std::thread t4(&tun_to_tcp_queue_process);
554 int main(int argc, char* argv[]) {
555 assert(argc > 1 && "Need tun name");
556 assert(argc > 2 && "Need tun remote host");
557 assert(argc > 3 && "Need tun local host");
558 assert(argc > 4 && "Need ipip remote host");
559 assert(argc > 5 && "Need ipip local host");
560 assert(argc > 6 && "Need server port");
561 assert(argc > 7 && "Need shared secret");
562 assert(argc > 8 && "Need mode (client or server)");
563 assert(argc > 9 && "Need src host");
564 if (std::string(argv[8]) == std::string("client"))
565 assert(argc > 10 && "Need dest host");
567 assert(std::string(argv[8]) == std::string("client") || std::string(argv[8]) == std::string("server"));
568 are_server = (std::string(argv[8]) == std::string("server"));
571 // Parse args into variables
574 char tun_name[IFNAMSIZ];
576 memset(tun_name, 0, sizeof(tun_name));
577 strcpy(tun_name, argv[1]);
579 tun_dest = inet_addr(argv[2]);
580 tun_src = inet_addr(argv[3]);
582 ipip_dest = inet_addr(argv[4]);
583 ipip_src = inet_addr(argv[5]);
586 local_port = atoi(argv[6]);
589 // Get local port in do_init() so that we pick a new one on reload
590 remote_port = atoi(argv[6]);
593 tcp_init_magic = atoll(argv[7]);
595 src = inet_addr(argv[9]);
597 memset(&dest, 0, sizeof(dest));
599 dest.sin_family = AF_INET;
600 dest.sin_addr.s_addr = inet_addr(argv[10]);
604 // Create tun and bind to sockets...
607 if (tun_alloc(tun_name, argv[3], argv[2], TUN_IF_COUNT, fd) != 0) {
608 fprintf(stderr, "Failed to alloc tun if\n");
612 fdr = socket(AF_INET, SOCK_RAW, IPPROTO_TCP);
614 fprintf(stderr, "Failed to get raw socket\n");
623 std::this_thread::sleep_for(std::chrono::seconds(15));
624 if (!are_server && last_ack_recv < std::chrono::steady_clock::now() - std::chrono::seconds(15)) {