Add some basic sanity tests for feature flags
[rust-lightning] / lightning / src / ln / features.rs
index eee534f71087215ff1f00e71be32e314ab44d623..61a862c0c55c71bb62975fc8dd4360a1e3f7f5bf 100644 (file)
@@ -29,6 +29,10 @@ mod sealed { // You should just use the type aliases instead.
        pub trait UpfrontShutdownScript: Context {}
        impl UpfrontShutdownScript for InitContext {}
        impl UpfrontShutdownScript for NodeContext {}
+
+       pub trait VariableLengthOnion: Context {}
+       impl VariableLengthOnion for InitContext {}
+       impl VariableLengthOnion for NodeContext {}
 }
 
 /// Tracks the set of features which a node implements, templated by the context in which it
@@ -69,7 +73,7 @@ impl InitFeatures {
        /// Create a Features with the features we support
        pub fn supported() -> InitFeatures {
                InitFeatures {
-                       flags: vec![2 | 1 << 5],
+                       flags: vec![2 | 1 << 5, 1 << (9-8)],
                        mark: PhantomData,
                }
        }
@@ -132,14 +136,14 @@ impl NodeFeatures {
        #[cfg(not(feature = "fuzztarget"))]
        pub(crate) fn supported() -> NodeFeatures {
                NodeFeatures {
-                       flags: vec![2 | 1 << 5],
+                       flags: vec![2 | 1 << 5, 1 << (9-8)],
                        mark: PhantomData,
                }
        }
        #[cfg(feature = "fuzztarget")]
        pub fn supported() -> NodeFeatures {
                NodeFeatures {
-                       flags: vec![2 | 1 << 5],
+                       flags: vec![2 | 1 << 5, 1 << (9-8)],
                        mark: PhantomData,
                }
        }
@@ -182,13 +186,30 @@ impl<T: sealed::Context> Features<T> {
 
        pub(crate) fn requires_unknown_bits(&self) -> bool {
                self.flags.iter().enumerate().any(|(idx, &byte)| {
-                       ( idx != 0 && (byte & 0x55) != 0 ) || ( idx == 0 && (byte & 0x14) != 0 )
+                       (match idx {
+                               // Unknown bits are even bits which we don't understand, we list ones which we do
+                               // here:
+                               // unknown, upfront_shutdown_script, unknown (actually initial_routing_sync, but it
+                               // is only valid as an optional feature), and data_loss_protect:
+                               0 => (byte & 0b01000100),
+                               // unknown, unknown, unknown, var_onion_optin:
+                               1 => (byte & 0b01010100),
+                               // fallback, all even bits set:
+                               _ => (byte & 0b01010101),
+                       }) != 0
                })
        }
 
        pub(crate) fn supports_unknown_bits(&self) -> bool {
                self.flags.iter().enumerate().any(|(idx, &byte)| {
-                       ( idx != 0 && byte != 0 ) || ( idx == 0 && (byte & 0xc4) != 0 )
+                       (match idx {
+                               // unknown, upfront_shutdown_script, initial_routing_sync (is only valid as an
+                               // optional feature), and data_loss_protect:
+                               0 => (byte & 0b11000100),
+                               // unknown, unknown, unknown, var_onion_optin:
+                               1 => (byte & 0b11111100),
+                               _ => byte,
+                       }) != 0
                })
        }
 
@@ -232,6 +253,12 @@ impl<T: sealed::UpfrontShutdownScript> Features<T> {
        }
 }
 
+impl<T: sealed::VariableLengthOnion> Features<T> {
+       pub(crate) fn supports_variable_length_onion(&self) -> bool {
+               self.flags.len() > 1 && (self.flags[1] & 3) != 0
+       }
+}
+
 impl<T: sealed::InitialRoutingSync> Features<T> {
        pub(crate) fn initial_routing_sync(&self) -> bool {
                self.flags.len() > 0 && (self.flags[0] & (1 << 3)) != 0
@@ -266,3 +293,41 @@ impl<R: ::std::io::Read, T: sealed::Context> Readable<R> for Features<T> {
                })
        }
 }
+
+#[cfg(test)]
+mod tests {
+       use super::{ChannelFeatures, InitFeatures, NodeFeatures};
+
+       #[test]
+       fn sanity_test_our_features() {
+               assert!(!ChannelFeatures::supported().requires_unknown_bits());
+               assert!(!ChannelFeatures::supported().supports_unknown_bits());
+               assert!(!InitFeatures::supported().requires_unknown_bits());
+               assert!(!InitFeatures::supported().supports_unknown_bits());
+               assert!(!NodeFeatures::supported().requires_unknown_bits());
+               assert!(!NodeFeatures::supported().supports_unknown_bits());
+
+               assert!(InitFeatures::supported().supports_upfront_shutdown_script());
+               assert!(NodeFeatures::supported().supports_upfront_shutdown_script());
+
+               assert!(InitFeatures::supported().supports_data_loss_protect());
+               assert!(NodeFeatures::supported().supports_data_loss_protect());
+
+               assert!(InitFeatures::supported().supports_variable_length_onion());
+               assert!(NodeFeatures::supported().supports_variable_length_onion());
+
+               let mut init_features = InitFeatures::supported();
+               init_features.set_initial_routing_sync();
+               assert!(!init_features.requires_unknown_bits());
+               assert!(!init_features.supports_unknown_bits());
+       }
+
+       #[test]
+       fn sanity_test_unkown_bits_testing() {
+               let mut features = ChannelFeatures::supported();
+               features.set_require_unknown_bits();
+               assert!(features.requires_unknown_bits());
+               features.clear_require_unknown_bits();
+               assert!(!features.requires_unknown_bits());
+       }
+}