From a50b5198e38aabd23e5f22d7ddba4aee3b856e49 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Wed, 7 Feb 2024 05:23:42 +0000 Subject: [PATCH] Accept query names by reference for downstream flexibility --- src/http.rs | 8 ++++---- src/query.rs | 40 ++++++++++++++++++++-------------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/http.rs b/src/http.rs index 944bea1..f344372 100644 --- a/src/http.rs +++ b/src/http.rs @@ -110,10 +110,10 @@ mod imp { break 'ret_err; }; let proof_res = match t.to_ascii_uppercase().as_str() { - "TXT" => build_txt_proof_async(resolver_sockaddr, query_name).await, - "TLSA" => build_tlsa_proof_async(resolver_sockaddr, query_name).await, - "A" => build_a_proof_async(resolver_sockaddr, query_name).await, - "AAAA" => build_aaaa_proof_async(resolver_sockaddr, query_name).await, + "TXT" => build_txt_proof_async(resolver_sockaddr, &query_name).await, + "TLSA" => build_tlsa_proof_async(resolver_sockaddr, &query_name).await, + "A" => build_a_proof_async(resolver_sockaddr, &query_name).await, + "AAAA" => build_aaaa_proof_async(resolver_sockaddr, &query_name).await, _ => break 'ret_err, }; let proof = if let Ok(proof) = proof_res { proof } else { diff --git a/src/query.rs b/src/query.rs index 27f2160..d81e58a 100644 --- a/src/query.rs +++ b/src/query.rs @@ -20,15 +20,15 @@ fn emap(v: Result) -> Result { v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response")) } -fn build_query(domain: Name, ty: u16) -> Vec { +fn build_query(domain: &Name, ty: u16) -> Vec { // TODO: Move to not allocating for the query let mut query = Vec::with_capacity(1024); - let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(&domain) + 11; + let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(domain) + 11; query.extend_from_slice(&query_msg_len.to_be_bytes()); query.extend_from_slice(&TXID.to_be_bytes()); query.extend_from_slice(&[0x01, 0x20]); // Flags: Recursive, Authenticated Data query.extend_from_slice(&[0, 1, 0, 0, 0, 0, 0, 1]); // One question, One additional - write_name(&mut query, &domain); + write_name(&mut query, domain); query.extend_from_slice(&ty.to_be_bytes()); query.extend_from_slice(&1u16.to_be_bytes()); // INternet class query.extend_from_slice(&[0, 0, 0x29]); // . OPT @@ -39,14 +39,14 @@ fn build_query(domain: Name, ty: u16) -> Vec { query } -fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> { +fn send_query(stream: &mut TcpStream, domain: &Name, ty: u16) -> Result<(), Error> { let query = build_query(domain, ty); stream.write_all(&query)?; Ok(()) } #[cfg(feature = "tokio")] -async fn send_query_async(stream: &mut TokioTcpStream, domain: Name, ty: u16) -> Result<(), Error> { +async fn send_query_async(stream: &mut TokioTcpStream, domain: &Name, ty: u16) -> Result<(), Error> { let query = build_query(domain, ty); stream.write_all(&query).await?; Ok(()) @@ -117,9 +117,9 @@ macro_rules! build_proof_impl { reached_root = true; } else { if i != 0 && rrsig.name == rrsig.key_name { - $send_query(&mut $stream, rrsig.key_name, DS::TYPE) + $send_query(&mut $stream, &rrsig.key_name, DS::TYPE) } else { - $send_query(&mut $stream, rrsig.key_name, DnsKey::TYPE) + $send_query(&mut $stream, &rrsig.key_name, DnsKey::TYPE) }$(.await?; $async_ok)??; // Either await?; Ok(())?, or just ? } } @@ -131,61 +131,61 @@ macro_rules! build_proof_impl { } } } -fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result, Error> { +fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result, Error> { let mut stream = TcpStream::connect(resolver)?; send_query(&mut stream, domain, ty)?; build_proof_impl!(stream, send_query, read_response) } #[cfg(feature = "tokio")] -async fn build_proof_async(resolver: SocketAddr, domain: Name, ty: u16) -> Result, Error> { +async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result, Error> { let mut stream = TokioTcpStream::connect(resolver).await?; send_query_async(&mut stream, domain, ty).await?; build_proof_impl!(stream, send_query_async, read_response_async, { Ok::<(), Error>(()) }) } /// Builds a DNSSEC proof for an A record by querying a recursive resolver -pub fn build_a_proof(resolver: SocketAddr, domain: Name) -> Result, Error> { +pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result, Error> { build_proof(resolver, domain, A::TYPE) } /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver -pub fn build_aaaa_proof(resolver: SocketAddr, domain: Name) -> Result, Error> { +pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result, Error> { build_proof(resolver, domain, AAAA::TYPE) } /// Builds a DNSSEC proof for a TXT record by querying a recursive resolver -pub fn build_txt_proof(resolver: SocketAddr, domain: Name) -> Result, Error> { +pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result, Error> { build_proof(resolver, domain, Txt::TYPE) } /// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver -pub fn build_tlsa_proof(resolver: SocketAddr, domain: Name) -> Result, Error> { +pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result, Error> { build_proof(resolver, domain, TLSA::TYPE) } /// Builds a DNSSEC proof for an A record by querying a recursive resolver #[cfg(feature = "tokio")] -pub async fn build_a_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { +pub async fn build_a_proof_async(resolver: SocketAddr, domain: &Name) -> Result, Error> { build_proof_async(resolver, domain, A::TYPE).await } /// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver #[cfg(feature = "tokio")] -pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { +pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: &Name) -> Result, Error> { build_proof_async(resolver, domain, AAAA::TYPE).await } /// Builds a DNSSEC proof for a TXT record by querying a recursive resolver #[cfg(feature = "tokio")] -pub async fn build_txt_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { +pub async fn build_txt_proof_async(resolver: SocketAddr, domain: &Name) -> Result, Error> { build_proof_async(resolver, domain, Txt::TYPE).await } /// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver #[cfg(feature = "tokio")] -pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { +pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Result, Error> { build_proof_async(resolver, domain, TLSA::TYPE).await } @@ -204,7 +204,7 @@ mod tests { fn test_cloudflare_txt_query() { let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); let query_name = "cloudflare.com.".try_into().unwrap(); - let proof = build_txt_proof(sockaddr, query_name).unwrap(); + let proof = build_txt_proof(sockaddr, &query_name).unwrap(); let mut rrs = parse_rr_stream(&proof).unwrap(); rrs.shuffle(&mut rand::rngs::OsRng); @@ -220,7 +220,7 @@ mod tests { fn test_txt_query() { let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap(); - let proof = build_txt_proof(sockaddr, query_name).unwrap(); + let proof = build_txt_proof(sockaddr, &query_name).unwrap(); let mut rrs = parse_rr_stream(&proof).unwrap(); rrs.shuffle(&mut rand::rngs::OsRng); @@ -240,7 +240,7 @@ mod tests { async fn test_txt_query_async() { let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap(); - let proof = build_txt_proof_async(sockaddr, query_name).await.unwrap(); + let proof = build_txt_proof_async(sockaddr, &query_name).await.unwrap(); let mut rrs = parse_rr_stream(&proof).unwrap(); rrs.shuffle(&mut rand::rngs::OsRng); -- 2.30.2