Generalize with_known_relevant_init_flags
[rust-lightning] / lightning / src / ln / features.rs
index fc2526fe55842bf41022cd67d88fba50a34aa672..381297703e882a08cce4577e01d5bd3f7a9a87af 100644 (file)
@@ -282,32 +282,11 @@ impl InitFeatures {
                }
                self
        }
-}
 
-impl ChannelFeatures {
-       /// Takes the flags that we know how to interpret in an init-context features that are also
-       /// relevant in a channel-context features and creates a channel-context features from them.
-       pub(crate) fn with_known_relevant_init_flags(_init_ctx: &InitFeatures) -> Self {
-               // There are currently no channel flags defined that we understand.
-               Self { flags: Vec::new(), mark: PhantomData, }
-       }
-}
-
-impl NodeFeatures {
-       /// Takes the flags that we know how to interpret in an init-context features that are also
-       /// relevant in a node-context features and creates a node-context features from them.
-       /// Be sure to blank out features that are unknown to us.
-       pub(crate) fn with_known_relevant_init_flags(init_ctx: &InitFeatures) -> Self {
-               use ln::features::sealed::Context;
-               let byte_count = sealed::NodeContext::KNOWN_FEATURE_MASK.len();
-
-               let mut flags = Vec::new();
-               for (i, feature_byte) in init_ctx.flags.iter().enumerate() {
-                       if i < byte_count {
-                               flags.push(feature_byte & sealed::NodeContext::KNOWN_FEATURE_MASK[i]);
-                       }
-               }
-               Self { flags, mark: PhantomData, }
+       /// Converts `InitFeatures` to `Features<C>`. Only known `InitFeatures` relevant to context `C`
+       /// are included in the result.
+       pub(crate) fn to_context<C: sealed::Context>(&self) -> Features<C> {
+               self.to_context_internal()
        }
 }
 
@@ -330,6 +309,21 @@ impl<T: sealed::Context> Features<T> {
                }
        }
 
+       /// Converts `Features<T>` to `Features<C>`. Only known `T` features relevant to context `C` are
+       /// included in the result.
+       fn to_context_internal<C: sealed::Context>(&self) -> Features<C> {
+               let byte_count = C::KNOWN_FEATURE_MASK.len();
+               let mut flags = Vec::new();
+               for (i, byte) in self.flags.iter().enumerate() {
+                       if i < byte_count {
+                               let known_source_features = T::KNOWN_FEATURE_MASK[i];
+                               let known_target_features = C::KNOWN_FEATURE_MASK[i];
+                               flags.push(byte & known_source_features & known_target_features);
+                       }
+               }
+               Features::<C> { flags, mark: PhantomData, }
+       }
+
        #[cfg(test)]
        /// Create a Features given a set of flags, in LE.
        pub fn from_le_bytes(flags: Vec<u8>) -> Features<T> {
@@ -346,15 +340,13 @@ impl<T: sealed::Context> Features<T> {
        }
 
        pub(crate) fn requires_unknown_bits(&self) -> bool {
-               use ln::features::sealed::Context;
-               let byte_count = sealed::InitContext::KNOWN_FEATURE_MASK.len();
-
-               // Bitwise AND-ing with all even bits set except for known features will select unknown
-               // required features.
+               // Bitwise AND-ing with all even bits set except for known features will select required
+               // unknown features.
+               let byte_count = T::KNOWN_FEATURE_MASK.len();
                self.flags.iter().enumerate().any(|(i, &byte)| {
                        let required_features = 0b01_01_01_01;
                        let unknown_features = if i < byte_count {
-                               !sealed::InitContext::KNOWN_FEATURE_MASK[i]
+                               !T::KNOWN_FEATURE_MASK[i]
                        } else {
                                0b11_11_11_11
                        };
@@ -363,14 +355,12 @@ impl<T: sealed::Context> Features<T> {
        }
 
        pub(crate) fn supports_unknown_bits(&self) -> bool {
-               use ln::features::sealed::Context;
-               let byte_count = sealed::InitContext::KNOWN_FEATURE_MASK.len();
-
                // Bitwise AND-ing with all even and odd bits set except for known features will select
-               // unknown features.
+               // both required and optional unknown features.
+               let byte_count = T::KNOWN_FEATURE_MASK.len();
                self.flags.iter().enumerate().any(|(i, &byte)| {
                        let unknown_features = if i < byte_count {
-                               !sealed::InitContext::KNOWN_FEATURE_MASK[i]
+                               !T::KNOWN_FEATURE_MASK[i]
                        } else {
                                0b11_11_11_11
                        };
@@ -416,8 +406,9 @@ impl<T: sealed::UpfrontShutdownScript> Features<T> {
                <T as sealed::UpfrontShutdownScript>::supports_feature(&self.flags)
        }
        #[cfg(test)]
-       pub(crate) fn unset_upfront_shutdown_script(&mut self) {
-               <T as sealed::UpfrontShutdownScript>::clear_bits(&mut self.flags)
+       pub(crate) fn clear_upfront_shutdown_script(mut self) -> Self {
+               <T as sealed::UpfrontShutdownScript>::clear_bits(&mut self.flags);
+               self
        }
 }
 
@@ -478,7 +469,7 @@ impl<T: sealed::Context> Readable for Features<T> {
 
 #[cfg(test)]
 mod tests {
-       use super::{ChannelFeatures, InitFeatures, NodeFeatures, Features};
+       use super::{ChannelFeatures, InitFeatures, NodeFeatures};
 
        #[test]
        fn sanity_test_our_features() {
@@ -520,26 +511,28 @@ mod tests {
        }
 
        #[test]
-       fn test_node_with_known_relevant_init_flags() {
-               // Create an InitFeatures with initial_routing_sync supported.
-               let init_features = InitFeatures::known();
+       fn convert_to_context_with_relevant_flags() {
+               let init_features = InitFeatures::known().clear_upfront_shutdown_script();
                assert!(init_features.initial_routing_sync());
+               assert!(!init_features.supports_upfront_shutdown_script());
 
-               // Attempt to pull out non-node-context feature flags from these InitFeatures.
-               let res = NodeFeatures::with_known_relevant_init_flags(&init_features);
-
+               let node_features: NodeFeatures = init_features.to_context();
                {
-                       // Check that the flags are as expected: optional_data_loss_protect,
-                       // option_upfront_shutdown_script, var_onion_optin, payment_secret, and
-                       // basic_mpp.
-                       assert_eq!(res.flags.len(), 3);
-                       assert_eq!(res.flags[0], 0b00100010);
-                       assert_eq!(res.flags[1], 0b10000010);
-                       assert_eq!(res.flags[2], 0b00000010);
+                       // Check that the flags are as expected:
+                       // - option_data_loss_protect
+                       // - var_onion_optin | payment_secret
+                       // - basic_mpp
+                       assert_eq!(node_features.flags.len(), 3);
+                       assert_eq!(node_features.flags[0], 0b00000010);
+                       assert_eq!(node_features.flags[1], 0b10000010);
+                       assert_eq!(node_features.flags[2], 0b00000010);
                }
 
-               // Check that the initial_routing_sync feature was correctly blanked out.
-               let new_features: InitFeatures = Features::from_le_bytes(res.flags);
-               assert!(!new_features.initial_routing_sync());
+               // Check that cleared flags are kept blank when converting back:
+               // - initial_routing_sync was not applicable to NodeContext
+               // - upfront_shutdown_script was cleared before converting
+               let features: InitFeatures = node_features.to_context_internal();
+               assert!(!features.initial_routing_sync());
+               assert!(!features.supports_upfront_shutdown_script());
        }
 }