use std::net::{SocketAddr, TcpStream};
use std::io::{Read, Write, Error, ErrorKind};
+#[cfg(feature = "tokio")]
+use tokio_crate::net::TcpStream as TokioTcpStream;
+#[cfg(feature = "tokio")]
+use tokio_crate::io::{AsyncReadExt, AsyncWriteExt};
+
+
use crate::write_rr;
use crate::rr::*;
use crate::ser::*;
v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response"))
}
-fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> {
+fn build_query(domain: Name, ty: u16) -> Vec<u8> {
+ // TODO: Move to not allocating for the query
let mut query = Vec::with_capacity(1024);
let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(&domain) + 11;
query.extend_from_slice(&query_msg_len.to_be_bytes());
query.extend_from_slice(&[0, 0]); // EDNS version 0
query.extend_from_slice(&0x8000u16.to_be_bytes()); // Accept DNSSEC RRs
query.extend_from_slice(&0u16.to_be_bytes()); // No additional data
+ query
+}
+
+fn send_query(stream: &mut TcpStream, domain: Name, ty: u16) -> Result<(), Error> {
+ let query = build_query(domain, ty);
stream.write_all(&query)?;
Ok(())
}
-fn read_response(stream: &mut TcpStream) -> Result<Vec<RR>, Error> {
- let mut len = [0; 2];
- stream.read_exact(&mut len)?;
- let mut resp = vec![0; u16::from_be_bytes(len) as usize];
- stream.read_exact(&mut resp)?;
+#[cfg(feature = "tokio")]
+async fn send_query_async(stream: &mut TokioTcpStream, domain: Name, ty: u16) -> Result<(), Error> {
+ let query = build_query(domain, ty);
+ stream.write_all(&query).await?;
+ Ok(())
+}
- let mut read: &[u8] = &resp;
+fn handle_response(resp: &[u8], proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
+ let mut read: &[u8] = resp;
if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); }
// 2 byte transaction ID
let flags = emap(read_u16(&mut read))?;
let _additional = emap(read_u16(&mut read))?;
for _ in 0..questions {
- emap(read_name(&mut read))?;
+ emap(read_wire_packet_name(&mut read, resp))?;
emap(read_u16(&mut read))?; // type
emap(read_u16(&mut read))?; // class
}
// Only read the answers (skip authorities and additional) as that's all we care about.
- let mut res = Vec::new();
+ let mut rrsig_opt = None;
for _ in 0..answers {
- res.push(emap(parse_wire_packet_rr(&mut read, &resp))?);
+ let rr = emap(parse_wire_packet_rr(&mut read, &resp))?;
+ write_rr(&rr, 0, proof);
+ if let RR::RRSig(rrsig) = rr { rrsig_opt = Some(rrsig); }
}
- Ok(res)
+ Ok(rrsig_opt)
+}
+
+fn read_response(stream: &mut TcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
+ let mut len = [0; 2];
+ stream.read_exact(&mut len)?;
+ let mut resp = vec![0; u16::from_be_bytes(len) as usize];
+ stream.read_exact(&mut resp)?;
+ handle_response(&resp, proof)
+}
+
+#[cfg(feature = "tokio")]
+async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec<u8>) -> Result<Option<RRSig>, Error> {
+ let mut len = [0; 2];
+ stream.read_exact(&mut len).await?;
+ let mut resp = vec![0; u16::from_be_bytes(len) as usize];
+ stream.read_exact(&mut resp).await?;
+ handle_response(&resp, proof)
}
fn build_proof(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
send_query(&mut stream, domain, ty)?;
let mut reached_root = false;
for _ in 0..10 {
- let resp = read_response(&mut stream)?;
- for rr in resp {
- write_rr(&rr, 0, &mut res);
- if rr.name().as_str() == "." {
+ let rrsig_opt = read_response(&mut stream, &mut res)?;
+ if let Some(rrsig) = rrsig_opt {
+ if rrsig.name.as_str() == "." {
reached_root = true;
} else {
- if let RR::RRSig(rrsig) = rr {
- if rrsig.name == rrsig.key_name {
- send_query(&mut stream, rrsig.key_name, DS::TYPE)?;
- } else {
- send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?;
- }
+ if rrsig.name == rrsig.key_name {
+ send_query(&mut stream, rrsig.key_name, DS::TYPE)?;
+ } else {
+ send_query(&mut stream, rrsig.key_name, DnsKey::TYPE)?;
+ }
+ }
+ }
+ if reached_root { break; }
+ }
+
+ if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) }
+ else { Ok(res) }
+}
+
+#[cfg(feature = "tokio")]
+async fn build_proof_async(resolver: SocketAddr, domain: Name, ty: u16) -> Result<Vec<u8>, Error> {
+ let mut stream = TokioTcpStream::connect(resolver).await?;
+ let mut res = Vec::new();
+ send_query_async(&mut stream, domain, ty).await?;
+ let mut reached_root = false;
+ for _ in 0..10 {
+ let rrsig_opt = read_response_async(&mut stream, &mut res).await?;
+ if let Some(rrsig) = rrsig_opt {
+ if rrsig.name.as_str() == "." {
+ reached_root = true;
+ } else {
+ if rrsig.name == rrsig.key_name {
+ send_query_async(&mut stream, rrsig.key_name, DS::TYPE).await?;
+ } else {
+ send_query_async(&mut stream, rrsig.key_name, DnsKey::TYPE).await?;
}
}
}
build_proof(resolver, domain, TLSA::TYPE)
}
+
+/// Builds a DNSSEC proof for an A record by querying a recursive resolver
+#[cfg(feature = "tokio")]
+pub async fn build_a_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
+ build_proof_async(resolver, domain, A::TYPE).await
+}
+
+/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver
+#[cfg(feature = "tokio")]
+pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
+ build_proof_async(resolver, domain, AAAA::TYPE).await
+}
+
+/// Builds a DNSSEC proof for a TXT record by querying a recursive resolver
+#[cfg(feature = "tokio")]
+pub async fn build_txt_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
+ build_proof_async(resolver, domain, Txt::TYPE).await
+}
+
+/// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver
+#[cfg(feature = "tokio")]
+pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: Name) -> Result<Vec<u8>, Error> {
+ build_proof_async(resolver, domain, TLSA::TYPE).await
+}
+
#[cfg(test)]
mod tests {
use super::*;
let verified_rrs = verify_rr_stream(&rrs).unwrap();
assert_eq!(verified_rrs.len(), 1);
}
+
+ #[cfg(feature = "tokio")]
+ use tokio_crate as tokio;
+
+ #[cfg(feature = "tokio")]
+ #[tokio::test]
+ async fn test_txt_query_async() {
+ let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
+ let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap();
+ let proof = build_txt_proof_async(sockaddr, query_name).await.unwrap();
+
+ let mut rrs = parse_rr_stream(&proof).unwrap();
+ rrs.shuffle(&mut rand::rngs::OsRng);
+ let verified_rrs = verify_rr_stream(&rrs).unwrap();
+ assert_eq!(verified_rrs.len(), 1);
+ }
}