From 37e5828ccf746baa10adc7ade9a9349986979f29 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Wed, 16 Mar 2022 04:53:27 +0000 Subject: [PATCH 1/1] Remove the need for an ipip link to reduce packet size (duh, mtu) --- main.cpp | 104 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 40 deletions(-) diff --git a/main.cpp b/main.cpp index eff9b46..c5e73e2 100644 --- a/main.cpp +++ b/main.cpp @@ -29,7 +29,8 @@ int getrandom(void* buf, size_t len, unsigned int flags) { } #endif -#define PACKET_READ_SIZE 1500 +#define PACKET_READ_OFFS (20-16) +#define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS) static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds) { @@ -115,6 +116,20 @@ static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t ex return header_size; } +#define UDP_PORTS 4242 +static int check_udp_header(const unsigned char* buf, ssize_t buf_len) { + if (buf_len < 8) { + return -1; + } + if (((((uint16_t)buf[0]) << 8) | (buf[1])) != UDP_PORTS) { + return -1; + } + if (((((uint16_t)buf[2]) << 8) | (buf[3])) != UDP_PORTS) { + return -1; + } + return 8; +} + void print_packet(const unsigned char* buf, ssize_t buf_len) { for (ssize_t i = 0; i < buf_len; ) { for (int j = 0; i < buf_len && j < 20; i++, j++) @@ -216,7 +231,8 @@ static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int syna buf[17] = checksum >> 8; // Checksum } -static void build_ip_header(unsigned char* buf, uint32_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) { +// len is data length, not including headers! +static void build_ip_header(unsigned char* buf, uint16_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) { buf[0 ] = (4 << 4) | 5; // IPv4 + IHL of 5 (20 bytes) buf[1 ] = 0; // DSCP 0 + ECN 0 buf[2 ] = (len + 20) >> 8; // Length @@ -236,6 +252,20 @@ static void build_ip_header(unsigned char* buf, uint32_t len, uint8_t proto, con buf[11] = checksum >> 8; } +// length is data length, not including headers +static void build_udp_header(unsigned char* buf, uint16_t len) { + buf[0] = UDP_PORTS >> 8; + buf[1] = UDP_PORTS & 0xff; + buf[2] = UDP_PORTS >> 8; + buf[3] = UDP_PORTS & 0xff; + buf[4] = (len + 8) >> 8; + buf[5] = (len + 8); + + // Checksum 0 for now + buf[6] = 0; + buf[7] = 0; +} + const signed char p_util_hexdigit[256] = { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, @@ -278,14 +308,15 @@ uint32_t hex_to_num(const unsigned char* buf) { static int fdr; static int fd[TUN_IF_COUNT]; static struct sockaddr_in dest; -static in_addr_t src, tun_src, tun_dest, ipip_src, ipip_dest; +static in_addr_t src, tun_src, tun_dest; static uint64_t tcp_init_magic; #define PENDING_MESSAGES_BUFF_SIZE (0x3000) -#define PACKET_READ_SIZE 1500 +#define PACKET_READ_OFFS (20-16) +#define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS) #define THREAD_POLL_SLEEP_MICS 50 struct MessageQueue { - std::tuple, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE]; + std::tuple, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE]; std::atomic nextPendingMessage, nextUndefinedMessage; MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {} MessageQueue(MessageQueue&& q) =delete; @@ -393,15 +424,13 @@ static void tcp_to_tun_queue_process() { fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err)); } } else if (!syn) { - tcp_buf += tcp_header_size - 20; + tcp_buf += tcp_header_size - 20 - 8; // IPv4 header is 20 bytes + UDP header is 8 bytes // Replace TCP with IPv4 header - //build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x01, ipip_dest, ipip_src); // ICMP - build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x11, ipip_dest, ipip_src); // UDP - // Add IPIP header - build_ip_header(tcp_buf - 20, nread - tcp_header_size - header_size + 20, 0x04, tun_dest, tun_src); + build_udp_header(tcp_buf + 20, nread - tcp_header_size - header_size); + build_ip_header(tcp_buf, nread - tcp_header_size - header_size + 8, 0x11, tun_dest, tun_src); // UDP - write(fd[0], tcp_buf - 20, nread - tcp_header_size - header_size + 40); + write(fd[0], tcp_buf, nread - tcp_header_size - header_size + 20 + 8); } } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true); } @@ -430,7 +459,7 @@ static void tun_to_tcp() { continue; auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage]; - memcpy(std::get<1>(new_msg).data(), buf, nread); + memcpy(std::get<1>(new_msg).data() + PACKET_READ_OFFS, buf, nread); std::get<2>(new_msg) = nread; queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE; @@ -453,24 +482,24 @@ static void tun_to_tcp_queue_process() { unsigned char* buf = std::get<1>(msg).data(); ssize_t nread = std::get<2>(msg); - int header_size = check_ip_header(buf, nread, 0x04); // Only support IPIP + int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, 0x11); // Only support UDP if (header_size < 0) continue; - int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x11); // Only support UDP - //int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x01); // Only support ICMP - if (internal_header_size < 0) + if (header_size + 8 > nread) { + fprintf(stderr, "Short UDP packet\n"); continue; + } - if (internal_header_size + header_size + 8 > nread) { - fprintf(stderr, "Short UDP-in-IPIP packet\n"); + if (check_udp_header(buf + PACKET_READ_OFFS + header_size, nread - header_size) != 8) { + fprintf(stderr, "Bad UDP header\n"); continue; } - size_t tcp_start_offset = header_size + internal_header_size - 20; - build_tcp_header(buf + tcp_start_offset, nread - tcp_start_offset - 20, 0, 0, src, dest.sin_addr.s_addr); + unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size + 8 - 20; + build_tcp_header(tcp_start, nread - header_size - 8, 0, 0, src, dest.sin_addr.s_addr); - ssize_t res = sendto(fdr, buf + tcp_start_offset, nread - tcp_start_offset, 0, (struct sockaddr*)&dest, sizeof(dest)); + ssize_t res = sendto(fdr, tcp_start, nread + 20 - 8 - header_size, 0, (struct sockaddr*)&dest, sizeof(dest)); if (res < 0) { int err = errno; fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err)); @@ -555,17 +584,15 @@ int main(int argc, char* argv[]) { assert(argc > 1 && "Need tun name"); assert(argc > 2 && "Need tun remote host"); assert(argc > 3 && "Need tun local host"); - assert(argc > 4 && "Need ipip remote host"); - assert(argc > 5 && "Need ipip local host"); - assert(argc > 6 && "Need server port"); - assert(argc > 7 && "Need shared secret"); - assert(argc > 8 && "Need mode (client or server)"); - assert(argc > 9 && "Need src host"); - if (std::string(argv[8]) == std::string("client")) - assert(argc > 10 && "Need dest host"); - - assert(std::string(argv[8]) == std::string("client") || std::string(argv[8]) == std::string("server")); - are_server = (std::string(argv[8]) == std::string("server")); + assert(argc > 4 && "Need server port"); + assert(argc > 5 && "Need shared secret"); + assert(argc > 6 && "Need mode (client or server)"); + assert(argc > 7 && "Need src host"); + if (std::string(argv[6]) == std::string("client")) + assert(argc > 8 && "Need dest host"); + + assert(std::string(argv[6]) == std::string("client") || std::string(argv[6]) == std::string("server")); + are_server = (std::string(argv[6]) == std::string("server")); // // Parse args into variables @@ -579,25 +606,22 @@ int main(int argc, char* argv[]) { tun_dest = inet_addr(argv[2]); tun_src = inet_addr(argv[3]); - ipip_dest = inet_addr(argv[4]); - ipip_src = inet_addr(argv[5]); - if (are_server) { - local_port = atoi(argv[6]); + local_port = atoi(argv[4]); remote_port = 0; } else { // Get local port in do_init() so that we pick a new one on reload - remote_port = atoi(argv[6]); + remote_port = atoi(argv[4]); } - tcp_init_magic = atoll(argv[7]); + tcp_init_magic = atoll(argv[5]); - src = inet_addr(argv[9]); + src = inet_addr(argv[7]); memset(&dest, 0, sizeof(dest)); if (!are_server) { dest.sin_family = AF_INET; - dest.sin_addr.s_addr = inet_addr(argv[10]); + dest.sin_addr.s_addr = inet_addr(argv[8]); } // -- 2.30.2