From 8267bce8419938335557c255642cf8ca5990936e Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 5 Feb 2024 10:02:45 +0000 Subject: [PATCH] Enable querying async using tokio --- Cargo.toml | 3 ++ src/query.rs | 139 +++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 121 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 86ddb6b..64da047 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,12 +12,15 @@ rust-version = "1.60.0" [features] std = [] +tokio = ["tokio_crate/net", "tokio_crate/io-util", "std"] [dependencies] ring = { version = "0.17", default-features = false, features = ["alloc"] } hex_lit = { version = "0.1", default-features = false, features = ["rust_v_1_46"] } +tokio_crate = { package = "tokio", version = "1.0", default-features = false, optional = true } [dev-dependencies] hex-conservative = { version = "0.1", default-features = false, features = ["alloc"] } base64 = "0.21" rand = { version = "0.8", default-features = false, features = ["getrandom"] } +tokio_crate = { package = "tokio", version = "1.0", features = ["rt", "macros"] } diff --git a/src/query.rs b/src/query.rs index 65da2f0..675d82b 100644 --- a/src/query.rs +++ b/src/query.rs @@ -4,6 +4,12 @@ use std::net::{SocketAddr, TcpStream}; use std::io::{Read, Write, Error, ErrorKind}; +#[cfg(feature = "tokio")] +use tokio_crate::net::TcpStream as TokioTcpStream; +#[cfg(feature = "tokio")] +use tokio_crate::io::{AsyncReadExt, AsyncWriteExt}; + + use crate::write_rr; use crate::rr::*; use crate::ser::*; @@ -16,7 +22,8 @@ fn emap(v: Result) -> Result { v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response")) } -fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> { +fn build_query(domain: Name, ty: u16) -> Vec { + // TODO: Move to not allocating for the query let mut query = Vec::with_capacity(1024); let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(&domain) + 11; query.extend_from_slice(&query_msg_len.to_be_bytes()); @@ -31,17 +38,24 @@ fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error query.extend_from_slice(&[0, 0]); // EDNS version 0 query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs query.extend_from_slice(&0u16.to_be_bytes()); // No additional data + query +} + +fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> { + let query = build_query(domain, ty); stream.write_all(&query)?; Ok(()) } -fn read_response(stream: &mut TcpStream) -> Result, Error> { - let mut len = [0; 2]; - stream.read_exact(&mut len)?; - let mut resp = vec![0; u16::from_be_bytes(len) as usize]; - stream.read_exact(&mut resp)?; +#[cfg(feature = "tokio")] +async fn send_query_async(stream: &mut TokioTcpStream, domain: Name, ty: u16) -> Result<(), Error> { + let query = build_query(domain, ty); + stream.write_all(&query).await?; + Ok(()) +} - let mut read: &[u8] = &resp; +fn handle_response(resp: &[u8], proof: &mut Vec) -> Result, Error> { + let mut read: &[u8] = resp; if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); } // 2 byte transaction ID let flags = emap(read_u16(&mut read))?; @@ -61,17 +75,36 @@ fn read_response(stream: &mut TcpStream) -> Result, Error> { let _additional = emap(read_u16(&mut read))?; for _ in 0..questions { - emap(read_name(&mut read))?; + emap(read_wire_packet_name(&mut read, resp))?; emap(read_u16(&mut read))?; // type emap(read_u16(&mut read))?; // class } // Only read the answers (skip authorities and additional) as that's all we care about. - let mut res = Vec::new(); + let mut rrsig_opt = None; for _ in 0..answers { - res.push(emap(parse_wire_packet_rr(&mut read, &resp))?); + let rr = emap(parse_wire_packet_rr(&mut read, &resp))?; + write_rr(&rr, 0, proof); + if let RR::RRSig(rrsig) = rr { rrsig_opt = Some(rrsig); } } - Ok(res) + Ok(rrsig_opt) +} + +fn read_response(stream: &mut TcpStream, proof: &mut Vec) -> Result, Error> { + let mut len = [0; 2]; + stream.read_exact(&mut len)?; + let mut resp = vec![0; u16::from_be_bytes(len) as usize]; + stream.read_exact(&mut resp)?; + handle_response(&resp, proof) +} + +#[cfg(feature = "tokio")] +async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec) -> Result, Error> { + let mut len = [0; 2]; + stream.read_exact(&mut len).await?; + let mut resp = vec![0; u16::from_be_bytes(len) as usize]; + stream.read_exact(&mut resp).await?; + handle_response(&resp, proof) } fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result, Error> { @@ -80,18 +113,41 @@ fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result, E send_query(&mut stream, domain, ty)?; let mut reached_root = false; for _ in 0..10 { - let resp = read_response(&mut stream)?; - for rr in resp { - write_rr(&rr, 0, &mut res); - if rr.name().as_str() == "." { + let rrsig_opt = read_response(&mut stream, &mut res)?; + if let Some(rrsig) = rrsig_opt { + if rrsig.name.as_str() == "." { reached_root = true; } else { - if let RR::RRSig(rrsig) = rr { - if rrsig.name == rrsig.key_name { - send_query(&mut stream, rrsig.key_name, DS::TYPE)?; - } else { - send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?; - } + if rrsig.name == rrsig.key_name { + send_query(&mut stream, rrsig.key_name, DS::TYPE)?; + } else { + send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?; + } + } + } + if reached_root { break; } + } + + if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) } + else { Ok(res) } +} + +#[cfg(feature = "tokio")] +async fn build_proof_async(resolver: SocketAddr, domain: Name, ty: u16) -> Result, Error> { + let mut stream = TokioTcpStream::connect(resolver).await?; + let mut res = Vec::new(); + send_query_async(&mut stream, domain, ty).await?; + let mut reached_root = false; + for _ in 0..10 { + let rrsig_opt = read_response_async(&mut stream, &mut res).await?; + if let Some(rrsig) = rrsig_opt { + if rrsig.name.as_str() == "." { + reached_root = true; + } else { + if rrsig.name == rrsig.key_name { + send_query_async(&mut stream, rrsig.key_name, DS::TYPE).await?; + } else { + send_query_async(&mut stream, rrsig.key_name, DnsKey::TYPE).await?; } } } @@ -122,6 +178,31 @@ pub fn build_tlsa_proof(resolver: SocketAddr, domain: Name) -> Result, E build_proof(resolver, domain, TLSA::TYPE) } + +/// Builds a DNSSEC proof for an A record by querying a recursive resolver +#[cfg(feature = "tokio")] +pub async fn build_a_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { + build_proof_async(resolver, domain, A::TYPE).await +} + +/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver +#[cfg(feature = "tokio")] +pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { + build_proof_async(resolver, domain, AAAA::TYPE).await +} + +/// Builds a DNSSEC proof for a TXT record by querying a recursive resolver +#[cfg(feature = "tokio")] +pub async fn build_txt_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { + build_proof_async(resolver, domain, Txt::TYPE).await +} + +/// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver +#[cfg(feature = "tokio")] +pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { + build_proof_async(resolver, domain, TLSA::TYPE).await +} + #[cfg(test)] mod tests { use super::*; @@ -142,4 +223,20 @@ mod tests { let verified_rrs = verify_rr_stream(&rrs).unwrap(); assert_eq!(verified_rrs.len(), 1); } + + #[cfg(feature = "tokio")] + use tokio_crate as tokio; + + #[cfg(feature = "tokio")] + #[tokio::test] + async fn test_txt_query_async() { + let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); + let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap(); + let proof = build_txt_proof_async(sockaddr, query_name).await.unwrap(); + + let mut rrs = parse_rr_stream(&proof).unwrap(); + rrs.shuffle(&mut rand::rngs::OsRng); + let verified_rrs = verify_rr_stream(&rrs).unwrap(); + assert_eq!(verified_rrs.len(), 1); + } } -- 2.39.5