Add support for building proofs using a local recursive resolver
[dnssec-prover] / src / query.rs
1 //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive
2 //! resolver.
3
4 use std::net::{SocketAddr, TcpStream};
5 use std::io::{Read, Write, Error, ErrorKind};
6
7 use crate::write_rr;
8 use crate::rr::*;
9 use crate::ser::*;
10
11 // We don't care about transaction IDs as we're only going to accept signed data. Thus, we use
12 // this constant instead of a random value.
13 const TXID: u16 = 0x4242;
14
15 fn emap<V>(v: Result<V, ()>) -> Result<V, Error> {
16         v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response"))
17 }
18
19 fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> {
20         let mut query = Vec::with_capacity(1024);
21         let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(&domain) + 11;
22         query.extend_from_slice(&query_msg_len.to_be_bytes());
23         query.extend_from_slice(&TXID.to_be_bytes());
24         query.extend_from_slice(&[0x01, 0x20]); // Flags: Recursive, Authenticated Data
25         query.extend_from_slice(&[0, 1, 0, 0, 0, 0, 0, 1]); // One question, One additional
26         write_name(&mut query, &domain);
27         query.extend_from_slice(&ty.to_be_bytes());
28         query.extend_from_slice(&1u16.to_be_bytes()); // INternet class
29         query.extend_from_slice(&[0, 0, 0x29]); // . OPT
30         query.extend_from_slice(&0u16.to_be_bytes()); // 0 UDP payload size
31         query.extend_from_slice(&[0, 0]); // EDNS version 0
32         query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs
33         query.extend_from_slice(&0u16.to_be_bytes()); // No additional data
34         stream.write_all(&query)?;
35         Ok(())
36 }
37
38 fn read_response(stream: &mut TcpStream) -> Result<Vec<RR>, Error> {
39         let mut len = [0; 2];
40         stream.read_exact(&mut len)?;
41         let mut resp = vec![0; u16::from_be_bytes(len) as usize];
42         stream.read_exact(&mut resp)?;
43
44         let mut read: &[u8] = &resp;
45         if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); }
46         // 2 byte transaction ID
47         let flags = emap(read_u16(&mut read))?;
48         if flags & 0b1000_0000_0000_0000 == 0 {
49                 return Err(Error::new(ErrorKind::Other, "Missing response flag"));
50         }
51         if flags & 0b0111_1010_0000_0111 != 0 {
52                 return Err(Error::new(ErrorKind::Other, "Server indicated error or provided bunk flags"));
53         }
54         if flags & 0b10_0000 == 0 {
55                 return Err(Error::new(ErrorKind::Other, "Server indicated data could not be authenticated"));
56         }
57         let questions = emap(read_u16(&mut read))?;
58         if questions != 1 { return Err(Error::new(ErrorKind::Other, "server responded to multiple Qs")); }
59         let answers = emap(read_u16(&mut read))?;
60         let _authorities = emap(read_u16(&mut read))?;
61         let _additional = emap(read_u16(&mut read))?;
62
63         for _ in 0..questions {
64                 emap(read_name(&mut read))?;
65                 emap(read_u16(&mut read))?; // type
66                 emap(read_u16(&mut read))?; // class
67         }
68
69         // Only read the answers (skip authorities and additional) as that's all we care about.
70         let mut res = Vec::new();
71         for _ in 0..answers {
72                 res.push(emap(parse_wire_packet_rr(&mut read, &resp))?);
73         }
74         Ok(res)
75 }
76
77 fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
78         let mut stream = TcpStream::connect(resolver)?;
79         let mut res = Vec::new();
80         send_query(&mut stream, domain, ty)?;
81         let mut reached_root = false;
82         for _ in 0..10 {
83                 let resp = read_response(&mut stream)?;
84                 for rr in resp {
85                         write_rr(&rr, 0, &mut res);
86                         if rr.name().as_str() == "." {
87                                 reached_root = true;
88                         } else {
89                                 if let RR::RRSig(rrsig) = rr {
90                                         if rrsig.name == rrsig.key_name {
91                                                 send_query(&mut stream, rrsig.key_name, DS::TYPE)?;
92                                         } else {
93                                                 send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?;
94                                         }
95                                 }
96                         }
97                 }
98                 if reached_root { break; }
99         }
100
101         if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
102         else { Ok(res) }
103 }
104
105 /// Builds a DNSSEC proof for an A record by querying a recursive resolver
106 pub fn build_a_proof(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
107         build_proof(resolver, domain, A::TYPE)
108 }
109
110 /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver
111 pub fn build_aaaa_proof(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
112         build_proof(resolver, domain, AAAA::TYPE)
113 }
114
115 /// Builds a DNSSEC proof for a TXT record by querying a recursive resolver
116 pub fn build_txt_proof(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
117         build_proof(resolver, domain, Txt::TYPE)
118 }
119
120 /// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver
121 pub fn build_tlsa_proof(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
122         build_proof(resolver, domain, TLSA::TYPE)
123 }
124
125 #[cfg(test)]
126 mod tests {
127         use super::*;
128         use crate::*;
129
130         use rand::seq::SliceRandom;
131
132         use std::net::ToSocketAddrs;
133
134         #[test]
135         fn test_txt_query() {
136                 let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
137                 let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
138                 let proof = build_txt_proof(sockaddr, query_name).unwrap();
139
140                 let mut rrs = parse_rr_stream(&proof).unwrap();
141                 rrs.shuffle(&mut rand::rngs::OsRng);
142                 let verified_rrs = verify_rr_stream(&rrs).unwrap();
143                 assert_eq!(verified_rrs.len(), 1);
144         }
145 }