From eea0d3040741e1c98d7ddd0db310885bf1ae0de5 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 3 Jan 2019 17:26:22 -0500 Subject: [PATCH] initial checkin on new host --- main.cpp | 607 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 607 insertions(+) create mode 100644 main.cpp diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000..3322298 --- /dev/null +++ b/main.cpp @@ -0,0 +1,607 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#if __has_include() +#include +#else +int getrandom(void* buf, size_t len, unsigned int flags) { + FILE* f = fopen("/dev/urandom", "r"); + int res = fread(buf, len, 1, f); + fclose(f); + return res; +} +#endif + + +static int tun_alloc(char *dev, int queues, int *fds) +{ + struct ifreq ifr; + int fd, err, i; + + if (!dev) + return -1; + + memset(&ifr, 0, sizeof(ifr)); + + /* Flags: IFF_TUN - TUN device (no Ethernet headers) + * IFF_TAP - TAP device + * + * IFF_NO_PI - Do not provide packet information + */ + ifr.ifr_flags = IFF_TUN | IFF_NO_PI | IFF_MULTI_QUEUE; + strncpy(ifr.ifr_name, dev, IFNAMSIZ); + + for (i = 0; i < queues; i++) { + if((fd = open("/dev/net/tun", O_RDWR)) < 0) + goto err; + err = ioctl(fd, TUNSETIFF, (void *) &ifr); + if (err) { + close(fd); + goto err; + } + fds[i] = fd; + } + return 0; +err: + for (--i; i >= 0; i--) + close(fds[i]); + return err; +} + +static int check_ip_header(const unsigned char* buf, ssize_t buf_len, uint8_t expected_type) { + if (buf_len < 20) { + // < size than IPv4? + return -1; + } + + if ((buf[0] & 0xf0) != (4 << 4)) { + // Only support IPv4 + return -1; + } + + uint8_t num_words = buf[0] & 0xf; + int header_size = num_words * 4; + if (header_size < 20) { + fprintf(stderr, "Invalid IPv4 IHL size (%d)\n", header_size); + return -1; + } + + if ((((uint16_t)buf[2]) << 8 | buf[3]) != buf_len) { + //fprintf(stderr, "Packet len %u != %ld\n", ((uint16_t)buf[2]) << 8 | buf[3], buf_len); + return -1; + } + + if (buf[9] != expected_type) { + fprintf(stderr, "Packet type %u, not %u\n", buf[9], expected_type); + return -1; + } + + return header_size; +} + +void print_packet(const unsigned char* buf, ssize_t buf_len) { + for (ssize_t i = 0; i < buf_len; ) { + for (int j = 0; i < buf_len && j < 20; i++, j++) + fprintf(stderr, "%02x", buf[i]); + fprintf(stderr, "\n"); + } +} + +static uint32_t data_checksum(const unsigned char* buf, size_t len) { + uint32_t sum = 0; + + while (len > 1) + { + sum += *buf++; + sum += (*buf++) << 8; + if (sum & 0x80000000) + sum = (sum & 0xFFFF) + (sum >> 16); + len -= 2; + } + + if (len & 1) + sum += *buf; + + return sum; +} + +static uint16_t finalize_data_checksum(uint32_t sum) { + while (sum >> 16) + sum = (sum & 0xFFFF) + (sum >> 16); + + return (uint16_t)(~sum); +} + +static uint16_t tcp_checksum(const unsigned char* buff, size_t len, in_addr_t src_addr, in_addr_t dest_addr) +{ + uint16_t *ip_src = (uint16_t*)&src_addr, *ip_dst = (uint16_t*)&dest_addr; + uint32_t sum = data_checksum(buff, len); + + sum += *(ip_src++); + sum += *ip_src; + sum += *(ip_dst++); + sum += *ip_dst; + sum += htons(IPPROTO_TCP); + sum += htons(len); + + return finalize_data_checksum(sum); +} + +static std::atomic highest_recvd_seq(0), cur_seq(0); +static std::atomic local_port(0); +static uint16_t remote_port; +static bool are_server; + +static void build_tcp_header(unsigned char* buf, uint32_t len, int syn, int synack, in_addr_t src_addr, in_addr_t dest_addr) { + buf[0 ] = local_port >> 8; // src port + buf[1 ] = local_port; // src port + buf[2 ] = remote_port >> 8; // dst port + buf[3 ] = remote_port; // dst port + + uint32_t seq = cur_seq.fetch_add((syn || synack) ? 1 : len, std::memory_order_acq_rel); + buf[4 ] = seq >> (8 * 3); // SEQ + buf[5 ] = seq >> (8 * 2); // SEQ + buf[6 ] = seq >> (8 * 1); // SEQ + buf[7 ] = seq >> (8 * 0); // SEQ + + uint32_t their_seq = highest_recvd_seq.load(std::memory_order_relaxed); + buf[8 ] = their_seq >> (8 * 3); // ACK + buf[9 ] = their_seq >> (8 * 2); // ACK + buf[10] = their_seq >> (8 * 1); // ACK + buf[11] = their_seq >> (8 * 0); // ACK + + bool longpkt = syn || synack; + buf[12] = (longpkt ? 6 : 5) << 4; // data offset + if (syn) + buf[13] = 1 << 1; // SYN + else if (synack) + buf[13] = (1 << 1) | (1 << 4); // SYN + ACK + else if (len == 0) + buf[13] = 1 << 4; // ACK + else + buf[13] = 3 << 3; // PSH + ACK + buf[14] = 0xff; // Window Size + buf[15] = 0xff; // Window Size + + buf[16] = 0x00; // Checksum + buf[17] = 0x00; // Checksum + buf[18] = 0x00; // URG Pointer + buf[19] = 0x00; // URG Pointer + + if (longpkt) { + buf[20] = 0x01; // NOP + buf[21] = 0x03; // Window Scale + buf[22] = 0x03; // Window Scale Option Length + buf[23] = 0x0e; // 1GB Window Size (0xffff << 0x0e) + } + + uint16_t checksum = tcp_checksum(buf, len + 20 + (longpkt ? 4 : 0), src_addr, dest_addr); + buf[16] = checksum; // Checksum + buf[17] = checksum >> 8; // Checksum +} + +static void build_ip_header(unsigned char* buf, uint32_t len, uint8_t proto, const in_addr_t& src_addr, const in_addr_t& dest_addr) { + buf[0 ] = (4 << 4) | 5; // IPv4 + IHL of 5 (20 bytes) + buf[1 ] = 0; // DSCP 0 + ECN 0 + buf[2 ] = (len + 20) >> 8; // Length + buf[3 ] = (len + 20); // Length + memset(buf + 4, 0, 4); // Identification and Fragment 0s + buf[6 ] = 1 << 6; // DF bit + buf[8 ] = 255; // TTL + buf[9 ] = proto; // Protocol Number + buf[10] = 0; // Checksum + buf[11] = 0; // Checksum + + memcpy(buf + 12, &src_addr, 4); + memcpy(buf + 16, &dest_addr, 4); + + uint16_t checksum = finalize_data_checksum(data_checksum(buf, 20)); + buf[10] = checksum; + buf[11] = checksum >> 8; +} + + +const signed char p_util_hexdigit[256] = +{ -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + 0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1, + -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, }; + +uint32_t hex_to_num(const char* buf) { + const char* pbegin = buf; + while (p_util_hexdigit[*buf] != -1) + buf++; + buf--; + uint32_t res = 0; + unsigned char* p1 = (unsigned char*)&res; + unsigned char* pend = p1 + 4; + while (buf >= pbegin && p1 < pend) { + *p1 = p_util_hexdigit[*buf--]; + if (buf >= pbegin) { + *p1 |= ((unsigned char)p_util_hexdigit[*buf--] << 4); + p1++; + } + } + + return res; +} + +#define TUN_IF_COUNT 4 +static int fdr; +static int fd[TUN_IF_COUNT]; +static struct sockaddr_in dest; +static in_addr_t src, tun_src, tun_dest, ipip_src, ipip_dest; +static const uint64_t tcp_init_magic = 0x1badcafedeadbeefULL; + +#define PENDING_MESSAGES_BUFF_SIZE (0x3000) +#define PACKET_READ_SIZE 1500 +#define THREAD_POLL_SLEEP_MICS 50 +struct MessageQueue { + std::tuple, ssize_t> messagesPendingRingBuff[PENDING_MESSAGES_BUFF_SIZE]; + std::atomic nextPendingMessage, nextUndefinedMessage; + MessageQueue() : nextPendingMessage(0), nextUndefinedMessage(0) {} + MessageQueue(MessageQueue&& q) =delete; + MessageQueue(MessageQueue& q) =delete; +}; + +static MessageQueue tcp_to_tun_queue; +static std::chrono::steady_clock::time_point last_ack_recv; + +static void tcp_to_tun() { + unsigned char buf[1500]; + struct sockaddr_in pkt_src; + memset(&pkt_src, 0, sizeof(pkt_src)); + + while (1) { + socklen_t hostsz = sizeof(pkt_src); + ssize_t nread = recvfrom(fdr, buf, sizeof(buf), 0, (struct sockaddr*)&pkt_src, &hostsz); + if (nread < 0) { + fprintf (stderr, "Failed to read tcp raw sock\n"); + exit(-1); + } + + if (tcp_to_tun_queue.nextPendingMessage == (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) + continue; + + auto& new_msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextUndefinedMessage]; + std::get<0>(new_msg) = pkt_src; + memcpy(std::get<1>(new_msg).data(), buf, nread); + std::get<2>(new_msg) = nread; + + tcp_to_tun_queue.nextUndefinedMessage = (tcp_to_tun_queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE; + } +} + +static void tcp_to_tun_queue_process() { + do { + while (tcp_to_tun_queue.nextUndefinedMessage == tcp_to_tun_queue.nextPendingMessage) { + std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS)); + } + + auto& msg = tcp_to_tun_queue.messagesPendingRingBuff[tcp_to_tun_queue.nextPendingMessage]; + + const sockaddr_in& pkt_src = std::get<0>(msg); + unsigned char* buf = std::get<1>(msg).data(); + ssize_t nread = std::get<2>(msg); + + int header_size = check_ip_header(buf, nread, 0x06); // Only support TCP + if (header_size < 0) + continue; + + if (nread - header_size < 20) { + fprintf(stderr, "Short TCP packet\n"); + continue; + } + + unsigned char* tcp_buf = buf + header_size; + + if (((tcp_buf[2] << 8) | tcp_buf[3]) != local_port) continue; + + bool syn = tcp_buf[13] & (1 << 1); + bool ack = tcp_buf[13] & (1 << 4); + + if (are_server && syn && !ack) { + // We're a server and just got a client + if (tcp_buf[4 ] != uint8_t(tcp_init_magic >> (7 * 8)) || + tcp_buf[5 ] != uint8_t(tcp_init_magic >> (6 * 8)) || + tcp_buf[6 ] != uint8_t(tcp_init_magic >> (5 * 8)) || + tcp_buf[7 ] != uint8_t(tcp_init_magic >> (4 * 8)) || + tcp_buf[8 ] != uint8_t(tcp_init_magic >> (3 * 8)) || + tcp_buf[9 ] != uint8_t(tcp_init_magic >> (2 * 8)) || + tcp_buf[10] != uint8_t(tcp_init_magic >> (1 * 8)) || + tcp_buf[11] != uint8_t(tcp_init_magic >> (0 * 8))) + continue; + + fprintf(stderr, "Got SYN, sending SYNACK\n"); + remote_port = (tcp_buf[0] << 8) | tcp_buf[1]; + dest = pkt_src; + } + + if (((tcp_buf[0] << 8) | tcp_buf[1]) != remote_port) continue; + if (pkt_src.sin_addr.s_addr != dest.sin_addr.s_addr) continue; + + uint8_t num_words = (tcp_buf[12] & 0xf0) >> 4; + int tcp_header_size = num_words * 4; + if (tcp_header_size < 20) { + fprintf(stderr, "Invalid TCP header size (%d)\n", tcp_header_size); + continue; + } + + highest_recvd_seq = ((((uint32_t)tcp_buf[4]) << (8 * 3)) | + (((uint32_t)tcp_buf[5]) << (8 * 2)) | + (((uint32_t)tcp_buf[6]) << (8 * 1)) | + (((uint32_t)tcp_buf[7]) << (8 * 0))) + + (syn ? 1 : nread - header_size - tcp_header_size); + + if (ack) + last_ack_recv = std::chrono::steady_clock::now(); + + if (are_server && syn && !ack) { + build_tcp_header(tcp_buf, 0, 0, 1, src, dest.sin_addr.s_addr); + + ssize_t res = sendto(fdr, tcp_buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest)); + if (res < 0) { + int err = errno; + fprintf(stderr, "Failed to send SYNACK with err %d (%s)\n", err, strerror(err)); + } + } else if (!syn) { + tcp_buf += tcp_header_size - 20; + + // Replace TCP with IPv4 header + //build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x01, ipip_dest, ipip_src); // ICMP + build_ip_header(tcp_buf, nread - tcp_header_size - header_size, 0x11, ipip_dest, ipip_src); // UDP + // Add IPIP header + build_ip_header(tcp_buf - 20, nread - tcp_header_size - header_size + 20, 0x04, tun_dest, tun_src); + + write(fd[0], tcp_buf - 20, nread - tcp_header_size - header_size + 40); + } + } while ((tcp_to_tun_queue.nextPendingMessage = (tcp_to_tun_queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true); +} + +static std::atomic_int tun_if_thread(0), tun_if_process_thread(0); +static std::atomic_bool pause_tun_read_reinit_tcp(false); +static MessageQueue tun_to_tcp_queue[TUN_IF_COUNT]; + +static void tun_to_tcp() { + unsigned char buf[PACKET_READ_SIZE]; + + int thread = tun_if_thread.fetch_add(1); + MessageQueue& queue = tun_to_tcp_queue[thread]; + + while (1) { + ssize_t nread = read(fd[thread], buf, sizeof(buf)); + if (pause_tun_read_reinit_tcp) + continue; + + if (nread < 0) { + fprintf (stderr, "Failed to read tun if\n"); + continue; + } + + if (queue.nextPendingMessage == (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) + continue; + + auto& new_msg = queue.messagesPendingRingBuff[queue.nextUndefinedMessage]; + memcpy(std::get<1>(new_msg).data(), buf, nread); + std::get<2>(new_msg) = nread; + + queue.nextUndefinedMessage = (queue.nextUndefinedMessage + 1) % PENDING_MESSAGES_BUFF_SIZE; + } +} + +static void tun_to_tcp_queue_process() { + int thread = tun_if_process_thread.fetch_add(1); + MessageQueue& queue = tun_to_tcp_queue[thread]; + + do { + while (queue.nextUndefinedMessage == queue.nextPendingMessage) { + std::this_thread::sleep_for(std::chrono::microseconds(THREAD_POLL_SLEEP_MICS)); + } + if (pause_tun_read_reinit_tcp) + continue; + + auto& msg = queue.messagesPendingRingBuff[queue.nextPendingMessage]; + + unsigned char* buf = std::get<1>(msg).data(); + ssize_t nread = std::get<2>(msg); + + int header_size = check_ip_header(buf, nread, 0x04); // Only support IPIP + if (header_size < 0) + continue; + + int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x11); // Only support UDP + //int internal_header_size = check_ip_header(buf + header_size, nread - header_size, 0x01); // Only support ICMP + if (internal_header_size < 0) + continue; + + if (internal_header_size + header_size + 8 > nread) { + fprintf(stderr, "Short UDP-in-IPIP packet\n"); + continue; + } + + size_t tcp_start_offset = header_size + internal_header_size - 20; + build_tcp_header(buf + tcp_start_offset, nread - tcp_start_offset - 20, 0, 0, src, dest.sin_addr.s_addr); + + ssize_t res = sendto(fdr, buf + tcp_start_offset, nread - tcp_start_offset, 0, (struct sockaddr*)&dest, sizeof(dest)); + if (res < 0) { + int err = errno; + fprintf(stderr, "Failed to send with err %d (%s)\n", err, strerror(err)); + } + + } while ((queue.nextPendingMessage = (queue.nextPendingMessage + 1) % PENDING_MESSAGES_BUFF_SIZE) || true); +} + +int do_init() { + // + // Send SYN and SYN/ACK + // + + if (!are_server) { + if (local_port) // Doing a re-init + pause_tun_read_reinit_tcp = true; + + uint16_t local_port_tmp = 0; + while (local_port_tmp < 1024) + assert(getrandom(&local_port_tmp, sizeof(local_port_tmp), 0) == sizeof(local_port_tmp)); + + local_port = local_port_tmp; + } + + uint32_t starting_ack = 0, starting_seq = 0; + memcpy(&starting_ack, &tcp_init_magic, 4); + memcpy(&starting_seq, ((const unsigned char*)&tcp_init_magic) + 4, 4); + highest_recvd_seq = starting_ack; + cur_seq = starting_seq; + + if (!pause_tun_read_reinit_tcp) { // Not doing a re-init + std::thread t(&tcp_to_tun); + std::thread t2(&tcp_to_tun_queue_process); + t.detach(); + t2.detach(); + } + + unsigned char buf[1500]; + if (!are_server) { + build_tcp_header(buf, 0, 1, 0, src, dest.sin_addr.s_addr); + ssize_t res = sendto(fdr, buf, 20 + 4, 0, (struct sockaddr*)&dest, sizeof(dest)); + if (res < 0) { + int err = errno; + fprintf(stderr, "Failed to send initial SYN with err %d (%s)\n", err, strerror(err)); + return -1; + } + } + + int i; + for (i = 0; i < 1000 && highest_recvd_seq == starting_ack; i++) + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + if (i == 1000) // Will come back in 10 seconds + return 0; + + if (!are_server) { + fprintf(stderr, "Got SYNACK, sending ACK and starting tun listen\n"); + + build_tcp_header(buf, 0, 0, 0, src, dest.sin_addr.s_addr); + ssize_t res = sendto(fdr, buf, 20, 0, (struct sockaddr*)&dest, sizeof(dest)); + if (res < 0) { + int err = errno; + fprintf(stderr, "Failed to send initial ACK with err %d (%s)\n", err, strerror(err)); + return -1; + } + } + + if (pause_tun_read_reinit_tcp) { + pause_tun_read_reinit_tcp = false; + } else { + for (int i = 0; i < TUN_IF_COUNT; i++) { + std::thread t3(&tun_to_tcp); + std::thread t4(&tun_to_tcp_queue_process); + t3.detach(); + t4.detach(); + } + } + + return 0; +} + +int main(int argc, char* argv[]) { + assert(argc > 1 && "Need tun name"); + assert(argc > 2 && "Need tun remote host"); + assert(argc > 3 && "Need tun local host"); + assert(argc > 4 && "Need ipip remote host"); + assert(argc > 5 && "Need ipip local host"); + assert(argc > 6 && "Need server port"); + assert(argc > 7 && "Need mode (client or server)"); + assert(argc > 8 && "Need src host"); + if (std::string(argv[7]) == std::string("client")) + assert(argc > 9 && "Need dest host"); + + assert(std::string(argv[7]) == std::string("client") || std::string(argv[7]) == std::string("server")); + are_server = (std::string(argv[7]) == std::string("server")); + + // + // Parse args into variables + // + + char tun_name[IFNAMSIZ]; + + memset(tun_name, 0, sizeof(tun_name)); + strcpy(tun_name, argv[1]); + + tun_dest = inet_addr(argv[2]); + tun_src = inet_addr(argv[3]); + + ipip_dest = inet_addr(argv[4]); + ipip_src = inet_addr(argv[5]); + + if (are_server) { + local_port = atoi(argv[6]); + remote_port = 0; + } else { + // Get local port in do_init() so that we pick a new one on reload + remote_port = atoi(argv[6]); + } + + src = inet_addr(argv[8]); + + memset(&dest, 0, sizeof(dest)); + if (!are_server) { + dest.sin_family = AF_INET; + dest.sin_addr.s_addr = inet_addr(argv[9]); + } + + // + // Create tun and bind to sockets... + // + + if (tun_alloc(tun_name, TUN_IF_COUNT, fd) < 0) { + fprintf(stderr, "Failed to alloc tun if\n"); + return -1; + } + + fdr = socket(AF_INET, SOCK_RAW, IPPROTO_TCP); + if (fdr < 0) { + fprintf(stderr, "Failed to get raw socket\n"); + return -1; + } + + int res = do_init(); + if (res) + return res; + + while (true) { + std::this_thread::sleep_for(std::chrono::seconds(15)); + if (!are_server && last_ack_recv < std::chrono::steady_clock::now() - std::chrono::seconds(15)) { + res = do_init(); + if (res) + return res; + } + } +} -- 2.30.2