Use `bitcoin_hashes` rather than `ring` for hashing
[dnssec-prover] / src / http.rs
1 //! A simple tokio-based HTTP server which serves DNSSEC proofs in RFC 9102 format.
2
3 #![deny(missing_docs)]
4
5 extern crate alloc;
6
7 pub mod rr;
8 pub mod ser;
9 pub mod query;
10
11 #[cfg(feature = "validation")]
12 mod crypto;
13 #[cfg(feature = "validation")]
14 pub mod validation;
15
16 #[cfg(any(feature = "build_server", all(feature = "tokio", feature = "validation")))]
17 use tokio_crate as tokio;
18
19 #[cfg(feature = "build_server")]
20 #[tokio::main]
21 async fn main() {
22         let resolver_sockaddr = std::env::var("RESOLVER")
23                 .expect("Please set the RESOLVER env variable to the TCP socket of a recursive DNS resolver")
24                 .parse().expect("RESOLVER was not a valid socket address");
25         let bind_addr = std::env::var("BIND")
26                 .expect("Please set the BIND env variable to a socket address to listen on");
27
28         let listener = tokio::net::TcpListener::bind(bind_addr).await
29                 .expect("Failed to bind to socket");
30         imp::run_server(listener, resolver_sockaddr).await;
31 }
32
33 #[cfg(any(feature = "build_server", all(feature = "tokio", feature = "validation")))]
34 mod imp {
35         use super::*;
36
37         use rr::Name;
38         use query::*;
39
40         use std::net::SocketAddr;
41
42         use tokio::net::TcpListener;
43         use tokio::io::{AsyncReadExt, AsyncWriteExt};
44
45         pub(super) async fn run_server(listener: TcpListener, resolver_sockaddr: SocketAddr) {
46                 loop {
47                         let (mut socket, _) = listener.accept().await.expect("Failed to accept new TCP connection");
48                         tokio::spawn(async move {
49                                 let mut response = ("400 Bad Request", "Bad Request");
50                                 'ret_err: loop { // goto label
51                                         let mut buf = [0; 4096];
52                                         let mut buf_pos = 0;
53                                         'read_req: loop {
54                                                 if buf_pos == buf.len() { response.1 = "Request Too Large"; break 'ret_err; }
55                                                 let read_res = { socket.read(&mut buf[buf_pos..]).await };
56                                                 match read_res {
57                                                         Ok(0) => return,
58                                                         Ok(len) => {
59                                                                 buf_pos += len;
60                                                                 for window in buf[..buf_pos].windows(2) {
61                                                                         if window == b"\r\n" { break 'read_req; }
62                                                                 }
63                                                         }
64                                                         Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {},
65                                                         Err(_) => return,
66                                                 }
67                                         }
68                                         let request;
69                                         if let Ok(s) = std::str::from_utf8(&buf[..buf_pos]) {
70                                                 if let Some((r, _)) = s.split_once("\r\n") {
71                                                         request = r;
72                                                 } else {
73                                                         debug_assert!(false);
74                                                         break 'ret_err;
75                                                 }
76                                         } else {
77                                                 break 'ret_err;
78                                         }
79
80                                         let mut parts = request.split(" ");
81                                         let (verb, path, http_vers);
82                                         if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; }
83                                         if let Some(p) = parts.next() { path = p; } else { break 'ret_err; }
84                                         if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; }
85                                         if parts.next().is_some() { break; }
86                                         if verb != "GET" { break; }
87                                         if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; }
88
89                                         const PATH_PFX: &'static str = "/dnssecproof?";
90                                         if !path.starts_with(PATH_PFX) {
91                                                 response = ("404 Not Found", "Not Found");
92                                                 break 'ret_err;
93                                         }
94
95                                         let (mut d, mut t) = ("", "");
96                                         for arg in path[PATH_PFX.len()..].split("&") {
97                                                 if let Some((k, v)) = arg.split_once("=") {
98                                                         if k == "d" {
99                                                                 d = v;
100                                                         } else if k == "t" {
101                                                                 t = v;
102                                                         }
103                                                 } else { break 'ret_err; }
104                                         }
105
106                                         if d == "" || t == "" {
107                                                 response.1 = "Missing d or t URI parameters";
108                                                 break 'ret_err;
109                                         }
110                                         let query_name = if let Ok(domain) = Name::try_from(d) { domain } else {
111                                                 response.1 = "Failed to parse domain, make sure it ends with .";
112                                                 break 'ret_err;
113                                         };
114                                         let proof_res = match t.to_ascii_uppercase().as_str() {
115                                                 "TXT" => build_txt_proof_async(resolver_sockaddr, &query_name).await,
116                                                 "TLSA" => build_tlsa_proof_async(resolver_sockaddr, &query_name).await,
117                                                 "A" => build_a_proof_async(resolver_sockaddr, &query_name).await,
118                                                 "AAAA" => build_aaaa_proof_async(resolver_sockaddr, &query_name).await,
119                                                 _ => break 'ret_err,
120                                         };
121                                         let (proof, cache_ttl) = if let Ok(proof) = proof_res { proof } else {
122                                                 response = ("404 Not Found", "Failed to generate proof for given domain");
123                                                 break 'ret_err;
124                                         };
125
126                                         let _ = socket.write_all(
127                                                 format!(
128                                                         "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/octet-stream\r\nCache-Control: public, max-age={}, s-maxage={}\r\nAccess-Control-Allow-Origin: *\r\n\r\n",
129                                                         proof.len(), cache_ttl, cache_ttl
130                                                 ).as_bytes()
131                                         ).await;
132                                         let _ = socket.write_all(&proof).await;
133                                         return;
134                                 }
135                                 let _ = socket.write_all(format!(
136                                         "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: text/plain\r\nAccess-Control-Allow-Origin: *\r\n\r\n{}",
137                                         response.0, response.1.len(), response.1,
138                                 ).as_bytes()).await;
139                         });
140                 }
141         }
142 }
143
144 #[cfg(all(feature = "tokio", feature = "validation", test))]
145 mod test {
146         use super::*;
147
148         use crate::ser::parse_rr_stream;
149         use crate::validation::verify_rr_stream;
150
151         use minreq;
152
153         #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
154         async fn test_lookup() {
155                 let ns = "8.8.8.8:53".parse().unwrap();
156                 let listener = tokio::net::TcpListener::bind("127.0.0.1:17492").await
157                         .expect("Failed to bind to socket");
158                 tokio::spawn(imp::run_server(listener, ns));
159                 let resp = minreq::get(
160                         "http://127.0.0.1:17492/dnssecproof?d=matt.user._bitcoin-payment.mattcorallo.com.&t=tXt"
161                 ).send().unwrap();
162
163                 assert_eq!(resp.status_code, 200);
164                 let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
165                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
166                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
167         }
168
169         #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
170         async fn test_lookup_a() {
171                 let ns = "9.9.9.9:53".parse().unwrap();
172                 let listener = tokio::net::TcpListener::bind("127.0.0.1:17493").await
173                         .expect("Failed to bind to socket");
174                 tokio::spawn(imp::run_server(listener, ns));
175                 let resp = minreq::get(
176                         "http://127.0.0.1:17493/dnssecproof?d=cloudflare.com.&t=a"
177                 ).send().unwrap();
178
179                 assert_eq!(resp.status_code, 200);
180                 let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
181                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
182                 assert!(verified_rrs.verified_rrs.len() >= 1);
183         }
184
185         #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
186         async fn test_lookup_tlsa() {
187                 let ns = "1.1.1.1:53".parse().unwrap();
188                 let listener = tokio::net::TcpListener::bind("127.0.0.1:17494").await
189                         .expect("Failed to bind to socket");
190                 tokio::spawn(imp::run_server(listener, ns));
191                 let resp = minreq::get(
192                         "http://127.0.0.1:17494/dnssecproof?d=_25._tcp.mail.as397444.net.&t=TLSA"
193                 ).send().unwrap();
194
195                 assert_eq!(resp.status_code, 200);
196                 let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
197                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
198                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
199         }
200 }