Add fuzzing of DNS server response parsing
[dnssec-prover] / src / query.rs
1 //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive
2 //! resolver.
3
4 use std::cmp;
5 use std::net::{SocketAddr, TcpStream};
6 use std::io::{Read, Write, Error, ErrorKind};
7
8 #[cfg(feature = "tokio")]
9 use tokio_crate::net::TcpStream as TokioTcpStream;
10 #[cfg(feature = "tokio")]
11 use tokio_crate::io::{AsyncReadExt, AsyncWriteExt};
12
13 use crate::rr::*;
14 use crate::ser::*;
15
16 // We don't care about transaction IDs as we're only going to accept signed data. Thus, we use
17 // this constant instead of a random value.
18 const TXID: u16 = 0x4242;
19
20 fn emap<V>(v: Result<V, ()>) -> Result<V, Error> {
21         v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response"))
22 }
23
24 fn build_query(domain: &Name, ty: u16) -> Vec<u8> {
25         // TODO: Move to not allocating for the query
26         let mut query = Vec::with_capacity(1024);
27         let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(domain) + 11;
28         query.extend_from_slice(&query_msg_len.to_be_bytes());
29         query.extend_from_slice(&TXID.to_be_bytes());
30         query.extend_from_slice(&[0x01, 0x20]); // Flags: Recursive, Authenticated Data
31         query.extend_from_slice(&[0, 1, 0, 0, 0, 0, 0, 1]); // One question, One additional
32         write_name(&mut query, domain);
33         query.extend_from_slice(&ty.to_be_bytes());
34         query.extend_from_slice(&1u16.to_be_bytes()); // INternet class
35         query.extend_from_slice(&[0, 0, 0x29]); // . OPT
36         query.extend_from_slice(&0u16.to_be_bytes()); // 0 UDP payload size
37         query.extend_from_slice(&[0, 0]); // EDNS version 0
38         query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs
39         query.extend_from_slice(&0u16.to_be_bytes()); // No additional data
40         query
41 }
42
43 fn send_query(stream: &mut TcpStream, domain: &Name, ty: u16) -> Result<(), Error> {
44         let query = build_query(domain, ty);
45         stream.write_all(&query)?;
46         Ok(())
47 }
48
49 #[cfg(feature = "tokio")]
50 async fn send_query_async(stream: &mut TokioTcpStream, domain: &Name, ty: u16) -> Result<(), Error> {
51         let query = build_query(domain, ty);
52         stream.write_all(&query).await?;
53         Ok(())
54 }
55
56 #[cfg(fuzzing)]
57 /// Read some input and parse it as if it came from a server, for fuzzing.
58 pub fn fuzz_response(response: &[u8]) {
59         let (mut proof, mut names) = (Vec::new(), Vec::new());
60         let _ = handle_response(response, &mut proof, &mut names);
61 }
62
63 fn handle_response(resp: &[u8], proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<Name>) -> Result<u32, Error> {
64         let mut read: &[u8] = resp;
65         if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); }
66         // 2 byte transaction ID
67         let flags = emap(read_u16(&mut read))?;
68         if flags & 0b1000_0000_0000_0000 == 0 {
69                 return Err(Error::new(ErrorKind::Other, "Missing response flag"));
70         }
71         if flags & 0b0111_1010_0000_0111 != 0 {
72                 return Err(Error::new(ErrorKind::Other, "Server indicated error or provided bunk flags"));
73         }
74         if flags & 0b10_0000 == 0 {
75                 return Err(Error::new(ErrorKind::Other, "Server indicated data could not be authenticated"));
76         }
77         let questions = emap(read_u16(&mut read))?;
78         if questions != 1 { return Err(Error::new(ErrorKind::Other, "server responded to multiple Qs")); }
79         let answers = emap(read_u16(&mut read))?;
80         if answers == 0 { return Err(Error::new(ErrorKind::Other, "No answers")); }
81         let _authorities = emap(read_u16(&mut read))?;
82         let _additional = emap(read_u16(&mut read))?;
83
84         for _ in 0..questions {
85                 emap(read_wire_packet_name(&mut read, resp))?;
86                 emap(read_u16(&mut read))?; // type
87                 emap(read_u16(&mut read))?; // class
88         }
89
90         // Only read the answers (skip authorities and additional) as that's all we care about.
91         let mut min_ttl = u32::MAX;
92         for _ in 0..answers {
93                 let (rr, ttl) = emap(parse_wire_packet_rr(&mut read, &resp))?;
94                 write_rr(&rr, ttl, proof);
95                 min_ttl = cmp::min(min_ttl, ttl);
96                 if let RR::RRSig(rrsig) = rr { rrsig_key_names.push(rrsig.key_name); }
97         }
98         Ok(min_ttl)
99 }
100
101 fn read_response(stream: &mut TcpStream, proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<Name>) -> Result<u32, Error> {
102         let mut len = [0; 2];
103         stream.read_exact(&mut len)?;
104         let mut resp = vec![0; u16::from_be_bytes(len) as usize];
105         stream.read_exact(&mut resp)?;
106         handle_response(&resp, proof, rrsig_key_names)
107 }
108
109 #[cfg(feature = "tokio")]
110 async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<Name>) -> Result<u32, Error> {
111         let mut len = [0; 2];
112         stream.read_exact(&mut len).await?;
113         let mut resp = vec![0; u16::from_be_bytes(len) as usize];
114         stream.read_exact(&mut resp).await?;
115         handle_response(&resp, proof, rrsig_key_names)
116 }
117
118 macro_rules! build_proof_impl {
119         ($stream: ident, $send_query: ident, $read_response: ident $(, $async_ok: tt)?) => { {
120                 // We require the initial query to have already gone out, and assume our resolver will
121                 // return any CNAMEs all the way to the final record in the response. From there, we just
122                 // have to take any RRSIGs in the response and walk them up to the root. We do so
123                 // iteratively, sending DNSKEY and DS lookups after every response, deduplicating requests
124                 // using `dnskeys_requested`.
125                 let mut res = Vec::new(); // The actual proof stream
126                 let mut min_ttl = u32::MAX; // Min TTL of any answer record
127                 const MAX_REQUESTS: usize = 20;
128                 let mut rrsig_key_names = Vec::with_capacity(4); // Last response's RRSIG key_names
129                 let mut dnskeys_requested = Vec::with_capacity(MAX_REQUESTS);
130                 let mut pending_queries = 1;
131                 let mut queries_made = 1;
132                 while pending_queries != 0 && queries_made <= MAX_REQUESTS {
133                         let response_min_ttl = $read_response(&mut $stream, &mut res, &mut rrsig_key_names)
134                                 $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
135                         pending_queries -= 1;
136                         min_ttl = cmp::min(min_ttl, response_min_ttl);
137                         rrsig_key_names.sort_unstable();
138                         rrsig_key_names.dedup();
139                         for key_name in rrsig_key_names.drain(..) {
140                                 if !dnskeys_requested.contains(&key_name) {
141                                         $send_query(&mut $stream, &key_name, DnsKey::TYPE)
142                                                 $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
143                                         pending_queries += 1;
144                                         queries_made += 1;
145                                         dnskeys_requested.push(key_name.clone());
146
147                                         if key_name.as_str() != "." {
148                                                 $send_query(&mut $stream, &key_name, DS::TYPE)
149                                                         $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
150                                                 pending_queries += 1;
151                                                 queries_made += 1;
152                                         }
153                                 }
154                         }
155                 }
156
157                 if queries_made > MAX_REQUESTS {
158                         Err(Error::new(ErrorKind::Other, "Too many requests required"))
159                 } else {
160                         Ok((res, min_ttl))
161                 }
162         } }
163 }
164
165 fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
166         let mut stream = TcpStream::connect(resolver)?;
167         send_query(&mut stream, domain, ty)?;
168         build_proof_impl!(stream, send_query, read_response)
169 }
170
171 #[cfg(feature = "tokio")]
172 async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
173         let mut stream = TokioTcpStream::connect(resolver).await?;
174         send_query_async(&mut stream, domain, ty).await?;
175         build_proof_impl!(stream, send_query_async, read_response_async, { Ok::<(), Error>(()) })
176 }
177
178 /// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as
179 /// well as the TTL for the proof provided by the recursive resolver.
180 ///
181 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
182 /// module to validate the records contained.
183 pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
184         build_proof(resolver, domain, A::TYPE)
185 }
186
187 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof
188 /// as well as the TTL for the proof provided by the recursive resolver.
189 ///
190 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
191 /// module to validate the records contained.
192 pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
193         build_proof(resolver, domain, AAAA::TYPE)
194 }
195
196 /// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof
197 /// as well as the TTL for the proof provided by the recursive resolver.
198 ///
199 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
200 /// module to validate the records contained.
201 pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
202         build_proof(resolver, domain, Txt::TYPE)
203 }
204
205 /// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof
206 /// as well as the TTL for the proof provided by the recursive resolver.
207 ///
208 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
209 /// module to validate the records contained.
210 pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
211         build_proof(resolver, domain, TLSA::TYPE)
212 }
213
214
215 /// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as
216 /// well as the TTL for the proof provided by the recursive resolver.
217 ///
218 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
219 /// module to validate the records contained.
220 #[cfg(feature = "tokio")]
221 pub async fn build_a_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
222         build_proof_async(resolver, domain, A::TYPE).await
223 }
224
225 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof
226 /// as well as the TTL for the proof provided by the recursive resolver.
227 ///
228 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
229 /// module to validate the records contained.
230 #[cfg(feature = "tokio")]
231 pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
232         build_proof_async(resolver, domain, AAAA::TYPE).await
233 }
234
235 /// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof
236 /// as well as the TTL for the proof provided by the recursive resolver.
237 ///
238 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
239 /// module to validate the records contained.
240 #[cfg(feature = "tokio")]
241 pub async fn build_txt_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
242         build_proof_async(resolver, domain, Txt::TYPE).await
243 }
244
245 /// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof
246 /// as well as the TTL for the proof provided by the recursive resolver.
247 ///
248 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
249 /// module to validate the records contained.
250 #[cfg(feature = "tokio")]
251 pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
252         build_proof_async(resolver, domain, TLSA::TYPE).await
253 }
254
255 #[cfg(all(feature = "validation", test))]
256 mod tests {
257         use super::*;
258         use crate::validation::*;
259
260         use rand::seq::SliceRandom;
261
262         use std::net::ToSocketAddrs;
263         use std::time::SystemTime;
264
265
266         #[test]
267         fn test_cloudflare_txt_query() {
268                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
269                 let query_name = "cloudflare.com.".try_into().unwrap();
270                 let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
271
272                 let mut rrs = parse_rr_stream(&proof).unwrap();
273                 rrs.shuffle(&mut rand::rngs::OsRng);
274                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
275                 assert!(verified_rrs.verified_rrs.len() > 1);
276
277                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
278                 assert!(verified_rrs.valid_from < now);
279                 assert!(verified_rrs.expires > now);
280         }
281
282         #[test]
283         fn test_sha1_query() {
284                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
285                 let query_name = "benthecarman.com.".try_into().unwrap();
286                 let (proof, _) = build_a_proof(sockaddr, &query_name).unwrap();
287
288                 let mut rrs = parse_rr_stream(&proof).unwrap();
289                 rrs.shuffle(&mut rand::rngs::OsRng);
290                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
291                 assert!(verified_rrs.verified_rrs.len() >= 1);
292
293                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
294                 assert!(verified_rrs.valid_from < now);
295                 assert!(verified_rrs.expires > now);
296         }
297
298         #[test]
299         fn test_txt_query() {
300                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
301                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
302                 let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
303
304                 let mut rrs = parse_rr_stream(&proof).unwrap();
305                 rrs.shuffle(&mut rand::rngs::OsRng);
306                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
307                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
308
309                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
310                 assert!(verified_rrs.valid_from < now);
311                 assert!(verified_rrs.expires > now);
312         }
313
314         #[test]
315         fn test_cname_query() {
316                 for resolver in ["1.1.1.1:53", "8.8.8.8:53", "9.9.9.9:53"] {
317                         let sockaddr = resolver.to_socket_addrs().unwrap().next().unwrap();
318                         let query_name = "cname_test.matcorallo.com.".try_into().unwrap();
319                         let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
320
321                         let mut rrs = parse_rr_stream(&proof).unwrap();
322                         rrs.shuffle(&mut rand::rngs::OsRng);
323                         let verified_rrs = verify_rr_stream(&rrs).unwrap();
324                         assert_eq!(verified_rrs.verified_rrs.len(), 2);
325
326                         let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
327                         assert!(verified_rrs.valid_from < now);
328                         assert!(verified_rrs.expires > now);
329
330                         let resolved_rrs = verified_rrs.resolve_name(&query_name);
331                         assert_eq!(resolved_rrs.len(), 1);
332                         if let RR::Txt(txt) = &resolved_rrs[0] {
333                                 assert_eq!(txt.name.as_str(), "txt_test.matcorallo.com.");
334                                 assert_eq!(txt.data, b"dnssec_prover_test");
335                         } else { panic!(); }
336                 }
337         }
338
339         #[cfg(feature = "tokio")]
340         use tokio_crate as tokio;
341
342         #[cfg(feature = "tokio")]
343         #[tokio::test]
344         async fn test_txt_query_async() {
345                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
346                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
347                 let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
348
349                 let mut rrs = parse_rr_stream(&proof).unwrap();
350                 rrs.shuffle(&mut rand::rngs::OsRng);
351                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
352                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
353
354                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
355                 assert!(verified_rrs.valid_from < now);
356                 assert!(verified_rrs.expires > now);
357         }
358
359         #[cfg(feature = "tokio")]
360         #[tokio::test]
361         async fn test_cross_domain_cname_query_async() {
362                 for resolver in ["1.1.1.1:53", "8.8.8.8:53", "9.9.9.9:53"] {
363                         let sockaddr = resolver.to_socket_addrs().unwrap().next().unwrap();
364                         let query_name = "wildcard.x_domain_cname_wild.matcorallo.com.".try_into().unwrap();
365                         let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
366
367                         let mut rrs = parse_rr_stream(&proof).unwrap();
368                         rrs.shuffle(&mut rand::rngs::OsRng);
369                         let verified_rrs = verify_rr_stream(&rrs).unwrap();
370                         assert_eq!(verified_rrs.verified_rrs.len(), 2);
371
372                         let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
373                         assert!(verified_rrs.valid_from < now);
374                         assert!(verified_rrs.expires > now);
375
376                         let resolved_rrs = verified_rrs.resolve_name(&query_name);
377                         assert_eq!(resolved_rrs.len(), 1);
378                         if let RR::Txt(txt) = &resolved_rrs[0] {
379                                 assert_eq!(txt.name.as_str(), "matt.user._bitcoin-payment.mattcorallo.com.");
380                                 assert!(txt.data.starts_with(b"bitcoin:"));
381                         } else { panic!(); }
382                 }
383         }
384 }