Test that we don't forget to track any outputs at monitor-load
[rust-lightning] / lightning / src / ln / features.rs
index d7e0299e8876143b09a2052d6380c842a257b827..641c1ffb6f6230fcd0ec6b8f294a78a15c66be1b 100644 (file)
@@ -33,6 +33,14 @@ mod sealed { // You should just use the type aliases instead.
        pub trait VariableLengthOnion: Context {}
        impl VariableLengthOnion for InitContext {}
        impl VariableLengthOnion for NodeContext {}
+
+       pub trait PaymentSecret: Context {}
+       impl PaymentSecret for InitContext {}
+       impl PaymentSecret for NodeContext {}
+
+       pub trait BasicMPP: Context {}
+       impl BasicMPP for InitContext {}
+       impl BasicMPP for NodeContext {}
 }
 
 /// Tracks the set of features which a node implements, templated by the context in which it
@@ -73,7 +81,7 @@ impl InitFeatures {
        /// Create a Features with the features we support
        pub fn supported() -> InitFeatures {
                InitFeatures {
-                       flags: vec![2 | 1 << 5, 1 << (9-8)],
+                       flags: vec![2 | 1 << 5, 1 << (9-8) | 1 << (15 - 8), 1 << (17 - 8*2)],
                        mark: PhantomData,
                }
        }
@@ -136,14 +144,14 @@ impl NodeFeatures {
        #[cfg(not(feature = "fuzztarget"))]
        pub(crate) fn supported() -> NodeFeatures {
                NodeFeatures {
-                       flags: vec![2 | 1 << 5, 1 << (9-8)],
+                       flags: vec![2 | 1 << 5, 1 << (9-8) | 1 << (15 - 8), 1 << (17 - 8*2)],
                        mark: PhantomData,
                }
        }
        #[cfg(feature = "fuzztarget")]
        pub fn supported() -> NodeFeatures {
                NodeFeatures {
-                       flags: vec![2 | 1 << 5, 1 << (9-8)],
+                       flags: vec![2 | 1 << 5, 1 << (9-8) | 1 << (15 - 8), 1 << (17 - 8*2)],
                        mark: PhantomData,
                }
        }
@@ -188,7 +196,8 @@ impl<T: sealed::Context> Features<T> {
                self.flags.iter().enumerate().any(|(idx, &byte)| {
                        (match idx {
                                0 => (byte & 0b01000100),
-                               1 => (byte & 0b01010100),
+                               1 => (byte & 0b00010100),
+                               2 => (byte & 0b01010100),
                                _ => (byte & 0b01010101),
                        }) != 0
                })
@@ -198,7 +207,8 @@ impl<T: sealed::Context> Features<T> {
                self.flags.iter().enumerate().any(|(idx, &byte)| {
                        (match idx {
                                0 => (byte & 0b11000100),
-                               1 => (byte & 0b11111100),
+                               1 => (byte & 0b00111100),
+                               2 => (byte & 0b11111100),
                                _ => byte,
                        }) != 0
                })
@@ -212,16 +222,19 @@ impl<T: sealed::Context> Features<T> {
 
        #[cfg(test)]
        pub(crate) fn set_require_unknown_bits(&mut self) {
-               let newlen = cmp::max(2, self.flags.len());
+               let newlen = cmp::max(3, self.flags.len());
                self.flags.resize(newlen, 0u8);
-               self.flags[1] |= 0x40;
+               self.flags[2] |= 0x40;
        }
 
        #[cfg(test)]
        pub(crate) fn clear_require_unknown_bits(&mut self) {
-               let newlen = cmp::max(2, self.flags.len());
+               let newlen = cmp::max(3, self.flags.len());
                self.flags.resize(newlen, 0u8);
-               self.flags[1] &= !0x40;
+               self.flags[2] &= !0x40;
+               if self.flags.len() == 3 && self.flags[2] == 0 {
+                       self.flags.resize(2, 0u8);
+               }
                if self.flags.len() == 2 && self.flags[1] == 0 {
                        self.flags.resize(1, 0u8);
                }
@@ -263,6 +276,23 @@ impl<T: sealed::InitialRoutingSync> Features<T> {
        }
 }
 
+impl<T: sealed::PaymentSecret> Features<T> {
+       #[allow(dead_code)]
+       // Note that we never need to test this since what really matters is the invoice - iff the
+       // invoice provides a payment_secret, we assume all the way through that we can do MPP.
+       pub(crate) fn payment_secret(&self) -> bool {
+               self.flags.len() > 1 && (self.flags[1] & (3 << (12-8))) != 0
+       }
+}
+
+impl<T: sealed::BasicMPP> Features<T> {
+       // We currently never test for this since we don't actually *generate* multipath routes.
+       #[allow(dead_code)]
+       pub(crate) fn basic_mpp(&self) -> bool {
+               self.flags.len() > 2 && (self.flags[2] & (3 << (16-8*2))) != 0
+       }
+}
+
 impl<T: sealed::Context> Writeable for Features<T> {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
                w.size_hint(self.flags.len() + 2);
@@ -284,3 +314,38 @@ impl<R: ::std::io::Read, T: sealed::Context> Readable<R> for Features<T> {
                })
        }
 }
+
+#[cfg(test)]
+mod tests {
+       use super::{ChannelFeatures, InitFeatures, NodeFeatures};
+
+       #[test]
+       fn sanity_test_our_features() {
+               assert!(!ChannelFeatures::supported().requires_unknown_bits());
+               assert!(!ChannelFeatures::supported().supports_unknown_bits());
+               assert!(!InitFeatures::supported().requires_unknown_bits());
+               assert!(!InitFeatures::supported().supports_unknown_bits());
+               assert!(!NodeFeatures::supported().requires_unknown_bits());
+               assert!(!NodeFeatures::supported().supports_unknown_bits());
+
+               assert!(InitFeatures::supported().supports_upfront_shutdown_script());
+               assert!(NodeFeatures::supported().supports_upfront_shutdown_script());
+
+               assert!(InitFeatures::supported().supports_data_loss_protect());
+               assert!(NodeFeatures::supported().supports_data_loss_protect());
+
+               let mut init_features = InitFeatures::supported();
+               init_features.set_initial_routing_sync();
+               assert!(!init_features.requires_unknown_bits());
+               assert!(!init_features.supports_unknown_bits());
+       }
+
+       #[test]
+       fn sanity_test_unkown_bits_testing() {
+               let mut features = ChannelFeatures::supported();
+               features.set_require_unknown_bits();
+               assert!(features.requires_unknown_bits());
+               features.clear_require_unknown_bits();
+               assert!(!features.requires_unknown_bits());
+       }
+}