Use (and expose) a `ProofBuilder` state machine for proving
[dnssec-prover] / src / query.rs
1 //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive
2 //! resolver.
3
4 use std::cmp;
5 use std::net::{SocketAddr, TcpStream};
6 use std::io::{Read, Write, Error, ErrorKind};
7
8 #[cfg(feature = "tokio")]
9 use tokio_crate::net::TcpStream as TokioTcpStream;
10 #[cfg(feature = "tokio")]
11 use tokio_crate::io::{AsyncReadExt, AsyncWriteExt};
12
13 use crate::rr::*;
14 use crate::ser::*;
15
16 // We don't care about transaction IDs as we're only going to accept signed data. Thus, we use
17 // this constant instead of a random value.
18 const TXID: u16 = 0x4242;
19
20 fn build_query(domain: &Name, ty: u16) -> Vec<u8> {
21         // TODO: Move to not allocating for the query
22         let mut query = Vec::with_capacity(1024);
23         let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(domain) + 11;
24         query.extend_from_slice(&query_msg_len.to_be_bytes());
25         query.extend_from_slice(&TXID.to_be_bytes());
26         query.extend_from_slice(&[0x01, 0x20]); // Flags: Recursive, Authenticated Data
27         query.extend_from_slice(&[0, 1, 0, 0, 0, 0, 0, 1]); // One question, One additional
28         write_name(&mut query, domain);
29         query.extend_from_slice(&ty.to_be_bytes());
30         query.extend_from_slice(&1u16.to_be_bytes()); // INternet class
31         query.extend_from_slice(&[0, 0, 0x29]); // . OPT
32         query.extend_from_slice(&0u16.to_be_bytes()); // 0 UDP payload size
33         query.extend_from_slice(&[0, 0]); // EDNS version 0
34         query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs
35         query.extend_from_slice(&0u16.to_be_bytes()); // No additional data
36         query
37 }
38
39 #[cfg(fuzzing)]
40 /// Read some input and parse it as if it came from a server, for fuzzing.
41 pub fn fuzz_response(response: &[u8]) {
42         let (mut proof, mut names) = (Vec::new(), Vec::new());
43         let _ = handle_response(response, &mut proof, &mut names);
44 }
45
46 fn handle_response(resp: &[u8], proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<Name>) -> Result<u32, ()> {
47         let mut read: &[u8] = resp;
48         if read_u16(&mut read)? != TXID { return Err(()); }
49         // 2 byte transaction ID
50         let flags = read_u16(&mut read)?;
51         if flags & 0b1000_0000_0000_0000 == 0 {
52                 return Err(());
53         }
54         if flags & 0b0111_1010_0000_0111 != 0 {
55                 return Err(());
56         }
57         if flags & 0b10_0000 == 0 {
58                 return Err(());
59         }
60         let questions = read_u16(&mut read)?;
61         if questions != 1 { return Err(()); }
62         let answers = read_u16(&mut read)?;
63         if answers == 0 { return Err(()); }
64         let _authorities = read_u16(&mut read)?;
65         let _additional = read_u16(&mut read)?;
66
67         for _ in 0..questions {
68                 read_wire_packet_name(&mut read, resp)?;
69                 read_u16(&mut read)?; // type
70                 read_u16(&mut read)?; // class
71         }
72
73         // Only read the answers (skip authorities and additional) as that's all we care about.
74         let mut min_ttl = u32::MAX;
75         for _ in 0..answers {
76                 let (rr, ttl) = parse_wire_packet_rr(&mut read, &resp)?;
77                 write_rr(&rr, ttl, proof);
78                 min_ttl = cmp::min(min_ttl, ttl);
79                 if let RR::RRSig(rrsig) = rr { rrsig_key_names.push(rrsig.key_name); }
80         }
81         Ok(min_ttl)
82 }
83
84 const MAX_REQUESTS: usize = 10;
85 /// A simple state machine which will generate a series of queries and process the responses until
86 /// it has built a DNSSEC proof.
87 ///
88 /// A [`ProofBuilder`] driver starts with [`ProofBuilder::new`], fetching the state machine and
89 /// initial query. As long as [`ProofBuilder::awaiting_responses`] returns true, responses should
90 /// be read from the resolver. For each query response read from the DNS resolver,
91 /// [`ProofBuilder::process_response`] should be called, and each fresh query returned should be
92 /// sent to the resolver. Once [`ProofBuilder::awaiting_responses`] returns false,
93 /// [`ProofBuilder::finish_proof`] should be called to fetch the resulting proof.
94 pub struct ProofBuilder {
95         proof: Vec<u8>,
96         min_ttl: u32,
97         dnskeys_requested: Vec<Name>,
98         pending_queries: usize,
99         queries_made: usize,
100 }
101
102 impl ProofBuilder {
103         /// Constructs a new [`ProofBuilder`] and an initial query to send to the recursive resolver to
104         /// begin the proof building process.
105         ///
106         /// Given a correctly-functioning resolver the proof will ultimately be able to prove the
107         /// contents of any records with the given `ty`pe at the given `name` (as long as the given
108         /// `ty`pe is supported by this library).
109         ///
110         /// You can find constants for supported standard types in the [`crate::rr`] module.
111         pub fn new(name: &Name, ty: u16) -> (ProofBuilder, Vec<u8>) {
112                 let initial_query = build_query(name, ty);
113                 (ProofBuilder {
114                         proof: Vec::new(),
115                         min_ttl: u32::MAX,
116                         dnskeys_requested: Vec::with_capacity(MAX_REQUESTS),
117                         pending_queries: 1,
118                         queries_made: 1,
119                 }, initial_query)
120         }
121
122         /// Returns true as long as further responses are expected from the resolver.
123         ///
124         /// As long as this returns true, responses should be read from the resolver and passed to
125         /// [`Self::process_response`]. Once this returns false, [`Self::finish_proof`] should be used
126         /// to (possibly) get the final proof.
127         pub fn awaiting_responses(&self) -> bool {
128                 self.pending_queries > 0 && self.queries_made <= MAX_REQUESTS
129         }
130
131         /// Processes a query response from the recursive resolver, returning a list of new queries to
132         /// send to the resolver.
133         pub fn process_response(&mut self, resp: &[u8]) -> Result<Vec<Vec<u8>>, ()> {
134                 if self.pending_queries == 0 { return Err(()); }
135
136                 let mut rrsig_key_names = Vec::new();
137                 let min_ttl = handle_response(&resp, &mut self.proof, &mut rrsig_key_names)?;
138                 self.min_ttl = cmp::min(self.min_ttl, min_ttl);
139                 self.pending_queries -= 1;
140
141                 rrsig_key_names.sort_unstable();
142                 rrsig_key_names.dedup();
143
144                 let mut new_queries = Vec::with_capacity(2);
145                 for key_name in rrsig_key_names.drain(..) {
146                         if !self.dnskeys_requested.contains(&key_name) {
147                                 new_queries.push(build_query(&key_name, DnsKey::TYPE));
148                                 self.pending_queries += 1;
149                                 self.queries_made += 1;
150                                 self.dnskeys_requested.push(key_name.clone());
151
152                                 if key_name.as_str() != "." {
153                                         new_queries.push(build_query(&key_name, DS::TYPE));
154                                         self.pending_queries += 1;
155                                         self.queries_made += 1;
156                                 }
157                         }
158                 }
159                 if self.queries_made <= MAX_REQUESTS {
160                         Ok(new_queries)
161                 } else {
162                         Ok(Vec::new())
163                 }
164         }
165
166         /// Finalizes the proof, if one is available, and returns it as well as the TTL that should be
167         /// used to cache the proof (i.e. the lowest TTL of all records which were used to build the
168         /// proof).
169         pub fn finish_proof(self) -> Result<(Vec<u8>, u32), ()> {
170                 if self.pending_queries > 0 || self.queries_made > MAX_REQUESTS {
171                         Err(())
172                 } else {
173                         Ok((self.proof, self.min_ttl))
174                 }
175         }
176 }
177
178 fn send_query(stream: &mut TcpStream, query: &[u8]) -> Result<(), Error> {
179         stream.write_all(&query)?;
180         Ok(())
181 }
182
183 #[cfg(feature = "tokio")]
184 async fn send_query_async(stream: &mut TokioTcpStream, query: &[u8]) -> Result<(), Error> {
185         stream.write_all(&query).await?;
186         Ok(())
187 }
188
189 type MsgBuf = [u8; u16::MAX as usize];
190
191 fn read_response(stream: &mut TcpStream, response_buf: &mut MsgBuf) -> Result<u16, Error> {
192         let mut len_bytes = [0; 2];
193         stream.read_exact(&mut len_bytes)?;
194         let len = u16::from_be_bytes(len_bytes);
195         stream.read_exact(&mut response_buf[..len as usize])?;
196         Ok(len)
197 }
198
199 #[cfg(feature = "tokio")]
200 async fn read_response_async(stream: &mut TokioTcpStream, response_buf: &mut MsgBuf) -> Result<u16, Error> {
201         let mut len_bytes = [0; 2];
202         stream.read_exact(&mut len_bytes).await?;
203         let len = u16::from_be_bytes(len_bytes);
204         stream.read_exact(&mut response_buf[..len as usize]).await?;
205         Ok(len)
206 }
207
208 macro_rules! build_proof_impl {
209         ($stream: ident, $send_query: ident, $read_response: ident, $domain: expr, $ty: expr $(, $async_ok: tt)?) => { {
210                 // We require the initial query to have already gone out, and assume our resolver will
211                 // return any CNAMEs all the way to the final record in the response. From there, we just
212                 // have to take any RRSIGs in the response and walk them up to the root. We do so
213                 // iteratively, sending DNSKEY and DS lookups after every response, deduplicating requests
214                 // using `dnskeys_requested`.
215                 let (mut builder, initial_query) = ProofBuilder::new($domain, $ty);
216                 $send_query(&mut $stream, &initial_query)
217                         $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
218                 let mut response_buf = [0; u16::MAX as usize];
219                 while builder.awaiting_responses() {
220                         let response_len = $read_response(&mut $stream, &mut response_buf)
221                                 $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
222                         let new_queries = builder.process_response(&response_buf[..response_len as usize])
223                                 .map_err(|()| Error::new(ErrorKind::Other, "Bad response"))?;
224                         for query in new_queries {
225                                 $send_query(&mut $stream, &query)
226                                         $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
227                         }
228                 }
229
230                 builder.finish_proof()
231                         .map_err(|()| Error::new(ErrorKind::Other, "Too many requests required"))
232         } }
233 }
234
235 fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
236         let mut stream = TcpStream::connect(resolver)?;
237         build_proof_impl!(stream, send_query, read_response, domain, ty)
238 }
239
240 #[cfg(feature = "tokio")]
241 async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
242         let mut stream = TokioTcpStream::connect(resolver).await?;
243         build_proof_impl!(stream, send_query_async, read_response_async, domain, ty, { Ok::<(), Error>(()) })
244 }
245
246 /// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as
247 /// well as the TTL for the proof provided by the recursive resolver.
248 ///
249 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
250 /// module to validate the records contained.
251 pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
252         build_proof(resolver, domain, A::TYPE)
253 }
254
255 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof
256 /// as well as the TTL for the proof provided by the recursive resolver.
257 ///
258 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
259 /// module to validate the records contained.
260 pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
261         build_proof(resolver, domain, AAAA::TYPE)
262 }
263
264 /// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof
265 /// as well as the TTL for the proof provided by the recursive resolver.
266 ///
267 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
268 /// module to validate the records contained.
269 pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
270         build_proof(resolver, domain, Txt::TYPE)
271 }
272
273 /// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof
274 /// as well as the TTL for the proof provided by the recursive resolver.
275 ///
276 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
277 /// module to validate the records contained.
278 pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
279         build_proof(resolver, domain, TLSA::TYPE)
280 }
281
282
283 /// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as
284 /// well as the TTL for the proof provided by the recursive resolver.
285 ///
286 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
287 /// module to validate the records contained.
288 #[cfg(feature = "tokio")]
289 pub async fn build_a_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
290         build_proof_async(resolver, domain, A::TYPE).await
291 }
292
293 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof
294 /// as well as the TTL for the proof provided by the recursive resolver.
295 ///
296 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
297 /// module to validate the records contained.
298 #[cfg(feature = "tokio")]
299 pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
300         build_proof_async(resolver, domain, AAAA::TYPE).await
301 }
302
303 /// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof
304 /// as well as the TTL for the proof provided by the recursive resolver.
305 ///
306 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
307 /// module to validate the records contained.
308 #[cfg(feature = "tokio")]
309 pub async fn build_txt_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
310         build_proof_async(resolver, domain, Txt::TYPE).await
311 }
312
313 /// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof
314 /// as well as the TTL for the proof provided by the recursive resolver.
315 ///
316 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
317 /// module to validate the records contained.
318 #[cfg(feature = "tokio")]
319 pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
320         build_proof_async(resolver, domain, TLSA::TYPE).await
321 }
322
323 #[cfg(all(feature = "validation", test))]
324 mod tests {
325         use super::*;
326         use crate::validation::*;
327
328         use rand::seq::SliceRandom;
329
330         use std::net::ToSocketAddrs;
331         use std::time::SystemTime;
332
333
334         #[test]
335         fn test_cloudflare_txt_query() {
336                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
337                 let query_name = "cloudflare.com.".try_into().unwrap();
338                 let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
339
340                 let mut rrs = parse_rr_stream(&proof).unwrap();
341                 rrs.shuffle(&mut rand::rngs::OsRng);
342                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
343                 assert!(verified_rrs.verified_rrs.len() > 1);
344
345                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
346                 assert!(verified_rrs.valid_from < now);
347                 assert!(verified_rrs.expires > now);
348         }
349
350         #[test]
351         fn test_sha1_query() {
352                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
353                 let query_name = "benthecarman.com.".try_into().unwrap();
354                 let (proof, _) = build_a_proof(sockaddr, &query_name).unwrap();
355
356                 let mut rrs = parse_rr_stream(&proof).unwrap();
357                 rrs.shuffle(&mut rand::rngs::OsRng);
358                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
359                 assert!(verified_rrs.verified_rrs.len() >= 1);
360
361                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
362                 assert!(verified_rrs.valid_from < now);
363                 assert!(verified_rrs.expires > now);
364         }
365
366         #[test]
367         fn test_txt_query() {
368                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
369                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
370                 let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
371
372                 let mut rrs = parse_rr_stream(&proof).unwrap();
373                 rrs.shuffle(&mut rand::rngs::OsRng);
374                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
375                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
376
377                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
378                 assert!(verified_rrs.valid_from < now);
379                 assert!(verified_rrs.expires > now);
380         }
381
382         #[test]
383         fn test_cname_query() {
384                 for resolver in ["1.1.1.1:53", "8.8.8.8:53", "9.9.9.9:53"] {
385                         let sockaddr = resolver.to_socket_addrs().unwrap().next().unwrap();
386                         let query_name = "cname_test.matcorallo.com.".try_into().unwrap();
387                         let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
388
389                         let mut rrs = parse_rr_stream(&proof).unwrap();
390                         rrs.shuffle(&mut rand::rngs::OsRng);
391                         let verified_rrs = verify_rr_stream(&rrs).unwrap();
392                         assert_eq!(verified_rrs.verified_rrs.len(), 2);
393
394                         let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
395                         assert!(verified_rrs.valid_from < now);
396                         assert!(verified_rrs.expires > now);
397
398                         let resolved_rrs = verified_rrs.resolve_name(&query_name);
399                         assert_eq!(resolved_rrs.len(), 1);
400                         if let RR::Txt(txt) = &resolved_rrs[0] {
401                                 assert_eq!(txt.name.as_str(), "txt_test.matcorallo.com.");
402                                 assert_eq!(txt.data, b"dnssec_prover_test");
403                         } else { panic!(); }
404                 }
405         }
406
407         #[cfg(feature = "tokio")]
408         use tokio_crate as tokio;
409
410         #[cfg(feature = "tokio")]
411         #[tokio::test]
412         async fn test_txt_query_async() {
413                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
414                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
415                 let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
416
417                 let mut rrs = parse_rr_stream(&proof).unwrap();
418                 rrs.shuffle(&mut rand::rngs::OsRng);
419                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
420                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
421
422                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
423                 assert!(verified_rrs.valid_from < now);
424                 assert!(verified_rrs.expires > now);
425         }
426
427         #[cfg(feature = "tokio")]
428         #[tokio::test]
429         async fn test_cross_domain_cname_query_async() {
430                 for resolver in ["1.1.1.1:53", "8.8.8.8:53", "9.9.9.9:53"] {
431                         let sockaddr = resolver.to_socket_addrs().unwrap().next().unwrap();
432                         let query_name = "wildcard.x_domain_cname_wild.matcorallo.com.".try_into().unwrap();
433                         let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
434
435                         let mut rrs = parse_rr_stream(&proof).unwrap();
436                         rrs.shuffle(&mut rand::rngs::OsRng);
437                         let verified_rrs = verify_rr_stream(&rrs).unwrap();
438                         assert_eq!(verified_rrs.verified_rrs.len(), 2);
439
440                         let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
441                         assert!(verified_rrs.valid_from < now);
442                         assert!(verified_rrs.expires > now);
443
444                         let resolved_rrs = verified_rrs.resolve_name(&query_name);
445                         assert_eq!(resolved_rrs.len(), 1);
446                         if let RR::Txt(txt) = &resolved_rrs[0] {
447                                 assert_eq!(txt.name.as_str(), "matt.user._bitcoin-payment.mattcorallo.com.");
448                                 assert!(txt.data.starts_with(b"bitcoin:"));
449                         } else { panic!(); }
450                 }
451         }
452 }