From 340b163ec6af224bb7ed8cb7dced608621a4dd28 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Wed, 16 Mar 2022 18:03:58 +0000 Subject: [PATCH] Use a custom header to support all packets and, importantly, frags --- main.cpp | 107 ++++++++++++++++++++++++------------------------------- 1 file changed, 47 insertions(+), 60 deletions(-) diff --git a/main.cpp b/main.cpp index e070706..c116e89 100644 --- a/main.cpp +++ b/main.cpp @@ -29,7 +29,19 @@ int getrandom(void* buf, size_t len, unsigned int flags) { } #endif -#define PACKET_READ_OFFS (20-16) +struct pkt_hdr { + // Note that ordering of the first three uint16_ts must match the wire! + uint16_t len; + uint16_t id; + uint16_t frag_flags: 3, + offset : 13; + uint8_t proto; + uint8_t padding; +}; +static_assert(sizeof(struct pkt_hdr) == 8); + +// Have to leave room to add a TCP header plus above header (after dropping IP header) +#define PACKET_READ_OFFS sizeof(pkt_hdr) #define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS) static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int queues, int *fds) @@ -62,7 +74,7 @@ static int tun_alloc(char *dev, const char* local_ip, const char* remote_ip, int fds[i] = fd; } - sprintf(buf, "ip link set %s mtu %d", dev, PACKET_READ_SIZE); + sprintf(buf, "ip link set dev %s mtu %ld", dev, PACKET_READ_SIZE - 20); err = system(buf); if (err) goto err; @@ -85,7 +97,7 @@ err: return err; } -static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t expected_type) { +static int check_ip_header(const unsigned char* buf, ssize_t buf_len, int16_t expected_type, struct pkt_hdr *gen_hdr) { if (buf_len < 20) { // < size than IPv4? return -1; @@ -103,31 +115,23 @@ static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t ex return -1; } - if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) { + /*if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) { //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len); return -1; - } + }*/ - if (buf[9] != expected_type) { + if (expected_type > 0 && buf[9] != expected_type) { fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type); return -1; } - return header_size; -} - -#define UDP_PORTS 4242 -static int check_udp_header(const unsigned char* buf, ssize_t buf_len) { - if (buf_len < 8) { - return -1; + if (gen_hdr) { + memcpy(gen_hdr, buf + 2, 6); // len, id, frag offset + gen_hdr->proto = buf[9]; + gen_hdr->padding = 0; } - if (((((uint16_t)buf[0]) << 8) | (buf[1])) != UDP_PORTS) { - return -1; - } - if (((((uint16_t)buf[2]) << 8) | (buf[3])) != UDP_PORTS) { - return -1; - } - return 8; + + return header_size; } void print_packet(const unsigned char* buf, ssize_t buf_len) { @@ -231,16 +235,12 @@ static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int syna buf[17] = checksum >> 8; // Checksum } -// len is data length, not including headers! -static void build_ip_header(unsigned char* buf, uint16_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) { +static void build_ip_header(unsigned char* buf, struct pkt_hdr hdr, const in_addr_t& src_addr, const in_addr_t& dest_addr) { buf[0 ] = (4 << 4) | 5; // IPv4 + IHL of 5 (20 bytes) buf[1 ] = 0; // DSCP 0 + ECN 0 - buf[2 ] = (len + 20) >> 8; // Length - buf[3 ] = (len + 20); // Length - memset(buf + 4, 0, 4); // Identification and Fragment 0s - buf[6 ] = 1 << 6; // DF bit + memcpy(buf + 2, &hdr, 6); // length, identification, flags, and offset buf[8 ] = 255; // TTL - buf[9 ] = proto; // Protocol Number + buf[9 ] = hdr.proto; // Protocol Number buf[10] = 0; // Checksum buf[11] = 0; // Checksum @@ -252,21 +252,6 @@ static void build_ip_header(unsigned char* buf, uint16_t len, uint8_t proto, con buf[11] = checksum >> 8; } -// length is data length, not including headers -static void build_udp_header(unsigned char* buf, uint16_t len) { - buf[0] = UDP_PORTS >> 8; - buf[1] = UDP_PORTS & 0xff; - buf[2] = UDP_PORTS >> 8; - buf[3] = UDP_PORTS & 0xff; - buf[4] = (len + 8) >> 8; - buf[5] = (len + 8); - - // Checksum 0 for now - buf[6] = 0; - buf[7] = 0; -} - - const signed char p_util_hexdigit[256] = { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, @@ -311,9 +296,7 @@ static struct sockaddr_in dest; static in_addr_t src, tun_src, tun_dest; static uint64_t tcp_init_magic; -#define PENDING_MESSAGES_BUFF_SIZE (0x3000) -#define PACKET_READ_OFFS (20-16) -#define PACKET_READ_SIZE (1500 - PACKET_READ_OFFS) +#define PENDING_MESSAGES_BUFF_SIZE (0x800) #define THREAD_POLL_SLEEP_MICS 250 struct MessageQueue { std::tuple, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE]; @@ -327,7 +310,7 @@ static MessageQueue tcp_to_tun_queue; static std::chrono::steady_clock::time_point last_ack_recv; static void tcp_to_tun() { - unsigned char buf[1500]; + unsigned char buf[PACKET_READ_SIZE + PACKET_READ_OFFS]; struct sockaddr_in pkt_src; memset(&pkt_src, 0, sizeof(pkt_src)); @@ -363,7 +346,7 @@ static void tcp_to_tun_queue_process() { unsigned char* buf = std::get<1>(msg).data(); ssize_t nread = std::get<2>(msg); - int header_size = check_ip_header(buf, nread, 0x06); // Only support TCP + int header_size = check_ip_header(buf, nread, 0x06, NULL); // Only support TCP if (header_size < 0) continue; @@ -424,11 +407,13 @@ static void tcp_to_tun_queue_process() { fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err)); } } else if (!syn) { - tcp_buf += tcp_header_size - 20 - 8; // IPv4 header is 20 bytes + UDP header is 8 bytes + if (nread < (ssize_t)(tcp_header_size + header_size + sizeof(struct pkt_hdr))) continue; + struct pkt_hdr hdr; + memcpy(&hdr, tcp_buf + tcp_header_size, sizeof(struct pkt_hdr)); + tcp_buf += tcp_header_size + sizeof(struct pkt_hdr) - 20; // IPv4 header is 20 bytes - // Replace TCP with IPv4 header - build_udp_header(tcp_buf + 20, nread - tcp_header_size - header_size); - build_ip_header(tcp_buf, nread - tcp_header_size - header_size + 8, 0x11, tun_dest, tun_src); // UDP + // Replace TCP + pkt_hdr with IPv4 header + build_ip_header(tcp_buf, hdr, tun_dest, tun_src); write(fd[0], tcp_buf, nread - tcp_header_size - header_size + 20 + 8); } @@ -482,24 +467,26 @@ static void tun_to_tcp_queue_process() { unsigned char* buf = std::get<1>(msg).data(); ssize_t nread = std::get<2>(msg); - int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, 0x11); // Only support UDP - if (header_size < 0) + struct pkt_hdr hdr; + int header_size = check_ip_header(buf + PACKET_READ_OFFS, nread, -1, &hdr); // Any paket type is okay + if (header_size < 20) continue; - if (header_size + 8 > nread) { - fprintf(stderr, "Short UDP packet\n"); + if (header_size > nread) { + fprintf(stderr, "Short packet\n"); continue; } - if (check_udp_header(buf + PACKET_READ_OFFS + header_size, nread - header_size) != 8) { - fprintf(stderr, "Bad UDP header\n"); + if ((size_t)nread > 1500 - 20 - sizeof(struct pkt_hdr)) { // Packets must fit in 1500 bytes with a TCP and packet hdr + fprintf(stderr, "Long packet\n"); continue; } - unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size + 8 - 20; - build_tcp_header(tcp_start, nread - header_size - 8, 0, 0, src, dest.sin_addr.s_addr); + unsigned char* tcp_start = buf + PACKET_READ_OFFS + header_size - 20 - sizeof(struct pkt_hdr); + memcpy(tcp_start + 20, &hdr, sizeof(struct pkt_hdr)); + build_tcp_header(tcp_start, nread - header_size + sizeof(struct pkt_hdr), 0, 0, src, dest.sin_addr.s_addr); - ssize_t res = sendto(fdr, tcp_start, nread + 20 - 8 - header_size, 0, (struct sockaddr*)&dest, sizeof(dest)); + ssize_t res = sendto(fdr, tcp_start, nread + 20 + sizeof(struct pkt_hdr) - header_size, 0, (struct sockaddr*)&dest, sizeof(dest)); if (res < 0) { int err = errno; fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err)); -- 2.30.2