Add trivial helper method to get the label count in a `Name`
authorMatt Corallo <git@bluematt.me>
Mon, 26 Feb 2024 03:01:59 +0000 (03:01 +0000)
committerMatt Corallo <git@bluematt.me>
Mon, 26 Feb 2024 20:44:36 +0000 (20:44 +0000)
src/rr.rs
src/validation.rs

index c3367a642d3f38044b9226ad20e80360a4fb5a92..1755685cabb527075526cb8ec14a28d490117c1a 100644 (file)
--- a/src/rr.rs
+++ b/src/rr.rs
@@ -23,6 +23,27 @@ pub struct Name(String);
 impl Name {
        /// Gets the underlying human-readable domain name
        pub fn as_str(&self) -> &str { &self.0 }
+       /// Gets the number of labels in this name
+       pub fn labels(&self) -> u8 {
+               if self.as_str() == "." {
+                       0
+               } else {
+                       self.as_str().chars().filter(|c| *c == '.').count() as u8
+               }
+       }
+       /// Gets a string containing the last `n` labels in this [`Name`] (which is also a valid name).
+       pub fn trailing_n_labels(&self, n: u8) -> Option<&str> {
+               let labels = self.labels();
+               if n > labels {
+                       None
+               } else if n == labels {
+                       Some(self.as_str())
+               } else if n == 0 {
+                       Some(".")
+               } else {
+                       self.as_str().splitn(labels as usize - n as usize + 1, ".").last()
+               }
+       }
 }
 impl core::ops::Deref for Name {
        type Target = str;
index 9d23ef8e759ad06da98027aadfe82183df2f7c56..cfcb66f9f73c4321a9c6c86275911cab38f609ca 100644 (file)
@@ -93,11 +93,11 @@ where Keys: IntoIterator<Item = &'a DnsKey> {
                        records.sort_unstable();
 
                        for record in records.iter() {
-                               let periods = record.name().as_str().chars().filter(|c| *c == '.').count();
+                               let record_labels = record.name().labels() as usize;
                                let labels = sig.labels.into();
-                               if periods != 1 && periods != labels {
-                                       if periods < labels { return Err(ValidationError::Invalid); }
-                                       let signed_name = record.name().as_str().splitn(periods - labels + 1, ".").last();
+                               if record_labels != labels {
+                                       if record_labels < labels { return Err(ValidationError::Invalid); }
+                                       let signed_name = record.name().trailing_n_labels(sig.labels);
                                        debug_assert!(signed_name.is_some());
                                        if let Some(name) = signed_name {
                                                signed_data.extend_from_slice(b"\x01*");