Correct proof building for records at a zone root
authorMatt Corallo <git@bluematt.me>
Tue, 6 Feb 2024 05:45:51 +0000 (05:45 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 6 Feb 2024 05:45:51 +0000 (05:45 +0000)
src/query.rs

index 8ce551955cf0f03497c2918f221c99b50611b383..967c840b83fae6985f97c4a65f1456a0743847d8 100644 (file)
@@ -107,55 +107,43 @@ async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec<u8>) -
        handle_response(&resp, proof)
 }
 
-fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
-       let mut stream = TcpStream::connect(resolver)?;
-       let mut res = Vec::new();
-       send_query(&mut stream, domain, ty)?;
-       let mut reached_root = false;
-       for _ in 0..10 {
-               let rrsig_opt = read_response(&mut stream, &mut res)?;
-               if let Some(rrsig) = rrsig_opt {
-                       if rrsig.name.as_str() == "." {
-                               reached_root = true;
-                       } else {
-                               if rrsig.name == rrsig.key_name {
-                                       send_query(&mut stream, rrsig.key_name, DS::TYPE)?;
+macro_rules! build_proof_impl {
+       ($stream: ident, $send_query: ident, $read_response: ident $(, $async_ok: tt)?) => { {
+               let mut res = Vec::new();
+               let mut reached_root = false;
+               for i in 0..10 {
+                       let rrsig_opt = $read_response(&mut $stream, &mut res)
+                               $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
+                       if let Some(rrsig) = rrsig_opt {
+                               if rrsig.name.as_str() == "." {
+                                       reached_root = true;
                                } else {
-                                       send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?;
+                                       if i != 0 && rrsig.name == rrsig.key_name {
+                                               $send_query(&mut $stream, rrsig.key_name, DS::TYPE)
+                                       } else {
+                                               $send_query(&mut $stream, rrsig.key_name, DnsKey::TYPE)
+                                       }$(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
                                }
                        }
+                       if reached_root { break; }
                }
-               if reached_root { break; }
-       }
 
-       if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
-       else { Ok(res) }
+               if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
+               else { Ok(res) }
+       } }
+}
+
+fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
+       let mut stream = TcpStream::connect(resolver)?;
+       send_query(&mut stream, domain, ty)?;
+       build_proof_impl!(stream, send_query, read_response)
 }
 
 #[cfg(feature = "tokio")]
 async fn build_proof_async(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
        let mut stream = TokioTcpStream::connect(resolver).await?;
-       let mut res = Vec::new();
        send_query_async(&mut stream, domain, ty).await?;
-       let mut reached_root = false;
-       for _ in 0..10 {
-               let rrsig_opt = read_response_async(&mut stream, &mut res).await?;
-               if let Some(rrsig) = rrsig_opt {
-                       if rrsig.name.as_str() == "." {
-                               reached_root = true;
-                       } else {
-                               if rrsig.name == rrsig.key_name {
-                                       send_query_async(&mut stream, rrsig.key_name, DS::TYPE).await?;
-                               } else {
-                                       send_query_async(&mut stream, rrsig.key_name, DnsKey::TYPE).await?;
-                               }
-                       }
-               }
-               if reached_root { break; }
-       }
-
-       if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
-       else { Ok(res) }
+       build_proof_impl!(stream, send_query_async, read_response_async, { Ok::<(), Error>(()) })
 }
 
 /// Builds a DNSSEC proof for an A record by querying a recursive resolver