Implement Flat Features
[rust-lightning] / lightning / src / ln / msgs.rs
index ea1953eb83cffee3c5a74f48acccb99142289d2b..e48080103bd1b1642c390ae3eb8e5ea1fd712a05 100644 (file)
@@ -25,6 +25,7 @@ use std::error::Error;
 use std::{cmp, fmt};
 use std::io::Read;
 use std::result::Result;
+use std::marker::PhantomData;
 
 use util::events;
 use util::ser::{Readable, Writeable, Writer};
@@ -52,49 +53,128 @@ pub enum DecodeError {
        Io(::std::io::Error),
 }
 
-/// Tracks localfeatures which are only in init messages
-#[derive(Clone, PartialEq)]
-pub struct LocalFeatures {
+/// The context in which a Feature object appears determines which bits of features the node
+/// supports will be set. We use this when creating our own Feature objects to select which bits to
+/// set and when passing around Feature objects to ensure the bits we're checking for are
+/// available.
+///
+/// This Context represents when the Feature appears in the init message, sent between peers and not
+/// rumored around the P2P network.
+pub struct FeatureContextInit {}
+/// The context in which a Feature object appears determines which bits of features the node
+/// supports will be set. We use this when creating our own Feature objects to select which bits to
+/// set and when passing around Feature objects to ensure the bits we're checking for are
+/// available.
+///
+/// This Context represents when the Feature appears in the node_announcement message, as it is
+/// rumored around the P2P network.
+pub struct FeatureContextNode {}
+/// The context in which a Feature object appears determines which bits of features the node
+/// supports will be set. We use this when creating our own Feature objects to select which bits to
+/// set and when passing around Feature objects to ensure the bits we're checking for are
+/// available.
+///
+/// This Context represents when the Feature appears in the ChannelAnnouncement message, as it is
+/// rumored around the P2P network.
+pub struct FeatureContextChannel {}
+/// The context in which a Feature object appears determines which bits of features the node
+/// supports will be set. We use this when creating our own Feature objects to select which bits to
+/// set and when passing around Feature objects to ensure the bits we're checking for are
+/// available.
+///
+/// This Context represents when the Feature appears in an invoice, used to determine the different
+/// options available for routing a payment.
+///
+/// Note that this is currently unused as invoices come to us via a different crate and are not
+/// native to rust-lightning directly.
+pub struct FeatureContextInvoice {}
+
+/// An internal trait capturing the various future context types
+pub trait FeatureContext {}
+impl FeatureContext for FeatureContextInit {}
+impl FeatureContext for FeatureContextNode {}
+impl FeatureContext for FeatureContextChannel {}
+impl FeatureContext for FeatureContextInvoice {}
+
+/// An internal trait capturing FeatureContextInit and FeatureContextNode
+pub trait FeatureContextInitNode : FeatureContext {}
+impl FeatureContextInitNode for FeatureContextInit {}
+impl FeatureContextInitNode for FeatureContextNode {}
+
+/// Tracks the set of features which a node implements, templated by the context in which it
+/// appears.
+pub struct Features<T: FeatureContext> {
+       #[cfg(not(test))]
        flags: Vec<u8>,
+       // Used to test encoding of diverse msgs
+       #[cfg(test)]
+       pub flags: Vec<u8>,
+       mark: PhantomData<T>,
 }
 
-impl LocalFeatures {
-       /// Create a blank LocalFeatures flags (visibility extended for fuzz tests)
+impl<T: FeatureContext> Clone for Features<T> {
+       fn clone(&self) -> Self {
+               Self {
+                       flags: self.flags.clone(),
+                       mark: PhantomData,
+               }
+       }
+}
+impl<T: FeatureContext> PartialEq for Features<T> {
+       fn eq(&self, o: &Self) -> bool {
+               self.flags.eq(&o.flags)
+       }
+}
+impl<T: FeatureContext> fmt::Debug for Features<T> {
+       fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
+               self.flags.fmt(fmt)
+       }
+}
+
+/// A feature message as it appears in an init message
+pub type InitFeatures = Features<FeatureContextInit>;
+/// A feature message as it appears in a node_announcement message
+pub type NodeFeatures = Features<FeatureContextNode>;
+/// A feature message as it appears in a channel_announcement message
+pub type ChannelFeatures = Features<FeatureContextChannel>;
+
+impl<T: FeatureContextInitNode> Features<T> {
+       /// Create a blank Features flags (visibility extended for fuzz tests)
        #[cfg(not(feature = "fuzztarget"))]
-       pub(crate) fn new() -> LocalFeatures {
-               LocalFeatures {
+       pub(crate) fn new() -> Features<T> {
+               Features {
                        flags: vec![2 | 1 << 5],
+                       mark: PhantomData,
                }
        }
        #[cfg(feature = "fuzztarget")]
-       pub fn new() -> LocalFeatures {
-               LocalFeatures {
+       pub fn new() -> Features<T> {
+               Features {
                        flags: vec![2 | 1 << 5],
+                       mark: PhantomData,
                }
        }
+}
 
-       pub(crate) fn supports_data_loss_protect(&self) -> bool {
-               self.flags.len() > 0 && (self.flags[0] & 3) != 0
-       }
-       pub(crate) fn initial_routing_sync(&self) -> bool {
-               self.flags.len() > 0 && (self.flags[0] & (1 << 3)) != 0
-       }
-       pub(crate) fn set_initial_routing_sync(&mut self) {
-               if self.flags.len() == 0 {
-                       self.flags.resize(1, 1 << 3);
-               } else {
-                       self.flags[0] |= 1 << 3;
+impl Features<FeatureContextChannel> {
+       /// Create a blank Features flags (visibility extended for fuzz tests)
+       #[cfg(not(feature = "fuzztarget"))]
+       pub(crate) fn new() -> Features<FeatureContextChannel> {
+               Features {
+                       flags: Vec::new(),
+                       mark: PhantomData,
                }
        }
-
-       pub(crate) fn supports_upfront_shutdown_script(&self) -> bool {
-               self.flags.len() > 0 && (self.flags[0] & (3 << 4)) != 0
-       }
-       #[cfg(test)]
-       pub(crate) fn unset_upfront_shutdown_script(&mut self) {
-               self.flags[0] ^= 1 << 5;
+       #[cfg(feature = "fuzztarget")]
+       pub fn new() -> Features<FeatureContextChannel> {
+               Features {
+                       flags: Vec::new(),
+                       mark: PhantomData,
+               }
        }
+}
 
+impl<T: FeatureContext> Features<T> {
        pub(crate) fn requires_unknown_bits(&self) -> bool {
                self.flags.iter().enumerate().any(|(idx, &byte)| {
                        ( idx != 0 && (byte & 0x55) != 0 ) || ( idx == 0 && (byte & 0x14) != 0 )
@@ -106,48 +186,83 @@ impl LocalFeatures {
                        ( idx != 0 && byte != 0 ) || ( idx == 0 && (byte & 0xc4) != 0 )
                })
        }
+
+       /// The number of bytes required to represent the feaature flags present. This does not include
+       /// the length bytes which are included in the serialized form.
+       pub(crate) fn byte_count(&self) -> usize {
+               self.flags.len()
+       }
 }
 
-/// Tracks globalfeatures which are in init messages and routing announcements
-#[derive(Clone, PartialEq, Debug)]
-pub struct GlobalFeatures {
-       #[cfg(not(test))]
-       flags: Vec<u8>,
-       // Used to test encoding of diverse msgs
+impl<T: FeatureContextInitNode> Features<T> {
+       pub(crate) fn supports_data_loss_protect(&self) -> bool {
+               self.flags.len() > 0 && (self.flags[0] & 3) != 0
+       }
+
+       pub(crate) fn supports_upfront_shutdown_script(&self) -> bool {
+               self.flags.len() > 0 && (self.flags[0] & (3 << 4)) != 0
+       }
        #[cfg(test)]
-       pub flags: Vec<u8>
+       pub(crate) fn unset_upfront_shutdown_script(&mut self) {
+               self.flags[0] ^= 1 << 5;
+       }
 }
 
-impl GlobalFeatures {
-       pub(crate) fn new() -> GlobalFeatures {
-               GlobalFeatures {
-                       flags: Vec::new(),
+impl Features<FeatureContextInit> {
+       pub(crate) fn initial_routing_sync(&self) -> bool {
+               self.flags.len() > 0 && (self.flags[0] & (1 << 3)) != 0
+       }
+       pub(crate) fn set_initial_routing_sync(&mut self) {
+               if self.flags.len() == 0 {
+                       self.flags.resize(1, 1 << 3);
+               } else {
+                       self.flags[0] |= 1 << 3;
                }
        }
 
-       pub(crate) fn requires_unknown_bits(&self) -> bool {
-               for &byte in self.flags.iter() {
-                       if (byte & 0x55) != 0 {
-                               return true;
+       /// Writes all features present up to, and including, 13.
+       pub(crate) fn write_up_to_13<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               let len = cmp::min(2, self.flags.len());
+               w.size_hint(len + 2);
+               (len as u16).write(w)?;
+               for i in (0..len).rev() {
+                       if i == 0 {
+                               self.flags[i].write(w)?;
+                       } else {
+                               (self.flags[i] & ((1 << (14 - 8)) - 1)).write(w)?;
                        }
                }
-               return false;
+               Ok(())
        }
 
-       pub(crate) fn supports_unknown_bits(&self) -> bool {
-               for &byte in self.flags.iter() {
-                       if byte != 0 {
-                               return true;
-                       }
+       /// or's another InitFeatures into this one.
+       pub(crate) fn or(&mut self, o: &InitFeatures) {
+               let total_feature_len = cmp::max(self.flags.len(), o.flags.len());
+               self.flags.resize(total_feature_len, 0u8);
+               for (feature, o_feature) in self.flags.iter_mut().zip(o.flags.iter()) {
+                       *feature |= *o_feature;
                }
-               return false;
        }
 }
 
+impl<T: FeatureContext> Writeable for Features<T> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               w.size_hint(self.flags.len() + 2);
+               self.flags.write(w)
+       }
+}
+
+impl<R: ::std::io::Read, T: FeatureContext> Readable<R> for Features<T> {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
+               Ok(Self {
+                       flags: Readable::read(r)?,
+                       mark: PhantomData,
+               })
+       }
+}
 /// An init message to be sent or received from a peer
 pub struct Init {
-       pub(crate) global_features: GlobalFeatures,
-       pub(crate) local_features: LocalFeatures,
+       pub(crate) features: InitFeatures,
 }
 
 /// An error message to be sent or received from a peer
@@ -461,7 +576,7 @@ impl<R: ::std::io::Read>  Readable<R> for Result<NetAddress, u8> {
 /// The unsigned part of a node_announcement
 #[derive(PartialEq, Clone, Debug)]
 pub struct UnsignedNodeAnnouncement {
-       pub(crate) features: GlobalFeatures,
+       pub(crate) features: NodeFeatures,
        pub(crate) timestamp: u32,
        /// The node_id this announcement originated from (don't rebroadcast the node_announcement back
        /// to this node).
@@ -485,7 +600,7 @@ pub struct NodeAnnouncement {
 /// The unsigned part of a channel_announcement
 #[derive(PartialEq, Clone, Debug)]
 pub struct UnsignedChannelAnnouncement {
-       pub(crate) features: GlobalFeatures,
+       pub(crate) features: ChannelFeatures,
        pub(crate) chain_hash: Sha256dHash,
        pub(crate) short_channel_id: u64,
        /// One of the two node_ids which are endpoints of this channel
@@ -616,9 +731,9 @@ pub enum OptionalField<T> {
 pub trait ChannelMessageHandler : events::MessageSendEventsProvider + Send + Sync {
        //Channel init:
        /// Handle an incoming open_channel message from the given peer.
-       fn handle_open_channel(&self, their_node_id: &PublicKey, their_local_features: LocalFeatures, msg: &OpenChannel);
+       fn handle_open_channel(&self, their_node_id: &PublicKey, their_features: InitFeatures, msg: &OpenChannel);
        /// Handle an incoming accept_channel message from the given peer.
-       fn handle_accept_channel(&self, their_node_id: &PublicKey, their_local_features: LocalFeatures, msg: &AcceptChannel);
+       fn handle_accept_channel(&self, their_node_id: &PublicKey, their_features: InitFeatures, msg: &AcceptChannel);
        /// Handle an incoming funding_created message from the given peer.
        fn handle_funding_created(&self, their_node_id: &PublicKey, msg: &FundingCreated);
        /// Handle an incoming funding_signed message from the given peer.
@@ -918,24 +1033,25 @@ impl_writeable!(FundingLocked, 32+33, {
        next_per_commitment_point
 });
 
-impl_writeable_len_match!(GlobalFeatures, {
-               { GlobalFeatures { ref flags }, flags.len() + 2 }
-       }, {
-       flags
-});
-
-impl_writeable_len_match!(LocalFeatures, {
-               { LocalFeatures { ref flags }, flags.len() + 2 }
-       }, {
-       flags
-});
+impl Writeable for Init {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               // global_features gets the bottom 13 bits of our features, and local_features gets all of
+               // our relevant feature bits. This keeps us compatible with old nodes.
+               self.features.write_up_to_13(w)?;
+               self.features.write(w)
+       }
+}
 
-impl_writeable_len_match!(Init, {
-               { Init { ref global_features, ref local_features }, global_features.flags.len() + local_features.flags.len() + 4 }
-       }, {
-       global_features,
-       local_features
-});
+impl<R: Read> Readable<R> for Init {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
+               let global_features: InitFeatures = Readable::read(r)?;
+               let mut features: InitFeatures = Readable::read(r)?;
+               features.or(&global_features);
+               Ok(Init {
+                       features
+               })
+       }
+}
 
 impl_writeable_len_match!(OpenChannel, {
                { OpenChannel { shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 319 + 2 + script.len() },
@@ -1140,7 +1256,7 @@ impl<R: Read> Readable<R> for Pong {
 
 impl Writeable for UnsignedChannelAnnouncement {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               w.size_hint(2 + 2*32 + 4*33 + self.features.flags.len() + self.excess_data.len());
+               w.size_hint(2 + 2*32 + 4*33 + self.features.byte_count() + self.excess_data.len());
                self.features.write(w)?;
                self.chain_hash.write(w)?;
                self.short_channel_id.write(w)?;
@@ -1157,7 +1273,7 @@ impl<R: Read> Readable<R> for UnsignedChannelAnnouncement {
        fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(Self {
                        features: {
-                               let f: GlobalFeatures = Readable::read(r)?;
+                               let f: ChannelFeatures = Readable::read(r)?;
                                if f.requires_unknown_bits() {
                                        return Err(DecodeError::UnknownRequiredFeature);
                                }
@@ -1180,7 +1296,7 @@ impl<R: Read> Readable<R> for UnsignedChannelAnnouncement {
 
 impl_writeable_len_match!(ChannelAnnouncement, {
                { ChannelAnnouncement { contents: UnsignedChannelAnnouncement {ref features, ref excess_data, ..}, .. },
-                       2 + 2*32 + 4*33 + features.flags.len() + excess_data.len() + 4*64 }
+                       2 + 2*32 + 4*33 + features.byte_count() + excess_data.len() + 4*64 }
        }, {
        node_signature_1,
        node_signature_2,
@@ -1263,7 +1379,7 @@ impl<R: Read> Readable<R> for ErrorMessage {
 
 impl Writeable for UnsignedNodeAnnouncement {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               w.size_hint(64 + 76 + self.features.flags.len() + self.addresses.len()*38 + self.excess_address_data.len() + self.excess_data.len());
+               w.size_hint(64 + 76 + self.features.byte_count() + self.addresses.len()*38 + self.excess_address_data.len() + self.excess_data.len());
                self.features.write(w)?;
                self.timestamp.write(w)?;
                self.node_id.write(w)?;
@@ -1289,7 +1405,7 @@ impl Writeable for UnsignedNodeAnnouncement {
 
 impl<R: Read> Readable<R> for UnsignedNodeAnnouncement {
        fn read(r: &mut R) -> Result<Self, DecodeError> {
-               let features: GlobalFeatures = Readable::read(r)?;
+               let features: NodeFeatures = Readable::read(r)?;
                if features.requires_unknown_bits() {
                        return Err(DecodeError::UnknownRequiredFeature);
                }
@@ -1376,7 +1492,7 @@ impl<R: Read> Readable<R> for UnsignedNodeAnnouncement {
 
 impl_writeable_len_match!(NodeAnnouncement, {
                { NodeAnnouncement { contents: UnsignedNodeAnnouncement { ref features, ref addresses, ref excess_address_data, ref excess_data, ..}, .. },
-                       64 + 76 + features.flags.len() + addresses.len()*38 + excess_address_data.len() + excess_data.len() }
+                       64 + 76 + features.byte_count() + addresses.len()*38 + excess_address_data.len() + excess_data.len() }
        }, {
        signature,
        contents
@@ -1386,7 +1502,7 @@ impl_writeable_len_match!(NodeAnnouncement, {
 mod tests {
        use hex;
        use ln::msgs;
-       use ln::msgs::{GlobalFeatures, LocalFeatures, OptionalField, OnionErrorPacket};
+       use ln::msgs::{ChannelFeatures, InitFeatures, NodeFeatures, OptionalField, OnionErrorPacket};
        use ln::channelmanager::{PaymentPreimage, PaymentHash};
        use util::ser::Writeable;
 
@@ -1400,6 +1516,8 @@ mod tests {
        use secp256k1::key::{PublicKey,SecretKey};
        use secp256k1::{Secp256k1, Message};
 
+       use std::marker::PhantomData;
+
        #[test]
        fn encoding_channel_reestablish_no_secret() {
                let cr = msgs::ChannelReestablish {
@@ -1483,7 +1601,7 @@ mod tests {
                let sig_2 = get_sig_on!(privkey_2, secp_ctx, String::from("01010101010101010101010101010101"));
                let sig_3 = get_sig_on!(privkey_3, secp_ctx, String::from("01010101010101010101010101010101"));
                let sig_4 = get_sig_on!(privkey_4, secp_ctx, String::from("01010101010101010101010101010101"));
-               let mut features = GlobalFeatures::new();
+               let mut features = ChannelFeatures::new();
                if unknown_features_bits {
                        features.flags = vec![0xFF, 0xFF];
                }
@@ -1539,7 +1657,7 @@ mod tests {
                let secp_ctx = Secp256k1::new();
                let (privkey_1, pubkey_1) = get_keys_from!("0101010101010101010101010101010101010101010101010101010101010101", secp_ctx);
                let sig_1 = get_sig_on!(privkey_1, secp_ctx, String::from("01010101010101010101010101010101"));
-               let mut features = GlobalFeatures::new();
+               let mut features = NodeFeatures::new();
                if unknown_features_bits {
                        features.flags = vec![0xFF, 0xFF];
                }
@@ -1594,7 +1712,7 @@ mod tests {
                if unknown_features_bits {
                        target_value.append(&mut hex::decode("0002ffff").unwrap());
                } else {
-                       target_value.append(&mut hex::decode("0000").unwrap());
+                       target_value.append(&mut hex::decode("000122").unwrap());
                }
                target_value.append(&mut hex::decode("013413a7031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f2020201010101010101010101010101010101010101010101010101010101010101010").unwrap());
                target_value.append(&mut vec![(addr_len >> 8) as u8, addr_len as u8]);
@@ -1993,40 +2111,26 @@ mod tests {
                assert_eq!(encoded_value, target_value);
        }
 
-       fn do_encoding_init(unknown_global_bits: bool, initial_routing_sync: bool) {
-               let mut global = GlobalFeatures::new();
-               if unknown_global_bits {
-                       global.flags = vec![0xFF, 0xFF];
-               }
-               let mut local = LocalFeatures::new();
-               if initial_routing_sync {
-                       local.set_initial_routing_sync();
-               }
-               let init = msgs::Init {
-                       global_features: global,
-                       local_features: local,
-               };
-               let encoded_value = init.encode();
-               let mut target_value = Vec::new();
-               if unknown_global_bits {
-                       target_value.append(&mut hex::decode("0002ffff").unwrap());
-               } else {
-                       target_value.append(&mut hex::decode("0000").unwrap());
-               }
-               if initial_routing_sync {
-                       target_value.append(&mut hex::decode("00012a").unwrap());
-               } else {
-                       target_value.append(&mut hex::decode("000122").unwrap());
-               }
-               assert_eq!(encoded_value, target_value);
-       }
-
        #[test]
        fn encoding_init() {
-               do_encoding_init(false, false);
-               do_encoding_init(true, false);
-               do_encoding_init(false, true);
-               do_encoding_init(true, true);
+               assert_eq!(msgs::Init {
+                       features: InitFeatures {
+                               flags: vec![0xFF, 0xFF, 0xFF],
+                               mark: PhantomData,
+                       },
+               }.encode(), hex::decode("00023fff0003ffffff").unwrap());
+               assert_eq!(msgs::Init {
+                       features: InitFeatures {
+                               flags: vec![0xFF],
+                               mark: PhantomData,
+                       },
+               }.encode(), hex::decode("0001ff0001ff").unwrap());
+               assert_eq!(msgs::Init {
+                       features: InitFeatures {
+                               flags: vec![],
+                               mark: PhantomData,
+                       },
+               }.encode(), hex::decode("00000000").unwrap());
        }
 
        #[test]