Correct cross-zone CNAME handling in proof generation
[dnssec-prover] / src / query.rs
index adf687a176c87f7fd733a6ffff637ce5ec813d1e..83f4227f4fee55aa60107b48a97de399620705e0 100644 (file)
@@ -1,6 +1,7 @@
 //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive
 //! resolver.
 
+use std::cmp;
 use std::net::{SocketAddr, TcpStream};
 use std::io::{Read, Write, Error, ErrorKind};
 
@@ -52,7 +53,7 @@ async fn send_query_async(stream: &mut TokioTcpStream, domain: &Name, ty: u16) -
        Ok(())
 }
 
-fn handle_response(resp: &[u8], proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
+fn handle_response(resp: &[u8], proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<Name>) -> Result<u32, Error> {
        let mut read: &[u8] = resp;
        if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); }
        // 2 byte transaction ID
@@ -80,113 +81,143 @@ fn handle_response(resp: &[u8], proof: &mut Vec<u8>) -> Result<Option<RRSig>, Er
        }
 
        // Only read the answers (skip authorities and additional) as that's all we care about.
-       let mut rrsig_opt = None;
+       let mut min_ttl = u32::MAX;
        for _ in 0..answers {
                let (rr, ttl) = emap(parse_wire_packet_rr(&mut read, &resp))?;
                write_rr(&rr, ttl, proof);
-               if let RR::RRSig(rrsig) = rr { rrsig_opt = Some(rrsig); }
+               min_ttl = cmp::min(min_ttl, ttl);
+               if let RR::RRSig(rrsig) = rr { rrsig_key_names.push(rrsig.key_name); }
        }
-       Ok(rrsig_opt)
+       Ok(min_ttl)
 }
 
-fn read_response(stream: &mut TcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
+fn read_response(stream: &mut TcpStream, proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<Name>) -> Result<u32, Error> {
        let mut len = [0; 2];
        stream.read_exact(&mut len)?;
        let mut resp = vec![0; u16::from_be_bytes(len) as usize];
        stream.read_exact(&mut resp)?;
-       handle_response(&resp, proof)
+       handle_response(&resp, proof, rrsig_key_names)
 }
 
 #[cfg(feature = "tokio")]
-async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
+async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<Name>) -> Result<u32, Error> {
        let mut len = [0; 2];
        stream.read_exact(&mut len).await?;
        let mut resp = vec![0; u16::from_be_bytes(len) as usize];
        stream.read_exact(&mut resp).await?;
-       handle_response(&resp, proof)
+       handle_response(&resp, proof, rrsig_key_names)
 }
 
 macro_rules! build_proof_impl {
        ($stream: ident, $send_query: ident, $read_response: ident $(, $async_ok: tt)?) => { {
-               let mut res = Vec::new();
-               let mut reached_root = false;
-               for i in 0..10 {
-                       let rrsig_opt = $read_response(&mut $stream, &mut res)
+               // We require the initial query to have already gone out, and assume our resolver will
+               // return any CNAMEs all the way to the final record in the response. From there, we just
+               // have to take any RRSIGs in the response and walk them up to the root. We do so
+               // iteratively, sending DNSKEY and DS lookups after every response, deduplicating requests
+               // using `dnskeys_requested`.
+               let mut res = Vec::new(); // The actual proof stream
+               let mut min_ttl = u32::MAX; // Min TTL of any answer record
+               const MAX_REQUESTS: usize = 20;
+               let mut rrsig_key_names = Vec::with_capacity(4); // Last response's RRSIG key_names
+               let mut dnskeys_requested = Vec::with_capacity(MAX_REQUESTS);
+               let mut pending_queries = 1;
+               let mut queries_made = 1;
+               while pending_queries != 0 && queries_made <= MAX_REQUESTS {
+                       let response_min_ttl = $read_response(&mut $stream, &mut res, &mut rrsig_key_names)
                                $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
-                       if let Some(rrsig) = rrsig_opt {
-                               if rrsig.name.as_str() == "." {
-                                       reached_root = true;
-                               } else {
-                                       if i != 0 && rrsig.name == rrsig.key_name {
-                                               $send_query(&mut $stream, &rrsig.key_name, DS::TYPE)
-                                       } else {
-                                               $send_query(&mut $stream, &rrsig.key_name, DnsKey::TYPE)
-                                       }$(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
+                       pending_queries -= 1;
+                       min_ttl = cmp::min(min_ttl, response_min_ttl);
+                       rrsig_key_names.sort_unstable();
+                       rrsig_key_names.dedup();
+                       for key_name in rrsig_key_names.drain(..) {
+                               if !dnskeys_requested.contains(&key_name) {
+                                       $send_query(&mut $stream, &key_name, DnsKey::TYPE)
+                                               $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
+                                       pending_queries += 1;
+                                       queries_made += 1;
+                                       dnskeys_requested.push(key_name.clone());
+
+                                       if key_name.as_str() != "." {
+                                               $send_query(&mut $stream, &key_name, DS::TYPE)
+                                                       $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
+                                               pending_queries += 1;
+                                               queries_made += 1;
+                                       }
                                }
                        }
-                       if reached_root { break; }
                }
 
-               if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
-               else { Ok(res) }
+               if queries_made > MAX_REQUESTS {
+                       Err(Error::new(ErrorKind::Other, "Too many requests required"))
+               } else {
+                       Ok((res, min_ttl))
+               }
        } }
 }
 
-fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<Vec<u8>, Error> {
+fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
        let mut stream = TcpStream::connect(resolver)?;
        send_query(&mut stream, domain, ty)?;
        build_proof_impl!(stream, send_query, read_response)
 }
 
 #[cfg(feature = "tokio")]
-async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<Vec<u8>, Error> {
+async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
        let mut stream = TokioTcpStream::connect(resolver).await?;
        send_query_async(&mut stream, domain, ty).await?;
        build_proof_impl!(stream, send_query_async, read_response_async, { Ok::<(), Error>(()) })
 }
 
-/// Builds a DNSSEC proof for an A record by querying a recursive resolver
-pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
+/// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as
+/// well as the TTL for the proof provided by the recursive resolver.
+pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof(resolver, domain, A::TYPE)
 }
 
-/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver
-pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
+/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof
+/// as well as the TTL for the proof provided by the recursive resolver.
+pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof(resolver, domain, AAAA::TYPE)
 }
 
-/// Builds a DNSSEC proof for a TXT record by querying a recursive resolver
-pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
+/// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof
+/// as well as the TTL for the proof provided by the recursive resolver.
+pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof(resolver, domain, Txt::TYPE)
 }
 
-/// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver
-pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
+/// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof
+/// as well as the TTL for the proof provided by the recursive resolver.
+pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof(resolver, domain, TLSA::TYPE)
 }
 
 
-/// Builds a DNSSEC proof for an A record by querying a recursive resolver
+/// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as
+/// well as the TTL for the proof provided by the recursive resolver.
 #[cfg(feature = "tokio")]
-pub async fn build_a_proof_async(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
+pub async fn build_a_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof_async(resolver, domain, A::TYPE).await
 }
 
-/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver
+/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof
+/// as well as the TTL for the proof provided by the recursive resolver.
 #[cfg(feature = "tokio")]
-pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
+pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof_async(resolver, domain, AAAA::TYPE).await
 }
 
-/// Builds a DNSSEC proof for a TXT record by querying a recursive resolver
+/// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof
+/// as well as the TTL for the proof provided by the recursive resolver.
 #[cfg(feature = "tokio")]
-pub async fn build_txt_proof_async(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
+pub async fn build_txt_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof_async(resolver, domain, Txt::TYPE).await
 }
 
-/// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver
+/// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof
+/// as well as the TTL for the proof provided by the recursive resolver.
 #[cfg(feature = "tokio")]
-pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<Vec<u8>, Error> {
+pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof_async(resolver, domain, TLSA::TYPE).await
 }
 
@@ -205,7 +236,7 @@ mod tests {
        fn test_cloudflare_txt_query() {
                let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
                let query_name = "cloudflare.com.".try_into().unwrap();
-               let proof = build_txt_proof(sockaddr, &query_name).unwrap();
+               let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
 
                let mut rrs = parse_rr_stream(&proof).unwrap();
                rrs.shuffle(&mut rand::rngs::OsRng);
@@ -221,7 +252,7 @@ mod tests {
        fn test_sha1_query() {
                let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
                let query_name = "benthecarman.com.".try_into().unwrap();
-               let proof = build_a_proof(sockaddr, &query_name).unwrap();
+               let (proof, _) = build_a_proof(sockaddr, &query_name).unwrap();
 
                let mut rrs = parse_rr_stream(&proof).unwrap();
                rrs.shuffle(&mut rand::rngs::OsRng);
@@ -237,7 +268,7 @@ mod tests {
        fn test_txt_query() {
                let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
                let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
-               let proof = build_txt_proof(sockaddr, &query_name).unwrap();
+               let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
 
                let mut rrs = parse_rr_stream(&proof).unwrap();
                rrs.shuffle(&mut rand::rngs::OsRng);
@@ -249,6 +280,31 @@ mod tests {
                assert!(verified_rrs.expires > now);
        }
 
+       #[test]
+       fn test_cname_query() {
+               for resolver in ["1.1.1.1:53", "8.8.8.8:53", "9.9.9.9:53"] {
+                       let sockaddr = resolver.to_socket_addrs().unwrap().next().unwrap();
+                       let query_name = "cname_test.matcorallo.com.".try_into().unwrap();
+                       let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
+
+                       let mut rrs = parse_rr_stream(&proof).unwrap();
+                       rrs.shuffle(&mut rand::rngs::OsRng);
+                       let verified_rrs = verify_rr_stream(&rrs).unwrap();
+                       assert_eq!(verified_rrs.verified_rrs.len(), 2);
+
+                       let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
+                       assert!(verified_rrs.valid_from < now);
+                       assert!(verified_rrs.expires > now);
+
+                       let resolved_rrs = verified_rrs.resolve_name(&query_name);
+                       assert_eq!(resolved_rrs.len(), 1);
+                       if let RR::Txt(txt) = &resolved_rrs[0] {
+                               assert_eq!(txt.name.as_str(), "txt_test.matcorallo.com.");
+                               assert_eq!(txt.data, b"dnssec_prover_test");
+                       } else { panic!(); }
+               }
+       }
+
        #[cfg(feature = "tokio")]
        use tokio_crate as tokio;
 
@@ -257,7 +313,7 @@ mod tests {
        async fn test_txt_query_async() {
                let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
                let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
-               let proof = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
+               let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
 
                let mut rrs = parse_rr_stream(&proof).unwrap();
                rrs.shuffle(&mut rand::rngs::OsRng);
@@ -268,4 +324,30 @@ mod tests {
                assert!(verified_rrs.valid_from < now);
                assert!(verified_rrs.expires > now);
        }
+
+       #[cfg(feature = "tokio")]
+       #[tokio::test]
+       async fn test_cross_domain_cname_query_async() {
+               for resolver in ["1.1.1.1:53", "8.8.8.8:53", "9.9.9.9:53"] {
+                       let sockaddr = resolver.to_socket_addrs().unwrap().next().unwrap();
+                       let query_name = "wildcard.x_domain_cname_wild.matcorallo.com.".try_into().unwrap();
+                       let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
+
+                       let mut rrs = parse_rr_stream(&proof).unwrap();
+                       rrs.shuffle(&mut rand::rngs::OsRng);
+                       let verified_rrs = verify_rr_stream(&rrs).unwrap();
+                       assert_eq!(verified_rrs.verified_rrs.len(), 2);
+
+                       let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
+                       assert!(verified_rrs.valid_from < now);
+                       assert!(verified_rrs.expires > now);
+
+                       let resolved_rrs = verified_rrs.resolve_name(&query_name);
+                       assert_eq!(resolved_rrs.len(), 1);
+                       if let RR::Txt(txt) = &resolved_rrs[0] {
+                               assert_eq!(txt.name.as_str(), "matt.user._bitcoin-payment.mattcorallo.com.");
+                               assert!(txt.data.starts_with(b"bitcoin:"));
+                       } else { panic!(); }
+               }
+       }
 }