Set a cache-control header on valid responses
[dnssec-prover] / src / query.rs
1 //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive
2 //! resolver.
3
4 use std::cmp;
5 use std::net::{SocketAddr, TcpStream};
6 use std::io::{Read, Write, Error, ErrorKind};
7
8 #[cfg(feature = "tokio")]
9 use tokio_crate::net::TcpStream as TokioTcpStream;
10 #[cfg(feature = "tokio")]
11 use tokio_crate::io::{AsyncReadExt, AsyncWriteExt};
12
13 use crate::rr::*;
14 use crate::ser::*;
15
16 // We don't care about transaction IDs as we're only going to accept signed data. Thus, we use
17 // this constant instead of a random value.
18 const TXID: u16 = 0x4242;
19
20 fn emap<V>(v: Result<V, ()>) -> Result<V, Error> {
21         v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response"))
22 }
23
24 fn build_query(domain: &Name, ty: u16) -> Vec<u8> {
25         // TODO: Move to not allocating for the query
26         let mut query = Vec::with_capacity(1024);
27         let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(domain) + 11;
28         query.extend_from_slice(&query_msg_len.to_be_bytes());
29         query.extend_from_slice(&TXID.to_be_bytes());
30         query.extend_from_slice(&[0x01, 0x20]); // Flags: Recursive, Authenticated Data
31         query.extend_from_slice(&[0, 1, 0, 0, 0, 0, 0, 1]); // One question, One additional
32         write_name(&mut query, domain);
33         query.extend_from_slice(&ty.to_be_bytes());
34         query.extend_from_slice(&1u16.to_be_bytes()); // INternet class
35         query.extend_from_slice(&[0, 0, 0x29]); // . OPT
36         query.extend_from_slice(&0u16.to_be_bytes()); // 0 UDP payload size
37         query.extend_from_slice(&[0, 0]); // EDNS version 0
38         query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs
39         query.extend_from_slice(&0u16.to_be_bytes()); // No additional data
40         query
41 }
42
43 fn send_query(stream: &mut TcpStream, domain: &Name, ty: u16) -> Result<(), Error> {
44         let query = build_query(domain, ty);
45         stream.write_all(&query)?;
46         Ok(())
47 }
48
49 #[cfg(feature = "tokio")]
50 async fn send_query_async(stream: &mut TokioTcpStream, domain: &Name, ty: u16) -> Result<(), Error> {
51         let query = build_query(domain, ty);
52         stream.write_all(&query).await?;
53         Ok(())
54 }
55
56 fn handle_response(resp: &[u8], proof: &mut Vec<u8>) -> Result<Option<(RRSig, u32)>, Error> {
57         let mut read: &[u8] = resp;
58         if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); }
59         // 2 byte transaction ID
60         let flags = emap(read_u16(&mut read))?;
61         if flags & 0b1000_0000_0000_0000 == 0 {
62                 return Err(Error::new(ErrorKind::Other, "Missing response flag"));
63         }
64         if flags & 0b0111_1010_0000_0111 != 0 {
65                 return Err(Error::new(ErrorKind::Other, "Server indicated error or provided bunk flags"));
66         }
67         if flags & 0b10_0000 == 0 {
68                 return Err(Error::new(ErrorKind::Other, "Server indicated data could not be authenticated"));
69         }
70         let questions = emap(read_u16(&mut read))?;
71         if questions != 1 { return Err(Error::new(ErrorKind::Other, "server responded to multiple Qs")); }
72         let answers = emap(read_u16(&mut read))?;
73         if answers == 0 { return Err(Error::new(ErrorKind::Other, "No answers")); }
74         let _authorities = emap(read_u16(&mut read))?;
75         let _additional = emap(read_u16(&mut read))?;
76
77         for _ in 0..questions {
78                 emap(read_wire_packet_name(&mut read, resp))?;
79                 emap(read_u16(&mut read))?; // type
80                 emap(read_u16(&mut read))?; // class
81         }
82
83         // Only read the answers (skip authorities and additional) as that's all we care about.
84         let mut rrsig_opt = None;
85         let mut min_ttl = u32::MAX;
86         for _ in 0..answers {
87                 let (rr, ttl) = emap(parse_wire_packet_rr(&mut read, &resp))?;
88                 write_rr(&rr, ttl, proof);
89                 min_ttl = cmp::min(min_ttl, ttl);
90                 if let RR::RRSig(rrsig) = rr { rrsig_opt = Some(rrsig); }
91         }
92         Ok(rrsig_opt.map(|rr| (rr, min_ttl)))
93 }
94
95 fn read_response(stream: &mut TcpStream, proof: &mut Vec<u8>) -> Result<Option<(RRSig, u32)>, Error> {
96         let mut len = [0; 2];
97         stream.read_exact(&mut len)?;
98         let mut resp = vec![0; u16::from_be_bytes(len) as usize];
99         stream.read_exact(&mut resp)?;
100         handle_response(&resp, proof)
101 }
102
103 #[cfg(feature = "tokio")]
104 async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec<u8>) -> Result<Option<(RRSig, u32)>, Error> {
105         let mut len = [0; 2];
106         stream.read_exact(&mut len).await?;
107         let mut resp = vec![0; u16::from_be_bytes(len) as usize];
108         stream.read_exact(&mut resp).await?;
109         handle_response(&resp, proof)
110 }
111
112 macro_rules! build_proof_impl {
113         ($stream: ident, $send_query: ident, $read_response: ident $(, $async_ok: tt)?) => { {
114                 let mut res = Vec::new();
115                 let mut reached_root = false;
116                 let mut min_ttl = u32::MAX;
117                 for i in 0..10 {
118                         let resp_opt = $read_response(&mut $stream, &mut res)
119                                 $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
120                         if let Some((rrsig, rrsig_min_ttl)) = resp_opt {
121                                 min_ttl = cmp::min(min_ttl, rrsig_min_ttl);
122                                 if rrsig.name.as_str() == "." {
123                                         reached_root = true;
124                                 } else {
125                                         if i != 0 && rrsig.name == rrsig.key_name {
126                                                 $send_query(&mut $stream, &rrsig.key_name, DS::TYPE)
127                                         } else {
128                                                 $send_query(&mut $stream, &rrsig.key_name, DnsKey::TYPE)
129                                         }$(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
130                                 }
131                         }
132                         if reached_root { break; }
133                 }
134
135                 if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
136                 else { Ok((res, min_ttl)) }
137         } }
138 }
139
140 fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
141         let mut stream = TcpStream::connect(resolver)?;
142         send_query(&mut stream, domain, ty)?;
143         build_proof_impl!(stream, send_query, read_response)
144 }
145
146 #[cfg(feature = "tokio")]
147 async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
148         let mut stream = TokioTcpStream::connect(resolver).await?;
149         send_query_async(&mut stream, domain, ty).await?;
150         build_proof_impl!(stream, send_query_async, read_response_async, { Ok::<(), Error>(()) })
151 }
152
153 /// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as
154 /// well as the TTL for the proof provided by the recursive resolver.
155 pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
156         build_proof(resolver, domain, A::TYPE)
157 }
158
159 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof
160 /// as well as the TTL for the proof provided by the recursive resolver.
161 pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
162         build_proof(resolver, domain, AAAA::TYPE)
163 }
164
165 /// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof
166 /// as well as the TTL for the proof provided by the recursive resolver.
167 pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
168         build_proof(resolver, domain, Txt::TYPE)
169 }
170
171 /// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof
172 /// as well as the TTL for the proof provided by the recursive resolver.
173 pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
174         build_proof(resolver, domain, TLSA::TYPE)
175 }
176
177
178 /// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as
179 /// well as the TTL for the proof provided by the recursive resolver.
180 #[cfg(feature = "tokio")]
181 pub async fn build_a_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
182         build_proof_async(resolver, domain, A::TYPE).await
183 }
184
185 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof
186 /// as well as the TTL for the proof provided by the recursive resolver.
187 #[cfg(feature = "tokio")]
188 pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
189         build_proof_async(resolver, domain, AAAA::TYPE).await
190 }
191
192 /// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof
193 /// as well as the TTL for the proof provided by the recursive resolver.
194 #[cfg(feature = "tokio")]
195 pub async fn build_txt_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
196         build_proof_async(resolver, domain, Txt::TYPE).await
197 }
198
199 /// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof
200 /// as well as the TTL for the proof provided by the recursive resolver.
201 #[cfg(feature = "tokio")]
202 pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
203         build_proof_async(resolver, domain, TLSA::TYPE).await
204 }
205
206 #[cfg(all(feature = "validation", test))]
207 mod tests {
208         use super::*;
209         use crate::validation::*;
210
211         use rand::seq::SliceRandom;
212
213         use std::net::ToSocketAddrs;
214         use std::time::SystemTime;
215
216
217         #[test]
218         fn test_cloudflare_txt_query() {
219                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
220                 let query_name = "cloudflare.com.".try_into().unwrap();
221                 let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
222
223                 let mut rrs = parse_rr_stream(&proof).unwrap();
224                 rrs.shuffle(&mut rand::rngs::OsRng);
225                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
226                 assert!(verified_rrs.verified_rrs.len() > 1);
227
228                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
229                 assert!(verified_rrs.valid_from < now);
230                 assert!(verified_rrs.expires > now);
231         }
232
233         #[test]
234         fn test_sha1_query() {
235                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
236                 let query_name = "benthecarman.com.".try_into().unwrap();
237                 let (proof, _) = build_a_proof(sockaddr, &query_name).unwrap();
238
239                 let mut rrs = parse_rr_stream(&proof).unwrap();
240                 rrs.shuffle(&mut rand::rngs::OsRng);
241                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
242                 assert!(verified_rrs.verified_rrs.len() >= 1);
243
244                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
245                 assert!(verified_rrs.valid_from < now);
246                 assert!(verified_rrs.expires > now);
247         }
248
249         #[test]
250         fn test_txt_query() {
251                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
252                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
253                 let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
254
255                 let mut rrs = parse_rr_stream(&proof).unwrap();
256                 rrs.shuffle(&mut rand::rngs::OsRng);
257                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
258                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
259
260                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
261                 assert!(verified_rrs.valid_from < now);
262                 assert!(verified_rrs.expires > now);
263         }
264
265         #[cfg(feature = "tokio")]
266         use tokio_crate as tokio;
267
268         #[cfg(feature = "tokio")]
269         #[tokio::test]
270         async fn test_txt_query_async() {
271                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
272                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
273                 let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
274
275                 let mut rrs = parse_rr_stream(&proof).unwrap();
276                 rrs.shuffle(&mut rand::rngs::OsRng);
277                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
278                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
279
280                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
281                 assert!(verified_rrs.valid_from < now);
282                 assert!(verified_rrs.expires > now);
283         }
284 }