675d82b0badc6d708af4ac653e4366d3626a53ed
[dnssec-prover] / src / query.rs
1 //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive
2 //! resolver.
3
4 use std::net::{SocketAddr, TcpStream};
5 use std::io::{Read, Write, Error, ErrorKind};
6
7 #[cfg(feature = "tokio")]
8 use tokio_crate::net::TcpStream as TokioTcpStream;
9 #[cfg(feature = "tokio")]
10 use tokio_crate::io::{AsyncReadExt, AsyncWriteExt};
11
12
13 use crate::write_rr;
14 use crate::rr::*;
15 use crate::ser::*;
16
17 // We don't care about transaction IDs as we're only going to accept signed data. Thus, we use
18 // this constant instead of a random value.
19 const TXID: u16 = 0x4242;
20
21 fn emap<V>(v: Result<V, ()>) -> Result<V, Error> {
22         v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response"))
23 }
24
25 fn build_query(domain: Name, ty: u16) -> Vec<u8> {
26         // TODO: Move to not allocating for the query
27         let mut query = Vec::with_capacity(1024);
28         let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(&domain) + 11;
29         query.extend_from_slice(&query_msg_len.to_be_bytes());
30         query.extend_from_slice(&TXID.to_be_bytes());
31         query.extend_from_slice(&[0x01, 0x20]); // Flags: Recursive, Authenticated Data
32         query.extend_from_slice(&[0, 1, 0, 0, 0, 0, 0, 1]); // One question, One additional
33         write_name(&mut query, &domain);
34         query.extend_from_slice(&ty.to_be_bytes());
35         query.extend_from_slice(&1u16.to_be_bytes()); // INternet class
36         query.extend_from_slice(&[0, 0, 0x29]); // . OPT
37         query.extend_from_slice(&0u16.to_be_bytes()); // 0 UDP payload size
38         query.extend_from_slice(&[0, 0]); // EDNS version 0
39         query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs
40         query.extend_from_slice(&0u16.to_be_bytes()); // No additional data
41         query
42 }
43
44 fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> {
45         let query = build_query(domain, ty);
46         stream.write_all(&query)?;
47         Ok(())
48 }
49
50 #[cfg(feature = "tokio")]
51 async fn send_query_async(stream: &mut TokioTcpStream, domain: Name, ty: u16) -> Result<(), Error> {
52         let query = build_query(domain, ty);
53         stream.write_all(&query).await?;
54         Ok(())
55 }
56
57 fn handle_response(resp: &[u8], proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
58         let mut read: &[u8] = resp;
59         if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); }
60         // 2 byte transaction ID
61         let flags = emap(read_u16(&mut read))?;
62         if flags & 0b1000_0000_0000_0000 == 0 {
63                 return Err(Error::new(ErrorKind::Other, "Missing response flag"));
64         }
65         if flags & 0b0111_1010_0000_0111 != 0 {
66                 return Err(Error::new(ErrorKind::Other, "Server indicated error or provided bunk flags"));
67         }
68         if flags & 0b10_0000 == 0 {
69                 return Err(Error::new(ErrorKind::Other, "Server indicated data could not be authenticated"));
70         }
71         let questions = emap(read_u16(&mut read))?;
72         if questions != 1 { return Err(Error::new(ErrorKind::Other, "server responded to multiple Qs")); }
73         let answers = emap(read_u16(&mut read))?;
74         let _authorities = emap(read_u16(&mut read))?;
75         let _additional = emap(read_u16(&mut read))?;
76
77         for _ in 0..questions {
78                 emap(read_wire_packet_name(&mut read, resp))?;
79                 emap(read_u16(&mut read))?; // type
80                 emap(read_u16(&mut read))?; // class
81         }
82
83         // Only read the answers (skip authorities and additional) as that's all we care about.
84         let mut rrsig_opt = None;
85         for _ in 0..answers {
86                 let rr = emap(parse_wire_packet_rr(&mut read, &resp))?;
87                 write_rr(&rr, 0, proof);
88                 if let RR::RRSig(rrsig) = rr { rrsig_opt = Some(rrsig); }
89         }
90         Ok(rrsig_opt)
91 }
92
93 fn read_response(stream: &mut TcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
94         let mut len = [0; 2];
95         stream.read_exact(&mut len)?;
96         let mut resp = vec![0; u16::from_be_bytes(len) as usize];
97         stream.read_exact(&mut resp)?;
98         handle_response(&resp, proof)
99 }
100
101 #[cfg(feature = "tokio")]
102 async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
103         let mut len = [0; 2];
104         stream.read_exact(&mut len).await?;
105         let mut resp = vec![0; u16::from_be_bytes(len) as usize];
106         stream.read_exact(&mut resp).await?;
107         handle_response(&resp, proof)
108 }
109
110 fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
111         let mut stream = TcpStream::connect(resolver)?;
112         let mut res = Vec::new();
113         send_query(&mut stream, domain, ty)?;
114         let mut reached_root = false;
115         for _ in 0..10 {
116                 let rrsig_opt = read_response(&mut stream, &mut res)?;
117                 if let Some(rrsig) = rrsig_opt {
118                         if rrsig.name.as_str() == "." {
119                                 reached_root = true;
120                         } else {
121                                 if rrsig.name == rrsig.key_name {
122                                         send_query(&mut stream, rrsig.key_name, DS::TYPE)?;
123                                 } else {
124                                         send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?;
125                                 }
126                         }
127                 }
128                 if reached_root { break; }
129         }
130
131         if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
132         else { Ok(res) }
133 }
134
135 #[cfg(feature = "tokio")]
136 async fn build_proof_async(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
137         let mut stream = TokioTcpStream::connect(resolver).await?;
138         let mut res = Vec::new();
139         send_query_async(&mut stream, domain, ty).await?;
140         let mut reached_root = false;
141         for _ in 0..10 {
142                 let rrsig_opt = read_response_async(&mut stream, &mut res).await?;
143                 if let Some(rrsig) = rrsig_opt {
144                         if rrsig.name.as_str() == "." {
145                                 reached_root = true;
146                         } else {
147                                 if rrsig.name == rrsig.key_name {
148                                         send_query_async(&mut stream, rrsig.key_name, DS::TYPE).await?;
149                                 } else {
150                                         send_query_async(&mut stream, rrsig.key_name, DnsKey::TYPE).await?;
151                                 }
152                         }
153                 }
154                 if reached_root { break; }
155         }
156
157         if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
158         else { Ok(res) }
159 }
160
161 /// Builds a DNSSEC proof for an A record by querying a recursive resolver
162 pub fn build_a_proof(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
163         build_proof(resolver, domain, A::TYPE)
164 }
165
166 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver
167 pub fn build_aaaa_proof(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
168         build_proof(resolver, domain, AAAA::TYPE)
169 }
170
171 /// Builds a DNSSEC proof for a TXT record by querying a recursive resolver
172 pub fn build_txt_proof(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
173         build_proof(resolver, domain, Txt::TYPE)
174 }
175
176 /// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver
177 pub fn build_tlsa_proof(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
178         build_proof(resolver, domain, TLSA::TYPE)
179 }
180
181
182 /// Builds a DNSSEC proof for an A record by querying a recursive resolver
183 #[cfg(feature = "tokio")]
184 pub async fn build_a_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
185         build_proof_async(resolver, domain, A::TYPE).await
186 }
187
188 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver
189 #[cfg(feature = "tokio")]
190 pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
191         build_proof_async(resolver, domain, AAAA::TYPE).await
192 }
193
194 /// Builds a DNSSEC proof for a TXT record by querying a recursive resolver
195 #[cfg(feature = "tokio")]
196 pub async fn build_txt_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
197         build_proof_async(resolver, domain, Txt::TYPE).await
198 }
199
200 /// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver
201 #[cfg(feature = "tokio")]
202 pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
203         build_proof_async(resolver, domain, TLSA::TYPE).await
204 }
205
206 #[cfg(test)]
207 mod tests {
208         use super::*;
209         use crate::*;
210
211         use rand::seq::SliceRandom;
212
213         use std::net::ToSocketAddrs;
214
215         #[test]
216         fn test_txt_query() {
217                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
218                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
219                 let proof = build_txt_proof(sockaddr, query_name).unwrap();
220
221                 let mut rrs = parse_rr_stream(&proof).unwrap();
222                 rrs.shuffle(&mut rand::rngs::OsRng);
223                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
224                 assert_eq!(verified_rrs.len(), 1);
225         }
226
227         #[cfg(feature = "tokio")]
228         use tokio_crate as tokio;
229
230         #[cfg(feature = "tokio")]
231         #[tokio::test]
232         async fn test_txt_query_async() {
233                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
234                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
235                 let proof = build_txt_proof_async(sockaddr, query_name).await.unwrap();
236
237                 let mut rrs = parse_rr_stream(&proof).unwrap();
238                 rrs.shuffle(&mut rand::rngs::OsRng);
239                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
240                 assert_eq!(verified_rrs.len(), 1);
241         }
242 }