5 #include <sys/socket.h>
7 #include <netinet/in.h>
13 #include <linux/if_tun.h>
22 #if __has_include(<sys/random.h>)
23 #include <sys/random.h>
25 int getrandom(void* buf, size_t len, unsigned int flags) {
26 FILE* f = fopen("/dev/urandom", "r");
27 int res = fread(buf, len, 1, f);
34 // Note that ordering of the first three uint16_ts must match the wire!
37 uint16_t frag_flags: 3,
42 static_assert(sizeof(struct pkt_hdr) == 8);
44 // Have to leave room to add a TCP header plus above header (after dropping IP header)
45 #define PACKET_READ_OFFS sizeof(pkt_hdr)
46 #define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS)
48 static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds)
57 memset(&ifr, 0, sizeof(ifr));
59 /* Flags: IFF_TUN - TUN device (no Ethernet headers)
60 * IFF_TAP - TAP device
62 * IFF_NO_PI - Do not provide packet information
64 ifr.ifr_flags = IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE;
65 strncpy(ifr.ifr_name, dev, IFNAMSIZ);
67 for (i = 0; i < queues; i++) {
68 if((fd = open("/dev/net/tun", O_RDWR)) < 0)
70 err = ioctl(fd, TUNSETIFF, (void *) &ifr);
78 sprintf(buf, "ip link set dev %s mtu %ld", dev, PACKET_READ_SIZE - 20);
82 sprintf(buf, "ip addr add %s/32 dev %s", local_ip, dev);
86 sprintf(buf, "ip link set %s up", dev);
90 sprintf(buf, "ip route add %s/32 dev %s", remote_ip, dev);
96 for (--i; i >= 0; i--)
101 static int check_ip_header(const unsigned char* buf, ssize_t buf_len, int16_t expected_type, struct pkt_hdr *gen_hdr) {
107 if ((buf[0] & 0xf0) != (4 << 4)) {
112 uint8_t num_words = buf[0] & 0xf;
113 int header_size = num_words * 4;
114 if (header_size < 20) {
115 fprintf(stderr, "Invalid IPv4 IHL size (%d)\n", header_size);
119 /*if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) {
120 //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len);
124 if (expected_type > 0 && buf[9] != expected_type) {
125 fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type);
130 memcpy(gen_hdr, buf + 2, 6); // len, id, frag offset
131 gen_hdr->proto = buf[9];
132 gen_hdr->padding = 0;
138 void print_packet(const unsigned char* buf, ssize_t buf_len) {
139 for (ssize_t i = 0; i < buf_len; ) {
140 for (int j = 0; i < buf_len && j < 20; i++, j++)
141 fprintf(stderr, "%02x", buf[i]);
142 fprintf(stderr, "\n");
146 static uint32_t data_checksum(const unsigned char* buf, size_t len) {
152 sum += (*buf++) << 8;
153 if (sum & 0x80000000)
154 sum = (sum & 0xFFFF) + (sum >> 16);
164 static uint16_t finalize_data_checksum(uint32_t sum) {
166 sum = (sum & 0xFFFF) + (sum >> 16);
168 return (uint16_t)(~sum);
171 static uint16_t tcp_checksum(const unsigned char* buff, size_t len, in_addr_t src_addr, in_addr_t dest_addr)
173 uint16_t *ip_src = (uint16_t*)&src_addr, *ip_dst = (uint16_t*)&dest_addr;
174 uint32_t sum = data_checksum(buff, len);
180 sum += htons(IPPROTO_TCP);
183 return finalize_data_checksum(sum);
186 static std::atomic<uint32_t> highest_recvd_seq(0), cur_seq(0);
187 static std::atomic<uint16_t> local_port(0);
188 static uint16_t remote_port;
189 static bool are_server;
190 static uint64_t timestamps_magic;
192 static int build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) {
193 buf[0 ] = local_port >> 8; // src port
194 buf[1 ] = local_port; // src port
195 buf[2 ] = remote_port >> 8; // dst port
196 buf[3 ] = remote_port; // dst port
198 uint32_t seq = cur_seq.fetch_add((syn || synack) ? 1 : len, std::memory_order_acq_rel);
199 buf[4 ] = seq >> (8 * 3); // SEQ
200 buf[5 ] = seq >> (8 * 2); // SEQ
201 buf[6 ] = seq >> (8 * 1); // SEQ
202 buf[7 ] = seq >> (8 * 0); // SEQ
204 uint32_t their_seq = highest_recvd_seq.load(std::memory_order_relaxed);
205 buf[8 ] = their_seq >> (8 * 3); // ACK
206 buf[9 ] = their_seq >> (8 * 2); // ACK
207 buf[10] = their_seq >> (8 * 1); // ACK
208 buf[11] = their_seq >> (8 * 0); // ACK
210 unsigned char hdrlen = syn ? 36 : (synack ? 24 : 20);
211 buf[12] = (hdrlen/4) << 4; // data offset
213 buf[13] = 1 << 1; // SYN
215 buf[13] = (1 << 1) | (1 << 4); // SYN + ACK
217 buf[13] = 1 << 4; // ACK
219 buf[13] = 3 << 3; // PSH + ACK
220 buf[14] = 0xff; // Window Size
221 buf[15] = 0xff; // Window Size
223 buf[16] = 0x00; // Checksum
224 buf[17] = 0x00; // Checksum
225 buf[18] = 0x00; // URG Pointer
226 buf[19] = 0x00; // URG Pointer
229 buf[20] = 0x01; // NOP
230 buf[21] = 0x03; // Window Scale
231 buf[22] = 0x03; // Window Scale Option Length
232 buf[23] = 0x0e; // 1GB Window Size (0xffff << 0x0e)
235 buf[24] = 0x01; // NOP
236 buf[25] = 0x01; // NOP
237 buf[26] = 8; // Timestamp
238 buf[27] = 10; // Timestamp Option Length
239 memcpy(buf + 28, ×tamps_magic, 8);
242 uint16_t checksum = tcp_checksum(buf, len + hdrlen, src_addr, dest_addr);
243 buf[16] = checksum; // Checksum
244 buf[17] = checksum >> 8; // Checksum
249 static void build_ip_header(unsigned char* buf, struct pkt_hdr hdr, const in_addr_t& src_addr, const in_addr_t& dest_addr) {
250 buf[0 ] = (4 << 4) | 5; // IPv4 + IHL of 5 (20 bytes)
251 buf[1 ] = 0; // DSCP 0 + ECN 0
252 memcpy(buf + 2, &hdr, 6); // length, identification, flags, and offset
253 buf[8 ] = 255; // TTL
254 buf[9 ] = hdr.proto; // Protocol Number
255 buf[10] = 0; // Checksum
256 buf[11] = 0; // Checksum
258 memcpy(buf + 12, &src_addr, 4);
259 memcpy(buf + 16, &dest_addr, 4);
261 uint16_t checksum = finalize_data_checksum(data_checksum(buf, 20));
263 buf[11] = checksum >> 8;
266 const signed char p_util_hexdigit[256] =
267 { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
268 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
269 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
270 0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,
271 -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
272 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
273 -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
274 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
275 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
276 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
277 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
278 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
279 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
280 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
281 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
282 -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, };
284 uint32_t hex_to_num(const unsigned char* buf) {
285 const unsigned char* pbegin = buf;
286 while (p_util_hexdigit[*buf] != -1)
290 unsigned char* p1 = (unsigned char*)&res;
291 unsigned char* pend = p1 + 4;
292 while (buf >= pbegin && p1 < pend) {
293 *p1 = p_util_hexdigit[*buf--];
295 *p1 |= ((unsigned char)p_util_hexdigit[*buf--] << 4);
303 #define TUN_IF_COUNT 4
305 static int fd[TUN_IF_COUNT];
306 static struct sockaddr_in dest;
307 static in_addr_t src, tun_src, tun_dest;
308 static uint32_t starting_ack;
310 #define PENDING_MESSAGES_BUFF_SIZE (0x800)
311 #define THREAD_POLL_SLEEP_MICS 250
312 struct MessageQueue {
313 std::tuple<sockaddr_in, std::array<unsigned char, PACKET_READ_SIZE + PACKET_READ_OFFS>, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE];
314 std::atomic<uint16_t> nextPendingMessage, nextUndefinedMessage;
315 MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {}
316 MessageQueue(MessageQueue&& q) =delete;
317 MessageQueue(MessageQueue& q) =delete;
320 static MessageQueue tcp_to_tun_queue;
321 static std::chrono::steady_clock::time_point last_ack_recv;
323 static void tcp_to_tun() {
324 unsigned char buf[PACKET_READ_SIZE + PACKET_READ_OFFS];
325 struct sockaddr_in pkt_src;
326 memset(&pkt_src, 0, sizeof(pkt_src));
329 socklen_t hostsz = sizeof(pkt_src);
330 ssize_t nread = recvfrom(fdr, buf, sizeof(buf), 0, (struct sockaddr*)&pkt_src, &hostsz);
332 fprintf (stderr, "Failed to read tcp raw sock\n");
336 if (tcp_to_tun_queue.nextPendingMessage == (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
339 auto& new_msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextUndefinedMessage];
340 std::get<0>(new_msg) = pkt_src;
341 memcpy(std::get<1>(new_msg).data(), buf, nread);
342 std::get<2>(new_msg) = nread;
344 tcp_to_tun_queue.nextUndefinedMessage = (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
348 static void tcp_to_tun_queue_process() {
350 while (tcp_to_tun_queue.nextUndefinedMessage == tcp_to_tun_queue.nextPendingMessage) {
351 std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
354 auto& msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextPendingMessage];
356 const sockaddr_in& pkt_src = std::get<0>(msg);
357 unsigned char* buf = std::get<1>(msg).data();
358 ssize_t nread = std::get<2>(msg);
360 int header_size = check_ip_header(buf, nread, 0x06, NULL); // Only support TCP
364 if (nread - header_size < 20) {
365 fprintf(stderr, "Short TCP packet\n");
369 unsigned char* tcp_buf = buf + header_size;
371 if (((tcp_buf[2] << 8) | tcp_buf[3]) != local_port) continue;
373 bool syn = tcp_buf[13] & (1 << 1);
374 bool ack = tcp_buf[13] & (1 << 4);
376 if (are_server && syn && !ack) {
377 uint32_t expected_ack = htobe32(starting_ack);
378 if (memcmp(tcp_buf + 8, &expected_ack, 4))
380 if (nread < 36 + header_size) { continue; }
381 // We're a server and just got a client...walk options until we find timestamps
382 const unsigned char* opt_buf = tcp_buf + 20;
383 bool found_magic = false;
384 while (!found_magic && opt_buf < buf + nread) {
386 case 1: opt_buf += 1; break;
387 case 2: opt_buf += 4; break;
388 case 3: opt_buf += 3; break;
389 // SACK should never appear
391 if (opt_buf + 10 <= buf + nread) {
392 if (!memcmp(opt_buf + 2, ×tamps_magic, 8)) {
398 default: opt_buf = buf + nread;
401 if (!found_magic) continue;
403 fprintf(stderr, "Got SYN, sending SYNACK\n");
404 remote_port = (tcp_buf[0] << 8) | tcp_buf[1];
408 if (((tcp_buf[0] << 8) | tcp_buf[1]) != remote_port) continue;
409 if (pkt_src.sin_addr.s_addr != dest.sin_addr.s_addr) continue;
411 uint8_t num_words = (tcp_buf[12] & 0xf0) >> 4;
412 int tcp_header_size = num_words * 4;
413 if (tcp_header_size < 20) {
414 fprintf(stderr, "Invalid TCP header size (%d)\n", tcp_header_size);
418 highest_recvd_seq = ((((uint32_t)tcp_buf[4]) << (8 * 3)) |
419 (((uint32_t)tcp_buf[5]) << (8 * 2)) |
420 (((uint32_t)tcp_buf[6]) << (8 * 1)) |
421 (((uint32_t)tcp_buf[7]) << (8 * 0))) +
422 (syn ? 1 : nread - header_size - tcp_header_size);
425 last_ack_recv = std::chrono::steady_clock::now();
427 if (are_server && syn && !ack) {
428 int len = build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr);
430 ssize_t res = sendto(fdr, tcp_buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
433 fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err));
436 if (nread < (ssize_t)(tcp_header_size + header_size + sizeof(struct pkt_hdr))) continue;
438 memcpy(&hdr, tcp_buf + tcp_header_size, sizeof(struct pkt_hdr));
439 tcp_buf += tcp_header_size + sizeof(struct pkt_hdr) - 20; // IPv4 header is 20 bytes
441 // Replace TCP + pkt_hdr with IPv4 header
442 build_ip_header(tcp_buf, hdr, tun_dest, tun_src);
444 write(fd[0], tcp_buf, nread - tcp_header_size - header_size + 20 + 8);
446 } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
449 static std::atomic_int tun_if_thread(0), tun_if_process_thread(0);
450 static std::atomic_bool pause_tun_read_reinit_tcp(false);
451 static MessageQueue tun_to_tcp_queue[TUN_IF_COUNT];
453 static void tun_to_tcp() {
454 unsigned char buf[PACKET_READ_SIZE];
456 int thread = tun_if_thread.fetch_add(1);
457 MessageQueue& queue = tun_to_tcp_queue[thread];
460 ssize_t nread = read(fd[thread], buf, sizeof(buf));
461 if (pause_tun_read_reinit_tcp)
465 fprintf (stderr, "Failed to read tun if\n");
469 if (queue.nextPendingMessage == (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE)
472 auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage];
473 memcpy(std::get<1>(new_msg).data() + PACKET_READ_OFFS, buf, nread);
474 std::get<2>(new_msg) = nread;
476 queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE;
480 static void tun_to_tcp_queue_process() {
481 int thread = tun_if_process_thread.fetch_add(1);
482 MessageQueue& queue = tun_to_tcp_queue[thread];
485 while (queue.nextUndefinedMessage == queue.nextPendingMessage) {
486 std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS));
488 if (pause_tun_read_reinit_tcp)
491 auto& msg = queue.messagesPendingRingBuff[queue.nextPendingMessage];
493 unsigned char* buf = std::get<1>(msg).data();
494 ssize_t nread = std::get<2>(msg);
497 int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, -1, &hdr); // Any paket type is okay
498 if (header_size < 20)
501 if (header_size > nread) {
502 fprintf(stderr, "Short packet\n");
506 if ((size_t)nread > 1500 - 20 - sizeof(struct pkt_hdr)) { // Packets must fit in 1500 bytes with a TCP and packet hdr
507 fprintf(stderr, "Long packet\n");
511 unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size - 20 - sizeof(struct pkt_hdr);
512 memcpy(tcp_start + 20, &hdr, sizeof(struct pkt_hdr));
513 build_tcp_header(tcp_start, nread - header_size + sizeof(struct pkt_hdr), 0, 0, src, dest.sin_addr.s_addr);
515 ssize_t res = sendto(fdr, tcp_start, nread + 20 + sizeof(struct pkt_hdr) - header_size, 0, (struct sockaddr*)&dest, sizeof(dest));
518 fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err));
521 } while ((queue.nextPendingMessage = (queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true);
526 // Send SYN and SYN/ACK
530 if (local_port) // Doing a re-init
531 pause_tun_read_reinit_tcp = true;
533 uint16_t local_port_tmp = 0;
534 while (local_port_tmp < 1024)
535 assert(getrandom(&local_port_tmp, sizeof(local_port_tmp), 0) == sizeof(local_port_tmp));
537 local_port = local_port_tmp;
540 highest_recvd_seq = starting_ack;
541 uint32_t starting_seq;
542 assert(getrandom(&starting_seq, sizeof(starting_seq), 0) == sizeof(starting_seq));
543 cur_seq = starting_seq;
545 if (!pause_tun_read_reinit_tcp) { // Not doing a re-init
546 std::thread t(&tcp_to_tun);
547 std::thread t2(&tcp_to_tun_queue_process);
552 unsigned char buf[1500];
554 int len = build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr);
555 ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
558 fprintf(stderr, "Failed to send initial SYN with err %d (%s)\n", err, strerror(err));
564 for (i = 0; i < 1000 && highest_recvd_seq == starting_ack; i++)
565 std::this_thread::sleep_for(std::chrono::milliseconds(10));
566 if (i == 1000) // Will come back in 10 seconds
570 fprintf(stderr, "Got SYNACK, sending ACK and starting tun listen\n");
572 int len = build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr);
573 ssize_t res = sendto(fdr, buf, len, 0, (struct sockaddr*)&dest, sizeof(dest));
576 fprintf(stderr, "Failed to send initial ACK with err %d (%s)\n", err, strerror(err));
581 if (pause_tun_read_reinit_tcp) {
582 pause_tun_read_reinit_tcp = false;
584 for (int i = 0; i < TUN_IF_COUNT; i++) {
585 std::thread t3(&tun_to_tcp);
586 std::thread t4(&tun_to_tcp_queue_process);
595 int main(int argc, char* argv[]) {
596 assert(argc > 1 && "Need tun name");
597 assert(argc > 2 && "Need tun remote host");
598 assert(argc > 3 && "Need tun local host");
599 assert(argc > 4 && "Need server port");
600 assert(argc > 5 && "Need shared secret");
601 assert(argc > 6 && "Need mode (client or server)");
602 assert(argc > 7 && "Need src host");
603 if (std::string(argv[6]) == std::string("client"))
604 assert(argc > 8 && "Need dest host");
606 assert(std::string(argv[6]) == std::string("client") || std::string(argv[6]) == std::string("server"));
607 are_server = (std::string(argv[6]) == std::string("server"));
610 // Parse args into variables
613 char tun_name[IFNAMSIZ];
615 memset(tun_name, 0, sizeof(tun_name));
616 strcpy(tun_name, argv[1]);
618 tun_dest = inet_addr(argv[2]);
619 tun_src = inet_addr(argv[3]);
622 local_port = atoi(argv[4]);
625 // Get local port in do_init() so that we pick a new one on reload
626 remote_port = atoi(argv[4]);
629 uint64_t tcp_init_magic = atoll(argv[5]);
630 starting_ack = tcp_init_magic >> 32;
631 timestamps_magic = htobe64(tcp_init_magic);
633 src = inet_addr(argv[7]);
635 memset(&dest, 0, sizeof(dest));
637 dest.sin_family = AF_INET;
638 dest.sin_addr.s_addr = inet_addr(argv[8]);
642 // Create tun and bind to sockets...
645 if (tun_alloc(tun_name, argv[3], argv[2], TUN_IF_COUNT, fd) != 0) {
646 fprintf(stderr, "Failed to alloc tun if\n");
650 fdr = socket(AF_INET, SOCK_RAW, IPPROTO_TCP);
652 fprintf(stderr, "Failed to get raw socket\n");
661 std::this_thread::sleep_for(std::chrono::seconds(15));
662 if (!are_server && last_ack_recv < std::chrono::steady_clock::now() - std::chrono::seconds(15)) {