Clean up and better comment math somewhat further
[dnssec-prover] / src / query.rs
index c2ea75e4de7a9fe96b408c7f5bf1ea90ffbd6a8e..bb724020ea98f792f1631ca20d89616415d9aa75 100644 (file)
@@ -21,9 +21,9 @@ use crate::ser::*;
 // In testing use a rather small buffer to ensure we hit the allocation paths sometimes. In
 // production, we should generally never actually need to go to heap as DNS messages are rarely
 // larger than a KiB or two.
-#[cfg(test)]
+#[cfg(any(test, fuzzing))]
 const STACK_BUF_LIMIT: u16 = 32;
-#[cfg(not(test))]
+#[cfg(not(any(test, fuzzing)))]
 const STACK_BUF_LIMIT: u16 = 2048;
 
 /// A buffer for storing queries and responses.
@@ -34,7 +34,8 @@ pub struct QueryBuf {
        len: u16,
 }
 impl QueryBuf {
-       fn new_zeroed(len: u16) -> Self {
+       /// Generates a new buffer of the given length, consisting of all zeros.
+       pub fn new_zeroed(len: u16) -> Self {
                let heap_buf = if len > STACK_BUF_LIMIT { vec![0; len as usize] } else { Vec::new() };
                Self {
                        buf: [0; STACK_BUF_LIMIT as usize],
@@ -42,7 +43,11 @@ impl QueryBuf {
                        len
                }
        }
-       pub(crate) fn extend_from_slice(&mut self, sl: &[u8]) {
+       /// Extends the size of this buffer by appending the given slice.
+       ///
+       /// If the total length of this buffer exceeds [`u16::MAX`] after appending, the buffer's state
+       /// is undefined, however pushing data beyond [`u16::MAX`] will not panic.
+       pub fn extend_from_slice(&mut self, sl: &[u8]) {
                let new_len = self.len.saturating_add(sl.len() as u16);
                let was_heap = self.len > STACK_BUF_LIMIT;
                let is_heap = new_len > STACK_BUF_LIMIT;
@@ -59,6 +64,14 @@ impl QueryBuf {
                target.copy_from_slice(sl);
                self.len = new_len;
        }
+       /// Converts this query into its bytes on the heap
+       pub fn into_vec(self) -> Vec<u8> {
+               if self.len > STACK_BUF_LIMIT {
+                       self.heap_buf
+               } else {
+                       self.buf[..self.len as usize].to_vec()
+               }
+       }
 }
 impl ops::Deref for QueryBuf {
        type Target = [u8];
@@ -80,14 +93,12 @@ impl ops::DerefMut for QueryBuf {
        }
 }
 
-// We don't care about transaction IDs as we're only going to accept signed data. Thus, we use
-// this constant instead of a random value.
-const TXID: u16 = 0x4242;
+// We don't care about transaction IDs as we're only going to accept signed data.
+// Further, if we're querying over DoH, the RFC says we SHOULD use a transaction ID of 0 here.
+const TXID: u16 = 0;
 
 fn build_query(domain: &Name, ty: u16) -> QueryBuf {
        let mut query = QueryBuf::new_zeroed(0);
-       let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(domain) + 11;
-       query.extend_from_slice(&query_msg_len.to_be_bytes());
        query.extend_from_slice(&TXID.to_be_bytes());
        query.extend_from_slice(&[0x01, 0x20]); // Flags: Recursive, Authenticated Data
        query.extend_from_slice(&[0, 1, 0, 0, 0, 0, 0, 1]); // One question, One additional
@@ -127,7 +138,7 @@ fn handle_response(resp: &[u8], proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<N
        if questions != 1 { return Err(()); }
        let answers = read_u16(&mut read)?;
        if answers == 0 { return Err(()); }
-       let _authorities = read_u16(&mut read)?;
+       let authorities = read_u16(&mut read)?;
        let _additional = read_u16(&mut read)?;
 
        for _ in 0..questions {
@@ -136,7 +147,7 @@ fn handle_response(resp: &[u8], proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<N
                read_u16(&mut read)?; // class
        }
 
-       // Only read the answers (skip authorities and additional) as that's all we care about.
+       // Only read the answers and NSEC records in authorities, skipping additional entirely.
        let mut min_ttl = u32::MAX;
        for _ in 0..answers {
                let (rr, ttl) = parse_wire_packet_rr(&mut read, &resp)?;
@@ -144,9 +155,43 @@ fn handle_response(resp: &[u8], proof: &mut Vec<u8>, rrsig_key_names: &mut Vec<N
                min_ttl = cmp::min(min_ttl, ttl);
                if let RR::RRSig(rrsig) = rr { rrsig_key_names.push(rrsig.key_name); }
        }
+
+       for _ in 0..authorities {
+               // Only include records from the authority section if they are NSEC/3 (or signatures
+               // thereover). We don't care about NS records here.
+               let (rr, ttl) = parse_wire_packet_rr(&mut read, &resp)?;
+               match &rr {
+                       RR::RRSig(rrsig) => {
+                               if rrsig.ty != NSec::TYPE && rrsig.ty != NSec3::TYPE {
+                                       continue;
+                               }
+                       },
+                       RR::NSec(_)|RR::NSec3(_) => {},
+                       _ => continue,
+               }
+               write_rr(&rr, ttl, proof);
+               min_ttl = cmp::min(min_ttl, ttl);
+               if let RR::RRSig(rrsig) = rr { rrsig_key_names.push(rrsig.key_name); }
+       }
+
        Ok(min_ttl)
 }
 
+#[cfg(fuzzing)]
+/// Read a stream of responses and handle them it as if they came from a server, for fuzzing.
+pub fn fuzz_proof_builder(mut response_stream: &[u8]) {
+       let (mut builder, _) = ProofBuilder::new(&"example.com.".try_into().unwrap(), Txt::TYPE);
+       while builder.awaiting_responses() {
+               let len = if let Ok(len) = read_u16(&mut response_stream) { len } else { return };
+               let mut buf = QueryBuf::new_zeroed(len);
+               if response_stream.len() < len as usize { return; }
+               buf.copy_from_slice(&response_stream[..len as usize]);
+               response_stream = &response_stream[len as usize..];
+               let _ = builder.process_response(&buf);
+       }
+       let _ = builder.finish_proof();
+}
+
 const MAX_REQUESTS: usize = 10;
 /// A simple state machine which will generate a series of queries and process the responses until
 /// it has built a DNSSEC proof.
@@ -157,6 +202,12 @@ const MAX_REQUESTS: usize = 10;
 /// [`ProofBuilder::process_response`] should be called, and each fresh query returned should be
 /// sent to the resolver. Once [`ProofBuilder::awaiting_responses`] returns false,
 /// [`ProofBuilder::finish_proof`] should be called to fetch the resulting proof.
+///
+/// To build a DNSSEC proof using a DoH server, take each [`QueryBuf`], encode it as base64url, and
+/// make a query to `https://doh-server/endpoint?dns=base64url_encoded_query` with an `Accept`
+/// header of `application/dns-message`. Each response, in raw binary, can be fed directly into
+/// [`ProofBuilder::process_response`].
+#[derive(Clone)]
 pub struct ProofBuilder {
        proof: Vec<u8>,
        min_ttl: u32,
@@ -243,12 +294,14 @@ impl ProofBuilder {
 
 #[cfg(feature = "std")]
 fn send_query(stream: &mut TcpStream, query: &[u8]) -> Result<(), Error> {
+       stream.write_all(&(query.len() as u16).to_be_bytes())?;
        stream.write_all(&query)?;
        Ok(())
 }
 
 #[cfg(feature = "tokio")]
 async fn send_query_async(stream: &mut TokioTcpStream, query: &[u8]) -> Result<(), Error> {
+       stream.write_all(&(query.len() as u16).to_be_bytes()).await?;
        stream.write_all(&query).await?;
        Ok(())
 }
@@ -453,7 +506,7 @@ mod tests {
        fn test_cname_query() {
                for resolver in ["1.1.1.1:53", "8.8.8.8:53", "9.9.9.9:53"] {
                        let sockaddr = resolver.to_socket_addrs().unwrap().next().unwrap();
-                       let query_name = "cname_test.matcorallo.com.".try_into().unwrap();
+                       let query_name = "cname_test.dnssec_proof_tests.bitcoin.ninja.".try_into().unwrap();
                        let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap();
 
                        let mut rrs = parse_rr_stream(&proof).unwrap();
@@ -468,7 +521,7 @@ mod tests {
                        let resolved_rrs = verified_rrs.resolve_name(&query_name);
                        assert_eq!(resolved_rrs.len(), 1);
                        if let RR::Txt(txt) = &resolved_rrs[0] {
-                               assert_eq!(txt.name.as_str(), "txt_test.matcorallo.com.");
+                               assert_eq!(txt.name.as_str(), "txt_test.dnssec_proof_tests.bitcoin.ninja.");
                                assert_eq!(txt.data, b"dnssec_prover_test");
                        } else { panic!(); }
                }
@@ -499,7 +552,7 @@ mod tests {
        async fn test_cross_domain_cname_query_async() {
                for resolver in ["1.1.1.1:53", "8.8.8.8:53", "9.9.9.9:53"] {
                        let sockaddr = resolver.to_socket_addrs().unwrap().next().unwrap();
-                       let query_name = "wildcard.x_domain_cname_wild.matcorallo.com.".try_into().unwrap();
+                       let query_name = "wildcard.x_domain_cname_wild.dnssec_proof_tests.bitcoin.ninja.".try_into().unwrap();
                        let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
 
                        let mut rrs = parse_rr_stream(&proof).unwrap();
@@ -519,4 +572,31 @@ mod tests {
                        } else { panic!(); }
                }
        }
+
+       #[cfg(feature = "tokio")]
+       #[tokio::test]
+       async fn test_dname_wildcard_query_async() {
+               for resolver in ["1.1.1.1:53", "8.8.8.8:53", "9.9.9.9:53"] {
+                       let sockaddr = resolver.to_socket_addrs().unwrap().next().unwrap();
+                       let query_name = "wildcard_a.wildcard_b.dname_test.dnssec_proof_tests.bitcoin.ninja.".try_into().unwrap();
+                       let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap();
+
+                       let mut rrs = parse_rr_stream(&proof).unwrap();
+                       rrs.shuffle(&mut rand::rngs::OsRng);
+                       let verified_rrs = verify_rr_stream(&rrs).unwrap();
+                       assert_eq!(verified_rrs.verified_rrs.len(), 3);
+
+                       let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
+                       assert!(verified_rrs.valid_from < now);
+                       assert!(verified_rrs.expires > now);
+
+                       let resolved_rrs = verified_rrs.resolve_name(&query_name);
+                       assert_eq!(resolved_rrs.len(), 1);
+                       if let RR::Txt(txt) = &resolved_rrs[0] {
+                               assert_eq!(txt.name.as_str(), "cname.wildcard_test.dnssec_proof_tests.bitcoin.ninja.");
+                               assert_eq!(txt.data, b"wildcard_test");
+                       } else { panic!(); }
+               }
+       }
+
 }