Add trivial HTTP server which serves proofs
authorMatt Corallo <git@bluematt.me>
Tue, 6 Feb 2024 04:19:14 +0000 (04:19 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 6 Feb 2024 04:19:14 +0000 (04:19 +0000)
Cargo.toml
src/http.rs [new file with mode: 0644]

index 5e31b018ee966a7669f15c6a837a6e275bc52a3e..c70cdf16021b52eae68772c006877740fda967d8 100644 (file)
@@ -13,14 +13,26 @@ rust-version = "1.60.0"
 [features]
 std = []
 tokio = ["tokio_crate/net", "tokio_crate/io-util", "std"]
+build_server = ["tokio", "tokio-stream", "tokio_crate/rt-multi-thread", "tokio_crate/macros"]
 
 [dependencies]
 ring = { version = "0.17", default-features = false, features = ["alloc"] }
 hex_lit = { version = "0.1", default-features = false, features = ["rust_v_1_46"] }
 tokio_crate = { package = "tokio", version = "1.0", default-features = false, optional = true }
+tokio-stream = { version = "0.1", default-features = false, optional = true, features = ["io-util"] }
 
 [dev-dependencies]
 hex-conservative = { version = "0.1", default-features = false, features = ["alloc"] }
 base64 = "0.21"
 rand = { version = "0.8", default-features = false, features = ["getrandom"] }
-tokio_crate = { package = "tokio", version = "1.0", features = ["rt", "macros"] }
+tokio_crate = { package = "tokio", version = "1.0", features = ["rt", "macros", "net", "rt-multi-thread"] }
+tokio-stream = { version = "0.1", default-features = false, features = ["io-util"] }
+minreq = { version = "2.0" }
+
+[lib]
+name = "dnssec_prover"
+path = "src/lib.rs"
+
+[[bin]]
+name = "http_proof_gen"
+path = "src/http.rs"
diff --git a/src/http.rs b/src/http.rs
new file mode 100644 (file)
index 0000000..e5cc1c2
--- /dev/null
@@ -0,0 +1,140 @@
+//! A simple tokio-based HTTP server which serves DNSSEC proofs in RFC 9102 format.
+
+#![deny(missing_docs)]
+
+extern crate alloc;
+
+pub mod rr;
+pub mod validation;
+mod ser;
+pub mod query;
+
+#[cfg(feature = "tokio")]
+use tokio_crate as tokio;
+
+#[cfg(feature = "build_server")]
+#[tokio::main]
+async fn main() {
+       let resolver_sockaddr = std::env::var("RESOLVER")
+               .expect("Please set the RESOLVER env variable to the TCP socket of a recursive DNS resolver")
+               .parse().expect("RESOLVER was not a valid socket address");
+       let bind_addr = std::env::var("BIND")
+               .expect("Please set the BIND env variable to a socket address to listen on");
+
+       let listener = tokio::net::TcpListener::bind(bind_addr).await
+               .expect("Failed to bind to socket");
+       imp::run_server(listener, resolver_sockaddr).await;
+}
+
+#[cfg(feature = "tokio")]
+mod imp {
+       use super::*;
+
+       use rr::Name;
+       use query::*;
+
+       use std::net::SocketAddr;
+
+       use tokio::net::TcpListener;
+       use tokio::io::{AsyncWriteExt, AsyncBufReadExt, BufReader};
+
+       use tokio_stream::wrappers::LinesStream;
+       use tokio_stream::StreamExt;
+
+       pub(super) async fn run_server(listener: TcpListener, resolver_sockaddr: SocketAddr) {
+               loop {
+                       let (mut socket, _) = listener.accept().await.expect("Failed to accept new TCP connection");
+                       tokio::spawn(async move {
+                               let mut response = ("400 Bad Request", "Bad Request");
+                               'ret_err: loop { // goto label
+                                       let buf_reader = BufReader::new(&mut socket);
+                                       let mut request_headers = LinesStream::new(buf_reader.lines())
+                                               .take_while(|line_res| if let Ok(line) = line_res { !line.is_empty() } else { false });
+                                       if let Some(Ok(request)) = request_headers.next().await {
+                                               let mut parts = request.split(" ");
+                                               let (verb, path, http_vers);
+                                               if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; }
+                                               if let Some(p) = parts.next() { path = p; } else { break 'ret_err; }
+                                               if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; }
+                                               if parts.next().is_some() { break; }
+                                               if verb != "GET" { break; }
+                                               if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; }
+
+                                               const PATH_PFX: &'static str = "/dnssecproof?";
+                                               if !path.starts_with(PATH_PFX) {
+                                                       response = ("404 Not Found", "Not Found");
+                                                       break 'ret_err;
+                                               }
+
+                                               let (mut d, mut t) = ("", "");
+                                               for arg in path[PATH_PFX.len()..].split("&") {
+                                                       if let Some((k, v)) = arg.split_once("=") {
+                                                               if k == "d" {
+                                                                       d = v;
+                                                               } else if k == "t" {
+                                                                       t = v;
+                                                               }
+                                                       } else { break 'ret_err; }
+                                               }
+
+                                               if d == "" || t == "" {
+                                                       response = ("500 Bad Request", "Missing d or t URI parameters");
+                                                       break 'ret_err;
+                                               }
+                                               let query_name = if let Ok(domain) = Name::try_from(d) { domain } else {
+                                                       response = ("500 Bad Request", "Failed to parse domain, make sure it ends with .");
+                                                       break 'ret_err;
+                                               };
+                                               let proof_res = match t.to_ascii_uppercase().as_str() {
+                                                       "TXT" => build_txt_proof_async(resolver_sockaddr, query_name).await,
+                                                       "TLSA" => build_tlsa_proof_async(resolver_sockaddr, query_name).await,
+                                                       "A" => build_a_proof_async(resolver_sockaddr, query_name).await,
+                                                       "AAAA" => build_aaaa_proof_async(resolver_sockaddr, query_name).await,
+                                                       _ => break 'ret_err,
+                                               };
+                                               let proof = if let Ok(proof) = proof_res { proof } else {
+                                                       response = ("404 Not Found", "Failed to generate proof for given domain");
+                                                       break 'ret_err;
+                                               };
+
+                                               let _ = socket.write_all(
+                                                       format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", proof.len()).as_bytes()
+                                               ).await;
+                                               let _ = socket.write_all(&proof).await;
+                                               return;
+                                       }
+                                       break;
+                               }
+                               let _ = socket.write_all(format!(
+                                       "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: text/plain\r\n\r\n{}",
+                                       response.0, response.1.len(), response.1,
+                               ).as_bytes()).await;
+                       });
+               }
+       }
+}
+
+#[cfg(all(feature = "tokio", test))]
+mod test {
+       use super::*;
+
+       use crate::validation::{parse_rr_stream, verify_rr_stream};
+
+       use minreq;
+
+       #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+       async fn test_lookup() {
+               let ns = "8.8.8.8:53".parse().unwrap();
+               let listener = tokio::net::TcpListener::bind("127.0.0.1:17492").await
+                       .expect("Failed to bind to socket");
+               tokio::spawn(imp::run_server(listener, ns));
+               let resp = minreq::get(
+                       "http://127.0.0.1:17492/dnssecproof?d=matt.user._bitcoin-payment.mattcorallo.com.&t=tXt"
+               ).send().unwrap();
+
+               assert_eq!(resp.status_code, 200);
+               let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
+               let verified_rrs = verify_rr_stream(&rrs).unwrap();
+               assert_eq!(verified_rrs.verified_rrs.len(), 1);
+       }
+}