Move features into a separate module out of msgs.
[rust-lightning] / lightning / src / ln / msgs.rs
index 277d96d168733e890be98391b6017aa7df138b50..70bf19f41e0b337dda57b9ca94c02b3b2f8005b8 100644 (file)
@@ -21,6 +21,8 @@ use secp256k1;
 use bitcoin_hashes::sha256d::Hash as Sha256dHash;
 use bitcoin::blockdata::script::Script;
 
+use ln::features::{ChannelFeatures, InitFeatures, NodeFeatures};
+
 use std::error::Error;
 use std::{cmp, fmt};
 use std::io::Read;
@@ -52,102 +54,9 @@ pub enum DecodeError {
        Io(::std::io::Error),
 }
 
-/// Tracks localfeatures which are only in init messages
-#[derive(Clone, PartialEq)]
-pub struct LocalFeatures {
-       flags: Vec<u8>,
-}
-
-impl LocalFeatures {
-       /// Create a blank LocalFeatures flags (visibility extended for fuzz tests)
-       #[cfg(not(feature = "fuzztarget"))]
-       pub(crate) fn new() -> LocalFeatures {
-               LocalFeatures {
-                       flags: vec![2 | 1 << 5],
-               }
-       }
-       #[cfg(feature = "fuzztarget")]
-       pub fn new() -> LocalFeatures {
-               LocalFeatures {
-                       flags: vec![2 | 1 << 5],
-               }
-       }
-
-       pub(crate) fn supports_data_loss_protect(&self) -> bool {
-               self.flags.len() > 0 && (self.flags[0] & 3) != 0
-       }
-       pub(crate) fn initial_routing_sync(&self) -> bool {
-               self.flags.len() > 0 && (self.flags[0] & (1 << 3)) != 0
-       }
-       pub(crate) fn set_initial_routing_sync(&mut self) {
-               if self.flags.len() == 0 {
-                       self.flags.resize(1, 1 << 3);
-               } else {
-                       self.flags[0] |= 1 << 3;
-               }
-       }
-
-       pub(crate) fn supports_upfront_shutdown_script(&self) -> bool {
-               self.flags.len() > 0 && (self.flags[0] & (3 << 4)) != 0
-       }
-       #[cfg(test)]
-       pub(crate) fn unset_upfront_shutdown_script(&mut self) {
-               self.flags[0] ^= 1 << 5;
-       }
-
-       pub(crate) fn requires_unknown_bits(&self) -> bool {
-               self.flags.iter().enumerate().any(|(idx, &byte)| {
-                       ( idx != 0 && (byte & 0x55) != 0 ) || ( idx == 0 && (byte & 0x14) != 0 )
-               })
-       }
-
-       pub(crate) fn supports_unknown_bits(&self) -> bool {
-               self.flags.iter().enumerate().any(|(idx, &byte)| {
-                       ( idx != 0 && byte != 0 ) || ( idx == 0 && (byte & 0xc4) != 0 )
-               })
-       }
-}
-
-/// Tracks globalfeatures which are in init messages and routing announcements
-#[derive(Clone, PartialEq, Debug)]
-pub struct GlobalFeatures {
-       #[cfg(not(test))]
-       flags: Vec<u8>,
-       // Used to test encoding of diverse msgs
-       #[cfg(test)]
-       pub flags: Vec<u8>
-}
-
-impl GlobalFeatures {
-       pub(crate) fn new() -> GlobalFeatures {
-               GlobalFeatures {
-                       flags: Vec::new(),
-               }
-       }
-
-       pub(crate) fn requires_unknown_bits(&self) -> bool {
-               for &byte in self.flags.iter() {
-                       if (byte & 0x55) != 0 {
-                               return true;
-                       }
-               }
-               return false;
-       }
-
-       pub(crate) fn supports_unknown_bits(&self) -> bool {
-               for &byte in self.flags.iter() {
-                       if byte != 0 {
-                               return true;
-                       }
-               }
-               return false;
-       }
-}
-
 /// An init message to be sent or received from a peer
 pub struct Init {
-       pub(crate) global_features: GlobalFeatures,
-       pub(crate) local_features: LocalFeatures,
+       pub(crate) features: InitFeatures,
 }
 
 /// An error message to be sent or received from a peer
@@ -461,7 +370,7 @@ impl<R: ::std::io::Read>  Readable<R> for Result<NetAddress, u8> {
 /// The unsigned part of a node_announcement
 #[derive(PartialEq, Clone, Debug)]
 pub struct UnsignedNodeAnnouncement {
-       pub(crate) features: GlobalFeatures,
+       pub(crate) features: NodeFeatures,
        pub(crate) timestamp: u32,
        /// The node_id this announcement originated from (don't rebroadcast the node_announcement back
        /// to this node).
@@ -485,7 +394,7 @@ pub struct NodeAnnouncement {
 /// The unsigned part of a channel_announcement
 #[derive(PartialEq, Clone, Debug)]
 pub struct UnsignedChannelAnnouncement {
-       pub(crate) features: GlobalFeatures,
+       pub(crate) features: ChannelFeatures,
        pub(crate) chain_hash: Sha256dHash,
        pub(crate) short_channel_id: u64,
        /// One of the two node_ids which are endpoints of this channel
@@ -616,42 +525,42 @@ pub enum OptionalField<T> {
 pub trait ChannelMessageHandler : events::MessageSendEventsProvider + Send + Sync {
        //Channel init:
        /// Handle an incoming open_channel message from the given peer.
-       fn handle_open_channel(&self, their_node_id: &PublicKey, their_local_features: LocalFeatures, msg: &OpenChannel) -> Result<(), LightningError>;
+       fn handle_open_channel(&self, their_node_id: &PublicKey, their_features: InitFeatures, msg: &OpenChannel);
        /// Handle an incoming accept_channel message from the given peer.
-       fn handle_accept_channel(&self, their_node_id: &PublicKey, their_local_features: LocalFeatures, msg: &AcceptChannel) -> Result<(), LightningError>;
+       fn handle_accept_channel(&self, their_node_id: &PublicKey, their_features: InitFeatures, msg: &AcceptChannel);
        /// Handle an incoming funding_created message from the given peer.
-       fn handle_funding_created(&self, their_node_id: &PublicKey, msg: &FundingCreated) -> Result<(), LightningError>;
+       fn handle_funding_created(&self, their_node_id: &PublicKey, msg: &FundingCreated);
        /// Handle an incoming funding_signed message from the given peer.
-       fn handle_funding_signed(&self, their_node_id: &PublicKey, msg: &FundingSigned) -> Result<(), LightningError>;
+       fn handle_funding_signed(&self, their_node_id: &PublicKey, msg: &FundingSigned);
        /// Handle an incoming funding_locked message from the given peer.
-       fn handle_funding_locked(&self, their_node_id: &PublicKey, msg: &FundingLocked) -> Result<(), LightningError>;
+       fn handle_funding_locked(&self, their_node_id: &PublicKey, msg: &FundingLocked);
 
        // Channl close:
        /// Handle an incoming shutdown message from the given peer.
-       fn handle_shutdown(&self, their_node_id: &PublicKey, msg: &Shutdown) -> Result<(), LightningError>;
+       fn handle_shutdown(&self, their_node_id: &PublicKey, msg: &Shutdown);
        /// Handle an incoming closing_signed message from the given peer.
-       fn handle_closing_signed(&self, their_node_id: &PublicKey, msg: &ClosingSigned) -> Result<(), LightningError>;
+       fn handle_closing_signed(&self, their_node_id: &PublicKey, msg: &ClosingSigned);
 
        // HTLC handling:
        /// Handle an incoming update_add_htlc message from the given peer.
-       fn handle_update_add_htlc(&self, their_node_id: &PublicKey, msg: &UpdateAddHTLC) -> Result<(), LightningError>;
+       fn handle_update_add_htlc(&self, their_node_id: &PublicKey, msg: &UpdateAddHTLC);
        /// Handle an incoming update_fulfill_htlc message from the given peer.
-       fn handle_update_fulfill_htlc(&self, their_node_id: &PublicKey, msg: &UpdateFulfillHTLC) -> Result<(), LightningError>;
+       fn handle_update_fulfill_htlc(&self, their_node_id: &PublicKey, msg: &UpdateFulfillHTLC);
        /// Handle an incoming update_fail_htlc message from the given peer.
-       fn handle_update_fail_htlc(&self, their_node_id: &PublicKey, msg: &UpdateFailHTLC) -> Result<(), LightningError>;
+       fn handle_update_fail_htlc(&self, their_node_id: &PublicKey, msg: &UpdateFailHTLC);
        /// Handle an incoming update_fail_malformed_htlc message from the given peer.
-       fn handle_update_fail_malformed_htlc(&self, their_node_id: &PublicKey, msg: &UpdateFailMalformedHTLC) -> Result<(), LightningError>;
+       fn handle_update_fail_malformed_htlc(&self, their_node_id: &PublicKey, msg: &UpdateFailMalformedHTLC);
        /// Handle an incoming commitment_signed message from the given peer.
-       fn handle_commitment_signed(&self, their_node_id: &PublicKey, msg: &CommitmentSigned) -> Result<(), LightningError>;
+       fn handle_commitment_signed(&self, their_node_id: &PublicKey, msg: &CommitmentSigned);
        /// Handle an incoming revoke_and_ack message from the given peer.
-       fn handle_revoke_and_ack(&self, their_node_id: &PublicKey, msg: &RevokeAndACK) -> Result<(), LightningError>;
+       fn handle_revoke_and_ack(&self, their_node_id: &PublicKey, msg: &RevokeAndACK);
 
        /// Handle an incoming update_fee message from the given peer.
-       fn handle_update_fee(&self, their_node_id: &PublicKey, msg: &UpdateFee) -> Result<(), LightningError>;
+       fn handle_update_fee(&self, their_node_id: &PublicKey, msg: &UpdateFee);
 
        // Channel-to-announce:
        /// Handle an incoming announcement_signatures message from the given peer.
-       fn handle_announcement_signatures(&self, their_node_id: &PublicKey, msg: &AnnouncementSignatures) -> Result<(), LightningError>;
+       fn handle_announcement_signatures(&self, their_node_id: &PublicKey, msg: &AnnouncementSignatures);
 
        // Connection loss/reestablish:
        /// Indicates a connection to the peer failed/an existing connection was lost. If no connection
@@ -663,7 +572,7 @@ pub trait ChannelMessageHandler : events::MessageSendEventsProvider + Send + Syn
        /// Handle a peer reconnecting, possibly generating channel_reestablish message(s).
        fn peer_connected(&self, their_node_id: &PublicKey);
        /// Handle an incoming channel_reestablish message from the given peer.
-       fn handle_channel_reestablish(&self, their_node_id: &PublicKey, msg: &ChannelReestablish) -> Result<(), LightningError>;
+       fn handle_channel_reestablish(&self, their_node_id: &PublicKey, msg: &ChannelReestablish);
 
        // Error:
        /// Handle an incoming error message from the given peer.
@@ -918,24 +827,24 @@ impl_writeable!(FundingLocked, 32+33, {
        next_per_commitment_point
 });
 
-impl_writeable_len_match!(GlobalFeatures, {
-               { GlobalFeatures { ref flags }, flags.len() + 2 }
-       }, {
-       flags
-});
-
-impl_writeable_len_match!(LocalFeatures, {
-               { LocalFeatures { ref flags }, flags.len() + 2 }
-       }, {
-       flags
-});
+impl Writeable for Init {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               // global_features gets the bottom 13 bits of our features, and local_features gets all of
+               // our relevant feature bits. This keeps us compatible with old nodes.
+               self.features.write_up_to_13(w)?;
+               self.features.write(w)
+       }
+}
 
-impl_writeable_len_match!(Init, {
-               { Init { ref global_features, ref local_features }, global_features.flags.len() + local_features.flags.len() + 4 }
-       }, {
-       global_features,
-       local_features
-});
+impl<R: Read> Readable<R> for Init {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
+               let global_features: InitFeatures = Readable::read(r)?;
+               let features: InitFeatures = Readable::read(r)?;
+               Ok(Init {
+                       features: features.or(global_features),
+               })
+       }
+}
 
 impl_writeable_len_match!(OpenChannel, {
                { OpenChannel { shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 319 + 2 + script.len() },
@@ -1140,7 +1049,7 @@ impl<R: Read> Readable<R> for Pong {
 
 impl Writeable for UnsignedChannelAnnouncement {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               w.size_hint(2 + 2*32 + 4*33 + self.features.flags.len() + self.excess_data.len());
+               w.size_hint(2 + 2*32 + 4*33 + self.features.byte_count() + self.excess_data.len());
                self.features.write(w)?;
                self.chain_hash.write(w)?;
                self.short_channel_id.write(w)?;
@@ -1156,13 +1065,7 @@ impl Writeable for UnsignedChannelAnnouncement {
 impl<R: Read> Readable<R> for UnsignedChannelAnnouncement {
        fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(Self {
-                       features: {
-                               let f: GlobalFeatures = Readable::read(r)?;
-                               if f.requires_unknown_bits() {
-                                       return Err(DecodeError::UnknownRequiredFeature);
-                               }
-                               f
-                       },
+                       features: Readable::read(r)?,
                        chain_hash: Readable::read(r)?,
                        short_channel_id: Readable::read(r)?,
                        node_id_1: Readable::read(r)?,
@@ -1180,7 +1083,7 @@ impl<R: Read> Readable<R> for UnsignedChannelAnnouncement {
 
 impl_writeable_len_match!(ChannelAnnouncement, {
                { ChannelAnnouncement { contents: UnsignedChannelAnnouncement {ref features, ref excess_data, ..}, .. },
-                       2 + 2*32 + 4*33 + features.flags.len() + excess_data.len() + 4*64 }
+                       2 + 2*32 + 4*33 + features.byte_count() + excess_data.len() + 4*64 }
        }, {
        node_signature_1,
        node_signature_2,
@@ -1263,7 +1166,7 @@ impl<R: Read> Readable<R> for ErrorMessage {
 
 impl Writeable for UnsignedNodeAnnouncement {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               w.size_hint(64 + 76 + self.features.flags.len() + self.addresses.len()*38 + self.excess_address_data.len() + self.excess_data.len());
+               w.size_hint(64 + 76 + self.features.byte_count() + self.addresses.len()*38 + self.excess_address_data.len() + self.excess_data.len());
                self.features.write(w)?;
                self.timestamp.write(w)?;
                self.node_id.write(w)?;
@@ -1289,10 +1192,7 @@ impl Writeable for UnsignedNodeAnnouncement {
 
 impl<R: Read> Readable<R> for UnsignedNodeAnnouncement {
        fn read(r: &mut R) -> Result<Self, DecodeError> {
-               let features: GlobalFeatures = Readable::read(r)?;
-               if features.requires_unknown_bits() {
-                       return Err(DecodeError::UnknownRequiredFeature);
-               }
+               let features: NodeFeatures = Readable::read(r)?;
                let timestamp: u32 = Readable::read(r)?;
                let node_id: PublicKey = Readable::read(r)?;
                let mut rgb = [0; 3];
@@ -1376,7 +1276,7 @@ impl<R: Read> Readable<R> for UnsignedNodeAnnouncement {
 
 impl_writeable_len_match!(NodeAnnouncement, {
                { NodeAnnouncement { contents: UnsignedNodeAnnouncement { ref features, ref addresses, ref excess_address_data, ref excess_data, ..}, .. },
-                       64 + 76 + features.flags.len() + addresses.len()*38 + excess_address_data.len() + excess_data.len() }
+                       64 + 76 + features.byte_count() + addresses.len()*38 + excess_address_data.len() + excess_data.len() }
        }, {
        signature,
        contents
@@ -1386,7 +1286,7 @@ impl_writeable_len_match!(NodeAnnouncement, {
 mod tests {
        use hex;
        use ln::msgs;
-       use ln::msgs::{GlobalFeatures, LocalFeatures, OptionalField, OnionErrorPacket};
+       use ln::msgs::{ChannelFeatures, InitFeatures, NodeFeatures, OptionalField, OnionErrorPacket};
        use ln::channelmanager::{PaymentPreimage, PaymentHash};
        use util::ser::Writeable;
 
@@ -1483,9 +1383,9 @@ mod tests {
                let sig_2 = get_sig_on!(privkey_2, secp_ctx, String::from("01010101010101010101010101010101"));
                let sig_3 = get_sig_on!(privkey_3, secp_ctx, String::from("01010101010101010101010101010101"));
                let sig_4 = get_sig_on!(privkey_4, secp_ctx, String::from("01010101010101010101010101010101"));
-               let mut features = GlobalFeatures::new();
+               let mut features = ChannelFeatures::supported();
                if unknown_features_bits {
-                       features.flags = vec![0xFF, 0xFF];
+                       features = ChannelFeatures::from_le_bytes(vec![0xFF, 0xFF]);
                }
                let unsigned_channel_announcement = msgs::UnsignedChannelAnnouncement {
                        features,
@@ -1539,10 +1439,12 @@ mod tests {
                let secp_ctx = Secp256k1::new();
                let (privkey_1, pubkey_1) = get_keys_from!("0101010101010101010101010101010101010101010101010101010101010101", secp_ctx);
                let sig_1 = get_sig_on!(privkey_1, secp_ctx, String::from("01010101010101010101010101010101"));
-               let mut features = GlobalFeatures::new();
-               if unknown_features_bits {
-                       features.flags = vec![0xFF, 0xFF];
-               }
+               let features = if unknown_features_bits {
+                       NodeFeatures::from_le_bytes(vec![0xFF, 0xFF])
+               } else {
+                       // Set to some features we may support
+                       NodeFeatures::from_le_bytes(vec![2 | 1 << 5])
+               };
                let mut addresses = Vec::new();
                if ipv4 {
                        addresses.push(msgs::NetAddress::IPv4 {
@@ -1594,7 +1496,7 @@ mod tests {
                if unknown_features_bits {
                        target_value.append(&mut hex::decode("0002ffff").unwrap());
                } else {
-                       target_value.append(&mut hex::decode("0000").unwrap());
+                       target_value.append(&mut hex::decode("000122").unwrap());
                }
                target_value.append(&mut hex::decode("013413a7031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f2020201010101010101010101010101010101010101010101010101010101010101010").unwrap());
                target_value.append(&mut vec![(addr_len >> 8) as u8, addr_len as u8]);
@@ -1993,40 +1895,17 @@ mod tests {
                assert_eq!(encoded_value, target_value);
        }
 
-       fn do_encoding_init(unknown_global_bits: bool, initial_routing_sync: bool) {
-               let mut global = GlobalFeatures::new();
-               if unknown_global_bits {
-                       global.flags = vec![0xFF, 0xFF];
-               }
-               let mut local = LocalFeatures::new();
-               if initial_routing_sync {
-                       local.set_initial_routing_sync();
-               }
-               let init = msgs::Init {
-                       global_features: global,
-                       local_features: local,
-               };
-               let encoded_value = init.encode();
-               let mut target_value = Vec::new();
-               if unknown_global_bits {
-                       target_value.append(&mut hex::decode("0002ffff").unwrap());
-               } else {
-                       target_value.append(&mut hex::decode("0000").unwrap());
-               }
-               if initial_routing_sync {
-                       target_value.append(&mut hex::decode("00012a").unwrap());
-               } else {
-                       target_value.append(&mut hex::decode("000122").unwrap());
-               }
-               assert_eq!(encoded_value, target_value);
-       }
-
        #[test]
        fn encoding_init() {
-               do_encoding_init(false, false);
-               do_encoding_init(true, false);
-               do_encoding_init(false, true);
-               do_encoding_init(true, true);
+               assert_eq!(msgs::Init {
+                       features: InitFeatures::from_le_bytes(vec![0xFF, 0xFF, 0xFF]),
+               }.encode(), hex::decode("00023fff0003ffffff").unwrap());
+               assert_eq!(msgs::Init {
+                       features: InitFeatures::from_le_bytes(vec![0xFF]),
+               }.encode(), hex::decode("0001ff0001ff").unwrap());
+               assert_eq!(msgs::Init {
+                       features: InitFeatures::from_le_bytes(vec![]),
+               }.encode(), hex::decode("00000000").unwrap());
        }
 
        #[test]