]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Merge pull request #1078 from TheBlueMatt/2021-09-chan-types
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Wed, 3 Nov 2021 16:58:33 +0000 (16:58 +0000)
committerGitHub <noreply@github.com>
Wed, 3 Nov 2021 16:58:33 +0000 (16:58 +0000)
Implement channel_type negotiation

1  2 
lightning-background-processor/src/lib.rs
lightning/src/ln/channel.rs
lightning/src/ln/features.rs

index 50743774b7aa487753f416de66f1c952a7c86a3f,0fef55a95c2c5369408613cebfa574e9b8b95c05..593f84a90898534742f5dbd0c0458ea3f53cf059
@@@ -14,8 -14,9 +14,8 @@@ use lightning::chain::chainmonitor::{Ch
  use lightning::chain::keysinterface::{Sign, KeysInterface};
  use lightning::ln::channelmanager::ChannelManager;
  use lightning::ln::msgs::{ChannelMessageHandler, RoutingMessageHandler};
 -use lightning::ln::peer_handler::{PeerManager, SocketDescriptor};
 -use lightning::ln::peer_handler::CustomMessageHandler;
 -use lightning::routing::network_graph::NetGraphMsgHandler;
 +use lightning::ln::peer_handler::{CustomMessageHandler, PeerManager, SocketDescriptor};
 +use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
  use lightning::util::events::{Event, EventHandler, EventsProvider};
  use lightning::util::logger::Logger;
  use std::sync::Arc;
@@@ -103,8 -104,7 +103,8 @@@ ChannelManagerPersister<Signer, M, T, K
  /// Decorates an [`EventHandler`] with common functionality provided by standard [`EventHandler`]s.
  struct DecoratingEventHandler<
        E: EventHandler,
 -      N: Deref<Target = NetGraphMsgHandler<A, L>>,
 +      N: Deref<Target = NetGraphMsgHandler<G, A, L>>,
 +      G: Deref<Target = NetworkGraph>,
        A: Deref,
        L: Deref,
  >
@@@ -115,11 -115,10 +115,11 @@@ where A::Target: chain::Access, L::Targ
  
  impl<
        E: EventHandler,
 -      N: Deref<Target = NetGraphMsgHandler<A, L>>,
 +      N: Deref<Target = NetGraphMsgHandler<G, A, L>>,
 +      G: Deref<Target = NetworkGraph>,
        A: Deref,
        L: Deref,
 -> EventHandler for DecoratingEventHandler<E, N, A, L>
 +> EventHandler for DecoratingEventHandler<E, N, G, A, L>
  where A::Target: chain::Access, L::Target: Logger {
        fn handle_event(&self, event: &Event) {
                if let Some(event_handler) = &self.net_graph_msg_handler {
@@@ -155,7 -154,7 +155,7 @@@ impl BackgroundProcessor 
        /// functionality implemented by other handlers.
        /// * [`NetGraphMsgHandler`] if given will update the [`NetworkGraph`] based on payment failures.
        ///
-       /// [top-level documentation]: Self
+       /// [top-level documentation]: BackgroundProcessor
        /// [`join`]: Self::join
        /// [`stop`]: Self::stop
        /// [`ChannelManager`]: lightning::ln::channelmanager::ChannelManager
                T: 'static + Deref + Send + Sync,
                K: 'static + Deref + Send + Sync,
                F: 'static + Deref + Send + Sync,
 +              G: 'static + Deref<Target = NetworkGraph> + Send + Sync,
                L: 'static + Deref + Send + Sync,
                P: 'static + Deref + Send + Sync,
                Descriptor: 'static + SocketDescriptor + Send + Sync,
                CMH: 'static + Deref + Send + Sync,
                RMH: 'static + Deref + Send + Sync,
 -              EH: 'static + EventHandler + Send + Sync,
 +              EH: 'static + EventHandler + Send,
                CMP: 'static + Send + ChannelManagerPersister<Signer, CW, T, K, F, L>,
                M: 'static + Deref<Target = ChainMonitor<Signer, CF, T, F, L, P>> + Send + Sync,
                CM: 'static + Deref<Target = ChannelManager<Signer, CW, T, K, F, L>> + Send + Sync,
 -              NG: 'static + Deref<Target = NetGraphMsgHandler<CA, L>> + Send + Sync,
 +              NG: 'static + Deref<Target = NetGraphMsgHandler<G, CA, L>> + Send + Sync,
                UMH: 'static + Deref + Send + Sync,
                PM: 'static + Deref<Target = PeerManager<Descriptor, CMH, RMH, L, UMH>> + Send + Sync,
        >(
                                        // timer, we should have disconnected all sockets by now (and they're probably
                                        // dead anyway), so disconnect them by calling `timer_tick_occurred()` twice.
                                        log_trace!(logger, "Awoke after more than double our ping timer, disconnecting peers.");
 -                                      peer_manager.timer_tick_occurred();
 -                                      peer_manager.timer_tick_occurred();
 +                                      peer_manager.disconnect_all_peers();
                                        last_ping_call = Instant::now();
                                } else if last_ping_call.elapsed().as_secs() > PING_TIMER {
                                        log_trace!(logger, "Calling PeerManager's timer_tick_occurred");
@@@ -317,8 -316,6 +317,8 @@@ mod tests 
        use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent};
        use lightning::util::ser::Writeable;
        use lightning::util::test_utils;
 +      use lightning_invoice::payment::{InvoicePayer, RetryAttempts};
 +      use lightning_invoice::utils::DefaultRouter;
        use lightning_persister::FilesystemPersister;
        use std::fs;
        use std::path::PathBuf;
  
        struct Node {
                node: Arc<SimpleArcChannelManager<ChainMonitor, test_utils::TestBroadcaster, test_utils::TestFeeEstimator, test_utils::TestLogger>>,
 -              net_graph_msg_handler: Option<Arc<NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>>>,
 +              net_graph_msg_handler: Option<Arc<NetGraphMsgHandler<Arc<NetworkGraph>, Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>>>,
                peer_manager: Arc<PeerManager<TestDescriptor, Arc<test_utils::TestChannelMessageHandler>, Arc<test_utils::TestRoutingMessageHandler>, Arc<test_utils::TestLogger>, IgnoringMessageHandler>>,
                chain_monitor: Arc<ChainMonitor>,
                persister: Arc<FilesystemPersister>,
                tx_broadcaster: Arc<test_utils::TestBroadcaster>,
 +              network_graph: Arc<NetworkGraph>,
                logger: Arc<test_utils::TestLogger>,
                best_block: BestBlock,
        }
                        let best_block = BestBlock::from_genesis(network);
                        let params = ChainParameters { network, best_block };
                        let manager = Arc::new(ChannelManager::new(fee_estimator.clone(), chain_monitor.clone(), tx_broadcaster.clone(), logger.clone(), keys_manager.clone(), UserConfig::default(), params));
 -                      let network_graph = NetworkGraph::new(genesis_block.header.block_hash());
 -                      let net_graph_msg_handler = Some(Arc::new(NetGraphMsgHandler::new(network_graph, Some(chain_source.clone()), logger.clone())));
 +                      let network_graph = Arc::new(NetworkGraph::new(genesis_block.header.block_hash()));
 +                      let net_graph_msg_handler = Some(Arc::new(NetGraphMsgHandler::new(network_graph.clone(), Some(chain_source.clone()), logger.clone())));
                        let msg_handler = MessageHandler { chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new()), route_handler: Arc::new(test_utils::TestRoutingMessageHandler::new() )};
                        let peer_manager = Arc::new(PeerManager::new(msg_handler, keys_manager.get_node_secret(), &seed, logger.clone(), IgnoringMessageHandler{}));
 -                      let node = Node { node: manager, net_graph_msg_handler, peer_manager, chain_monitor, persister, tx_broadcaster, logger, best_block };
 +                      let node = Node { node: manager, net_graph_msg_handler, peer_manager, chain_monitor, persister, tx_broadcaster, network_graph, logger, best_block };
                        nodes.push(node);
                }
  
  
                assert!(bg_processor.stop().is_ok());
        }
 +
 +      #[test]
 +      fn test_invoice_payer() {
 +              let nodes = create_nodes(2, "test_invoice_payer".to_string());
 +
 +              // Initiate the background processors to watch each node.
 +              let data_dir = nodes[0].persister.get_data_dir();
 +              let persister = move |node: &ChannelManager<InMemorySigner, Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>>| FilesystemPersister::persist_manager(data_dir.clone(), node);
 +              let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger));
 +              let scorer = Arc::new(Mutex::new(test_utils::TestScorer::default()));
 +              let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2)));
 +              let event_handler = Arc::clone(&invoice_payer);
 +              let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
 +              assert!(bg_processor.stop().is_ok());
 +      }
  }
index dffde0be7918f4ebab5f44c6d9d873a8db4666b3,68bcf89985199954cf28bf284b2de2421e6d2de0..e39478bd740c063ce1f29a7dc853238ec75fcf48
@@@ -23,7 -23,7 +23,7 @@@ use bitcoin::secp256k1::{Secp256k1,Sign
  use bitcoin::secp256k1;
  
  use ln::{PaymentPreimage, PaymentHash};
- use ln::features::{ChannelFeatures, InitFeatures};
+ use ln::features::{ChannelFeatures, ChannelTypeFeatures, InitFeatures};
  use ln::msgs;
  use ln::msgs::{DecodeError, OptionalField, DataLossProtect};
  use ln::script::{self, ShutdownScript};
@@@ -550,6 -550,9 +550,9 @@@ pub(super) struct Channel<Signer: Sign
        // is fine, but as a sanity check in our failure to generate the second claim, we check here
        // that the original was a claim, and that we aren't now trying to fulfill a failed HTLC.
        historical_inbound_htlc_fulfills: HashSet<u64>,
+       /// This channel's type, as negotiated during channel open
+       channel_type: ChannelTypeFeatures,
  }
  
  #[cfg(any(test, feature = "fuzztarget"))]
@@@ -775,6 -778,11 +778,11 @@@ impl<Signer: Sign> Channel<Signer> 
  
                        #[cfg(any(test, feature = "fuzztarget"))]
                        historical_inbound_htlc_fulfills: HashSet::new(),
+                       // We currently only actually support one channel type, so don't retry with new types
+                       // on error messages. When we support more we'll need fallback support (assuming we
+                       // want to support old types).
+                       channel_type: ChannelTypeFeatures::only_static_remote_key(),
                })
        }
  
                where K::Target: KeysInterface<Signer = Signer>,
            F::Target: FeeEstimator
        {
+               // First check the channel type is known, failing before we do anything else if we don't
+               // support this channel type.
+               let channel_type = if let Some(channel_type) = &msg.channel_type {
+                       if channel_type.supports_any_optional_bits() {
+                               return Err(ChannelError::Close("Channel Type field contained optional bits - this is not allowed".to_owned()));
+                       }
+                       if *channel_type != ChannelTypeFeatures::only_static_remote_key() {
+                               return Err(ChannelError::Close("Channel Type was not understood".to_owned()));
+                       }
+                       channel_type.clone()
+               } else {
+                       ChannelTypeFeatures::from_counterparty_init(&their_features)
+               };
+               if !channel_type.supports_static_remote_key() {
+                       return Err(ChannelError::Close("Channel Type was not understood - we require static remote key".to_owned()));
+               }
                let holder_signer = keys_provider.get_channel_signer(true, msg.funding_satoshis);
                let pubkeys = holder_signer.pubkeys().clone();
                let counterparty_pubkeys = ChannelPublicKeys {
  
                        #[cfg(any(test, feature = "fuzztarget"))]
                        historical_inbound_htlc_fulfills: HashSet::new(),
+                       channel_type,
                };
  
                Ok(chan)
                                Some(script) => script.clone().into_inner(),
                                None => Builder::new().into_script(),
                        }),
+                       channel_type: Some(self.channel_type.clone()),
                }
        }
  
@@@ -5240,6 -5268,7 +5268,7 @@@ impl<Signer: Sign> Writeable for Channe
                        (7, self.shutdown_scriptpubkey, option),
                        (9, self.target_closing_feerate_sats_per_kw, option),
                        (11, self.monitor_pending_finalized_fulfills, vec_type),
+                       (13, self.channel_type, required),
                });
  
                Ok(())
@@@ -5474,6 -5503,9 +5503,9 @@@ impl<'a, Signer: Sign, K: Deref> Readab
                let mut announcement_sigs = None;
                let mut target_closing_feerate_sats_per_kw = None;
                let mut monitor_pending_finalized_fulfills = Some(Vec::new());
+               // Prior to supporting channel type negotiation, all of our channels were static_remotekey
+               // only, so we default to that if none was written.
+               let mut channel_type = Some(ChannelTypeFeatures::only_static_remote_key());
                read_tlv_fields!(reader, {
                        (0, announcement_sigs, option),
                        (1, minimum_depth, option),
                        (7, shutdown_scriptpubkey, option),
                        (9, target_closing_feerate_sats_per_kw, option),
                        (11, monitor_pending_finalized_fulfills, vec_type),
+                       (13, channel_type, option),
                });
  
+               let chan_features = channel_type.as_ref().unwrap();
+               if chan_features.supports_unknown_bits() || chan_features.requires_unknown_bits() {
+                       // If the channel was written by a new version and negotiated with features we don't
+                       // understand yet, refuse to read it.
+                       return Err(DecodeError::UnknownRequiredFeature);
+               }
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&keys_source.get_secure_random_bytes());
  
  
                        #[cfg(any(test, feature = "fuzztarget"))]
                        historical_inbound_htlc_fulfills,
+                       channel_type: channel_type.unwrap(),
                })
        }
  }
@@@ -5768,7 -5810,6 +5810,7 @@@ mod tests 
                                first_hop_htlc_msat: 548,
                                payment_id: PaymentId([42; 32]),
                                payment_secret: None,
 +                              payee: None,
                        }
                });
  
index 888dcd3ac0c076f36eb8058712f8c708d4827d42,46a296001ef2411a45407fe03cc8e04d786a4312..32ba9de758632036942318c756f64708793e1b1a
@@@ -22,7 -22,7 +22,7 @@@
  //! [BOLT #9]: https://github.com/lightningnetwork/lightning-rfc/blob/master/09-features.md
  //! [messages]: crate::ln::msgs
  
- use io;
+ use {io, io_extras};
  use prelude::*;
  use core::{cmp, fmt};
  use core::hash::{Hash, Hasher};
@@@ -194,6 -194,30 +194,30 @@@ mod sealed 
                        BasicMPP,
                ],
        });
+       // This isn't a "real" feature context, and is only used in the channel_type field in an
+       // `OpenChannel` message.
+       define_context!(ChannelTypeContext {
+               required_features: [
+                       // Byte 0
+                       ,
+                       // Byte 1
+                       StaticRemoteKey,
+                       // Byte 2
+                       ,
+                       // Byte 3
+                       ,
+               ],
+               optional_features: [
+                       // Byte 0
+                       ,
+                       // Byte 1
+                       ,
+                       // Byte 2
+                       ,
+                       // Byte 3
+                       ,
+               ],
+       });
  
        /// Defines a feature with the given bits for the specified [`Context`]s. The generated trait is
        /// useful for manipulating feature flags.
        define_feature!(9, VariableLengthOnion, [InitContext, NodeContext, InvoiceContext],
                "Feature flags for `var_onion_optin`.", set_variable_length_onion_optional,
                set_variable_length_onion_required);
-       define_feature!(13, StaticRemoteKey, [InitContext, NodeContext],
+       define_feature!(13, StaticRemoteKey, [InitContext, NodeContext, ChannelTypeContext],
                "Feature flags for `option_static_remotekey`.", set_static_remote_key_optional,
                set_static_remote_key_required);
        define_feature!(15, PaymentSecret, [InitContext, NodeContext, InvoiceContext],
@@@ -388,6 -412,18 +412,18 @@@ pub type ChannelFeatures = Features<sea
  /// Features used within an invoice.
  pub type InvoiceFeatures = Features<sealed::InvoiceContext>;
  
+ /// Features used within the channel_type field in an OpenChannel message.
+ ///
+ /// A channel is always of some known "type", describing the transaction formats used and the exact
+ /// semantics of our interaction with our peer.
+ ///
+ /// Note that because a channel is a specific type which is proposed by the opener and accepted by
+ /// the counterparty, only required features are allowed here.
+ ///
+ /// This is serialized differently from other feature types - it is not prefixed by a length, and
+ /// thus must only appear inside a TLV where its length is known in advance.
+ pub type ChannelTypeFeatures = Features<sealed::ChannelTypeContext>;
  impl InitFeatures {
        /// Writes all features present up to, and including, 13.
        pub(crate) fn write_up_to_13<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
@@@ -432,16 -468,38 +468,38 @@@ impl InvoiceFeatures 
        /// Getting a route for a keysend payment to a private node requires providing the payee's
        /// features (since they were not announced in a node announcement). However, keysend payments
        /// don't have an invoice to pull the payee's features from, so this method is provided for use in
 -      /// [`get_keysend_route`], thus omitting the need for payers to manually construct an
 -      /// `InvoiceFeatures` for [`get_route`].
 +      /// [`Payee::for_keysend`], thus omitting the need for payers to manually construct an
 +      /// `InvoiceFeatures` for [`find_route`].
        ///
 -      /// [`get_keysend_route`]: crate::routing::router::get_keysend_route
 -      /// [`get_route`]: crate::routing::router::get_route
 +      /// [`Payee::for_keysend`]: crate::routing::router::Payee::for_keysend
 +      /// [`find_route`]: crate::routing::router::find_route
        pub(crate) fn for_keysend() -> InvoiceFeatures {
                InvoiceFeatures::empty().set_variable_length_onion_optional()
        }
  }
  
+ impl ChannelTypeFeatures {
+       /// Constructs the implicit channel type based on the common supported types between us and our
+       /// counterparty
+       pub(crate) fn from_counterparty_init(counterparty_init: &InitFeatures) -> Self {
+               let mut ret = counterparty_init.to_context_internal();
+               // ChannelTypeFeatures must only contain required bits, so we OR the required forms of all
+               // optional bits and then AND out the optional ones.
+               for byte in ret.flags.iter_mut() {
+                       *byte |= (*byte & 0b10_10_10_10) >> 1;
+                       *byte &= 0b01_01_01_01;
+               }
+               ret
+       }
+       /// Constructs a ChannelTypeFeatures with only static_remotekey set
+       pub(crate) fn only_static_remote_key() -> Self {
+               let mut ret = Self::empty();
+               <sealed::ChannelTypeContext as sealed::StaticRemoteKey>::set_required_bit(&mut ret.flags);
+               ret
+       }
+ }
  impl ToBase32 for InvoiceFeatures {
        fn write_base32<W: WriteBase32>(&self, writer: &mut W) -> Result<(), <W as WriteBase32>::Err> {
                // Explanation for the "4": the normal way to round up when dividing is to add the divisor
@@@ -553,6 -611,25 +611,25 @@@ impl<T: sealed::Context> Features<T> 
                &self.flags
        }
  
+       fn write_be<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               for f in self.flags.iter().rev() { // Swap back to big-endian
+                       f.write(w)?;
+               }
+               Ok(())
+       }
+       fn from_be_bytes(mut flags: Vec<u8>) -> Features<T> {
+               flags.reverse(); // Swap to little-endian
+               Self {
+                       flags,
+                       mark: PhantomData,
+               }
+       }
+       pub(crate) fn supports_any_optional_bits(&self) -> bool {
+               self.flags.iter().any(|&byte| (byte & 0b10_10_10_10) != 0)
+       }
        /// Returns true if this `Features` object contains unknown feature flags which are set as
        /// "required".
        pub fn requires_unknown_bits(&self) -> bool {
@@@ -692,31 -769,44 +769,44 @@@ impl<T: sealed::ShutdownAnySegwit> Feat
                self
        }
  }
- impl<T: sealed::Context> Writeable for Features<T> {
-       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.flags.len() as u16).write(w)?;
-               for f in self.flags.iter().rev() { // Swap back to big-endian
-                       f.write(w)?;
+ macro_rules! impl_feature_len_prefixed_write {
+       ($features: ident) => {
+               impl Writeable for $features {
+                       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+                               (self.flags.len() as u16).write(w)?;
+                               self.write_be(w)
+                       }
+               }
+               impl Readable for $features {
+                       fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
+                               Ok(Self::from_be_bytes(Vec::<u8>::read(r)?))
+                       }
                }
-               Ok(())
        }
  }
- impl<T: sealed::Context> Readable for Features<T> {
+ impl_feature_len_prefixed_write!(InitFeatures);
+ impl_feature_len_prefixed_write!(ChannelFeatures);
+ impl_feature_len_prefixed_write!(NodeFeatures);
+ impl_feature_len_prefixed_write!(InvoiceFeatures);
+ // Because ChannelTypeFeatures only appears inside of TLVs, it doesn't have a length prefix when
+ // serialized. Thus, we can't use `impl_feature_len_prefixed_write`, above, and have to write our
+ // own serialization.
+ impl Writeable for ChannelTypeFeatures {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.write_be(w)
+       }
+ }
+ impl Readable for ChannelTypeFeatures {
        fn read<R: io::Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let mut flags: Vec<u8> = Readable::read(r)?;
-               flags.reverse(); // Swap to little-endian
-               Ok(Self {
-                       flags,
-                       mark: PhantomData,
-               })
+               let v = io_extras::read_to_end(r)?;
+               Ok(Self::from_be_bytes(v))
        }
  }
  
  #[cfg(test)]
  mod tests {
-       use super::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
+       use super::{ChannelFeatures, ChannelTypeFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
        use bitcoin::bech32::{Base32Len, FromBase32, ToBase32, u5};
  
        #[test]
                let features_deserialized = InvoiceFeatures::from_base32(&features_as_u5s).unwrap();
                assert_eq!(features, features_deserialized);
        }
+       #[test]
+       fn test_channel_type_mapping() {
+               // If we map an InvoiceFeatures with StaticRemoteKey optional, it should map into a
+               // required-StaticRemoteKey ChannelTypeFeatures.
+               let init_features = InitFeatures::empty().set_static_remote_key_optional();
+               let converted_features = ChannelTypeFeatures::from_counterparty_init(&init_features);
+               assert_eq!(converted_features, ChannelTypeFeatures::only_static_remote_key());
+               assert!(!converted_features.supports_any_optional_bits());
+               assert!(converted_features.requires_static_remote_key());
+       }
  }