Return early if we get a response with no answers
[dnssec-prover] / src / query.rs
1 //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive
2 //! resolver.
3
4 use std::net::{SocketAddr, TcpStream};
5 use std::io::{Read, Write, Error, ErrorKind};
6
7 #[cfg(feature = "tokio")]
8 use tokio_crate::net::TcpStream as TokioTcpStream;
9 #[cfg(feature = "tokio")]
10 use tokio_crate::io::{AsyncReadExt, AsyncWriteExt};
11
12 use crate::rr::*;
13 use crate::ser::*;
14
15 // We don't care about transaction IDs as we're only going to accept signed data. Thus, we use
16 // this constant instead of a random value.
17 const TXID: u16 = 0x4242;
18
19 fn emap<V>(v: Result<V, ()>) -> Result<V, Error> {
20         v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response"))
21 }
22
23 fn build_query(domain: &Name, ty: u16) -> Vec<u8> {
24         // TODO: Move to not allocating for the query
25         let mut query = Vec::with_capacity(1024);
26         let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(domain) + 11;
27         query.extend_from_slice(&query_msg_len.to_be_bytes());
28         query.extend_from_slice(&TXID.to_be_bytes());
29         query.extend_from_slice(&[0x01, 0x20]); // Flags: Recursive, Authenticated Data
30         query.extend_from_slice(&[0, 1, 0, 0, 0, 0, 0, 1]); // One question, One additional
31         write_name(&mut query, domain);
32         query.extend_from_slice(&ty.to_be_bytes());
33         query.extend_from_slice(&1u16.to_be_bytes()); // INternet class
34         query.extend_from_slice(&[0, 0, 0x29]); // . OPT
35         query.extend_from_slice(&0u16.to_be_bytes()); // 0 UDP payload size
36         query.extend_from_slice(&[0, 0]); // EDNS version 0
37         query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs
38         query.extend_from_slice(&0u16.to_be_bytes()); // No additional data
39         query
40 }
41
42 fn send_query(stream: &mut TcpStream, domain: &Name, ty: u16) -> Result<(), Error> {
43         let query = build_query(domain, ty);
44         stream.write_all(&query)?;
45         Ok(())
46 }
47
48 #[cfg(feature = "tokio")]
49 async fn send_query_async(stream: &mut TokioTcpStream, domain: &Name, ty: u16) -> Result<(), Error> {
50         let query = build_query(domain, ty);
51         stream.write_all(&query).await?;
52         Ok(())
53 }
54
55 fn handle_response(resp: &[u8], proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
56         let mut read: &[u8] = resp;
57         if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); }
58         // 2 byte transaction ID
59         let flags = emap(read_u16(&mut read))?;
60         if flags & 0b1000_0000_0000_0000 == 0 {
61                 return Err(Error::new(ErrorKind::Other, "Missing response flag"));
62         }
63         if flags & 0b0111_1010_0000_0111 != 0 {
64                 return Err(Error::new(ErrorKind::Other, "Server indicated error or provided bunk flags"));
65         }
66         if flags & 0b10_0000 == 0 {
67                 return Err(Error::new(ErrorKind::Other, "Server indicated data could not be authenticated"));
68         }
69         let questions = emap(read_u16(&mut read))?;
70         if questions != 1 { return Err(Error::new(ErrorKind::Other, "server responded to multiple Qs")); }
71         let answers = emap(read_u16(&mut read))?;
72         if answers == 0 { return Err(Error::new(ErrorKind::Other, "No answers")); }
73         let _authorities = emap(read_u16(&mut read))?;
74         let _additional = emap(read_u16(&mut read))?;
75
76         for _ in 0..questions {
77                 emap(read_wire_packet_name(&mut read, resp))?;
78                 emap(read_u16(&mut read))?; // type
79                 emap(read_u16(&mut read))?; // class
80         }
81
82         // Only read the answers (skip authorities and additional) as that's all we care about.
83         let mut rrsig_opt = None;
84         for _ in 0..answers {
85                 let (rr, ttl) = emap(parse_wire_packet_rr(&mut read, &resp))?;
86                 write_rr(&rr, ttl, proof);
87                 if let RR::RRSig(rrsig) = rr { rrsig_opt = Some(rrsig); }
88         }
89         Ok(rrsig_opt)
90 }
91
92 fn read_response(stream: &mut TcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
93         let mut len = [0; 2];
94         stream.read_exact(&mut len)?;
95         let mut resp = vec![0; u16::from_be_bytes(len) as usize];
96         stream.read_exact(&mut resp)?;
97         handle_response(&resp, proof)
98 }
99
100 #[cfg(feature = "tokio")]
101 async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
102         let mut len = [0; 2];
103         stream.read_exact(&mut len).await?;
104         let mut resp = vec![0; u16::from_be_bytes(len) as usize];
105         stream.read_exact(&mut resp).await?;
106         handle_response(&resp, proof)
107 }
108
109 macro_rules! build_proof_impl {
110         ($stream: ident, $send_query: ident, $read_response: ident $(, $async_ok: tt)?) => { {
111                 let mut res = Vec::new();
112                 let mut reached_root = false;
113                 for i in 0..10 {
114                         let rrsig_opt = $read_response(&mut $stream, &mut res)
115                                 $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
116                         if let Some(rrsig) = rrsig_opt {
117                                 if rrsig.name.as_str() == "." {
118                                         reached_root = true;
119                                 } else {
120                                         if i != 0 && rrsig.name == rrsig.key_name {
121                                                 $send_query(&mut $stream, &rrsig.key_name, DS::TYPE)
122                                         } else {
123                                                 $send_query(&mut $stream, &rrsig.key_name, DnsKey::TYPE)
124                                         }$(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
125                                 }
126                         }
127                         if reached_root { break; }
128                 }
129
130                 if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
131                 else { Ok(res) }
132         } }
133 }
134
135 fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<Vec<u8>, Error> {
136         let mut stream = TcpStream::connect(resolver)?;
137         send_query(&mut stream, domain, ty)?;
138         build_proof_impl!(stream, send_query, read_response)
139 }
140
141 #[cfg(feature = "tokio")]
142 async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<Vec<u8>, Error> {
143         let mut stream = TokioTcpStream::connect(resolver).await?;
144         send_query_async(&mut stream, domain, ty).await?;
145         build_proof_impl!(stream, send_query_async, read_response_async, { Ok::<(), Error>(()) })
146 }
147
148 /// Builds a DNSSEC proof for an A record by querying a recursive resolver
149 pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
150         build_proof(resolver, domain, A::TYPE)
151 }
152
153 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver
154 pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
155         build_proof(resolver, domain, AAAA::TYPE)
156 }
157
158 /// Builds a DNSSEC proof for a TXT record by querying a recursive resolver
159 pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
160         build_proof(resolver, domain, Txt::TYPE)
161 }
162
163 /// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver
164 pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
165         build_proof(resolver, domain, TLSA::TYPE)
166 }
167
168
169 /// Builds a DNSSEC proof for an A record by querying a recursive resolver
170 #[cfg(feature = "tokio")]
171 pub async fn build_a_proof_async(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
172         build_proof_async(resolver, domain, A::TYPE).await
173 }
174
175 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver
176 #[cfg(feature = "tokio")]
177 pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
178         build_proof_async(resolver, domain, AAAA::TYPE).await
179 }
180
181 /// Builds a DNSSEC proof for a TXT record by querying a recursive resolver
182 #[cfg(feature = "tokio")]
183 pub async fn build_txt_proof_async(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
184         build_proof_async(resolver, domain, Txt::TYPE).await
185 }
186
187 /// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver
188 #[cfg(feature = "tokio")]
189 pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
190         build_proof_async(resolver, domain, TLSA::TYPE).await
191 }
192
193 #[cfg(all(feature = "validation", test))]
194 mod tests {
195         use super::*;
196         use crate::validation::*;
197
198         use rand::seq::SliceRandom;
199
200         use std::net::ToSocketAddrs;
201         use std::time::SystemTime;
202
203
204         #[test]
205         fn test_cloudflare_txt_query() {
206                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
207                 let query_name = "cloudflare.com.".try_into().unwrap();
208                 let proof = build_txt_proof(sockaddr, &query_name).unwrap();
209
210                 let mut rrs = parse_rr_stream(&proof).unwrap();
211                 rrs.shuffle(&mut rand::rngs::OsRng);
212                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
213                 assert!(verified_rrs.verified_rrs.len() > 1);
214
215                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
216                 assert!(verified_rrs.valid_from < now);
217                 assert!(verified_rrs.expires > now);
218         }
219
220         #[test]
221         fn test_txt_query() {
222                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
223                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
224                 let proof = build_txt_proof(sockaddr, &query_name).unwrap();
225
226                 let mut rrs = parse_rr_stream(&proof).unwrap();
227                 rrs.shuffle(&mut rand::rngs::OsRng);
228                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
229                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
230
231                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
232                 assert!(verified_rrs.valid_from < now);
233                 assert!(verified_rrs.expires > now);
234         }
235
236         #[cfg(feature = "tokio")]
237         use tokio_crate as tokio;
238
239         #[cfg(feature = "tokio")]
240         #[tokio::test]
241         async fn test_txt_query_async() {
242                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
243                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
244                 let proof = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
245
246                 let mut rrs = parse_rr_stream(&proof).unwrap();
247                 rrs.shuffle(&mut rand::rngs::OsRng);
248                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
249                 assert_eq!(verified_rrs.verified_rrs.len(), 1);
250
251                 let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
252                 assert!(verified_rrs.valid_from < now);
253                 assert!(verified_rrs.expires > now);
254         }
255 }