Advance self blinded payment paths
[rust-lightning] / lightning / src / ln / features.rs
index 04ce90445f994b6216f1a9c58ee7fcccbfcd8501..51c608c1a6b47231762bbb35627032dcf33d19bc 100644 (file)
 //! [BOLT #9]: https://github.com/lightning/bolts/blob/master/09-features.md
 //! [messages]: crate::ln::msgs
 
-use crate::{io, io_extras};
+#[allow(unused_imports)]
 use crate::prelude::*;
+
+use crate::{io, io_extras};
 use core::{cmp, fmt};
 use core::borrow::Borrow;
 use core::hash::{Hash, Hasher};
 use core::marker::PhantomData;
 
-use bitcoin::bech32;
-use bitcoin::bech32::{Base32Len, FromBase32, ToBase32, u5, WriteBase32};
+use bech32::{Base32Len, FromBase32, ToBase32, u5, WriteBase32};
 use crate::ln::msgs::DecodeError;
 use crate::util::ser::{Readable, WithoutLength, Writeable, Writer};
 
 mod sealed {
+       #[allow(unused_imports)]
        use crate::prelude::*;
        use crate::ln::features::Features;
 
@@ -440,6 +442,9 @@ mod sealed {
                set_unknown_feature_required, supports_unknown_test_feature, requires_unknown_test_feature);
 }
 
+const ANY_REQUIRED_FEATURES_MASK: u8 = 0b01_01_01_01;
+const ANY_OPTIONAL_FEATURES_MASK: u8 = 0b10_10_10_10;
+
 /// Tracks the set of features which a node implements, templated by the context in which it
 /// appears.
 ///
@@ -613,8 +618,8 @@ impl ChannelTypeFeatures {
                // ChannelTypeFeatures must only contain required bits, so we OR the required forms of all
                // optional bits and then AND out the optional ones.
                for byte in ret.flags.iter_mut() {
-                       *byte |= (*byte & 0b10_10_10_10) >> 1;
-                       *byte &= 0b01_01_01_01;
+                       *byte |= (*byte & ANY_OPTIONAL_FEATURES_MASK) >> 1;
+                       *byte &= ANY_REQUIRED_FEATURES_MASK;
                }
                ret
        }
@@ -759,7 +764,7 @@ impl<T: sealed::Context> Features<T> {
        }
 
        pub(crate) fn supports_any_optional_bits(&self) -> bool {
-               self.flags.iter().any(|&byte| (byte & 0b10_10_10_10) != 0)
+               self.flags.iter().any(|&byte| (byte & ANY_OPTIONAL_FEATURES_MASK) != 0)
        }
 
        /// Returns true if this `Features` object contains required features unknown by `other`.
@@ -767,20 +772,30 @@ impl<T: sealed::Context> Features<T> {
                // Bitwise AND-ing with all even bits set except for known features will select required
                // unknown features.
                self.flags.iter().enumerate().any(|(i, &byte)| {
-                       const REQUIRED_FEATURES: u8 = 0b01_01_01_01;
-                       const OPTIONAL_FEATURES: u8 = 0b10_10_10_10;
-                       let unknown_features = if i < other.flags.len() {
-                               // Form a mask similar to !T::KNOWN_FEATURE_MASK only for `other`
-                               !(other.flags[i]
-                                       | ((other.flags[i] >> 1) & REQUIRED_FEATURES)
-                                       | ((other.flags[i] << 1) & OPTIONAL_FEATURES))
-                       } else {
-                               0b11_11_11_11
-                       };
-                       (byte & (REQUIRED_FEATURES & unknown_features)) != 0
+                       let unknown_features = unset_features_mask_at_position(other, i);
+                       (byte & (ANY_REQUIRED_FEATURES_MASK & unknown_features)) != 0
                })
        }
 
+       pub(crate) fn required_unknown_bits_from(&self, other: &Self) -> Vec<usize> {
+               let mut unknown_bits = Vec::new();
+
+               // Bitwise AND-ing with all even bits set except for known features will select required
+               // unknown features.
+               self.flags.iter().enumerate().for_each(|(i, &byte)| {
+                       let unknown_features = unset_features_mask_at_position(other, i);
+                       if byte & unknown_features != 0 {
+                               for bit in (0..8).step_by(2) {
+                                       if ((byte & unknown_features) >> bit) & 1 == 1 {
+                                               unknown_bits.push(i * 8 + bit);
+                                       }
+                               }
+                       }
+               });
+
+               unknown_bits
+       }
+
        /// Returns true if this `Features` object contains unknown feature flags which are set as
        /// "required".
        pub fn requires_unknown_bits(&self) -> bool {
@@ -788,13 +803,12 @@ impl<T: sealed::Context> Features<T> {
                // unknown features.
                let byte_count = T::KNOWN_FEATURE_MASK.len();
                self.flags.iter().enumerate().any(|(i, &byte)| {
-                       let required_features = 0b01_01_01_01;
                        let unknown_features = if i < byte_count {
                                !T::KNOWN_FEATURE_MASK[i]
                        } else {
                                0b11_11_11_11
                        };
-                       (byte & (required_features & unknown_features)) != 0
+                       (byte & (ANY_REQUIRED_FEATURES_MASK & unknown_features)) != 0
                })
        }
 
@@ -1015,10 +1029,21 @@ impl<T: sealed::Context> Readable for WithoutLength<Features<T>> {
        }
 }
 
+pub(crate) fn unset_features_mask_at_position<T: sealed::Context>(other: &Features<T>, index: usize) -> u8 {
+       if index < other.flags.len() {
+               // Form a mask similar to !T::KNOWN_FEATURE_MASK only for `other`
+               !(other.flags[index]
+                       | ((other.flags[index] >> 1) & ANY_REQUIRED_FEATURES_MASK)
+                       | ((other.flags[index] << 1) & ANY_OPTIONAL_FEATURES_MASK))
+       } else {
+               0b11_11_11_11
+       }
+}
+
 #[cfg(test)]
 mod tests {
        use super::{ChannelFeatures, ChannelTypeFeatures, InitFeatures, Bolt11InvoiceFeatures, NodeFeatures, OfferFeatures, sealed};
-       use bitcoin::bech32::{Base32Len, FromBase32, ToBase32, u5};
+       use bech32::{Base32Len, FromBase32, ToBase32, u5};
        use crate::util::ser::{Readable, WithoutLength, Writeable};
 
        #[test]
@@ -1031,11 +1056,24 @@ mod tests {
                features.set_unknown_feature_required();
                assert!(features.requires_unknown_bits());
                assert!(features.supports_unknown_bits());
+               assert_eq!(features.required_unknown_bits_from(&ChannelFeatures::empty()), vec![123456788]);
 
                let mut features = ChannelFeatures::empty();
                features.set_unknown_feature_optional();
                assert!(!features.requires_unknown_bits());
                assert!(features.supports_unknown_bits());
+               assert_eq!(features.required_unknown_bits_from(&ChannelFeatures::empty()), vec![]);
+
+               let mut features = ChannelFeatures::empty();
+               features.set_unknown_feature_required();
+               features.set_custom_bit(123456786).unwrap();
+               assert!(features.requires_unknown_bits());
+               assert!(features.supports_unknown_bits());
+               assert_eq!(features.required_unknown_bits_from(&ChannelFeatures::empty()), vec![123456786, 123456788]);
+
+               let mut limiter = ChannelFeatures::empty();
+               limiter.set_unknown_feature_optional();
+               assert_eq!(features.required_unknown_bits_from(&limiter), vec![123456786]);
        }
 
        #[test]