X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;ds=sidebyside;f=src%2Fhttp.rs;h=073ba66b50ce701ec10cdffb14e0ef24401a219f;hb=8b836e295ec81c50907fe1f60ca43c82df792025;hp=e5cc1c24947cf5945fecf147ec1297aade995f79;hpb=c89e9c90894f383527f9eb407a8dc1864b931b9c;p=dnssec-prover diff --git a/src/http.rs b/src/http.rs index e5cc1c2..073ba66 100644 --- a/src/http.rs +++ b/src/http.rs @@ -5,11 +5,13 @@ extern crate alloc; pub mod rr; -pub mod validation; -mod ser; +pub mod ser; pub mod query; -#[cfg(feature = "tokio")] +#[cfg(feature = "validation")] +pub mod validation; + +#[cfg(any(feature = "build_server", all(feature = "tokio", feature = "validation")))] use tokio_crate as tokio; #[cfg(feature = "build_server")] @@ -26,7 +28,7 @@ async fn main() { imp::run_server(listener, resolver_sockaddr).await; } -#[cfg(feature = "tokio")] +#[cfg(any(feature = "build_server", all(feature = "tokio", feature = "validation")))] mod imp { use super::*; @@ -36,10 +38,7 @@ mod imp { use std::net::SocketAddr; use tokio::net::TcpListener; - use tokio::io::{AsyncWriteExt, AsyncBufReadExt, BufReader}; - - use tokio_stream::wrappers::LinesStream; - use tokio_stream::StreamExt; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; pub(super) async fn run_server(listener: TcpListener, resolver_sockaddr: SocketAddr) { loop { @@ -47,66 +46,92 @@ mod imp { tokio::spawn(async move { let mut response = ("400 Bad Request", "Bad Request"); 'ret_err: loop { // goto label - let buf_reader = BufReader::new(&mut socket); - let mut request_headers = LinesStream::new(buf_reader.lines()) - .take_while(|line_res| if let Ok(line) = line_res { !line.is_empty() } else { false }); - if let Some(Ok(request)) = request_headers.next().await { - let mut parts = request.split(" "); - let (verb, path, http_vers); - if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; } - if let Some(p) = parts.next() { path = p; } else { break 'ret_err; } - if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; } - if parts.next().is_some() { break; } - if verb != "GET" { break; } - if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; } - - const PATH_PFX: &'static str = "/dnssecproof?"; - if !path.starts_with(PATH_PFX) { - response = ("404 Not Found", "Not Found"); - break 'ret_err; - } - - let (mut d, mut t) = ("", ""); - for arg in path[PATH_PFX.len()..].split("&") { - if let Some((k, v)) = arg.split_once("=") { - if k == "d" { - d = v; - } else if k == "t" { - t = v; + let mut buf = [0; 4096]; + let mut buf_pos = 0; + 'read_req: loop { + if buf_pos == buf.len() { response.1 = "Request Too Large"; break 'ret_err; } + let read_res = { socket.read(&mut buf[buf_pos..]).await }; + match read_res { + Ok(0) => return, + Ok(len) => { + buf_pos += len; + for window in buf[..buf_pos].windows(2) { + if window == b"\r\n" { break 'read_req; } } - } else { break 'ret_err; } + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}, + Err(_) => return, } - - if d == "" || t == "" { - response = ("500 Bad Request", "Missing d or t URI parameters"); + } + let request; + if let Ok(s) = std::str::from_utf8(&buf[..buf_pos]) { + if let Some((r, _)) = s.split_once("\r\n") { + request = r; + } else { + debug_assert!(false); break 'ret_err; } - let query_name = if let Ok(domain) = Name::try_from(d) { domain } else { - response = ("500 Bad Request", "Failed to parse domain, make sure it ends with ."); - break 'ret_err; - }; - let proof_res = match t.to_ascii_uppercase().as_str() { - "TXT" => build_txt_proof_async(resolver_sockaddr, query_name).await, - "TLSA" => build_tlsa_proof_async(resolver_sockaddr, query_name).await, - "A" => build_a_proof_async(resolver_sockaddr, query_name).await, - "AAAA" => build_aaaa_proof_async(resolver_sockaddr, query_name).await, - _ => break 'ret_err, - }; - let proof = if let Ok(proof) = proof_res { proof } else { - response = ("404 Not Found", "Failed to generate proof for given domain"); - break 'ret_err; - }; + } else { + break 'ret_err; + } + + let mut parts = request.split(" "); + let (verb, path, http_vers); + if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; } + if let Some(p) = parts.next() { path = p; } else { break 'ret_err; } + if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; } + if parts.next().is_some() { break; } + if verb != "GET" { break; } + if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; } + + const PATH_PFX: &'static str = "/dnssecproof?"; + if !path.starts_with(PATH_PFX) { + response = ("404 Not Found", "Not Found"); + break 'ret_err; + } + + let (mut d, mut t) = ("", ""); + for arg in path[PATH_PFX.len()..].split("&") { + if let Some((k, v)) = arg.split_once("=") { + if k == "d" { + d = v; + } else if k == "t" { + t = v; + } + } else { break 'ret_err; } + } - let _ = socket.write_all( - format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", proof.len()).as_bytes() - ).await; - let _ = socket.write_all(&proof).await; - return; + if d == "" || t == "" { + response.1 = "Missing d or t URI parameters"; + break 'ret_err; } - break; + let query_name = if let Ok(domain) = Name::try_from(d) { domain } else { + response.1 = "Failed to parse domain, make sure it ends with ."; + break 'ret_err; + }; + let proof_res = match t.to_ascii_uppercase().as_str() { + "TXT" => build_txt_proof_async(resolver_sockaddr, &query_name).await, + "TLSA" => build_tlsa_proof_async(resolver_sockaddr, &query_name).await, + "A" => build_a_proof_async(resolver_sockaddr, &query_name).await, + "AAAA" => build_aaaa_proof_async(resolver_sockaddr, &query_name).await, + _ => break 'ret_err, + }; + let (proof, cache_ttl) = if let Ok(proof) = proof_res { proof } else { + response = ("404 Not Found", "Failed to generate proof for given domain"); + break 'ret_err; + }; + + let _ = socket.write_all( + format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/octet-stream\r\nCache-Control: public, max-age={}, s-maxage={}\r\nAccess-Control-Allow-Origin: *\r\n\r\n", + proof.len(), cache_ttl, cache_ttl + ).as_bytes() + ).await; + let _ = socket.write_all(&proof).await; + return; } let _ = socket.write_all(format!( - "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: text/plain\r\n\r\n{}", + "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: text/plain\r\nAccess-Control-Allow-Origin: *\r\n\r\n{}", response.0, response.1.len(), response.1, ).as_bytes()).await; }); @@ -114,11 +139,12 @@ mod imp { } } -#[cfg(all(feature = "tokio", test))] +#[cfg(all(feature = "tokio", feature = "validation", test))] mod test { use super::*; - use crate::validation::{parse_rr_stream, verify_rr_stream}; + use crate::ser::parse_rr_stream; + use crate::validation::verify_rr_stream; use minreq; @@ -137,4 +163,36 @@ mod test { let verified_rrs = verify_rr_stream(&rrs).unwrap(); assert_eq!(verified_rrs.verified_rrs.len(), 1); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_lookup_a() { + let ns = "9.9.9.9:53".parse().unwrap(); + let listener = tokio::net::TcpListener::bind("127.0.0.1:17493").await + .expect("Failed to bind to socket"); + tokio::spawn(imp::run_server(listener, ns)); + let resp = minreq::get( + "http://127.0.0.1:17493/dnssecproof?d=cloudflare.com.&t=a" + ).send().unwrap(); + + assert_eq!(resp.status_code, 200); + let rrs = parse_rr_stream(resp.as_bytes()).unwrap(); + let verified_rrs = verify_rr_stream(&rrs).unwrap(); + assert!(verified_rrs.verified_rrs.len() >= 1); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_lookup_tlsa() { + let ns = "1.1.1.1:53".parse().unwrap(); + let listener = tokio::net::TcpListener::bind("127.0.0.1:17494").await + .expect("Failed to bind to socket"); + tokio::spawn(imp::run_server(listener, ns)); + let resp = minreq::get( + "http://127.0.0.1:17494/dnssecproof?d=_25._tcp.mail.as397444.net.&t=TLSA" + ).send().unwrap(); + + assert_eq!(resp.status_code, 200); + let rrs = parse_rr_stream(resp.as_bytes()).unwrap(); + let verified_rrs = verify_rr_stream(&rrs).unwrap(); + assert_eq!(verified_rrs.verified_rrs.len(), 1); + } }