Enable querying async using tokio
authorMatt Corallo <git@bluematt.me>
Mon, 5 Feb 2024 10:02:45 +0000 (10:02 +0000)
committerMatt Corallo <git@bluematt.me>
Mon, 5 Feb 2024 10:04:10 +0000 (10:04 +0000)
Cargo.toml
src/query.rs

index 86ddb6b0f88a89b422c2870a2a4d143fd4f24731..64da0475f701892b8163bda52c195e18f3b941ee 100644 (file)
@@ -12,12 +12,15 @@ rust-version = "1.60.0"
 
 [features]
 std = []
+tokio = ["tokio_crate/net", "tokio_crate/io-util", "std"]
 
 [dependencies]
 ring = { version = "0.17", default-features = false, features = ["alloc"] }
 hex_lit = { version = "0.1", default-features = false, features = ["rust_v_1_46"] }
+tokio_crate = { package = "tokio", version = "1.0", default-features = false, optional = true }
 
 [dev-dependencies]
 hex-conservative = { version = "0.1", default-features = false, features = ["alloc"] }
 base64 = "0.21"
 rand = { version = "0.8", default-features = false, features = ["getrandom"] }
+tokio_crate = { package = "tokio", version = "1.0", features = ["rt", "macros"] }
index 65da2f0f25e19d5bc6ff8370715c48c47aa5420a..675d82b0badc6d708af4ac653e4366d3626a53ed 100644 (file)
@@ -4,6 +4,12 @@
 use std::net::{SocketAddr, TcpStream};
 use std::io::{Read, Write, Error, ErrorKind};
 
+#[cfg(feature = "tokio")]
+use tokio_crate::net::TcpStream as TokioTcpStream;
+#[cfg(feature = "tokio")]
+use tokio_crate::io::{AsyncReadExt, AsyncWriteExt};
+
+
 use crate::write_rr;
 use crate::rr::*;
 use crate::ser::*;
@@ -16,7 +22,8 @@ fn emap<V>(v: Result<V, ()>) -> Result<V, Error> {
        v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response"))
 }
 
-fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> {
+fn build_query(domain: Name, ty: u16) -> Vec<u8> {
+       // TODO: Move to not allocating for the query
        let mut query = Vec::with_capacity(1024);
        let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(&domain) + 11;
        query.extend_from_slice(&query_msg_len.to_be_bytes());
@@ -31,17 +38,24 @@ fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error
        query.extend_from_slice(&[0, 0]); // EDNS version 0
        query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs
        query.extend_from_slice(&0u16.to_be_bytes()); // No additional data
+       query
+}
+
+fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> {
+       let query = build_query(domain, ty);
        stream.write_all(&query)?;
        Ok(())
 }
 
-fn read_response(stream: &mut TcpStream) -> Result<Vec<RR>, Error> {
-       let mut len = [0; 2];
-       stream.read_exact(&mut len)?;
-       let mut resp = vec![0; u16::from_be_bytes(len) as usize];
-       stream.read_exact(&mut resp)?;
+#[cfg(feature = "tokio")]
+async fn send_query_async(stream: &mut TokioTcpStream, domain: Name, ty: u16) -> Result<(), Error> {
+       let query = build_query(domain, ty);
+       stream.write_all(&query).await?;
+       Ok(())
+}
 
-       let mut read: &[u8] = &resp;
+fn handle_response(resp: &[u8], proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
+       let mut read: &[u8] = resp;
        if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); }
        // 2 byte transaction ID
        let flags = emap(read_u16(&mut read))?;
@@ -61,17 +75,36 @@ fn read_response(stream: &mut TcpStream) -> Result<Vec<RR>, Error> {
        let _additional = emap(read_u16(&mut read))?;
 
        for _ in 0..questions {
-               emap(read_name(&mut read))?;
+               emap(read_wire_packet_name(&mut read, resp))?;
                emap(read_u16(&mut read))?; // type
                emap(read_u16(&mut read))?; // class
        }
 
        // Only read the answers (skip authorities and additional) as that's all we care about.
-       let mut res = Vec::new();
+       let mut rrsig_opt = None;
        for _ in 0..answers {
-               res.push(emap(parse_wire_packet_rr(&mut read, &resp))?);
+               let rr = emap(parse_wire_packet_rr(&mut read, &resp))?;
+               write_rr(&rr, 0, proof);
+               if let RR::RRSig(rrsig) = rr { rrsig_opt = Some(rrsig); }
        }
-       Ok(res)
+       Ok(rrsig_opt)
+}
+
+fn read_response(stream: &mut TcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
+       let mut len = [0; 2];
+       stream.read_exact(&mut len)?;
+       let mut resp = vec![0; u16::from_be_bytes(len) as usize];
+       stream.read_exact(&mut resp)?;
+       handle_response(&resp, proof)
+}
+
+#[cfg(feature = "tokio")]
+async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
+       let mut len = [0; 2];
+       stream.read_exact(&mut len).await?;
+       let mut resp = vec![0; u16::from_be_bytes(len) as usize];
+       stream.read_exact(&mut resp).await?;
+       handle_response(&resp, proof)
 }
 
 fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
@@ -80,18 +113,41 @@ fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, E
        send_query(&mut stream, domain, ty)?;
        let mut reached_root = false;
        for _ in 0..10 {
-               let resp = read_response(&mut stream)?;
-               for rr in resp {
-                       write_rr(&rr, 0, &mut res);
-                       if rr.name().as_str() == "." {
+               let rrsig_opt = read_response(&mut stream, &mut res)?;
+               if let Some(rrsig) = rrsig_opt {
+                       if rrsig.name.as_str() == "." {
                                reached_root = true;
                        } else {
-                               if let RR::RRSig(rrsig) = rr {
-                                       if rrsig.name == rrsig.key_name {
-                                               send_query(&mut stream, rrsig.key_name, DS::TYPE)?;
-                                       } else {
-                                               send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?;
-                                       }
+                               if rrsig.name == rrsig.key_name {
+                                       send_query(&mut stream, rrsig.key_name, DS::TYPE)?;
+                               } else {
+                                       send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?;
+                               }
+                       }
+               }
+               if reached_root { break; }
+       }
+
+       if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
+       else { Ok(res) }
+}
+
+#[cfg(feature = "tokio")]
+async fn build_proof_async(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
+       let mut stream = TokioTcpStream::connect(resolver).await?;
+       let mut res = Vec::new();
+       send_query_async(&mut stream, domain, ty).await?;
+       let mut reached_root = false;
+       for _ in 0..10 {
+               let rrsig_opt = read_response_async(&mut stream, &mut res).await?;
+               if let Some(rrsig) = rrsig_opt {
+                       if rrsig.name.as_str() == "." {
+                               reached_root = true;
+                       } else {
+                               if rrsig.name == rrsig.key_name {
+                                       send_query_async(&mut stream, rrsig.key_name, DS::TYPE).await?;
+                               } else {
+                                       send_query_async(&mut stream, rrsig.key_name, DnsKey::TYPE).await?;
                                }
                        }
                }
@@ -122,6 +178,31 @@ pub fn build_tlsa_proof(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, E
        build_proof(resolver, domain, TLSA::TYPE)
 }
 
+
+/// Builds a DNSSEC proof for an A record by querying a recursive resolver
+#[cfg(feature = "tokio")]
+pub async fn build_a_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
+       build_proof_async(resolver, domain, A::TYPE).await
+}
+
+/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver
+#[cfg(feature = "tokio")]
+pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
+       build_proof_async(resolver, domain, AAAA::TYPE).await
+}
+
+/// Builds a DNSSEC proof for a TXT record by querying a recursive resolver
+#[cfg(feature = "tokio")]
+pub async fn build_txt_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
+       build_proof_async(resolver, domain, Txt::TYPE).await
+}
+
+/// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver
+#[cfg(feature = "tokio")]
+pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
+       build_proof_async(resolver, domain, TLSA::TYPE).await
+}
+
 #[cfg(test)]
 mod tests {
        use super::*;
@@ -142,4 +223,20 @@ mod tests {
                let verified_rrs = verify_rr_stream(&rrs).unwrap();
                assert_eq!(verified_rrs.len(), 1);
        }
+
+       #[cfg(feature = "tokio")]
+       use tokio_crate as tokio;
+
+       #[cfg(feature = "tokio")]
+       #[tokio::test]
+       async fn test_txt_query_async() {
+               let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
+               let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
+               let proof = build_txt_proof_async(sockaddr, query_name).await.unwrap();
+
+               let mut rrs = parse_rr_stream(&proof).unwrap();
+               rrs.shuffle(&mut rand::rngs::OsRng);
+               let verified_rrs = verify_rr_stream(&rrs).unwrap();
+               assert_eq!(verified_rrs.len(), 1);
+       }
 }