Avoid allocating for all message buffers, expose querying in no-std
authorMatt Corallo <git@bluematt.me>
Mon, 12 Feb 2024 00:28:11 +0000 (00:28 +0000)
committerMatt Corallo <git@bluematt.me>
Mon, 12 Feb 2024 00:28:11 +0000 (00:28 +0000)
src/lib.rs
src/query.rs
src/ser.rs

index b36f05483bd0545f57988ad56a5516b59078601b..902f6a35853966c26ed8b99214863cad0d6c3052 100644 (file)
@@ -34,9 +34,7 @@ extern crate alloc;
 
 pub mod rr;
 pub mod ser;
+pub mod query;
 
 #[cfg(feature = "validation")]
 pub mod validation;
-
-#[cfg(feature = "std")]
-pub mod query;
index 1ed76d578b97201fe80af8e87963619960c0aaae..c2ea75e4de7a9fe96b408c7f5bf1ea90ffbd6a8e 100644 (file)
@@ -1,8 +1,13 @@
 //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive
 //! resolver.
 
-use std::cmp;
+use core::{cmp, ops};
+use alloc::vec;
+use alloc::vec::Vec;
+
+#[cfg(feature = "std")]
 use std::net::{SocketAddr, TcpStream};
+#[cfg(feature = "std")]
 use std::io::{Read, Write, Error, ErrorKind};
 
 #[cfg(feature = "tokio")]
@@ -13,13 +18,74 @@ use tokio_crate::io::{AsyncReadExt, AsyncWriteExt};
 use crate::rr::*;
 use crate::ser::*;
 
+// In testing use a rather small buffer to ensure we hit the allocation paths sometimes. In
+// production, we should generally never actually need to go to heap as DNS messages are rarely
+// larger than a KiB or two.
+#[cfg(test)]
+const STACK_BUF_LIMIT: u16 = 32;
+#[cfg(not(test))]
+const STACK_BUF_LIMIT: u16 = 2048;
+
+/// A buffer for storing queries and responses.
+#[derive(Clone, PartialEq, Eq)]
+pub struct QueryBuf {
+       buf: [u8; STACK_BUF_LIMIT as usize],
+       heap_buf: Vec<u8>,
+       len: u16,
+}
+impl QueryBuf {
+       fn new_zeroed(len: u16) -> Self {
+               let heap_buf = if len > STACK_BUF_LIMIT { vec![0; len as usize] } else { Vec::new() };
+               Self {
+                       buf: [0; STACK_BUF_LIMIT as usize],
+                       heap_buf,
+                       len
+               }
+       }
+       pub(crate) fn extend_from_slice(&mut self, sl: &[u8]) {
+               let new_len = self.len.saturating_add(sl.len() as u16);
+               let was_heap = self.len > STACK_BUF_LIMIT;
+               let is_heap = new_len > STACK_BUF_LIMIT;
+               if was_heap != is_heap {
+                       self.heap_buf = vec![0; new_len as usize];
+                       self.heap_buf[..self.len as usize].copy_from_slice(&self.buf[..self.len as usize]);
+               }
+               let target = if is_heap {
+                       self.heap_buf.resize(new_len as usize, 0);
+                       &mut self.heap_buf[self.len as usize..]
+               } else {
+                       &mut self.buf[self.len as usize..new_len as usize]
+               };
+               target.copy_from_slice(sl);
+               self.len = new_len;
+       }
+}
+impl ops::Deref for QueryBuf {
+       type Target = [u8];
+       fn deref(&self) -> &[u8] {
+               if self.len > STACK_BUF_LIMIT {
+                       &self.heap_buf
+               } else {
+                       &self.buf[..self.len as usize]
+               }
+       }
+}
+impl ops::DerefMut for QueryBuf {
+       fn deref_mut(&mut self) -> &mut [u8] {
+               if self.len > STACK_BUF_LIMIT {
+                       &mut self.heap_buf
+               } else {
+                       &mut self.buf[..self.len as usize]
+               }
+       }
+}
+
 // We don't care about transaction IDs as we're only going to accept signed data. Thus, we use
 // this constant instead of a random value.
 const TXID: u16 = 0x4242;
 
-fn build_query(domain: &Name, ty: u16) -> Vec<u8> {
-       // TODO: Move to not allocating for the query
-       let mut query = Vec::with_capacity(1024);
+fn build_query(domain: &Name, ty: u16) -> QueryBuf {
+       let mut query = QueryBuf::new_zeroed(0);
        let query_msg_len: u16 = 2 + 2 + 8 + 2 + 2 + name_len(domain) + 11;
        query.extend_from_slice(&query_msg_len.to_be_bytes());
        query.extend_from_slice(&TXID.to_be_bytes());
@@ -108,7 +174,7 @@ impl ProofBuilder {
        /// `ty`pe is supported by this library).
        ///
        /// You can find constants for supported standard types in the [`crate::rr`] module.
-       pub fn new(name: &Name, ty: u16) -> (ProofBuilder, Vec<u8>) {
+       pub fn new(name: &Name, ty: u16) -> (ProofBuilder, QueryBuf) {
                let initial_query = build_query(name, ty);
                (ProofBuilder {
                        proof: Vec::new(),
@@ -130,7 +196,7 @@ impl ProofBuilder {
 
        /// Processes a query response from the recursive resolver, returning a list of new queries to
        /// send to the resolver.
-       pub fn process_response(&mut self, resp: &[u8]) -> Result<Vec<Vec<u8>>, ()> {
+       pub fn process_response(&mut self, resp: &QueryBuf) -> Result<Vec<QueryBuf>, ()> {
                if self.pending_queries == 0 { return Err(()); }
 
                let mut rrsig_key_names = Vec::new();
@@ -175,6 +241,7 @@ impl ProofBuilder {
        }
 }
 
+#[cfg(feature = "std")]
 fn send_query(stream: &mut TcpStream, query: &[u8]) -> Result<(), Error> {
        stream.write_all(&query)?;
        Ok(())
@@ -186,25 +253,25 @@ async fn send_query_async(stream: &mut TokioTcpStream, query: &[u8]) -> Result<(
        Ok(())
 }
 
-type MsgBuf = [u8; u16::MAX as usize];
-
-fn read_response(stream: &mut TcpStream, response_buf: &mut MsgBuf) -> Result<u16, Error> {
+#[cfg(feature = "std")]
+fn read_response(stream: &mut TcpStream) -> Result<QueryBuf, Error> {
        let mut len_bytes = [0; 2];
        stream.read_exact(&mut len_bytes)?;
-       let len = u16::from_be_bytes(len_bytes);
-       stream.read_exact(&mut response_buf[..len as usize])?;
-       Ok(len)
+       let mut buf = QueryBuf::new_zeroed(u16::from_be_bytes(len_bytes));
+       stream.read_exact(&mut buf)?;
+       Ok(buf)
 }
 
 #[cfg(feature = "tokio")]
-async fn read_response_async(stream: &mut TokioTcpStream, response_buf: &mut MsgBuf) -> Result<u16, Error> {
+async fn read_response_async(stream: &mut TokioTcpStream) -> Result<QueryBuf, Error> {
        let mut len_bytes = [0; 2];
        stream.read_exact(&mut len_bytes).await?;
-       let len = u16::from_be_bytes(len_bytes);
-       stream.read_exact(&mut response_buf[..len as usize]).await?;
-       Ok(len)
+       let mut buf = QueryBuf::new_zeroed(u16::from_be_bytes(len_bytes));
+       stream.read_exact(&mut buf).await?;
+       Ok(buf)
 }
 
+#[cfg(feature = "std")]
 macro_rules! build_proof_impl {
        ($stream: ident, $send_query: ident, $read_response: ident, $domain: expr, $ty: expr $(, $async_ok: tt)?) => { {
                // We require the initial query to have already gone out, and assume our resolver will
@@ -215,11 +282,10 @@ macro_rules! build_proof_impl {
                let (mut builder, initial_query) = ProofBuilder::new($domain, $ty);
                $send_query(&mut $stream, &initial_query)
                        $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
-               let mut response_buf = [0; u16::MAX as usize];
                while builder.awaiting_responses() {
-                       let response_len = $read_response(&mut $stream, &mut response_buf)
+                       let response = $read_response(&mut $stream)
                                $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ?
-                       let new_queries = builder.process_response(&response_buf[..response_len as usize])
+                       let new_queries = builder.process_response(&response)
                                .map_err(|()| Error::new(ErrorKind::Other, "Bad response"))?;
                        for query in new_queries {
                                $send_query(&mut $stream, &query)
@@ -232,6 +298,7 @@ macro_rules! build_proof_impl {
        } }
 }
 
+#[cfg(feature = "std")]
 fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec<u8>, u32), Error> {
        let mut stream = TcpStream::connect(resolver)?;
        build_proof_impl!(stream, send_query, read_response, domain, ty)
@@ -248,6 +315,7 @@ async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Resu
 ///
 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
 /// module to validate the records contained.
+#[cfg(feature = "std")]
 pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof(resolver, domain, A::TYPE)
 }
@@ -257,6 +325,7 @@ pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u3
 ///
 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
 /// module to validate the records contained.
+#[cfg(feature = "std")]
 pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof(resolver, domain, AAAA::TYPE)
 }
@@ -266,6 +335,7 @@ pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>,
 ///
 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
 /// module to validate the records contained.
+#[cfg(feature = "std")]
 pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof(resolver, domain, Txt::TYPE)
 }
@@ -275,6 +345,7 @@ pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>,
 ///
 /// Note that this proof is NOT verified in any way, you need to use the [`crate::validation`]
 /// module to validate the records contained.
+#[cfg(feature = "std")]
 pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec<u8>, u32), Error> {
        build_proof(resolver, domain, TLSA::TYPE)
 }
@@ -320,7 +391,7 @@ pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Resu
        build_proof_async(resolver, domain, TLSA::TYPE).await
 }
 
-#[cfg(all(feature = "validation", test))]
+#[cfg(all(feature = "validation", feature = "std", test))]
 mod tests {
        use super::*;
        use crate::validation::*;
@@ -330,7 +401,6 @@ mod tests {
        use std::net::ToSocketAddrs;
        use std::time::SystemTime;
 
-
        #[test]
        fn test_cloudflare_txt_query() {
                let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap();
index 8829725dc139f0d0dca485e622e8afb21648b2e5..2d329733b3fc5e8b4016668baab0503d11c0a621 100644 (file)
@@ -4,6 +4,7 @@ use alloc::vec::Vec;
 use alloc::string::String;
 
 use crate::rr::*;
+use crate::query::QueryBuf;
 
 pub(crate) fn read_u8(inp: &mut &[u8]) -> Result<u8, ()> {
        let res = *inp.get(0).ok_or(())?;
@@ -58,6 +59,7 @@ pub(crate) fn read_wire_packet_name(inp: &mut &[u8], wire_packet: &[u8]) -> Resu
 
 pub(crate) trait Writer { fn write(&mut self, buf: &[u8]); }
 impl Writer for Vec<u8> { fn write(&mut self, buf: &[u8]) { self.extend_from_slice(buf); } }
+impl Writer for QueryBuf { fn write(&mut self, buf: &[u8]) { self.extend_from_slice(buf); } }
 #[cfg(feature = "validation")]
 impl Writer for ring::digest::Context { fn write(&mut self, buf: &[u8]) { self.update(buf); } }
 pub(crate) fn write_name<W: Writer>(out: &mut W, name: &str) {