e5cc1c24947cf5945fecf147ec1297aade995f79
[dnssec-prover] / src / http.rs
1 //! A simple tokio-based HTTP server which serves DNSSEC proofs in RFC 9102 format.
2
3 #![deny(missing_docs)]
4
5 extern crate alloc;
6
7 pub mod rr;
8 pub mod validation;
9 mod ser;
10 pub mod query;
11
12 #[cfg(feature = "tokio")]
13 use tokio_crate as tokio;
14
15 #[cfg(feature = "build_server")]
16 #[tokio::main]
17 async fn main() {
18         let resolver_sockaddr = std::env::var("RESOLVER")
19                 .expect("Please set the RESOLVER env variable to the TCP socket of a recursive DNS resolver")
20                 .parse().expect("RESOLVER was not a valid socket address");
21         let bind_addr = std::env::var("BIND")
22                 .expect("Please set the BIND env variable to a socket address to listen on");
23
24         let listener = tokio::net::TcpListener::bind(bind_addr).await
25                 .expect("Failed to bind to socket");
26         imp::run_server(listener, resolver_sockaddr).await;
27 }
28
29 #[cfg(feature = "tokio")]
30 mod imp {
31         use super::*;
32
33         use rr::Name;
34         use query::*;
35
36         use std::net::SocketAddr;
37
38         use tokio::net::TcpListener;
39         use tokio::io::{AsyncWriteExt, AsyncBufReadExt, BufReader};
40
41         use tokio_stream::wrappers::LinesStream;
42         use tokio_stream::StreamExt;
43
44         pub(super) async fn run_server(listener: TcpListener, resolver_sockaddr: SocketAddr) {
45                 loop {
46                         let (mut socket, _) = listener.accept().await.expect("Failed to accept new TCP connection");
47                         tokio::spawn(async move {
48                                 let mut response = ("400 Bad Request", "Bad Request");
49                                 'ret_err: loop { // goto label
50                                         let buf_reader = BufReader::new(&mut socket);
51                                         let mut request_headers = LinesStream::new(buf_reader.lines())
52                                                 .take_while(|line_res| if let Ok(line) = line_res { !line.is_empty() } else { false });
53                                         if let Some(Ok(request)) = request_headers.next().await {
54                                                 let mut parts = request.split(" ");
55                                                 let (verb, path, http_vers);
56                                                 if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; }
57                                                 if let Some(p) = parts.next() { path = p; } else { break 'ret_err; }
58                                                 if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; }
59                                                 if parts.next().is_some() { break; }
60                                                 if verb != "GET" { break; }
61                                                 if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; }
62
63                                                 const PATH_PFX: &'static str = "/dnssecproof?";
64                                                 if !path.starts_with(PATH_PFX) {
65                                                         response = ("404 Not Found", "Not Found");
66                                                         break 'ret_err;
67                                                 }
68
69                                                 let (mut d, mut t) = ("", "");
70                                                 for arg in path[PATH_PFX.len()..].split("&") {
71                                                         if let Some((k, v)) = arg.split_once("=") {
72                                                                 if k == "d" {
73                                                                         d = v;
74                                                                 } else if k == "t" {
75                                                                         t = v;
76                                                                 }
77                                                         } else { break 'ret_err; }
78                                                 }
79
80                                                 if d == "" || t == "" {
81                                                         response = ("500 Bad Request", "Missing d or t URI parameters");
82                                                         break 'ret_err;
83                                                 }
84                                                 let query_name = if let Ok(domain) = Name::try_from(d) { domain } else {
85                                                         response = ("500 Bad Request", "Failed to parse domain, make sure it ends with .");
86                                                         break 'ret_err;
87                                                 };
88                                                 let proof_res = match t.to_ascii_uppercase().as_str() {
89                                                         "TXT" => build_txt_proof_async(resolver_sockaddr, query_name).await,
90                                                         "TLSA" => build_tlsa_proof_async(resolver_sockaddr, query_name).await,
91                                                         "A" => build_a_proof_async(resolver_sockaddr, query_name).await,
92                                                         "AAAA" => build_aaaa_proof_async(resolver_sockaddr, query_name).await,
93                                                         _ => break 'ret_err,
94                                                 };
95                                                 let proof = if let Ok(proof) = proof_res { proof } else {
96                                                         response = ("404 Not Found", "Failed to generate proof for given domain");
97                                                         break 'ret_err;
98                                                 };
99
100                                                 let _ = socket.write_all(
101                                                         format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", proof.len()).as_bytes()
102                                                 ).await;
103                                                 let _ = socket.write_all(&proof).await;
104                                                 return;
105                                         }
106                                         break;
107                                 }
108                                 let _ = socket.write_all(format!(
109                                         "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: text/plain\r\n\r\n{}",
110                                         response.0, response.1.len(), response.1,
111                                 ).as_bytes()).await;
112                         });
113                 }
114         }
115 }
116
117 #[cfg(all(feature = "tokio", test))]
118 mod test {
119         use super::*;
120
121         use crate::validation::{parse_rr_stream, verify_rr_stream};
122
123         use minreq;
124
125         #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
126         async fn test_lookup() {
127                 let ns = "8.8.8.8:53".parse().unwrap();
128                 let listener = tokio::net::TcpListener::bind("127.0.0.1:17492").await
129                         .expect("Failed to bind to socket");
130                 tokio::spawn(imp::run_server(listener, ns));
131                 let resp = minreq::get(
132                         "http://127.0.0.1:17492/dnssecproof?d=matt.user._bitcoin-payment.mattcorallo.com.&t=tXt"
133                 ).send().unwrap();
134
135                 assert_eq!(resp.status_code, 200);
136                 let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
137                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
138                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
139         }
140 }