Accept query names by reference for downstream flexibility
[dnssec-prover] / src / http.rs
index e5cc1c24947cf5945fecf147ec1297aade995f79..f344372e4d72ca1993b237470f4a4b23fc32c514 100644 (file)
@@ -5,11 +5,13 @@
 extern crate alloc;
 
 pub mod rr;
-pub mod validation;
-mod ser;
+pub mod ser;
 pub mod query;
 
-#[cfg(feature = "tokio")]
+#[cfg(feature = "validation")]
+pub mod validation;
+
+#[cfg(any(feature = "build_server", all(feature = "tokio", feature = "validation")))]
 use tokio_crate as tokio;
 
 #[cfg(feature = "build_server")]
@@ -26,7 +28,7 @@ async fn main() {
        imp::run_server(listener, resolver_sockaddr).await;
 }
 
-#[cfg(feature = "tokio")]
+#[cfg(any(feature = "build_server", all(feature = "tokio", feature = "validation")))]
 mod imp {
        use super::*;
 
@@ -36,10 +38,7 @@ mod imp {
        use std::net::SocketAddr;
 
        use tokio::net::TcpListener;
-       use tokio::io::{AsyncWriteExt, AsyncBufReadExt, BufReader};
-
-       use tokio_stream::wrappers::LinesStream;
-       use tokio_stream::StreamExt;
+       use tokio::io::{AsyncReadExt, AsyncWriteExt};
 
        pub(super) async fn run_server(listener: TcpListener, resolver_sockaddr: SocketAddr) {
                loop {
@@ -47,63 +46,86 @@ mod imp {
                        tokio::spawn(async move {
                                let mut response = ("400 Bad Request", "Bad Request");
                                'ret_err: loop { // goto label
-                                       let buf_reader = BufReader::new(&mut socket);
-                                       let mut request_headers = LinesStream::new(buf_reader.lines())
-                                               .take_while(|line_res| if let Ok(line) = line_res { !line.is_empty() } else { false });
-                                       if let Some(Ok(request)) = request_headers.next().await {
-                                               let mut parts = request.split(" ");
-                                               let (verb, path, http_vers);
-                                               if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; }
-                                               if let Some(p) = parts.next() { path = p; } else { break 'ret_err; }
-                                               if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; }
-                                               if parts.next().is_some() { break; }
-                                               if verb != "GET" { break; }
-                                               if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; }
-
-                                               const PATH_PFX: &'static str = "/dnssecproof?";
-                                               if !path.starts_with(PATH_PFX) {
-                                                       response = ("404 Not Found", "Not Found");
-                                                       break 'ret_err;
-                                               }
-
-                                               let (mut d, mut t) = ("", "");
-                                               for arg in path[PATH_PFX.len()..].split("&") {
-                                                       if let Some((k, v)) = arg.split_once("=") {
-                                                               if k == "d" {
-                                                                       d = v;
-                                                               } else if k == "t" {
-                                                                       t = v;
+                                       let mut buf = [0; 4096];
+                                       let mut buf_pos = 0;
+                                       'read_req: loop {
+                                               if buf_pos == buf.len() { response.1 = "Request Too Large"; break 'ret_err; }
+                                               let read_res = { socket.read(&mut buf[buf_pos..]).await };
+                                               match read_res {
+                                                       Ok(0) => return,
+                                                       Ok(len) => {
+                                                               buf_pos += len;
+                                                               for window in buf[..buf_pos].windows(2) {
+                                                                       if window == b"\r\n" { break 'read_req; }
                                                                }
-                                                       } else { break 'ret_err; }
+                                                       }
+                                                       Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {},
+                                                       Err(_) => return,
                                                }
-
-                                               if d == "" || t == "" {
-                                                       response = ("500 Bad Request", "Missing d or t URI parameters");
+                                       }
+                                       let request;
+                                       if let Ok(s) = std::str::from_utf8(&buf[..buf_pos]) {
+                                               if let Some((r, _)) = s.split_once("\r\n") {
+                                                       request = r;
+                                               } else {
+                                                       debug_assert!(false);
                                                        break 'ret_err;
                                                }
-                                               let query_name = if let Ok(domain) = Name::try_from(d) { domain } else {
-                                                       response = ("500 Bad Request", "Failed to parse domain, make sure it ends with .");
-                                                       break 'ret_err;
-                                               };
-                                               let proof_res = match t.to_ascii_uppercase().as_str() {
-                                                       "TXT" => build_txt_proof_async(resolver_sockaddr, query_name).await,
-                                                       "TLSA" => build_tlsa_proof_async(resolver_sockaddr, query_name).await,
-                                                       "A" => build_a_proof_async(resolver_sockaddr, query_name).await,
-                                                       "AAAA" => build_aaaa_proof_async(resolver_sockaddr, query_name).await,
-                                                       _ => break 'ret_err,
-                                               };
-                                               let proof = if let Ok(proof) = proof_res { proof } else {
-                                                       response = ("404 Not Found", "Failed to generate proof for given domain");
-                                                       break 'ret_err;
-                                               };
+                                       } else {
+                                               break 'ret_err;
+                                       }
+
+                                       let mut parts = request.split(" ");
+                                       let (verb, path, http_vers);
+                                       if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; }
+                                       if let Some(p) = parts.next() { path = p; } else { break 'ret_err; }
+                                       if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; }
+                                       if parts.next().is_some() { break; }
+                                       if verb != "GET" { break; }
+                                       if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; }
+
+                                       const PATH_PFX: &'static str = "/dnssecproof?";
+                                       if !path.starts_with(PATH_PFX) {
+                                               response = ("404 Not Found", "Not Found");
+                                               break 'ret_err;
+                                       }
+
+                                       let (mut d, mut t) = ("", "");
+                                       for arg in path[PATH_PFX.len()..].split("&") {
+                                               if let Some((k, v)) = arg.split_once("=") {
+                                                       if k == "d" {
+                                                               d = v;
+                                                       } else if k == "t" {
+                                                               t = v;
+                                                       }
+                                               } else { break 'ret_err; }
+                                       }
 
-                                               let _ = socket.write_all(
-                                                       format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", proof.len()).as_bytes()
-                                               ).await;
-                                               let _ = socket.write_all(&proof).await;
-                                               return;
+                                       if d == "" || t == "" {
+                                               response.1 = "Missing d or t URI parameters";
+                                               break 'ret_err;
                                        }
-                                       break;
+                                       let query_name = if let Ok(domain) = Name::try_from(d) { domain } else {
+                                               response.1 = "Failed to parse domain, make sure it ends with .";
+                                               break 'ret_err;
+                                       };
+                                       let proof_res = match t.to_ascii_uppercase().as_str() {
+                                               "TXT" => build_txt_proof_async(resolver_sockaddr, &query_name).await,
+                                               "TLSA" => build_tlsa_proof_async(resolver_sockaddr, &query_name).await,
+                                               "A" => build_a_proof_async(resolver_sockaddr, &query_name).await,
+                                               "AAAA" => build_aaaa_proof_async(resolver_sockaddr, &query_name).await,
+                                               _ => break 'ret_err,
+                                       };
+                                       let proof = if let Ok(proof) = proof_res { proof } else {
+                                               response = ("404 Not Found", "Failed to generate proof for given domain");
+                                               break 'ret_err;
+                                       };
+
+                                       let _ = socket.write_all(
+                                               format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", proof.len()).as_bytes()
+                                       ).await;
+                                       let _ = socket.write_all(&proof).await;
+                                       return;
                                }
                                let _ = socket.write_all(format!(
                                        "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: text/plain\r\n\r\n{}",
@@ -114,7 +136,7 @@ mod imp {
        }
 }
 
-#[cfg(all(feature = "tokio", test))]
+#[cfg(all(feature = "tokio", feature = "validation", test))]
 mod test {
        use super::*;
 
@@ -137,4 +159,36 @@ mod test {
                let verified_rrs = verify_rr_stream(&rrs).unwrap();
                assert_eq!(verified_rrs.verified_rrs.len(), 1);
        }
+
+       #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+       async fn test_lookup_a() {
+               let ns = "4.4.4.4:53".parse().unwrap();
+               let listener = tokio::net::TcpListener::bind("127.0.0.1:17493").await
+                       .expect("Failed to bind to socket");
+               tokio::spawn(imp::run_server(listener, ns));
+               let resp = minreq::get(
+                       "http://127.0.0.1:17493/dnssecproof?d=cloudflare.com.&t=a"
+               ).send().unwrap();
+
+               assert_eq!(resp.status_code, 200);
+               let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
+               let verified_rrs = verify_rr_stream(&rrs).unwrap();
+               assert_eq!(verified_rrs.verified_rrs.len(), 1);
+       }
+
+       #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+       async fn test_lookup_tlsa() {
+               let ns = "1.1.1.1:53".parse().unwrap();
+               let listener = tokio::net::TcpListener::bind("127.0.0.1:17494").await
+                       .expect("Failed to bind to socket");
+               tokio::spawn(imp::run_server(listener, ns));
+               let resp = minreq::get(
+                       "http://127.0.0.1:17494/dnssecproof?d=_25._tcp.mail.as397444.net.&t=TLSA"
+               ).send().unwrap();
+
+               assert_eq!(resp.status_code, 200);
+               let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
+               let verified_rrs = verify_rr_stream(&rrs).unwrap();
+               assert_eq!(verified_rrs.verified_rrs.len(), 1);
+       }
 }