]> git.bitcoin.ninja Git - rust-lightning/blobdiff - lightning/src/ln/msgs.rs
Refactor features a bit more to describe what the constructors do
[rust-lightning] / lightning / src / ln / msgs.rs
index e48080103bd1b1642c390ae3eb8e5ea1fd712a05..f051b555a9aaf1a279f5b56f1137159daf886cd8 100644 (file)
@@ -105,6 +105,7 @@ impl FeatureContextInitNode for FeatureContextNode {}
 /// appears.
 pub struct Features<T: FeatureContext> {
        #[cfg(not(test))]
+       /// Note that, for convinience, flags is LITTLE endian (despite being big-endian on the wire)
        flags: Vec<u8>,
        // Used to test encoding of diverse msgs
        #[cfg(test)]
@@ -139,16 +140,16 @@ pub type NodeFeatures = Features<FeatureContextNode>;
 pub type ChannelFeatures = Features<FeatureContextChannel>;
 
 impl<T: FeatureContextInitNode> Features<T> {
-       /// Create a blank Features flags (visibility extended for fuzz tests)
+       /// Create a Features with the features we support
        #[cfg(not(feature = "fuzztarget"))]
-       pub(crate) fn new() -> Features<T> {
+       pub(crate) fn supported() -> Features<T> {
                Features {
                        flags: vec![2 | 1 << 5],
                        mark: PhantomData,
                }
        }
        #[cfg(feature = "fuzztarget")]
-       pub fn new() -> Features<T> {
+       pub fn supported() -> Features<T> {
                Features {
                        flags: vec![2 | 1 << 5],
                        mark: PhantomData,
@@ -157,16 +158,16 @@ impl<T: FeatureContextInitNode> Features<T> {
 }
 
 impl Features<FeatureContextChannel> {
-       /// Create a blank Features flags (visibility extended for fuzz tests)
+       /// Create a Features with the features we support
        #[cfg(not(feature = "fuzztarget"))]
-       pub(crate) fn new() -> Features<FeatureContextChannel> {
+       pub(crate) fn supported() -> Features<FeatureContextChannel> {
                Features {
                        flags: Vec::new(),
                        mark: PhantomData,
                }
        }
        #[cfg(feature = "fuzztarget")]
-       pub fn new() -> Features<FeatureContextChannel> {
+       pub fn supported() -> Features<FeatureContextChannel> {
                Features {
                        flags: Vec::new(),
                        mark: PhantomData,
@@ -175,6 +176,14 @@ impl Features<FeatureContextChannel> {
 }
 
 impl<T: FeatureContext> Features<T> {
+       /// Create a blank Features with no fetures set
+       pub fn empty() -> Features<T> {
+               Features {
+                       flags: Vec::new(),
+                       mark: PhantomData,
+               }
+       }
+
        pub(crate) fn requires_unknown_bits(&self) -> bool {
                self.flags.iter().enumerate().any(|(idx, &byte)| {
                        ( idx != 0 && (byte & 0x55) != 0 ) || ( idx == 0 && (byte & 0x14) != 0 )
@@ -192,6 +201,23 @@ impl<T: FeatureContext> Features<T> {
        pub(crate) fn byte_count(&self) -> usize {
                self.flags.len()
        }
+
+       #[cfg(test)]
+       pub(crate) fn set_require_unknown_bits(&mut self) {
+               let newlen = cmp::max(2, self.flags.len());
+               self.flags.resize(newlen, 0u8);
+               self.flags[1] |= 0x40;
+       }
+
+       #[cfg(test)]
+       pub(crate) fn clear_require_unknown_bits(&mut self) {
+               let newlen = cmp::max(2, self.flags.len());
+               self.flags.resize(newlen, 0u8);
+               self.flags[1] &= !0x40;
+               if self.flags.len() == 2 && self.flags[1] == 0 {
+                       self.flags.resize(1, 0u8);
+               }
+       }
 }
 
 impl<T: FeatureContextInitNode> Features<T> {
@@ -248,18 +274,25 @@ impl Features<FeatureContextInit> {
 impl<T: FeatureContext> Writeable for Features<T> {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
                w.size_hint(self.flags.len() + 2);
-               self.flags.write(w)
+               (self.flags.len() as u16).write(w)?;
+               for f in self.flags.iter().rev() { // We have to swap the endianness back to BE for writing
+                       f.write(w)?;
+               }
+               Ok(())
        }
 }
 
 impl<R: ::std::io::Read, T: FeatureContext> Readable<R> for Features<T> {
        fn read(r: &mut R) -> Result<Self, DecodeError> {
+               let mut flags: Vec<u8> = Readable::read(r)?;
+               flags.reverse(); // Swap to big-endian
                Ok(Self {
-                       flags: Readable::read(r)?,
+                       flags,
                        mark: PhantomData,
                })
        }
 }
+
 /// An init message to be sent or received from a peer
 pub struct Init {
        pub(crate) features: InitFeatures,
@@ -1272,13 +1305,7 @@ impl Writeable for UnsignedChannelAnnouncement {
 impl<R: Read> Readable<R> for UnsignedChannelAnnouncement {
        fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(Self {
-                       features: {
-                               let f: ChannelFeatures = Readable::read(r)?;
-                               if f.requires_unknown_bits() {
-                                       return Err(DecodeError::UnknownRequiredFeature);
-                               }
-                               f
-                       },
+                       features: Readable::read(r)?,
                        chain_hash: Readable::read(r)?,
                        short_channel_id: Readable::read(r)?,
                        node_id_1: Readable::read(r)?,
@@ -1406,9 +1433,6 @@ impl Writeable for UnsignedNodeAnnouncement {
 impl<R: Read> Readable<R> for UnsignedNodeAnnouncement {
        fn read(r: &mut R) -> Result<Self, DecodeError> {
                let features: NodeFeatures = Readable::read(r)?;
-               if features.requires_unknown_bits() {
-                       return Err(DecodeError::UnknownRequiredFeature);
-               }
                let timestamp: u32 = Readable::read(r)?;
                let node_id: PublicKey = Readable::read(r)?;
                let mut rgb = [0; 3];
@@ -1601,7 +1625,7 @@ mod tests {
                let sig_2 = get_sig_on!(privkey_2, secp_ctx, String::from("01010101010101010101010101010101"));
                let sig_3 = get_sig_on!(privkey_3, secp_ctx, String::from("01010101010101010101010101010101"));
                let sig_4 = get_sig_on!(privkey_4, secp_ctx, String::from("01010101010101010101010101010101"));
-               let mut features = ChannelFeatures::new();
+               let mut features = ChannelFeatures::supported();
                if unknown_features_bits {
                        features.flags = vec![0xFF, 0xFF];
                }
@@ -1657,9 +1681,12 @@ mod tests {
                let secp_ctx = Secp256k1::new();
                let (privkey_1, pubkey_1) = get_keys_from!("0101010101010101010101010101010101010101010101010101010101010101", secp_ctx);
                let sig_1 = get_sig_on!(privkey_1, secp_ctx, String::from("01010101010101010101010101010101"));
-               let mut features = NodeFeatures::new();
+               let mut features = NodeFeatures::empty();
                if unknown_features_bits {
                        features.flags = vec![0xFF, 0xFF];
+               } else {
+                       // Set to some features we may support
+                       features.flags = vec![2 | 1 << 5];
                }
                let mut addresses = Vec::new();
                if ipv4 {