From 52e5e29f86bd9faaf68bf97b114fe92d25d9b91d Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 12 Feb 2024 00:03:37 +0000 Subject: [PATCH] Use (and expose) a `ProofBuilder` state machine for proving This will allow us to expose the state machine we use for building proofs in, eg, javascript, allowing the construction of proofs using DoH. --- src/query.rs | 234 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 151 insertions(+), 83 deletions(-) diff --git a/src/query.rs b/src/query.rs index 3b1d217..1ed76d5 100644 --- a/src/query.rs +++ b/src/query.rs @@ -17,10 +17,6 @@ use crate::ser::*; // this constant instead of a random value. const TXID: u16 = 0x4242; -fn emap(v: Result) -> Result { - v.map_err(|_| Error::new(ErrorKind::Other, "Bad Response")) -} - fn build_query(domain: &Name, ty: u16) -> Vec { // TODO: Move to not allocating for the query let mut query = Vec::with_capacity(1024); @@ -40,19 +36,6 @@ fn build_query(domain: &Name, ty: u16) -> Vec { query } -fn send_query(stream: &mut TcpStream, domain: &Name, ty: u16) -> Result<(), Error> { - let query = build_query(domain, ty); - stream.write_all(&query)?; - Ok(()) -} - -#[cfg(feature = "tokio")] -async fn send_query_async(stream: &mut TokioTcpStream, domain: &Name, ty: u16) -> Result<(), Error> { - let query = build_query(domain, ty); - stream.write_all(&query).await?; - Ok(()) -} - #[cfg(fuzzing)] /// Read some input and parse it as if it came from a server, for fuzzing. pub fn fuzz_response(response: &[u8]) { @@ -60,37 +43,37 @@ pub fn fuzz_response(response: &[u8]) { let _ = handle_response(response, &mut proof, &mut names); } -fn handle_response(resp: &[u8], proof: &mut Vec, rrsig_key_names: &mut Vec) -> Result { +fn handle_response(resp: &[u8], proof: &mut Vec, rrsig_key_names: &mut Vec) -> Result { let mut read: &[u8] = resp; - if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); } + if read_u16(&mut read)? != TXID { return Err(()); } // 2 byte transaction ID - let flags = emap(read_u16(&mut read))?; + let flags = read_u16(&mut read)?; if flags & 0b1000_0000_0000_0000 == 0 { - return Err(Error::new(ErrorKind::Other, "Missing response flag")); + return Err(()); } if flags & 0b0111_1010_0000_0111 != 0 { - return Err(Error::new(ErrorKind::Other, "Server indicated error or provided bunk flags")); + return Err(()); } if flags & 0b10_0000 == 0 { - return Err(Error::new(ErrorKind::Other, "Server indicated data could not be authenticated")); + return Err(()); } - let questions = emap(read_u16(&mut read))?; - if questions != 1 { return Err(Error::new(ErrorKind::Other, "server responded to multiple Qs")); } - let answers = emap(read_u16(&mut read))?; - if answers == 0 { return Err(Error::new(ErrorKind::Other, "No answers")); } - let _authorities = emap(read_u16(&mut read))?; - let _additional = emap(read_u16(&mut read))?; + let questions = read_u16(&mut read)?; + if questions != 1 { return Err(()); } + let answers = read_u16(&mut read)?; + if answers == 0 { return Err(()); } + let _authorities = read_u16(&mut read)?; + let _additional = read_u16(&mut read)?; for _ in 0..questions { - emap(read_wire_packet_name(&mut read, resp))?; - emap(read_u16(&mut read))?; // type - emap(read_u16(&mut read))?; // class + read_wire_packet_name(&mut read, resp)?; + read_u16(&mut read)?; // type + read_u16(&mut read)?; // class } // Only read the answers (skip authorities and additional) as that's all we care about. let mut min_ttl = u32::MAX; for _ in 0..answers { - let (rr, ttl) = emap(parse_wire_packet_rr(&mut read, &resp))?; + let (rr, ttl) = parse_wire_packet_rr(&mut read, &resp)?; write_rr(&rr, ttl, proof); min_ttl = cmp::min(min_ttl, ttl); if let RR::RRSig(rrsig) = rr { rrsig_key_names.push(rrsig.key_name); } @@ -98,81 +81,166 @@ fn handle_response(resp: &[u8], proof: &mut Vec, rrsig_key_names: &mut Vec, rrsig_key_names: &mut Vec) -> Result { - let mut len = [0; 2]; - stream.read_exact(&mut len)?; - let mut resp = vec![0; u16::from_be_bytes(len) as usize]; - stream.read_exact(&mut resp)?; - handle_response(&resp, proof, rrsig_key_names) +const MAX_REQUESTS: usize = 10; +/// A simple state machine which will generate a series of queries and process the responses until +/// it has built a DNSSEC proof. +/// +/// A [`ProofBuilder`] driver starts with [`ProofBuilder::new`], fetching the state machine and +/// initial query. As long as [`ProofBuilder::awaiting_responses`] returns true, responses should +/// be read from the resolver. For each query response read from the DNS resolver, +/// [`ProofBuilder::process_response`] should be called, and each fresh query returned should be +/// sent to the resolver. Once [`ProofBuilder::awaiting_responses`] returns false, +/// [`ProofBuilder::finish_proof`] should be called to fetch the resulting proof. +pub struct ProofBuilder { + proof: Vec, + min_ttl: u32, + dnskeys_requested: Vec, + pending_queries: usize, + queries_made: usize, +} + +impl ProofBuilder { + /// Constructs a new [`ProofBuilder`] and an initial query to send to the recursive resolver to + /// begin the proof building process. + /// + /// Given a correctly-functioning resolver the proof will ultimately be able to prove the + /// contents of any records with the given `ty`pe at the given `name` (as long as the given + /// `ty`pe is supported by this library). + /// + /// You can find constants for supported standard types in the [`crate::rr`] module. + pub fn new(name: &Name, ty: u16) -> (ProofBuilder, Vec) { + let initial_query = build_query(name, ty); + (ProofBuilder { + proof: Vec::new(), + min_ttl: u32::MAX, + dnskeys_requested: Vec::with_capacity(MAX_REQUESTS), + pending_queries: 1, + queries_made: 1, + }, initial_query) + } + + /// Returns true as long as further responses are expected from the resolver. + /// + /// As long as this returns true, responses should be read from the resolver and passed to + /// [`Self::process_response`]. Once this returns false, [`Self::finish_proof`] should be used + /// to (possibly) get the final proof. + pub fn awaiting_responses(&self) -> bool { + self.pending_queries > 0 && self.queries_made <= MAX_REQUESTS + } + + /// Processes a query response from the recursive resolver, returning a list of new queries to + /// send to the resolver. + pub fn process_response(&mut self, resp: &[u8]) -> Result>, ()> { + if self.pending_queries == 0 { return Err(()); } + + let mut rrsig_key_names = Vec::new(); + let min_ttl = handle_response(&resp, &mut self.proof, &mut rrsig_key_names)?; + self.min_ttl = cmp::min(self.min_ttl, min_ttl); + self.pending_queries -= 1; + + rrsig_key_names.sort_unstable(); + rrsig_key_names.dedup(); + + let mut new_queries = Vec::with_capacity(2); + for key_name in rrsig_key_names.drain(..) { + if !self.dnskeys_requested.contains(&key_name) { + new_queries.push(build_query(&key_name, DnsKey::TYPE)); + self.pending_queries += 1; + self.queries_made += 1; + self.dnskeys_requested.push(key_name.clone()); + + if key_name.as_str() != "." { + new_queries.push(build_query(&key_name, DS::TYPE)); + self.pending_queries += 1; + self.queries_made += 1; + } + } + } + if self.queries_made <= MAX_REQUESTS { + Ok(new_queries) + } else { + Ok(Vec::new()) + } + } + + /// Finalizes the proof, if one is available, and returns it as well as the TTL that should be + /// used to cache the proof (i.e. the lowest TTL of all records which were used to build the + /// proof). + pub fn finish_proof(self) -> Result<(Vec, u32), ()> { + if self.pending_queries > 0 || self.queries_made > MAX_REQUESTS { + Err(()) + } else { + Ok((self.proof, self.min_ttl)) + } + } +} + +fn send_query(stream: &mut TcpStream, query: &[u8]) -> Result<(), Error> { + stream.write_all(&query)?; + Ok(()) } #[cfg(feature = "tokio")] -async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec, rrsig_key_names: &mut Vec) -> Result { - let mut len = [0; 2]; - stream.read_exact(&mut len).await?; - let mut resp = vec![0; u16::from_be_bytes(len) as usize]; - stream.read_exact(&mut resp).await?; - handle_response(&resp, proof, rrsig_key_names) +async fn send_query_async(stream: &mut TokioTcpStream, query: &[u8]) -> Result<(), Error> { + stream.write_all(&query).await?; + Ok(()) +} + +type MsgBuf = [u8; u16::MAX as usize]; + +fn read_response(stream: &mut TcpStream, response_buf: &mut MsgBuf) -> Result { + let mut len_bytes = [0; 2]; + stream.read_exact(&mut len_bytes)?; + let len = u16::from_be_bytes(len_bytes); + stream.read_exact(&mut response_buf[..len as usize])?; + Ok(len) +} + +#[cfg(feature = "tokio")] +async fn read_response_async(stream: &mut TokioTcpStream, response_buf: &mut MsgBuf) -> Result { + let mut len_bytes = [0; 2]; + stream.read_exact(&mut len_bytes).await?; + let len = u16::from_be_bytes(len_bytes); + stream.read_exact(&mut response_buf[..len as usize]).await?; + Ok(len) } macro_rules! build_proof_impl { - ($stream: ident, $send_query: ident, $read_response: ident $(, $async_ok: tt)?) => { { + ($stream: ident, $send_query: ident, $read_response: ident, $domain: expr, $ty: expr $(, $async_ok: tt)?) => { { // We require the initial query to have already gone out, and assume our resolver will // return any CNAMEs all the way to the final record in the response. From there, we just // have to take any RRSIGs in the response and walk them up to the root. We do so // iteratively, sending DNSKEY and DS lookups after every response, deduplicating requests // using `dnskeys_requested`. - let mut res = Vec::new(); // The actual proof stream - let mut min_ttl = u32::MAX; // Min TTL of any answer record - const MAX_REQUESTS: usize = 20; - let mut rrsig_key_names = Vec::with_capacity(4); // Last response's RRSIG key_names - let mut dnskeys_requested = Vec::with_capacity(MAX_REQUESTS); - let mut pending_queries = 1; - let mut queries_made = 1; - while pending_queries != 0 && queries_made <= MAX_REQUESTS { - let response_min_ttl = $read_response(&mut $stream, &mut res, &mut rrsig_key_names) + let (mut builder, initial_query) = ProofBuilder::new($domain, $ty); + $send_query(&mut $stream, &initial_query) + $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ? + let mut response_buf = [0; u16::MAX as usize]; + while builder.awaiting_responses() { + let response_len = $read_response(&mut $stream, &mut response_buf) $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ? - pending_queries -= 1; - min_ttl = cmp::min(min_ttl, response_min_ttl); - rrsig_key_names.sort_unstable(); - rrsig_key_names.dedup(); - for key_name in rrsig_key_names.drain(..) { - if !dnskeys_requested.contains(&key_name) { - $send_query(&mut $stream, &key_name, DnsKey::TYPE) - $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ? - pending_queries += 1; - queries_made += 1; - dnskeys_requested.push(key_name.clone()); - - if key_name.as_str() != "." { - $send_query(&mut $stream, &key_name, DS::TYPE) - $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ? - pending_queries += 1; - queries_made += 1; - } - } + let new_queries = builder.process_response(&response_buf[..response_len as usize]) + .map_err(|()| Error::new(ErrorKind::Other, "Bad response"))?; + for query in new_queries { + $send_query(&mut $stream, &query) + $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ? } } - if queries_made > MAX_REQUESTS { - Err(Error::new(ErrorKind::Other, "Too many requests required")) - } else { - Ok((res, min_ttl)) - } + builder.finish_proof() + .map_err(|()| Error::new(ErrorKind::Other, "Too many requests required")) } } } fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec, u32), Error> { let mut stream = TcpStream::connect(resolver)?; - send_query(&mut stream, domain, ty)?; - build_proof_impl!(stream, send_query, read_response) + build_proof_impl!(stream, send_query, read_response, domain, ty) } #[cfg(feature = "tokio")] async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec, u32), Error> { let mut stream = TokioTcpStream::connect(resolver).await?; - send_query_async(&mut stream, domain, ty).await?; - build_proof_impl!(stream, send_query_async, read_response_async, { Ok::<(), Error>(()) }) + build_proof_impl!(stream, send_query_async, read_response_async, domain, ty, { Ok::<(), Error>(()) }) } /// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as -- 2.39.5