From 91c19b82695c29881ab5b809c681be786b942f09 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 26 Feb 2024 03:01:59 +0000 Subject: [PATCH] Add trivial helper method to get the label count in a `Name` --- src/rr.rs | 21 +++++++++++++++++++++ src/validation.rs | 8 ++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/rr.rs b/src/rr.rs index c3367a6..1755685 100644 --- a/src/rr.rs +++ b/src/rr.rs @@ -23,6 +23,27 @@ pub struct Name(String); impl Name { /// Gets the underlying human-readable domain name pub fn as_str(&self) -> &str { &self.0 } + /// Gets the number of labels in this name + pub fn labels(&self) -> u8 { + if self.as_str() == "." { + 0 + } else { + self.as_str().chars().filter(|c| *c == '.').count() as u8 + } + } + /// Gets a string containing the last `n` labels in this [`Name`] (which is also a valid name). + pub fn trailing_n_labels(&self, n: u8) -> Option<&str> { + let labels = self.labels(); + if n > labels { + None + } else if n == labels { + Some(self.as_str()) + } else if n == 0 { + Some(".") + } else { + self.as_str().splitn(labels as usize - n as usize + 1, ".").last() + } + } } impl core::ops::Deref for Name { type Target = str; diff --git a/src/validation.rs b/src/validation.rs index 9d23ef8..cfcb66f 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -93,11 +93,11 @@ where Keys: IntoIterator { records.sort_unstable(); for record in records.iter() { - let periods = record.name().as_str().chars().filter(|c| *c == '.').count(); + let record_labels = record.name().labels() as usize; let labels = sig.labels.into(); - if periods != 1 && periods != labels { - if periods < labels { return Err(ValidationError::Invalid); } - let signed_name = record.name().as_str().splitn(periods - labels + 1, ".").last(); + if record_labels != labels { + if record_labels < labels { return Err(ValidationError::Invalid); } + let signed_name = record.name().trailing_n_labels(sig.labels); debug_assert!(signed_name.is_some()); if let Some(name) = signed_name { signed_data.extend_from_slice(b"\x01*"); -- 2.39.5