//! This module exposes utilities for building DNSSEC proofs by directly querying a recursive
//! resolver.
-use std::cmp;
+use core::{cmp, ops};
+use alloc::vec;
+use alloc::vec::Vec;
+
+#[cfg(feature = "std")]
use std::net::{SocketAddr, TcpStream};
+#[cfg(feature = "std")]
use std::io::{Read, Write, Error, ErrorKind};
#[cfg(feature = "tokio")]
use crate::rr::*;
use crate::ser::*;
+// In testing use a rather small buffer to ensure we hit the allocation paths sometimes. In
+// production, we should generally never actually need to go to heap as DNS messages are rarely
+// larger than a KiB or two.
+#[cfg(test)]
+const STACK_BUF_LIMIT: u16 = 32;
+#[cfg(not(test))]
+const STACK_BUF_LIMIT: u16 = 2048;
+
+/// A buffer for storing queries and responses.
+#[derive(Clone, PartialEq, Eq)]
+pub struct QueryBuf {
+ buf: [u8; STACK_BUF_LIMIT as usize],
+ heap_buf: Vec<u8>,
+ len: u16,
+}
+impl QueryBuf {
+ fn new_zeroed(len: u16) -> Self {
+ let heap_buf = if len > STACK_BUF_LIMIT { vec![0; len as usize] } else { Vec::new() };
+ Self {
+ buf: [0; STACK_BUF_LIMIT as usize],
+ heap_buf,
+ len
+ }
+ }
+ pub(crate) fn extend_from_slice(&mut self, sl: &[u8]) {
+ let new_len = self.len.saturating_add(sl.len() as u16);
+ let was_heap = self.len > STACK_BUF_LIMIT;
+ let is_heap = new_len > STACK_BUF_LIMIT;
+ if was_heap != is_heap {
+ self.heap_buf = vec![0; new_len as usize];
+ self.heap_buf[..self.len as usize].copy_from_slice(&self.buf[..self.len as usize]);
+ }
+ let target = if is_heap {
+ self.heap_buf.resize(new_len as usize, 0);
+ &mut self.heap_buf[self.len as usize..]
+ } else {
+ &mut self.buf[self.len as usize..new_len as usize]
+ };
+ target.copy_from_slice(sl);
+ self.len = new_len;
+ }
+}
+impl ops::Deref for QueryBuf {
+ type Target = [u8];
+ fn deref(&self) -> &[u8] {
+ if self.len > STACK_BUF_LIMIT {
+ &self.heap_buf
+ } else {
+ &self.buf[..self.len as usize]
+ }
+ }
+}
+impl ops::DerefMut for QueryBuf {
+ fn deref_mut(&mut self) -> &mut [u8] {
+ if self.len > STACK_BUF_LIMIT {
+ &mut self.heap_buf
+ } else {
+ &mut self.buf[..self.len as usize]
+ }
+ }
+}
+
// We don't care about transaction IDs as we're only going to accept signed data. Thus, we use
// this constant instead of a random value.
const TXID: u16 = 0x4242;
-fn build_query(domain: &Name, ty: u16) -> Vec<u8> {
- // TODO: Move to not allocating for the query
- let mut query = Vec::with_capacity(1024);
+fn build_query(domain: &Name, ty: u16) -> QueryBuf {
+ let mut query = QueryBuf::new_zeroed(0);
let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(domain) + 11;
query.extend_from_slice(&query_msg_len.to_be_bytes());
query.extend_from_slice(&TXID.to_be_bytes());
/// `ty`pe is supported by this library).
///
/// You can find constants for supported standard types in the [`crate::rr`] module.
- pub fn new(name: &Name, ty: u16) -> (ProofBuilder, Vec<u8>) {
+ pub fn new(name: &Name, ty: u16) -> (ProofBuilder, QueryBuf) {
let initial_query = build_query(name, ty);
(ProofBuilder {
proof: Vec::new(),
/// Processes a query response from the recursive resolver, returning a list of new queries to
/// send to the resolver.
- pub fn process_response(&mut self, resp: &[u8]) -> Result<Vec<Vec<u8>>, ()> {
+ pub fn process_response(&mut self, resp: &QueryBuf) -> Result<Vec<QueryBuf>, ()> {
if self.pending_queries == 0 { return Err(()); }
let mut rrsig_key_names = Vec::new();
}
}
+#[cfg(feature = "std")]
fn send_query(stream: &mut TcpStream, query: &[u8]) -> Result<(), Error> {
stream.write_all(&query)?;
Ok(())
Ok(())
}
-type MsgBuf = [u8; u16::MAX as usize];
-
-fn read_response(stream: &mut TcpStream, response_buf: &mut MsgBuf) -> Result<u16, Error> {
+#[cfg(feature = "std")]
+fn read_response(stream: &mut TcpStream) -> Result<QueryBuf, Error> {
let mut len_bytes = [0; 2];
stream.read_exact(&mut len_bytes)?;
- let len = u16::from_be_bytes(len_bytes);
- stream.read_exact(&mut response_buf[..len as usize])?;
- Ok(len)
+ let mut buf = QueryBuf::new_zeroed(u16::from_be_bytes(len_bytes));
+ stream.read_exact(&mut buf)?;
+ Ok(buf)
}
#[cfg(feature = "tokio")]
-async fn read_response_async(stream: &mut TokioTcpStream, response_buf: &mut MsgBuf) -> Result<u16, Error> {
+async fn read_response_async(stream: &mut TokioTcpStream) -> Result<QueryBuf, Error> {
let mut len_bytes = [0; 2];
stream.read_exact(&mut len_bytes).await?;
- let len = u16::from_be_bytes(len_bytes);
- stream.read_exact(&mut response_buf[..len as usize]).await?;
- Ok(len)
+ let mut buf = QueryBuf::new_zeroed(u16::from_be_bytes(len_bytes));
+ stream.read_exact(&mut buf).await?;
+ Ok(buf)
}
+#[cfg(feature = "std")]
macro_rules! build_proof_impl {
($stream: ident, $send_query: ident, $read_response: ident, $domain: expr, $ty: expr $(, $async_ok: tt)?) => { {
// We require the initial query to have already gone out, and assume our resolver will
let (mut builder, initial_query) = ProofBuilder::new($domain, $ty);
$send_query(&mut $stream, &initial_query)
$(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
- let mut response_buf = [0; u16::MAX as usize];
while builder.awaiting_responses() {
- let response_len = $read_response(&mut $stream, &mut response_buf)
+ let response = $read_response(&mut $stream)
$(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
- let new_queries = builder.process_response(&response_buf[..response_len as usize])
+ let new_queries = builder.process_response(&response)
.map_err(|()| Error::new(ErrorKind::Other, "Bad response"))?;
for query in new_queries {
$send_query(&mut $stream, &query)
} }
}
+#[cfg(feature = "std")]
fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
let mut stream = TcpStream::connect(resolver)?;
build_proof_impl!(stream, send_query, read_response, domain, ty)
///
/// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
/// module to validate the records contained.
+#[cfg(feature = "std")]
pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
build_proof(resolver, domain, A::TYPE)
}
///
/// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
/// module to validate the records contained.
+#[cfg(feature = "std")]
pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
build_proof(resolver, domain, AAAA::TYPE)
}
///
/// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
/// module to validate the records contained.
+#[cfg(feature = "std")]
pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
build_proof(resolver, domain, Txt::TYPE)
}
///
/// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
/// module to validate the records contained.
+#[cfg(feature = "std")]
pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
build_proof(resolver, domain, TLSA::TYPE)
}
build_proof_async(resolver, domain, TLSA::TYPE).await
}
-#[cfg(all(feature = "validation", test))]
+#[cfg(all(feature = "validation", feature = "std", test))]
mod tests {
use super::*;
use crate::validation::*;
use std::net::ToSocketAddrs;
use std::time::SystemTime;
-
#[test]
fn test_cloudflare_txt_query() {
let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();