From 6f9eaca05a5f707cd928a1f6cb97043fed6ee70e Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 12 Feb 2024 00:28:11 +0000 Subject: [PATCH] Avoid allocating for all message buffers, expose querying in no-std --- src/lib.rs | 4 +- src/query.rs | 112 +++++++++++++++++++++++++++++++++++++++++---------- src/ser.rs | 2 + 3 files changed, 94 insertions(+), 24 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b36f054..902f6a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,9 +34,7 @@ extern crate alloc; pub mod rr; pub mod ser; +pub mod query; #[cfg(feature = "validation")] pub mod validation; - -#[cfg(feature = "std")] -pub mod query; diff --git a/src/query.rs b/src/query.rs index 1ed76d5..c2ea75e 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,8 +1,13 @@ //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive //! resolver. -use std::cmp; +use core::{cmp, ops}; +use alloc::vec; +use alloc::vec::Vec; + +#[cfg(feature = "std")] use std::net::{SocketAddr, TcpStream}; +#[cfg(feature = "std")] use std::io::{Read, Write, Error, ErrorKind}; #[cfg(feature = "tokio")] @@ -13,13 +18,74 @@ use tokio_crate::io::{AsyncReadExt, AsyncWriteExt}; use crate::rr::*; use crate::ser::*; +// In testing use a rather small buffer to ensure we hit the allocation paths sometimes. In +// production, we should generally never actually need to go to heap as DNS messages are rarely +// larger than a KiB or two. +#[cfg(test)] +const STACK_BUF_LIMIT: u16 = 32; +#[cfg(not(test))] +const STACK_BUF_LIMIT: u16 = 2048; + +/// A buffer for storing queries and responses. +#[derive(Clone, PartialEq, Eq)] +pub struct QueryBuf { + buf: [u8; STACK_BUF_LIMIT as usize], + heap_buf: Vec, + len: u16, +} +impl QueryBuf { + fn new_zeroed(len: u16) -> Self { + let heap_buf = if len > STACK_BUF_LIMIT { vec![0; len as usize] } else { Vec::new() }; + Self { + buf: [0; STACK_BUF_LIMIT as usize], + heap_buf, + len + } + } + pub(crate) fn extend_from_slice(&mut self, sl: &[u8]) { + let new_len = self.len.saturating_add(sl.len() as u16); + let was_heap = self.len > STACK_BUF_LIMIT; + let is_heap = new_len > STACK_BUF_LIMIT; + if was_heap != is_heap { + self.heap_buf = vec![0; new_len as usize]; + self.heap_buf[..self.len as usize].copy_from_slice(&self.buf[..self.len as usize]); + } + let target = if is_heap { + self.heap_buf.resize(new_len as usize, 0); + &mut self.heap_buf[self.len as usize..] + } else { + &mut self.buf[self.len as usize..new_len as usize] + }; + target.copy_from_slice(sl); + self.len = new_len; + } +} +impl ops::Deref for QueryBuf { + type Target = [u8]; + fn deref(&self) -> &[u8] { + if self.len > STACK_BUF_LIMIT { + &self.heap_buf + } else { + &self.buf[..self.len as usize] + } + } +} +impl ops::DerefMut for QueryBuf { + fn deref_mut(&mut self) -> &mut [u8] { + if self.len > STACK_BUF_LIMIT { + &mut self.heap_buf + } else { + &mut self.buf[..self.len as usize] + } + } +} + // We don't care about transaction IDs as we're only going to accept signed data. Thus, we use // this constant instead of a random value. const TXID: u16 = 0x4242; -fn build_query(domain: &Name, ty: u16) -> Vec { - // TODO: Move to not allocating for the query - let mut query = Vec::with_capacity(1024); +fn build_query(domain: &Name, ty: u16) -> QueryBuf { + let mut query = QueryBuf::new_zeroed(0); let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(domain) + 11; query.extend_from_slice(&query_msg_len.to_be_bytes()); query.extend_from_slice(&TXID.to_be_bytes()); @@ -108,7 +174,7 @@ impl ProofBuilder { /// `ty`pe is supported by this library). /// /// You can find constants for supported standard types in the [`crate::rr`] module. - pub fn new(name: &Name, ty: u16) -> (ProofBuilder, Vec) { + pub fn new(name: &Name, ty: u16) -> (ProofBuilder, QueryBuf) { let initial_query = build_query(name, ty); (ProofBuilder { proof: Vec::new(), @@ -130,7 +196,7 @@ impl ProofBuilder { /// Processes a query response from the recursive resolver, returning a list of new queries to /// send to the resolver. - pub fn process_response(&mut self, resp: &[u8]) -> Result>, ()> { + pub fn process_response(&mut self, resp: &QueryBuf) -> Result, ()> { if self.pending_queries == 0 { return Err(()); } let mut rrsig_key_names = Vec::new(); @@ -175,6 +241,7 @@ impl ProofBuilder { } } +#[cfg(feature = "std")] fn send_query(stream: &mut TcpStream, query: &[u8]) -> Result<(), Error> { stream.write_all(&query)?; Ok(()) @@ -186,25 +253,25 @@ async fn send_query_async(stream: &mut TokioTcpStream, query: &[u8]) -> Result<( Ok(()) } -type MsgBuf = [u8; u16::MAX as usize]; - -fn read_response(stream: &mut TcpStream, response_buf: &mut MsgBuf) -> Result { +#[cfg(feature = "std")] +fn read_response(stream: &mut TcpStream) -> Result { let mut len_bytes = [0; 2]; stream.read_exact(&mut len_bytes)?; - let len = u16::from_be_bytes(len_bytes); - stream.read_exact(&mut response_buf[..len as usize])?; - Ok(len) + let mut buf = QueryBuf::new_zeroed(u16::from_be_bytes(len_bytes)); + stream.read_exact(&mut buf)?; + Ok(buf) } #[cfg(feature = "tokio")] -async fn read_response_async(stream: &mut TokioTcpStream, response_buf: &mut MsgBuf) -> Result { +async fn read_response_async(stream: &mut TokioTcpStream) -> Result { let mut len_bytes = [0; 2]; stream.read_exact(&mut len_bytes).await?; - let len = u16::from_be_bytes(len_bytes); - stream.read_exact(&mut response_buf[..len as usize]).await?; - Ok(len) + let mut buf = QueryBuf::new_zeroed(u16::from_be_bytes(len_bytes)); + stream.read_exact(&mut buf).await?; + Ok(buf) } +#[cfg(feature = "std")] macro_rules! build_proof_impl { ($stream: ident, $send_query: ident, $read_response: ident, $domain: expr, $ty: expr $(, $async_ok: tt)?) => { { // We require the initial query to have already gone out, and assume our resolver will @@ -215,11 +282,10 @@ macro_rules! build_proof_impl { let (mut builder, initial_query) = ProofBuilder::new($domain, $ty); $send_query(&mut $stream, &initial_query) $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ? - let mut response_buf = [0; u16::MAX as usize]; while builder.awaiting_responses() { - let response_len = $read_response(&mut $stream, &mut response_buf) + let response = $read_response(&mut $stream) $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ? - let new_queries = builder.process_response(&response_buf[..response_len as usize]) + let new_queries = builder.process_response(&response) .map_err(|()| Error::new(ErrorKind::Other, "Bad response"))?; for query in new_queries { $send_query(&mut $stream, &query) @@ -232,6 +298,7 @@ macro_rules! build_proof_impl { } } } +#[cfg(feature = "std")] fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec, u32), Error> { let mut stream = TcpStream::connect(resolver)?; build_proof_impl!(stream, send_query, read_response, domain, ty) @@ -248,6 +315,7 @@ async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Resu /// /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`] /// module to validate the records contained. +#[cfg(feature = "std")] pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof(resolver, domain, A::TYPE) } @@ -257,6 +325,7 @@ pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u3 /// /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`] /// module to validate the records contained. +#[cfg(feature = "std")] pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof(resolver, domain, AAAA::TYPE) } @@ -266,6 +335,7 @@ pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, /// /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`] /// module to validate the records contained. +#[cfg(feature = "std")] pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof(resolver, domain, Txt::TYPE) } @@ -275,6 +345,7 @@ pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, /// /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`] /// module to validate the records contained. +#[cfg(feature = "std")] pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof(resolver, domain, TLSA::TYPE) } @@ -320,7 +391,7 @@ pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Resu build_proof_async(resolver, domain, TLSA::TYPE).await } -#[cfg(all(feature = "validation", test))] +#[cfg(all(feature = "validation", feature = "std", test))] mod tests { use super::*; use crate::validation::*; @@ -330,7 +401,6 @@ mod tests { use std::net::ToSocketAddrs; use std::time::SystemTime; - #[test] fn test_cloudflare_txt_query() { let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); diff --git a/src/ser.rs b/src/ser.rs index 8829725..2d32973 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -4,6 +4,7 @@ use alloc::vec::Vec; use alloc::string::String; use crate::rr::*; +use crate::query::QueryBuf; pub(crate) fn read_u8(inp: &mut &[u8]) -> Result { let res = *inp.get(0).ok_or(())?; @@ -58,6 +59,7 @@ pub(crate) fn read_wire_packet_name(inp: &mut &[u8], wire_packet: &[u8]) -> Resu pub(crate) trait Writer { fn write(&mut self, buf: &[u8]); } impl Writer for Vec { fn write(&mut self, buf: &[u8]) { self.extend_from_slice(buf); } } +impl Writer for QueryBuf { fn write(&mut self, buf: &[u8]) { self.extend_from_slice(buf); } } #[cfg(feature = "validation")] impl Writer for ring::digest::Context { fn write(&mut self, buf: &[u8]) { self.update(buf); } } pub(crate) fn write_name(out: &mut W, name: &str) { -- 2.39.5