Sanity check that known features are not required
[rust-lightning] / lightning / src / ln / features.rs
index 0dd5fa3049364a853c99ce34fe69d226a8dab689..b65b9b1b245e0a8922a4b69a8dbebfe2f76a500e 100644 (file)
@@ -166,6 +166,13 @@ mod sealed {
                                /// [`BYTE_OFFSET`]: #associatedconstant.BYTE_OFFSET
                                const OPTIONAL_MASK: u8 = 1 << (Self::ODD_BIT - 8 * Self::BYTE_OFFSET);
 
+                               /// Returns whether the feature is required by the given flags.
+                               #[inline]
+                               fn requires_feature(flags: &Vec<u8>) -> bool {
+                                       flags.len() > Self::BYTE_OFFSET &&
+                                               (flags[Self::BYTE_OFFSET] & Self::REQUIRED_MASK) != 0
+                               }
+
                                /// Returns whether the feature is supported by the given flags.
                                #[inline]
                                fn supports_feature(flags: &Vec<u8>) -> bool {
@@ -173,6 +180,16 @@ mod sealed {
                                                (flags[Self::BYTE_OFFSET] & (Self::REQUIRED_MASK | Self::OPTIONAL_MASK)) != 0
                                }
 
+                               /// Sets the feature's required (even) bit in the given flags.
+                               #[inline]
+                               fn set_required_bit(flags: &mut Vec<u8>) {
+                                       if flags.len() <= Self::BYTE_OFFSET {
+                                               flags.resize(Self::BYTE_OFFSET + 1, 0u8);
+                                       }
+
+                                       flags[Self::BYTE_OFFSET] |= Self::REQUIRED_MASK;
+                               }
+
                                /// Sets the feature's optional (odd) bit in the given flags.
                                #[inline]
                                fn set_optional_bit(flags: &mut Vec<u8>) {
@@ -191,6 +208,10 @@ mod sealed {
                                                flags[Self::BYTE_OFFSET] &= !Self::REQUIRED_MASK;
                                                flags[Self::BYTE_OFFSET] &= !Self::OPTIONAL_MASK;
                                        }
+
+                                       let last_non_zero_byte = flags.iter().rposition(|&byte| byte != 0);
+                                       let size = if let Some(offset) = last_non_zero_byte { offset + 1 } else { 0 };
+                                       flags.resize(size, 0u8);
                                }
                        }
 
@@ -219,6 +240,30 @@ mod sealed {
                "Feature flags for `payment_secret`.");
        define_feature!(17, BasicMPP, [InitContext, NodeContext],
                "Feature flags for `basic_mpp`.");
+
+       #[cfg(test)]
+       define_context!(TestingContext {
+               required_features: [
+                       // Byte 0
+                       ,
+                       // Byte 1
+                       ,
+                       // Byte 2
+                       UnknownFeature,
+               ],
+               optional_features: [
+                       // Byte 0
+                       ,
+                       // Byte 1
+                       ,
+                       // Byte 2
+                       ,
+               ],
+       });
+
+       #[cfg(test)]
+       define_feature!(23, UnknownFeature, [TestingContext],
+               "Feature flags for an unknown feature used in testing.");
 }
 
 /// Tracks the set of features which a node implements, templated by the context in which it
@@ -282,6 +327,12 @@ impl InitFeatures {
                }
                self
        }
+
+       /// Converts `InitFeatures` to `Features<C>`. Only known `InitFeatures` relevant to context `C`
+       /// are included in the result.
+       pub(crate) fn to_context<C: sealed::Context>(&self) -> Features<C> {
+               self.to_context_internal()
+       }
 }
 
 impl<T: sealed::Context> Features<T> {
@@ -303,18 +354,19 @@ impl<T: sealed::Context> Features<T> {
                }
        }
 
-       /// Takes the flags that we know how to interpret in an init-context features that are also
-       /// relevant in a node-context features and creates a node-context features from them.
-       /// Be sure to blank out features that are unknown to us.
-       pub(crate) fn with_known_relevant_init_flags(init_ctx: &InitFeatures) -> Self {
-               let byte_count = T::KNOWN_FEATURE_MASK.len();
+       /// Converts `Features<T>` to `Features<C>`. Only known `T` features relevant to context `C` are
+       /// included in the result.
+       fn to_context_internal<C: sealed::Context>(&self) -> Features<C> {
+               let byte_count = C::KNOWN_FEATURE_MASK.len();
                let mut flags = Vec::new();
-               for (i, feature_byte) in init_ctx.flags.iter().enumerate() {
+               for (i, byte) in self.flags.iter().enumerate() {
                        if i < byte_count {
-                               flags.push(feature_byte & T::KNOWN_FEATURE_MASK[i]);
+                               let known_source_features = T::KNOWN_FEATURE_MASK[i];
+                               let known_target_features = C::KNOWN_FEATURE_MASK[i];
+                               flags.push(byte & known_source_features & known_target_features);
                        }
                }
-               Self { flags, mark: PhantomData, }
+               Features::<C> { flags, mark: PhantomData, }
        }
 
        #[cfg(test)]
@@ -368,43 +420,51 @@ impl<T: sealed::Context> Features<T> {
        }
 
        #[cfg(test)]
-       pub(crate) fn set_require_unknown_bits(&mut self) {
-               let newlen = cmp::max(3, self.flags.len());
-               self.flags.resize(newlen, 0u8);
-               self.flags[2] |= 0x40;
+       pub(crate) fn set_required_unknown_bits(&mut self) {
+               <sealed::TestingContext as sealed::UnknownFeature>::set_required_bit(&mut self.flags);
        }
 
        #[cfg(test)]
-       pub(crate) fn clear_require_unknown_bits(&mut self) {
-               let newlen = cmp::max(3, self.flags.len());
-               self.flags.resize(newlen, 0u8);
-               self.flags[2] &= !0x40;
-               if self.flags.len() == 3 && self.flags[2] == 0 {
-                       self.flags.resize(2, 0u8);
-               }
-               if self.flags.len() == 2 && self.flags[1] == 0 {
-                       self.flags.resize(1, 0u8);
-               }
+       pub(crate) fn set_optional_unknown_bits(&mut self) {
+               <sealed::TestingContext as sealed::UnknownFeature>::set_optional_bit(&mut self.flags);
+       }
+
+       #[cfg(test)]
+       pub(crate) fn clear_unknown_bits(&mut self) {
+               <sealed::TestingContext as sealed::UnknownFeature>::clear_bits(&mut self.flags);
        }
 }
 
 impl<T: sealed::DataLossProtect> Features<T> {
+       #[cfg(test)]
+       pub(crate) fn requires_data_loss_protect(&self) -> bool {
+               <T as sealed::DataLossProtect>::requires_feature(&self.flags)
+       }
        pub(crate) fn supports_data_loss_protect(&self) -> bool {
                <T as sealed::DataLossProtect>::supports_feature(&self.flags)
        }
 }
 
 impl<T: sealed::UpfrontShutdownScript> Features<T> {
+       #[cfg(test)]
+       pub(crate) fn requires_upfront_shutdown_script(&self) -> bool {
+               <T as sealed::UpfrontShutdownScript>::requires_feature(&self.flags)
+       }
        pub(crate) fn supports_upfront_shutdown_script(&self) -> bool {
                <T as sealed::UpfrontShutdownScript>::supports_feature(&self.flags)
        }
        #[cfg(test)]
-       pub(crate) fn unset_upfront_shutdown_script(&mut self) {
-               <T as sealed::UpfrontShutdownScript>::clear_bits(&mut self.flags)
+       pub(crate) fn clear_upfront_shutdown_script(mut self) -> Self {
+               <T as sealed::UpfrontShutdownScript>::clear_bits(&mut self.flags);
+               self
        }
 }
 
 impl<T: sealed::VariableLengthOnion> Features<T> {
+       #[cfg(test)]
+       pub(crate) fn requires_variable_length_onion(&self) -> bool {
+               <T as sealed::VariableLengthOnion>::requires_feature(&self.flags)
+       }
        pub(crate) fn supports_variable_length_onion(&self) -> bool {
                <T as sealed::VariableLengthOnion>::supports_feature(&self.flags)
        }
@@ -420,16 +480,24 @@ impl<T: sealed::InitialRoutingSync> Features<T> {
 }
 
 impl<T: sealed::PaymentSecret> Features<T> {
-       #[allow(dead_code)]
+       #[cfg(test)]
+       pub(crate) fn requires_payment_secret(&self) -> bool {
+               <T as sealed::PaymentSecret>::requires_feature(&self.flags)
+       }
        // Note that we never need to test this since what really matters is the invoice - iff the
        // invoice provides a payment_secret, we assume that we can use it (ie that the recipient
        // supports payment_secret).
+       #[allow(dead_code)]
        pub(crate) fn supports_payment_secret(&self) -> bool {
                <T as sealed::PaymentSecret>::supports_feature(&self.flags)
        }
 }
 
 impl<T: sealed::BasicMPP> Features<T> {
+       #[cfg(test)]
+       pub(crate) fn requires_basic_mpp(&self) -> bool {
+               <T as sealed::BasicMPP>::requires_feature(&self.flags)
+       }
        // We currently never test for this since we don't actually *generate* multipath routes.
        #[allow(dead_code)]
        pub(crate) fn supports_basic_mpp(&self) -> bool {
@@ -461,10 +529,10 @@ impl<T: sealed::Context> Readable for Features<T> {
 
 #[cfg(test)]
 mod tests {
-       use super::{ChannelFeatures, InitFeatures, NodeFeatures, Features};
+       use super::{ChannelFeatures, InitFeatures, NodeFeatures};
 
        #[test]
-       fn sanity_test_our_features() {
+       fn sanity_test_known_features() {
                assert!(!ChannelFeatures::known().requires_unknown_bits());
                assert!(!ChannelFeatures::known().supports_unknown_bits());
                assert!(!InitFeatures::known().requires_unknown_bits());
@@ -474,18 +542,28 @@ mod tests {
 
                assert!(InitFeatures::known().supports_upfront_shutdown_script());
                assert!(NodeFeatures::known().supports_upfront_shutdown_script());
+               assert!(!InitFeatures::known().requires_upfront_shutdown_script());
+               assert!(!NodeFeatures::known().requires_upfront_shutdown_script());
 
                assert!(InitFeatures::known().supports_data_loss_protect());
                assert!(NodeFeatures::known().supports_data_loss_protect());
+               assert!(!InitFeatures::known().requires_data_loss_protect());
+               assert!(!NodeFeatures::known().requires_data_loss_protect());
 
                assert!(InitFeatures::known().supports_variable_length_onion());
                assert!(NodeFeatures::known().supports_variable_length_onion());
+               assert!(!InitFeatures::known().requires_variable_length_onion());
+               assert!(!NodeFeatures::known().requires_variable_length_onion());
 
                assert!(InitFeatures::known().supports_payment_secret());
                assert!(NodeFeatures::known().supports_payment_secret());
+               assert!(!InitFeatures::known().requires_payment_secret());
+               assert!(!NodeFeatures::known().requires_payment_secret());
 
                assert!(InitFeatures::known().supports_basic_mpp());
                assert!(NodeFeatures::known().supports_basic_mpp());
+               assert!(!InitFeatures::known().requires_basic_mpp());
+               assert!(!NodeFeatures::known().requires_basic_mpp());
 
                let mut init_features = InitFeatures::known();
                assert!(init_features.initial_routing_sync());
@@ -494,35 +572,47 @@ mod tests {
        }
 
        #[test]
-       fn sanity_test_unkown_bits_testing() {
-               let mut features = ChannelFeatures::known();
-               features.set_require_unknown_bits();
+       fn sanity_test_unknown_bits() {
+               let mut features = ChannelFeatures::empty();
+               assert!(!features.requires_unknown_bits());
+               assert!(!features.supports_unknown_bits());
+
+               features.set_required_unknown_bits();
                assert!(features.requires_unknown_bits());
-               features.clear_require_unknown_bits();
+               assert!(features.supports_unknown_bits());
+
+               features.clear_unknown_bits();
                assert!(!features.requires_unknown_bits());
+               assert!(!features.supports_unknown_bits());
+
+               features.set_optional_unknown_bits();
+               assert!(!features.requires_unknown_bits());
+               assert!(features.supports_unknown_bits());
        }
 
        #[test]
-       fn test_node_with_known_relevant_init_flags() {
-               // Create an InitFeatures with initial_routing_sync supported.
-               let init_features = InitFeatures::known();
+       fn convert_to_context_with_relevant_flags() {
+               let init_features = InitFeatures::known().clear_upfront_shutdown_script();
                assert!(init_features.initial_routing_sync());
+               assert!(!init_features.supports_upfront_shutdown_script());
 
-               // Attempt to pull out non-node-context feature flags from these InitFeatures.
-               let res = NodeFeatures::with_known_relevant_init_flags(&init_features);
-
+               let node_features: NodeFeatures = init_features.to_context();
                {
-                       // Check that the flags are as expected: optional_data_loss_protect,
-                       // option_upfront_shutdown_script, var_onion_optin, payment_secret, and
-                       // basic_mpp.
-                       assert_eq!(res.flags.len(), 3);
-                       assert_eq!(res.flags[0], 0b00100010);
-                       assert_eq!(res.flags[1], 0b10000010);
-                       assert_eq!(res.flags[2], 0b00000010);
+                       // Check that the flags are as expected:
+                       // - option_data_loss_protect
+                       // - var_onion_optin | payment_secret
+                       // - basic_mpp
+                       assert_eq!(node_features.flags.len(), 3);
+                       assert_eq!(node_features.flags[0], 0b00000010);
+                       assert_eq!(node_features.flags[1], 0b10000010);
+                       assert_eq!(node_features.flags[2], 0b00000010);
                }
 
-               // Check that the initial_routing_sync feature was correctly blanked out.
-               let new_features: InitFeatures = Features::from_le_bytes(res.flags);
-               assert!(!new_features.initial_routing_sync());
+               // Check that cleared flags are kept blank when converting back:
+               // - initial_routing_sync was not applicable to NodeContext
+               // - upfront_shutdown_script was cleared before converting
+               let features: InitFeatures = node_features.to_context_internal();
+               assert!(!features.initial_routing_sync());
+               assert!(!features.supports_upfront_shutdown_script());
        }
 }