From 974fe2805f8a34d0fa88c05c2374ed69c51d7674 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Tue, 6 Feb 2024 18:03:22 +0000 Subject: [PATCH] Make HTTP server large-req DoS safe rather than using a frontend --- Cargo.toml | 4 +- src/http.rs | 130 ++++++++++++++++++++++++++++++---------------------- 2 files changed, 76 insertions(+), 58 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 32ac430..edbcebd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,20 +13,18 @@ rust-version = "1.60.0" [features] std = [] tokio = ["tokio_crate/net", "tokio_crate/io-util", "std"] -build_server = ["tokio", "tokio-stream", "tokio_crate/rt-multi-thread", "tokio_crate/macros"] +build_server = ["tokio", "tokio_crate/rt-multi-thread", "tokio_crate/macros"] [dependencies] ring = { version = "0.17", default-features = false, features = ["alloc"] } hex_lit = { version = "0.1", default-features = false, features = ["rust_v_1_46"] } tokio_crate = { package = "tokio", version = "1.0", default-features = false, optional = true } -tokio-stream = { version = "0.1", default-features = false, optional = true, features = ["io-util"] } [dev-dependencies] hex-conservative = { version = "0.1", default-features = false, features = ["alloc"] } base64 = "0.21" rand = { version = "0.8", default-features = false, features = ["getrandom"] } tokio_crate = { package = "tokio", version = "1.0", features = ["rt", "macros", "net", "rt-multi-thread"] } -tokio-stream = { version = "0.1", default-features = false, features = ["io-util"] } minreq = { version = "2.0" } [lib] diff --git a/src/http.rs b/src/http.rs index e5cc1c2..53a0b18 100644 --- a/src/http.rs +++ b/src/http.rs @@ -36,10 +36,7 @@ mod imp { use std::net::SocketAddr; use tokio::net::TcpListener; - use tokio::io::{AsyncWriteExt, AsyncBufReadExt, BufReader}; - - use tokio_stream::wrappers::LinesStream; - use tokio_stream::StreamExt; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; pub(super) async fn run_server(listener: TcpListener, resolver_sockaddr: SocketAddr) { loop { @@ -47,63 +44,86 @@ mod imp { tokio::spawn(async move { let mut response = ("400 Bad Request", "Bad Request"); 'ret_err: loop { // goto label - let buf_reader = BufReader::new(&mut socket); - let mut request_headers = LinesStream::new(buf_reader.lines()) - .take_while(|line_res| if let Ok(line) = line_res { !line.is_empty() } else { false }); - if let Some(Ok(request)) = request_headers.next().await { - let mut parts = request.split(" "); - let (verb, path, http_vers); - if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; } - if let Some(p) = parts.next() { path = p; } else { break 'ret_err; } - if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; } - if parts.next().is_some() { break; } - if verb != "GET" { break; } - if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; } - - const PATH_PFX: &'static str = "/dnssecproof?"; - if !path.starts_with(PATH_PFX) { - response = ("404 Not Found", "Not Found"); - break 'ret_err; - } - - let (mut d, mut t) = ("", ""); - for arg in path[PATH_PFX.len()..].split("&") { - if let Some((k, v)) = arg.split_once("=") { - if k == "d" { - d = v; - } else if k == "t" { - t = v; + let mut buf = [0; 4096]; + let mut buf_pos = 0; + 'read_req: loop { + if buf_pos == buf.len() { response.1 = "Request Too Large"; break 'ret_err; } + let read_res = { socket.read(&mut buf[buf_pos..]).await }; + match read_res { + Ok(0) => return, + Ok(len) => { + buf_pos += len; + for window in buf[..buf_pos].windows(2) { + if window == b"\r\n" { break 'read_req; } } - } else { break 'ret_err; } + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(_) => return, } - - if d == "" || t == "" { - response = ("500 Bad Request", "Missing d or t URI parameters"); + } + let request; + if let Ok(s) = std::str::from_utf8(&buf[..buf_pos]) { + if let Some((r, _)) = s.split_once("\r\n") { + request = r; + } else { + debug_assert!(false); break 'ret_err; } - let query_name = if let Ok(domain) = Name::try_from(d) { domain } else { - response = ("500 Bad Request", "Failed to parse domain, make sure it ends with ."); - break 'ret_err; - }; - let proof_res = match t.to_ascii_uppercase().as_str() { - "TXT" => build_txt_proof_async(resolver_sockaddr, query_name).await, - "TLSA" => build_tlsa_proof_async(resolver_sockaddr, query_name).await, - "A" => build_a_proof_async(resolver_sockaddr, query_name).await, - "AAAA" => build_aaaa_proof_async(resolver_sockaddr, query_name).await, - _ => break 'ret_err, - }; - let proof = if let Ok(proof) = proof_res { proof } else { - response = ("404 Not Found", "Failed to generate proof for given domain"); - break 'ret_err; - }; + } else { + break 'ret_err; + } + + let mut parts = request.split(" "); + let (verb, path, http_vers); + if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; } + if let Some(p) = parts.next() { path = p; } else { break 'ret_err; } + if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; } + if parts.next().is_some() { break; } + if verb != "GET" { break; } + if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; } + + const PATH_PFX: &'static str = "/dnssecproof?"; + if !path.starts_with(PATH_PFX) { + response = ("404 Not Found", "Not Found"); + break 'ret_err; + } + + let (mut d, mut t) = ("", ""); + for arg in path[PATH_PFX.len()..].split("&") { + if let Some((k, v)) = arg.split_once("=") { + if k == "d" { + d = v; + } else if k == "t" { + t = v; + } + } else { break 'ret_err; } + } - let _ = socket.write_all( - format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", proof.len()).as_bytes() - ).await; - let _ = socket.write_all(&proof).await; - return; + if d == "" || t == "" { + response.1 = "Missing d or t URI parameters"; + break 'ret_err; } - break; + let query_name = if let Ok(domain) = Name::try_from(d) { domain } else { + response.1 = "Failed to parse domain, make sure it ends with ."; + break 'ret_err; + }; + let proof_res = match t.to_ascii_uppercase().as_str() { + "TXT" => build_txt_proof_async(resolver_sockaddr, query_name).await, + "TLSA" => build_tlsa_proof_async(resolver_sockaddr, query_name).await, + "A" => build_a_proof_async(resolver_sockaddr, query_name).await, + "AAAA" => build_aaaa_proof_async(resolver_sockaddr, query_name).await, + _ => break 'ret_err, + }; + let proof = if let Ok(proof) = proof_res { proof } else { + response = ("404 Not Found", "Failed to generate proof for given domain"); + break 'ret_err; + }; + + let _ = socket.write_all( + format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", proof.len()).as_bytes() + ).await; + let _ = socket.write_all(&proof).await; + return; } let _ = socket.write_all(format!( "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: text/plain\r\n\r\n{}", -- 2.30.2