X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=src%2Fquery.rs;h=675d82b0badc6d708af4ac653e4366d3626a53ed;hb=8267bce8419938335557c255642cf8ca5990936e;hp=65da2f0f25e19d5bc6ff8370715c48c47aa5420a;hpb=f0f3fa43e9a566dc16df104611ac9b5293691e04;p=dnssec-prover diff --git a/src/query.rs b/src/query.rs index 65da2f0..675d82b 100644 --- a/src/query.rs +++ b/src/query.rs @@ -4,6 +4,12 @@ use std::net::{SocketAddr, TcpStream}; use std::io::{Read, Write, Error, ErrorKind}; +#[cfg(feature = "tokio")] +use tokio_crate::net::TcpStream as TokioTcpStream; +#[cfg(feature = "tokio")] +use tokio_crate::io::{AsyncReadExt, AsyncWriteExt}; + + use crate::write_rr; use crate::rr::*; use crate::ser::*; @@ -16,7 +22,8 @@ fn emap(v: Result) -> Result { v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response")) } -fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> { +fn build_query(domain: Name, ty: u16) -> Vec { + // TODO: Move to not allocating for the query let mut query = Vec::with_capacity(1024); let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(&domain) + 11; query.extend_from_slice(&query_msg_len.to_be_bytes()); @@ -31,17 +38,24 @@ fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error query.extend_from_slice(&[0, 0]); // EDNS version 0 query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs query.extend_from_slice(&0u16.to_be_bytes()); // No additional data + query +} + +fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> { + let query = build_query(domain, ty); stream.write_all(&query)?; Ok(()) } -fn read_response(stream: &mut TcpStream) -> Result, Error> { - let mut len = [0; 2]; - stream.read_exact(&mut len)?; - let mut resp = vec![0; u16::from_be_bytes(len) as usize]; - stream.read_exact(&mut resp)?; +#[cfg(feature = "tokio")] +async fn send_query_async(stream: &mut TokioTcpStream, domain: Name, ty: u16) -> Result<(), Error> { + let query = build_query(domain, ty); + stream.write_all(&query).await?; + Ok(()) +} - let mut read: &[u8] = &resp; +fn handle_response(resp: &[u8], proof: &mut Vec) -> Result, Error> { + let mut read: &[u8] = resp; if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); } // 2 byte transaction ID let flags = emap(read_u16(&mut read))?; @@ -61,17 +75,36 @@ fn read_response(stream: &mut TcpStream) -> Result, Error> { let _additional = emap(read_u16(&mut read))?; for _ in 0..questions { - emap(read_name(&mut read))?; + emap(read_wire_packet_name(&mut read, resp))?; emap(read_u16(&mut read))?; // type emap(read_u16(&mut read))?; // class } // Only read the answers (skip authorities and additional) as that's all we care about. - let mut res = Vec::new(); + let mut rrsig_opt = None; for _ in 0..answers { - res.push(emap(parse_wire_packet_rr(&mut read, &resp))?); + let rr = emap(parse_wire_packet_rr(&mut read, &resp))?; + write_rr(&rr, 0, proof); + if let RR::RRSig(rrsig) = rr { rrsig_opt = Some(rrsig); } } - Ok(res) + Ok(rrsig_opt) +} + +fn read_response(stream: &mut TcpStream, proof: &mut Vec) -> Result, Error> { + let mut len = [0; 2]; + stream.read_exact(&mut len)?; + let mut resp = vec![0; u16::from_be_bytes(len) as usize]; + stream.read_exact(&mut resp)?; + handle_response(&resp, proof) +} + +#[cfg(feature = "tokio")] +async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec) -> Result, Error> { + let mut len = [0; 2]; + stream.read_exact(&mut len).await?; + let mut resp = vec![0; u16::from_be_bytes(len) as usize]; + stream.read_exact(&mut resp).await?; + handle_response(&resp, proof) } fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result, Error> { @@ -80,18 +113,41 @@ fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result, E send_query(&mut stream, domain, ty)?; let mut reached_root = false; for _ in 0..10 { - let resp = read_response(&mut stream)?; - for rr in resp { - write_rr(&rr, 0, &mut res); - if rr.name().as_str() == "." { + let rrsig_opt = read_response(&mut stream, &mut res)?; + if let Some(rrsig) = rrsig_opt { + if rrsig.name.as_str() == "." { reached_root = true; } else { - if let RR::RRSig(rrsig) = rr { - if rrsig.name == rrsig.key_name { - send_query(&mut stream, rrsig.key_name, DS::TYPE)?; - } else { - send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?; - } + if rrsig.name == rrsig.key_name { + send_query(&mut stream, rrsig.key_name, DS::TYPE)?; + } else { + send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?; + } + } + } + if reached_root { break; } + } + + if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) } + else { Ok(res) } +} + +#[cfg(feature = "tokio")] +async fn build_proof_async(resolver: SocketAddr, domain: Name, ty: u16) -> Result, Error> { + let mut stream = TokioTcpStream::connect(resolver).await?; + let mut res = Vec::new(); + send_query_async(&mut stream, domain, ty).await?; + let mut reached_root = false; + for _ in 0..10 { + let rrsig_opt = read_response_async(&mut stream, &mut res).await?; + if let Some(rrsig) = rrsig_opt { + if rrsig.name.as_str() == "." { + reached_root = true; + } else { + if rrsig.name == rrsig.key_name { + send_query_async(&mut stream, rrsig.key_name, DS::TYPE).await?; + } else { + send_query_async(&mut stream, rrsig.key_name, DnsKey::TYPE).await?; } } } @@ -122,6 +178,31 @@ pub fn build_tlsa_proof(resolver: SocketAddr, domain: Name) -> Result, E build_proof(resolver, domain, TLSA::TYPE) } + +/// Builds a DNSSEC proof for an A record by querying a recursive resolver +#[cfg(feature = "tokio")] +pub async fn build_a_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { + build_proof_async(resolver, domain, A::TYPE).await +} + +/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver +#[cfg(feature = "tokio")] +pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { + build_proof_async(resolver, domain, AAAA::TYPE).await +} + +/// Builds a DNSSEC proof for a TXT record by querying a recursive resolver +#[cfg(feature = "tokio")] +pub async fn build_txt_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { + build_proof_async(resolver, domain, Txt::TYPE).await +} + +/// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver +#[cfg(feature = "tokio")] +pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: Name) -> Result, Error> { + build_proof_async(resolver, domain, TLSA::TYPE).await +} + #[cfg(test)] mod tests { use super::*; @@ -142,4 +223,20 @@ mod tests { let verified_rrs = verify_rr_stream(&rrs).unwrap(); assert_eq!(verified_rrs.len(), 1); } + + #[cfg(feature = "tokio")] + use tokio_crate as tokio; + + #[cfg(feature = "tokio")] + #[tokio::test] + async fn test_txt_query_async() { + let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); + let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap(); + let proof = build_txt_proof_async(sockaddr, query_name).await.unwrap(); + + let mut rrs = parse_rr_stream(&proof).unwrap(); + rrs.shuffle(&mut rand::rngs::OsRng); + let verified_rrs = verify_rr_stream(&rrs).unwrap(); + assert_eq!(verified_rrs.len(), 1); + } }