From 821d8fdb9880a306ff870c45a9268c0dfff1056e Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 5 Feb 2024 10:03:58 +0000 Subject: [PATCH] Add support for parsing compressed names out of wire packets --- src/rr.rs | 27 ++++++++++++------------ src/ser.rs | 60 ++++++++++++++++++++++++++++-------------------------- 2 files changed, 45 insertions(+), 42 deletions(-) diff --git a/src/rr.rs b/src/rr.rs index a611005..145f0b4 100644 --- a/src/rr.rs +++ b/src/rr.rs @@ -126,7 +126,7 @@ pub(crate) trait StaticRecord : Ord + Sized { const TYPE: u16; fn name(&self) -> &Name; fn write_u16_len_prefixed_data(&self, out: &mut Vec); - fn read_from_data(name: Name, data: &[u8]) -> Result; + fn read_from_data(name: Name, data: &[u8], wire_packet: &[u8]) -> Result; } /// A trait describing a resource record (including the [`RR`] enum). pub trait Record : Ord { @@ -169,7 +169,7 @@ pub struct Txt { impl StaticRecord for Txt { const TYPE: u16 = 16; fn name(&self) -> &Name { &self.name } - fn read_from_data(name: Name, mut data: &[u8]) -> Result { + fn read_from_data(name: Name, mut data: &[u8], _wire_packet: &[u8]) -> Result { let mut parsed_data = Vec::with_capacity(data.len() - 1); while !data.is_empty() { let len = read_u8(&mut data)? as usize; @@ -218,7 +218,7 @@ pub struct TLSA { impl StaticRecord for TLSA { const TYPE: u16 = 52; fn name(&self) -> &Name { &self.name } - fn read_from_data(name: Name, mut data: &[u8]) -> Result { + fn read_from_data(name: Name, mut data: &[u8], _wire_packet: &[u8]) -> Result { Ok(TLSA { name, cert_usage: read_u8(&mut data)?, selector: read_u8(&mut data)?, data_ty: read_u8(&mut data)?, data: data.to_vec(), @@ -245,8 +245,8 @@ pub struct CName { impl StaticRecord for CName { const TYPE: u16 = 5; fn name(&self) -> &Name { &self.name } - fn read_from_data(name: Name, mut data: &[u8]) -> Result { - Ok(CName { name, canonical_name: read_name(&mut data)? }) + fn read_from_data(name: Name, mut data: &[u8], wire_packet: &[u8]) -> Result { + Ok(CName { name, canonical_name: read_wire_packet_name(&mut data, wire_packet)? }) } fn write_u16_len_prefixed_data(&self, out: &mut Vec) { let len: u16 = name_len(&self.canonical_name); @@ -272,7 +272,7 @@ pub struct DnsKey { impl StaticRecord for DnsKey { const TYPE: u16 = 48; fn name(&self) -> &Name { &self.name } - fn read_from_data(name: Name, mut data: &[u8]) -> Result { + fn read_from_data(name: Name, mut data: &[u8], _wire_packet: &[u8]) -> Result { Ok(DnsKey { name, flags: read_u16(&mut data)?, protocol: read_u8(&mut data)?, alg: read_u8(&mut data)?, pubkey: data.to_vec(), @@ -330,7 +330,7 @@ pub struct DS { impl StaticRecord for DS { const TYPE: u16 = 43; fn name(&self) -> &Name { &self.name } - fn read_from_data(name: Name, mut data: &[u8]) -> Result { + fn read_from_data(name: Name, mut data: &[u8], _wire_packet: &[u8]) -> Result { Ok(DS { name, key_tag: read_u16(&mut data)?, alg: read_u8(&mut data)?, digest_type: read_u8(&mut data)?, digest: data.to_vec(), @@ -386,12 +386,13 @@ pub struct RRSig { impl StaticRecord for RRSig { const TYPE: u16 = 46; fn name(&self) -> &Name { &self.name } - fn read_from_data(name: Name, mut data: &[u8]) -> Result { + fn read_from_data(name: Name, mut data: &[u8], wire_packet: &[u8]) -> Result { Ok(RRSig { name, ty: read_u16(&mut data)?, alg: read_u8(&mut data)?, labels: read_u8(&mut data)?, orig_ttl: read_u32(&mut data)?, expiration: read_u32(&mut data)?, inception: read_u32(&mut data)?, - key_tag: read_u16(&mut data)?, key_name: read_name(&mut data)?, + key_tag: read_u16(&mut data)?, + key_name: read_wire_packet_name(&mut data, wire_packet)?, signature: data.to_vec(), }) } @@ -421,7 +422,7 @@ pub struct A { impl StaticRecord for A { const TYPE: u16 = 1; fn name(&self) -> &Name { &self.name } - fn read_from_data(name: Name, data: &[u8]) -> Result { + fn read_from_data(name: Name, data: &[u8], _wire_packet: &[u8]) -> Result { if data.len() != 4 { return Err(()); } let mut address = [0; 4]; address.copy_from_slice(&data); @@ -444,7 +445,7 @@ pub struct AAAA { impl StaticRecord for AAAA { const TYPE: u16 = 28; fn name(&self) -> &Name { &self.name } - fn read_from_data(name: Name, data: &[u8]) -> Result { + fn read_from_data(name: Name, data: &[u8], _wire_packet: &[u8]) -> Result { if data.len() != 16 { return Err(()); } let mut address = [0; 16]; address.copy_from_slice(&data); @@ -472,8 +473,8 @@ pub struct NS { impl StaticRecord for NS { const TYPE: u16 = 2; fn name(&self) -> &Name { &self.name } - fn read_from_data(name: Name, mut data: &[u8]) -> Result { - Ok(NS { name, name_server: read_name(&mut data)? }) + fn read_from_data(name: Name, mut data: &[u8], wire_packet: &[u8]) -> Result { + Ok(NS { name, name_server: read_wire_packet_name(&mut data, wire_packet)? }) } fn write_u16_len_prefixed_data(&self, out: &mut Vec) { out.extend_from_slice(&name_len(&self.name_server).to_be_bytes()); diff --git a/src/ser.rs b/src/ser.rs index c0566c3..e88c9ba 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -27,20 +27,30 @@ pub(crate) fn read_u32(inp: &mut &[u8]) -> Result { Ok(u32::from_be_bytes(bytes)) } -pub(crate) fn read_name(inp: &mut &[u8]) -> Result { - let mut name = String::with_capacity(1024); +fn read_wire_packet_labels(inp: &mut &[u8], wire_packet: &[u8], name: &mut String) -> Result<(), ()> { loop { let len = read_u8(inp)? as usize; if len == 0 { - if name.is_empty() { name += "."; } + if name.is_empty() { *name += "."; } + break; + } else if len >= 0xc0 { + let offs = ((len & !0xc0) << 8) | read_u8(inp)? as usize; + if offs >= wire_packet.len() { return Err(()); } + read_wire_packet_labels(&mut &wire_packet[offs..], wire_packet, name)?; break; } if inp.len() <= len { return Err(()); } - name += core::str::from_utf8(&inp[..len]).map_err(|_| ())?; - name += "."; + *name += core::str::from_utf8(&inp[..len]).map_err(|_| ())?; + *name += "."; *inp = &inp[len..]; - if name.len() > 1024 { return Err(()); } + if name.len() > 255 { return Err(()); } } + Ok(()) +} + +pub(crate) fn read_wire_packet_name(inp: &mut &[u8], wire_packet: &[u8]) -> Result { + let mut name = String::with_capacity(1024); + read_wire_packet_labels(inp, wire_packet, &mut name)?; Ok(name.try_into()?) } @@ -70,8 +80,8 @@ pub(crate) fn name_len(name: &Name) -> u16 { } } -pub(crate) fn parse_rr(inp: &mut &[u8]) -> Result { - let name = read_name(inp)?; +pub(crate) fn parse_wire_packet_rr(inp: &mut &[u8], wire_packet: &[u8]) -> Result { + let name = read_wire_packet_name(inp, wire_packet)?; let ty = read_u16(inp)?; let class = read_u16(inp)?; if class != 1 { return Err(()); } // We only support the INternet @@ -82,31 +92,23 @@ pub(crate) fn parse_rr(inp: &mut &[u8]) -> Result { *inp = &inp[data_len..]; match ty { - A::TYPE => Ok(RR::A(A::read_from_data(name, data)?)), - AAAA::TYPE => Ok(RR::AAAA(AAAA::read_from_data(name, data)?)), - NS::TYPE => Ok(RR::NS(NS::read_from_data(name, data)?)), - Txt::TYPE => { - Ok(RR::Txt(Txt::read_from_data(name, data)?)) - } - CName::TYPE => { - Ok(RR::CName(CName::read_from_data(name, data)?)) - } - TLSA::TYPE => { - Ok(RR::TLSA(TLSA::read_from_data(name, data)?)) - }, - DnsKey::TYPE => { - Ok(RR::DnsKey(DnsKey::read_from_data(name, data)?)) - }, - DS::TYPE => { - Ok(RR::DS(DS::read_from_data(name, data)?)) - }, - RRSig::TYPE => { - Ok(RR::RRSig(RRSig::read_from_data(name, data)?)) - }, + A::TYPE => Ok(RR::A(A::read_from_data(name, data, wire_packet)?)), + AAAA::TYPE => Ok(RR::AAAA(AAAA::read_from_data(name, data, wire_packet)?)), + NS::TYPE => Ok(RR::NS(NS::read_from_data(name, data, wire_packet)?)), + Txt::TYPE => Ok(RR::Txt(Txt::read_from_data(name, data, wire_packet)?)), + CName::TYPE => Ok(RR::CName(CName::read_from_data(name, data, wire_packet)?)), + TLSA::TYPE => Ok(RR::TLSA(TLSA::read_from_data(name, data, wire_packet)?)), + DnsKey::TYPE => Ok(RR::DnsKey(DnsKey::read_from_data(name, data, wire_packet)?)), + DS::TYPE => Ok(RR::DS(DS::read_from_data(name, data, wire_packet)?)), + RRSig::TYPE => Ok(RR::RRSig(RRSig::read_from_data(name, data, wire_packet)?)), _ => Err(()), } } +pub(crate) fn parse_rr(inp: &mut &[u8]) -> Result { + parse_wire_packet_rr(inp, &[]) +} + pub(crate) fn bytes_to_rsa_pk<'a>(pubkey: &'a [u8]) -> Result, ()> { if pubkey.len() <= 3 { return Err(()); } -- 2.39.5