Set initial_routing_sync in InitFeatures
[rust-lightning] / lightning / src / ln / features.rs
index 61a862c0c55c71bb62975fc8dd4360a1e3f7f5bf..f6912662497fb33f712838f40c907b8cf65e4235 100644 (file)
@@ -33,6 +33,14 @@ mod sealed { // You should just use the type aliases instead.
        pub trait VariableLengthOnion: Context {}
        impl VariableLengthOnion for InitContext {}
        impl VariableLengthOnion for NodeContext {}
+
+       pub trait PaymentSecret: Context {}
+       impl PaymentSecret for InitContext {}
+       impl PaymentSecret for NodeContext {}
+
+       pub trait BasicMPP: Context {}
+       impl BasicMPP for InitContext {}
+       impl BasicMPP for NodeContext {}
 }
 
 /// Tracks the set of features which a node implements, templated by the context in which it
@@ -73,7 +81,7 @@ impl InitFeatures {
        /// Create a Features with the features we support
        pub fn supported() -> InitFeatures {
                InitFeatures {
-                       flags: vec![2 | 1 << 5, 1 << (9-8)],
+                       flags: vec![2 | 1 << 3 | 1 << 5, 1 << (9-8) | 1 << (15 - 8), 1 << (17 - 8*2)],
                        mark: PhantomData,
                }
        }
@@ -136,25 +144,33 @@ impl NodeFeatures {
        #[cfg(not(feature = "fuzztarget"))]
        pub(crate) fn supported() -> NodeFeatures {
                NodeFeatures {
-                       flags: vec![2 | 1 << 5, 1 << (9-8)],
+                       flags: vec![2 | 1 << 5, 1 << (9 - 8) | 1 << (15 - 8), 1 << (17 - 8*2)],
                        mark: PhantomData,
                }
        }
        #[cfg(feature = "fuzztarget")]
        pub fn supported() -> NodeFeatures {
                NodeFeatures {
-                       flags: vec![2 | 1 << 5, 1 << (9-8)],
+                       flags: vec![2 | 1 << 5, 1 << (9 - 8) | 1 << (15 - 8), 1 << (17 - 8*2)],
                        mark: PhantomData,
                }
        }
 
        /// Takes the flags that we know how to interpret in an init-context features that are also
        /// relevant in a node-context features and creates a node-context features from them.
+       /// Be sure to blank out features that are unknown to us.
        pub(crate) fn with_known_relevant_init_flags(init_ctx: &InitFeatures) -> Self {
                let mut flags = Vec::new();
-               if init_ctx.flags.len() > 0 {
-                       // Pull out data_loss_protect and upfront_shutdown_script (bits 0, 1, 4, and 5)
-                       flags.push(init_ctx.flags.last().unwrap() & 0b00110011);
+               for (i, feature_byte)in init_ctx.flags.iter().enumerate() {
+                       match i {
+                               // Blank out initial_routing_sync (feature bits 2/3), gossip_queries (6/7),
+                               // gossip_queries_ex (10/11), option_static_remotekey (12/13), and
+                               // option_support_large_channel (16/17)
+                               0 => flags.push(feature_byte & 0b00110011),
+                               1 => flags.push(feature_byte & 0b11000011),
+                               2 => flags.push(feature_byte & 0b00000011),
+                               _ => (),
+                       }
                }
                Self { flags, mark: PhantomData, }
        }
@@ -192,8 +208,10 @@ impl<T: sealed::Context> Features<T> {
                                // unknown, upfront_shutdown_script, unknown (actually initial_routing_sync, but it
                                // is only valid as an optional feature), and data_loss_protect:
                                0 => (byte & 0b01000100),
-                               // unknown, unknown, unknown, var_onion_optin:
-                               1 => (byte & 0b01010100),
+                               // payment_secret, unknown, unknown, var_onion_optin:
+                               1 => (byte & 0b00010100),
+                               // unknown, unknown, unknown, basic_mpp:
+                               2 => (byte & 0b01010100),
                                // fallback, all even bits set:
                                _ => (byte & 0b01010101),
                        }) != 0
@@ -206,8 +224,10 @@ impl<T: sealed::Context> Features<T> {
                                // unknown, upfront_shutdown_script, initial_routing_sync (is only valid as an
                                // optional feature), and data_loss_protect:
                                0 => (byte & 0b11000100),
-                               // unknown, unknown, unknown, var_onion_optin:
-                               1 => (byte & 0b11111100),
+                               // payment_secret, unknown, unknown, var_onion_optin:
+                               1 => (byte & 0b00111100),
+                               // unknown, unknown, unknown, basic_mpp:
+                               2 => (byte & 0b11111100),
                                _ => byte,
                        }) != 0
                })
@@ -221,16 +241,19 @@ impl<T: sealed::Context> Features<T> {
 
        #[cfg(test)]
        pub(crate) fn set_require_unknown_bits(&mut self) {
-               let newlen = cmp::max(2, self.flags.len());
+               let newlen = cmp::max(3, self.flags.len());
                self.flags.resize(newlen, 0u8);
-               self.flags[1] |= 0x40;
+               self.flags[2] |= 0x40;
        }
 
        #[cfg(test)]
        pub(crate) fn clear_require_unknown_bits(&mut self) {
-               let newlen = cmp::max(2, self.flags.len());
+               let newlen = cmp::max(3, self.flags.len());
                self.flags.resize(newlen, 0u8);
-               self.flags[1] &= !0x40;
+               self.flags[2] &= !0x40;
+               if self.flags.len() == 3 && self.flags[2] == 0 {
+                       self.flags.resize(2, 0u8);
+               }
                if self.flags.len() == 2 && self.flags[1] == 0 {
                        self.flags.resize(1, 0u8);
                }
@@ -249,7 +272,7 @@ impl<T: sealed::UpfrontShutdownScript> Features<T> {
        }
        #[cfg(test)]
        pub(crate) fn unset_upfront_shutdown_script(&mut self) {
-               self.flags[0] ^= 1 << 5;
+               self.flags[0] &= !(1 << 5);
        }
 }
 
@@ -263,15 +286,31 @@ impl<T: sealed::InitialRoutingSync> Features<T> {
        pub(crate) fn initial_routing_sync(&self) -> bool {
                self.flags.len() > 0 && (self.flags[0] & (1 << 3)) != 0
        }
-       pub(crate) fn set_initial_routing_sync(&mut self) {
-               if self.flags.len() == 0 {
-                       self.flags.resize(1, 1 << 3);
-               } else {
-                       self.flags[0] |= 1 << 3;
+       pub(crate) fn clear_initial_routing_sync(&mut self) {
+               if self.flags.len() > 0 {
+                       self.flags[0] &= !(1 << 3);
                }
        }
 }
 
+impl<T: sealed::PaymentSecret> Features<T> {
+       #[allow(dead_code)]
+       // Note that we never need to test this since what really matters is the invoice - iff the
+       // invoice provides a payment_secret, we assume that we can use it (ie that the recipient
+       // supports payment_secret).
+       pub(crate) fn supports_payment_secret(&self) -> bool {
+               self.flags.len() > 1 && (self.flags[1] & (3 << (14-8))) != 0
+       }
+}
+
+impl<T: sealed::BasicMPP> Features<T> {
+       // We currently never test for this since we don't actually *generate* multipath routes.
+       #[allow(dead_code)]
+       pub(crate) fn supports_basic_mpp(&self) -> bool {
+               self.flags.len() > 2 && (self.flags[2] & (3 << (16-8*2))) != 0
+       }
+}
+
 impl<T: sealed::Context> Writeable for Features<T> {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
                w.size_hint(self.flags.len() + 2);
@@ -283,8 +322,8 @@ impl<T: sealed::Context> Writeable for Features<T> {
        }
 }
 
-impl<R: ::std::io::Read, T: sealed::Context> Readable<R> for Features<T> {
-       fn read(r: &mut R) -> Result<Self, DecodeError> {
+impl<T: sealed::Context> Readable for Features<T> {
+       fn read<R: ::std::io::Read>(r: &mut R) -> Result<Self, DecodeError> {
                let mut flags: Vec<u8> = Readable::read(r)?;
                flags.reverse(); // Swap to little-endian
                Ok(Self {
@@ -296,7 +335,7 @@ impl<R: ::std::io::Read, T: sealed::Context> Readable<R> for Features<T> {
 
 #[cfg(test)]
 mod tests {
-       use super::{ChannelFeatures, InitFeatures, NodeFeatures};
+       use super::{ChannelFeatures, InitFeatures, NodeFeatures, Features};
 
        #[test]
        fn sanity_test_our_features() {
@@ -316,10 +355,16 @@ mod tests {
                assert!(InitFeatures::supported().supports_variable_length_onion());
                assert!(NodeFeatures::supported().supports_variable_length_onion());
 
+               assert!(InitFeatures::supported().supports_payment_secret());
+               assert!(NodeFeatures::supported().supports_payment_secret());
+
+               assert!(InitFeatures::supported().supports_basic_mpp());
+               assert!(NodeFeatures::supported().supports_basic_mpp());
+
                let mut init_features = InitFeatures::supported();
-               init_features.set_initial_routing_sync();
-               assert!(!init_features.requires_unknown_bits());
-               assert!(!init_features.supports_unknown_bits());
+               assert!(init_features.initial_routing_sync());
+               init_features.clear_initial_routing_sync();
+               assert!(!init_features.initial_routing_sync());
        }
 
        #[test]
@@ -330,4 +375,28 @@ mod tests {
                features.clear_require_unknown_bits();
                assert!(!features.requires_unknown_bits());
        }
+
+       #[test]
+       fn test_node_with_known_relevant_init_flags() {
+               // Create an InitFeatures with initial_routing_sync supported.
+               let init_features = InitFeatures::supported();
+               assert!(init_features.initial_routing_sync());
+
+               // Attempt to pull out non-node-context feature flags from these InitFeatures.
+               let res = NodeFeatures::with_known_relevant_init_flags(&init_features);
+
+               {
+                       // Check that the flags are as expected: optional_data_loss_protect,
+                       // option_upfront_shutdown_script, var_onion_optin, payment_secret, and
+                       // basic_mpp.
+                       assert_eq!(res.flags.len(), 3);
+                       assert_eq!(res.flags[0], 0b00100010);
+                       assert_eq!(res.flags[1], 0b10000010);
+                       assert_eq!(res.flags[2], 0b00000010);
+               }
+
+               // Check that the initial_routing_sync feature was correctly blanked out.
+               let new_features: InitFeatures = Features::from_le_bytes(res.flags);
+               assert!(!new_features.initial_routing_sync());
+       }
 }