Avoid overriding $RUSTFLAGS when needed for rustc 1.63
[dnssec-prover] / src / http.rs
1 //! A simple tokio-based HTTP server which serves DNSSEC proofs in RFC 9102 format.
2
3 #![deny(missing_docs)]
4
5 // const_slice_from_raw_parts was stabilized in 1.64, however we support building on 1.63 as well.
6 // Luckily, it seems to work fine in 1.63 with the feature flag (and RUSTC_BOOTSTRAP=1) enabled.
7 #![cfg_attr(all(feature = "validation", rust_1_63), feature(const_slice_from_raw_parts))]
8
9 #![allow(clippy::new_without_default)] // why is this even a lint
10 #![allow(clippy::result_unit_err)] // Why in the hell is this a lint?
11 #![allow(clippy::get_first)] // Sometimes this improves readability
12 #![allow(clippy::needless_lifetimes)] // lifetimes improve readability
13 #![allow(clippy::needless_borrow)] // borrows indicate read-only/non-move
14 #![allow(clippy::too_many_arguments)] // sometimes we don't have an option
15 #![allow(clippy::identity_op)] // sometimes identities improve readability for repeated actions
16 #![allow(clippy::erasing_op)] // sometimes identities improve readability for repeated actions
17
18 extern crate alloc;
19
20 /// The maximum number of requests we will make when building a proof or the maximum number of
21 /// [`rr::RRSig`] sets we'll validate records from when validating proofs.
22 // Note that this is duplicated exactly in src/lib.rs
23 pub const MAX_PROOF_STEPS: usize = 20;
24
25 pub mod rr;
26 pub mod ser;
27 pub mod query;
28
29 #[cfg(feature = "validation")]
30 mod base32;
31 #[cfg(feature = "validation")]
32 mod crypto;
33 #[cfg(feature = "validation")]
34 pub mod validation;
35
36 #[cfg(any(feature = "build_server", all(feature = "tokio", feature = "validation")))]
37 use tokio_crate as tokio;
38
39 #[cfg(feature = "build_server")]
40 #[tokio::main]
41 async fn main() {
42         let resolver_sockaddr = std::env::var("RESOLVER")
43                 .expect("Please set the RESOLVER env variable to the TCP socket of a recursive DNS resolver")
44                 .parse().expect("RESOLVER was not a valid socket address");
45         let bind_addr = std::env::var("BIND")
46                 .expect("Please set the BIND env variable to a socket address to listen on");
47
48         let listener = tokio::net::TcpListener::bind(bind_addr).await
49                 .expect("Failed to bind to socket");
50         imp::run_server(listener, resolver_sockaddr).await;
51 }
52
53 #[cfg(not(feature = "build_server"))]
54 fn main() { panic!("You need to enable the `build_server` feature to use the built-in server"); }
55
56 #[cfg(any(feature = "build_server", all(feature = "tokio", feature = "validation")))]
57 mod imp {
58         use super::*;
59
60         use rr::Name;
61         use query::*;
62
63         use std::net::SocketAddr;
64
65         use tokio::net::TcpListener;
66         use tokio::io::{AsyncReadExt, AsyncWriteExt};
67
68         pub(super) async fn run_server(listener: TcpListener, resolver_sockaddr: SocketAddr) {
69                 loop {
70                         let (mut socket, _) = listener.accept().await.expect("Failed to accept new TCP connection");
71                         tokio::spawn(async move {
72                                 let mut response = ("400 Bad Request", "Bad Request");
73                                 'ret_err: loop { // goto label
74                                         let mut buf = [0; 4096];
75                                         let mut buf_pos = 0;
76                                         'read_req: loop {
77                                                 if buf_pos == buf.len() { response.1 = "Request Too Large"; break 'ret_err; }
78                                                 let read_res = { socket.read(&mut buf[buf_pos..]).await };
79                                                 match read_res {
80                                                         Ok(0) => return,
81                                                         Ok(len) => {
82                                                                 buf_pos += len;
83                                                                 for window in buf[..buf_pos].windows(2) {
84                                                                         if window == b"\r\n" { break 'read_req; }
85                                                                 }
86                                                         }
87                                                         Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {},
88                                                         Err(_) => return,
89                                                 }
90                                         }
91                                         let request;
92                                         if let Ok(s) = std::str::from_utf8(&buf[..buf_pos]) {
93                                                 if let Some((r, _)) = s.split_once("\r\n") {
94                                                         request = r;
95                                                 } else {
96                                                         debug_assert!(false);
97                                                         break 'ret_err;
98                                                 }
99                                         } else {
100                                                 break 'ret_err;
101                                         }
102
103                                         let mut parts = request.split(" ");
104                                         let (verb, path, http_vers);
105                                         if let Some(v) = parts.next() { verb = v; } else { break 'ret_err; }
106                                         if let Some(p) = parts.next() { path = p; } else { break 'ret_err; }
107                                         if let Some(v) = parts.next() { http_vers = v; } else { break 'ret_err; }
108                                         if parts.next().is_some() { break; }
109                                         if verb != "GET" { break; }
110                                         if http_vers != "HTTP/1.1" && http_vers != "HTTP/1.0" { break 'ret_err; }
111
112                                         const PATH_PFX: &'static str = "/dnssecproof?";
113                                         if !path.starts_with(PATH_PFX) {
114                                                 response = ("404 Not Found", "Not Found");
115                                                 break 'ret_err;
116                                         }
117
118                                         let (mut d, mut t) = ("", "");
119                                         for arg in path[PATH_PFX.len()..].split("&") {
120                                                 if let Some((k, v)) = arg.split_once("=") {
121                                                         if k == "d" {
122                                                                 d = v;
123                                                         } else if k == "t" {
124                                                                 t = v;
125                                                         }
126                                                 } else { break 'ret_err; }
127                                         }
128
129                                         if d == "" || t == "" {
130                                                 response.1 = "Missing d or t URI parameters";
131                                                 break 'ret_err;
132                                         }
133                                         let query_name = if let Ok(domain) = Name::try_from(d) { domain } else {
134                                                 response.1 = "Failed to parse domain, make sure it ends with .";
135                                                 break 'ret_err;
136                                         };
137                                         let proof_res = match t.to_ascii_uppercase().as_str() {
138                                                 "TXT" => build_txt_proof_async(resolver_sockaddr, &query_name).await,
139                                                 "TLSA" => build_tlsa_proof_async(resolver_sockaddr, &query_name).await,
140                                                 "A" => build_a_proof_async(resolver_sockaddr, &query_name).await,
141                                                 "AAAA" => build_aaaa_proof_async(resolver_sockaddr, &query_name).await,
142                                                 _ => break 'ret_err,
143                                         };
144                                         let (proof, cache_ttl) = if let Ok(proof) = proof_res { proof } else {
145                                                 response = ("404 Not Found", "Failed to generate proof for given domain");
146                                                 break 'ret_err;
147                                         };
148
149                                         let _ = socket.write_all(
150                                                 format!(
151                                                         "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/octet-stream\r\nCache-Control: public, max-age={}, s-maxage={}\r\nAccess-Control-Allow-Origin: *\r\n\r\n",
152                                                         proof.len(), cache_ttl, cache_ttl
153                                                 ).as_bytes()
154                                         ).await;
155                                         let _ = socket.write_all(&proof).await;
156                                         return;
157                                 }
158                                 let _ = socket.write_all(format!(
159                                         "HTTP/1.1 {}\r\nContent-Length: {}\r\nContent-Type: text/plain\r\nAccess-Control-Allow-Origin: *\r\n\r\n{}",
160                                         response.0, response.1.len(), response.1,
161                                 ).as_bytes()).await;
162                         });
163                 }
164         }
165 }
166
167 #[cfg(all(feature = "tokio", feature = "validation", test))]
168 mod test {
169         use super::*;
170
171         use crate::ser::parse_rr_stream;
172         use crate::validation::verify_rr_stream;
173
174         use minreq;
175
176         #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
177         async fn test_lookup() {
178                 let ns = "8.8.8.8:53".parse().unwrap();
179                 let listener = tokio::net::TcpListener::bind("127.0.0.1:17492").await
180                         .expect("Failed to bind to socket");
181                 tokio::spawn(imp::run_server(listener, ns));
182                 let resp = minreq::get(
183                         "http://127.0.0.1:17492/dnssecproof?d=matt.user._bitcoin-payment.mattcorallo.com.&t=tXt"
184                 ).send().unwrap();
185
186                 assert_eq!(resp.status_code, 200);
187                 let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
188                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
189                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
190         }
191
192         #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
193         async fn test_lookup_a() {
194                 let ns = "9.9.9.9:53".parse().unwrap();
195                 let listener = tokio::net::TcpListener::bind("127.0.0.1:17493").await
196                         .expect("Failed to bind to socket");
197                 tokio::spawn(imp::run_server(listener, ns));
198                 let resp = minreq::get(
199                         "http://127.0.0.1:17493/dnssecproof?d=cloudflare.com.&t=a"
200                 ).send().unwrap();
201
202                 assert_eq!(resp.status_code, 200);
203                 let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
204                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
205                 assert!(verified_rrs.verified_rrs.len() >= 1);
206         }
207
208         #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
209         async fn test_lookup_tlsa() {
210                 let ns = "1.1.1.1:53".parse().unwrap();
211                 let listener = tokio::net::TcpListener::bind("127.0.0.1:17494").await
212                         .expect("Failed to bind to socket");
213                 tokio::spawn(imp::run_server(listener, ns));
214                 let resp = minreq::get(
215                         "http://127.0.0.1:17494/dnssecproof?d=_25._tcp.mail.as397444.net.&t=TLSA"
216                 ).send().unwrap();
217
218                 assert_eq!(resp.status_code, 200);
219                 let rrs = parse_rr_stream(resp.as_bytes()).unwrap();
220                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
221                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
222         }
223 }