From d71d4d689d86d108a7d1c5a47de7d1202e9648e0 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 8 Feb 2024 23:54:52 +0000 Subject: [PATCH] Set a cache-control header on valid responses --- src/http.rs | 7 ++++-- src/query.rs | 71 +++++++++++++++++++++++++++++++--------------------- 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/src/http.rs b/src/http.rs index 5312bbb..98aec61 100644 --- a/src/http.rs +++ b/src/http.rs @@ -116,13 +116,16 @@ mod imp { "AAAA" => build_aaaa_proof_async(resolver_sockaddr, &query_name).await, _ => break 'ret_err, }; - let proof = if let Ok(proof) = proof_res { proof } else { + let (proof, cache_ttl) = if let Ok(proof) = proof_res { proof } else { response = ("404 Not Found", "Failed to generate proof for given domain"); break 'ret_err; }; let _ = socket.write_all( - format!("HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/octet-stream\r\nAccess-Control-Allow-Origin: *\r\n\r\n", proof.len()).as_bytes() + format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/octet-stream\r\nCache-Control: public, max-age={}, s-maxage={}\r\nAccess-Control-Allow-Origin: *\r\n\r\n", + proof.len(), cache_ttl, cache_ttl + ).as_bytes() ).await; let _ = socket.write_all(&proof).await; return; diff --git a/src/query.rs b/src/query.rs index adf687a..b73b5b1 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,6 +1,7 @@ //! This module exposes utilities for building DNSSEC proofs by directly querying a recursive //! resolver. +use std::cmp; use std::net::{SocketAddr, TcpStream}; use std::io::{Read, Write, Error, ErrorKind}; @@ -52,7 +53,7 @@ async fn send_query_async(stream: &mut TokioTcpStream, domain: &Name, ty: u16) - Ok(()) } -fn handle_response(resp: &[u8], proof: &mut Vec) -> Result, Error> { +fn handle_response(resp: &[u8], proof: &mut Vec) -> Result, Error> { let mut read: &[u8] = resp; if emap(read_u16(&mut read))? != TXID { return Err(Error::new(ErrorKind::Other, "bad txid")); } // 2 byte transaction ID @@ -81,15 +82,17 @@ fn handle_response(resp: &[u8], proof: &mut Vec) -> Result, Er // Only read the answers (skip authorities and additional) as that's all we care about. let mut rrsig_opt = None; + let mut min_ttl = u32::MAX; for _ in 0..answers { let (rr, ttl) = emap(parse_wire_packet_rr(&mut read, &resp))?; write_rr(&rr, ttl, proof); + min_ttl = cmp::min(min_ttl, ttl); if let RR::RRSig(rrsig) = rr { rrsig_opt = Some(rrsig); } } - Ok(rrsig_opt) + Ok(rrsig_opt.map(|rr| (rr, min_ttl))) } -fn read_response(stream: &mut TcpStream, proof: &mut Vec) -> Result, Error> { +fn read_response(stream: &mut TcpStream, proof: &mut Vec) -> Result, Error> { let mut len = [0; 2]; stream.read_exact(&mut len)?; let mut resp = vec![0; u16::from_be_bytes(len) as usize]; @@ -98,7 +101,7 @@ fn read_response(stream: &mut TcpStream, proof: &mut Vec) -> Result) -> Result, Error> { +async fn read_response_async(stream: &mut TokioTcpStream, proof: &mut Vec) -> Result, Error> { let mut len = [0; 2]; stream.read_exact(&mut len).await?; let mut resp = vec![0; u16::from_be_bytes(len) as usize]; @@ -110,10 +113,12 @@ macro_rules! build_proof_impl { ($stream: ident, $send_query: ident, $read_response: ident $(, $async_ok: tt)?) => { { let mut res = Vec::new(); let mut reached_root = false; + let mut min_ttl = u32::MAX; for i in 0..10 { - let rrsig_opt = $read_response(&mut $stream, &mut res) + let resp_opt = $read_response(&mut $stream, &mut res) $(.await?; $async_ok)??; // Either await?; Ok(())?, or just ? - if let Some(rrsig) = rrsig_opt { + if let Some((rrsig, rrsig_min_ttl)) = resp_opt { + min_ttl = cmp::min(min_ttl, rrsig_min_ttl); if rrsig.name.as_str() == "." { reached_root = true; } else { @@ -128,65 +133,73 @@ macro_rules! build_proof_impl { } if !reached_root { Err(Error::new(ErrorKind::Other, "Too many requests required")) } - else { Ok(res) } + else { Ok((res, min_ttl)) } } } } -fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result, Error> { +fn build_proof(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec, u32), Error> { let mut stream = TcpStream::connect(resolver)?; send_query(&mut stream, domain, ty)?; build_proof_impl!(stream, send_query, read_response) } #[cfg(feature = "tokio")] -async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result, Error> { +async fn build_proof_async(resolver: SocketAddr, domain: &Name, ty: u16) -> Result<(Vec, u32), Error> { let mut stream = TokioTcpStream::connect(resolver).await?; send_query_async(&mut stream, domain, ty).await?; build_proof_impl!(stream, send_query_async, read_response_async, { Ok::<(), Error>(()) }) } -/// Builds a DNSSEC proof for an A record by querying a recursive resolver -pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result, Error> { +/// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as +/// well as the TTL for the proof provided by the recursive resolver. +pub fn build_a_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof(resolver, domain, A::TYPE) } -/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver -pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result, Error> { +/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof +/// as well as the TTL for the proof provided by the recursive resolver. +pub fn build_aaaa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof(resolver, domain, AAAA::TYPE) } -/// Builds a DNSSEC proof for a TXT record by querying a recursive resolver -pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result, Error> { +/// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof +/// as well as the TTL for the proof provided by the recursive resolver. +pub fn build_txt_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof(resolver, domain, Txt::TYPE) } -/// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver -pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result, Error> { +/// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof +/// as well as the TTL for the proof provided by the recursive resolver. +pub fn build_tlsa_proof(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof(resolver, domain, TLSA::TYPE) } -/// Builds a DNSSEC proof for an A record by querying a recursive resolver +/// Builds a DNSSEC proof for an A record by querying a recursive resolver, returning the proof as +/// well as the TTL for the proof provided by the recursive resolver. #[cfg(feature = "tokio")] -pub async fn build_a_proof_async(resolver: SocketAddr, domain: &Name) -> Result, Error> { +pub async fn build_a_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof_async(resolver, domain, A::TYPE).await } -/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver +/// Builds a DNSSEC proof for an AAAA record by querying a recursive resolver, returning the proof +/// as well as the TTL for the proof provided by the recursive resolver. #[cfg(feature = "tokio")] -pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: &Name) -> Result, Error> { +pub async fn build_aaaa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof_async(resolver, domain, AAAA::TYPE).await } -/// Builds a DNSSEC proof for a TXT record by querying a recursive resolver +/// Builds a DNSSEC proof for an TXT record by querying a recursive resolver, returning the proof +/// as well as the TTL for the proof provided by the recursive resolver. #[cfg(feature = "tokio")] -pub async fn build_txt_proof_async(resolver: SocketAddr, domain: &Name) -> Result, Error> { +pub async fn build_txt_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof_async(resolver, domain, Txt::TYPE).await } -/// Builds a DNSSEC proof for a TLSA record by querying a recursive resolver +/// Builds a DNSSEC proof for an TLSA record by querying a recursive resolver, returning the proof +/// as well as the TTL for the proof provided by the recursive resolver. #[cfg(feature = "tokio")] -pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Result, Error> { +pub async fn build_tlsa_proof_async(resolver: SocketAddr, domain: &Name) -> Result<(Vec, u32), Error> { build_proof_async(resolver, domain, TLSA::TYPE).await } @@ -205,7 +218,7 @@ mod tests { fn test_cloudflare_txt_query() { let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); let query_name = "cloudflare.com.".try_into().unwrap(); - let proof = build_txt_proof(sockaddr, &query_name).unwrap(); + let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap(); let mut rrs = parse_rr_stream(&proof).unwrap(); rrs.shuffle(&mut rand::rngs::OsRng); @@ -221,7 +234,7 @@ mod tests { fn test_sha1_query() { let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); let query_name = "benthecarman.com.".try_into().unwrap(); - let proof = build_a_proof(sockaddr, &query_name).unwrap(); + let (proof, _) = build_a_proof(sockaddr, &query_name).unwrap(); let mut rrs = parse_rr_stream(&proof).unwrap(); rrs.shuffle(&mut rand::rngs::OsRng); @@ -237,7 +250,7 @@ mod tests { fn test_txt_query() { let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap(); - let proof = build_txt_proof(sockaddr, &query_name).unwrap(); + let (proof, _) = build_txt_proof(sockaddr, &query_name).unwrap(); let mut rrs = parse_rr_stream(&proof).unwrap(); rrs.shuffle(&mut rand::rngs::OsRng); @@ -257,7 +270,7 @@ mod tests { async fn test_txt_query_async() { let sockaddr = "8.8.8.8:53".to_socket_addrs().unwrap().next().unwrap(); let query_name = "matt.user._bitcoin-payment.mattcorallo.com.".try_into().unwrap(); - let proof = build_txt_proof_async(sockaddr, &query_name).await.unwrap(); + let (proof, _) = build_txt_proof_async(sockaddr, &query_name).await.unwrap(); let mut rrs = parse_rr_stream(&proof).unwrap(); rrs.shuffle(&mut rand::rngs::OsRng); -- 2.39.5