Add splicing feature flag (also triggers dual_funding)
[rust-lightning] / lightning / src / ln / peer_handler.rs
index cfa32a009fe5f9213be2f250b2e033473b16c1b2..9ce230861456927f53bf4a2b65dd33e800e17e5d 100644 (file)
@@ -36,9 +36,10 @@ use crate::util::atomic_counter::AtomicCounter;
 use crate::util::logger::{Level, Logger, WithContext};
 use crate::util::string::PrintableString;
 
+#[allow(unused_imports)]
 use crate::prelude::*;
+
 use crate::io;
-use alloc::collections::VecDeque;
 use crate::sync::{Mutex, MutexGuard, FairRwLock};
 use core::sync::atomic::{AtomicBool, AtomicU32, AtomicI32, Ordering};
 use core::{cmp, hash, fmt, mem};
@@ -247,12 +248,15 @@ impl ChannelMessageHandler for ErroringMessageHandler {
        fn handle_stfu(&self, their_node_id: &PublicKey, msg: &msgs::Stfu) {
                ErroringMessageHandler::push_error(&self, their_node_id, msg.channel_id);
        }
+       #[cfg(splicing)]
        fn handle_splice(&self, their_node_id: &PublicKey, msg: &msgs::Splice) {
                ErroringMessageHandler::push_error(&self, their_node_id, msg.channel_id);
        }
+       #[cfg(splicing)]
        fn handle_splice_ack(&self, their_node_id: &PublicKey, msg: &msgs::SpliceAck) {
                ErroringMessageHandler::push_error(&self, their_node_id, msg.channel_id);
        }
+       #[cfg(splicing)]
        fn handle_splice_locked(&self, their_node_id: &PublicKey, msg: &msgs::SpliceLocked) {
                ErroringMessageHandler::push_error(&self, their_node_id, msg.channel_id);
        }
@@ -1474,7 +1478,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                let networks = self.message_handler.chan_handler.get_chain_hashes();
                                                                let resp = msgs::Init { features, networks, remote_network_address: filter_addresses(peer.their_socket_address.clone()) };
                                                                self.enqueue_message(peer, &resp);
-                                                               peer.awaiting_pong_timer_tick_intervals = 0;
                                                        },
                                                        NextNoiseStep::ActThree => {
                                                                let their_node_id = try_potential_handleerror!(peer,
@@ -1487,7 +1490,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                let networks = self.message_handler.chan_handler.get_chain_hashes();
                                                                let resp = msgs::Init { features, networks, remote_network_address: filter_addresses(peer.their_socket_address.clone()) };
                                                                self.enqueue_message(peer, &resp);
-                                                               peer.awaiting_pong_timer_tick_intervals = 0;
                                                        },
                                                        NextNoiseStep::NoiseComplete => {
                                                                if peer.pending_read_is_header {
@@ -1591,15 +1593,37 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
        }
 
        /// Process an incoming message and return a decision (ok, lightning error, peer handling error) regarding the next action with the peer
+       ///
        /// Returns the message back if it needs to be broadcasted to all other peers.
        fn handle_message(
                &self,
                peer_mutex: &Mutex<Peer>,
-               mut peer_lock: MutexGuard<Peer>,
-               message: wire::Message<<<CMH as core::ops::Deref>::Target as wire::CustomMessageReader>::CustomMessage>
-       ) -> Result<Option<wire::Message<<<CMH as core::ops::Deref>::Target as wire::CustomMessageReader>::CustomMessage>>, MessageHandlingError> {
+               peer_lock: MutexGuard<Peer>,
+               message: wire::Message<<<CMH as Deref>::Target as wire::CustomMessageReader>::CustomMessage>
+       ) -> Result<Option<wire::Message<<<CMH as Deref>::Target as wire::CustomMessageReader>::CustomMessage>>, MessageHandlingError> {
                let their_node_id = peer_lock.their_node_id.clone().expect("We know the peer's public key by the time we receive messages").0;
                let logger = WithContext::from(&self.logger, Some(their_node_id), None);
+
+               let message = match self.do_handle_message_holding_peer_lock(peer_lock, message, &their_node_id, &logger)? {
+                       Some(processed_message) => processed_message,
+                       None => return Ok(None),
+               };
+
+               self.do_handle_message_without_peer_lock(peer_mutex, message, &their_node_id, &logger)
+       }
+
+       // Conducts all message processing that requires us to hold the `peer_lock`.
+       //
+       // Returns `None` if the message was fully processed and otherwise returns the message back to
+       // allow it to be subsequently processed by `do_handle_message_without_peer_lock`.
+       fn do_handle_message_holding_peer_lock<'a>(
+               &self,
+               mut peer_lock: MutexGuard<Peer>,
+               message: wire::Message<<<CMH as Deref>::Target as wire::CustomMessageReader>::CustomMessage>,
+               their_node_id: &PublicKey,
+               logger: &WithContext<'a, L>
+       ) -> Result<Option<wire::Message<<<CMH as Deref>::Target as wire::CustomMessageReader>::CustomMessage>>, MessageHandlingError>
+       {
                peer_lock.received_message_since_timer_tick = true;
 
                // Need an Init as first message
@@ -1658,6 +1682,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                return Err(PeerHandleError { }.into());
                        }
 
+                       peer_lock.awaiting_pong_timer_tick_intervals = 0;
                        peer_lock.their_features = Some(msg.features);
                        return Ok(None);
                } else if peer_lock.their_features.is_none() {
@@ -1680,8 +1705,20 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                        peer_lock.received_channel_announce_since_backlogged = true;
                }
 
-               mem::drop(peer_lock);
+               Ok(Some(message))
+       }
 
+       // Conducts all message processing that doesn't require us to hold the `peer_lock`.
+       //
+       // Returns the message back if it needs to be broadcasted to all other peers.
+       fn do_handle_message_without_peer_lock<'a>(
+               &self,
+               peer_mutex: &Mutex<Peer>,
+               message: wire::Message<<<CMH as Deref>::Target as wire::CustomMessageReader>::CustomMessage>,
+               their_node_id: &PublicKey,
+               logger: &WithContext<'a, L>
+       ) -> Result<Option<wire::Message<<<CMH as Deref>::Target as wire::CustomMessageReader>::CustomMessage>>, MessageHandlingError>
+       {
                if is_gossip_msg(message.type_id()) {
                        log_gossip!(logger, "Received message {:?} from {}", message, log_pubkey!(their_node_id));
                } else {
@@ -1750,13 +1787,16 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                self.message_handler.chan_handler.handle_stfu(&their_node_id, &msg);
                        }
 
+                       #[cfg(splicing)]
                        // Splicing messages:
                        wire::Message::Splice(msg) => {
                                self.message_handler.chan_handler.handle_splice(&their_node_id, &msg);
                        }
+                       #[cfg(splicing)]
                        wire::Message::SpliceAck(msg) => {
                                self.message_handler.chan_handler.handle_splice_ack(&their_node_id, &msg);
                        }
+                       #[cfg(splicing)]
                        wire::Message::SpliceLocked(msg) => {
                                self.message_handler.chan_handler.handle_splice_locked(&their_node_id, &msg);
                        }
@@ -1883,7 +1923,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                Ok(should_forward)
        }
 
-       fn forward_broadcast_msg(&self, peers: &HashMap<Descriptor, Mutex<Peer>>, msg: &wire::Message<<<CMH as core::ops::Deref>::Target as wire::CustomMessageReader>::CustomMessage>, except_node: Option<&PublicKey>) {
+       fn forward_broadcast_msg(&self, peers: &HashMap<Descriptor, Mutex<Peer>>, msg: &wire::Message<<<CMH as Deref>::Target as wire::CustomMessageReader>::CustomMessage>, except_node: Option<&PublicKey>) {
                match msg {
                        wire::Message::ChannelAnnouncement(ref msg) => {
                                log_gossip!(self.logger, "Sending message to all peers except {:?} or the announced channel's counterparties: {:?}", except_node, msg);
@@ -2275,7 +2315,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                        // We do not have the peers write lock, so we just store that we're
                                                                        // about to disconnect the peer and do it after we finish
                                                                        // processing most messages.
-                                                                       let msg = msg.map(|msg| wire::Message::<<<CMH as core::ops::Deref>::Target as wire::CustomMessageReader>::CustomMessage>::Error(msg));
+                                                                       let msg = msg.map(|msg| wire::Message::<<<CMH as Deref>::Target as wire::CustomMessageReader>::CustomMessage>::Error(msg));
                                                                        peers_to_disconnect.insert(node_id, msg);
                                                                },
                                                                msgs::ErrorAction::DisconnectPeerWithWarning { msg } => {
@@ -2639,7 +2679,7 @@ mod tests {
        use crate::ln::ChannelId;
        use crate::ln::features::{InitFeatures, NodeFeatures};
        use crate::ln::peer_channel_encryptor::PeerChannelEncryptor;
-       use crate::ln::peer_handler::{CustomMessageHandler, PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler, filter_addresses};
+       use crate::ln::peer_handler::{CustomMessageHandler, PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler, filter_addresses, ErroringMessageHandler, MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER};
        use crate::ln::{msgs, wire};
        use crate::ln::msgs::{LightningError, SocketAddress};
        use crate::util::test_utils;
@@ -2648,11 +2688,13 @@ mod tests {
        use bitcoin::blockdata::constants::ChainHash;
        use bitcoin::secp256k1::{PublicKey, SecretKey};
 
-       use crate::prelude::*;
        use crate::sync::{Arc, Mutex};
        use core::convert::Infallible;
        use core::sync::atomic::{AtomicBool, Ordering};
 
+       #[allow(unused_imports)]
+       use crate::prelude::*;
+
        #[derive(Clone)]
        struct FileDescriptor {
                fd: u16,
@@ -3179,6 +3221,105 @@ mod tests {
                assert!(peers[0].read_event(&mut fd_a, &b_data).is_err());
        }
 
+       #[test]
+       fn test_inbound_conn_handshake_complete_awaiting_pong() {
+               // Test that we do not disconnect an outbound peer after the noise handshake completes due
+               // to a pong timeout for a ping that was never sent if a timer tick fires after we send act
+               // two of the noise handshake along with our init message but before we receive their init
+               // message.
+               let logger = test_utils::TestLogger::new();
+               let node_signer_a = test_utils::TestNodeSigner::new(SecretKey::from_slice(&[42; 32]).unwrap());
+               let node_signer_b = test_utils::TestNodeSigner::new(SecretKey::from_slice(&[43; 32]).unwrap());
+               let peer_a = PeerManager::new(MessageHandler {
+                       chan_handler: ErroringMessageHandler::new(),
+                       route_handler: IgnoringMessageHandler {},
+                       onion_message_handler: IgnoringMessageHandler {},
+                       custom_message_handler: IgnoringMessageHandler {},
+               }, 0, &[0; 32], &logger, &node_signer_a);
+               let peer_b = PeerManager::new(MessageHandler {
+                       chan_handler: ErroringMessageHandler::new(),
+                       route_handler: IgnoringMessageHandler {},
+                       onion_message_handler: IgnoringMessageHandler {},
+                       custom_message_handler: IgnoringMessageHandler {},
+               }, 0, &[1; 32], &logger, &node_signer_b);
+
+               let a_id = node_signer_a.get_node_id(Recipient::Node).unwrap();
+               let mut fd_a = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
+               let mut fd_b = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
+
+               // Exchange messages with both peers until they both complete the init handshake.
+               let act_one = peer_b.new_outbound_connection(a_id, fd_b.clone(), None).unwrap();
+               peer_a.new_inbound_connection(fd_a.clone(), None).unwrap();
+
+               assert_eq!(peer_a.read_event(&mut fd_a, &act_one).unwrap(), false);
+               peer_a.process_events();
+
+               let act_two = fd_a.outbound_data.lock().unwrap().split_off(0);
+               assert_eq!(peer_b.read_event(&mut fd_b, &act_two).unwrap(), false);
+               peer_b.process_events();
+
+               // Calling this here triggers the race on inbound connections.
+               peer_b.timer_tick_occurred();
+
+               let act_three_with_init_b = fd_b.outbound_data.lock().unwrap().split_off(0);
+               assert!(!peer_a.peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().handshake_complete());
+               assert_eq!(peer_a.read_event(&mut fd_a, &act_three_with_init_b).unwrap(), false);
+               peer_a.process_events();
+               assert!(peer_a.peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().handshake_complete());
+
+               let init_a = fd_a.outbound_data.lock().unwrap().split_off(0);
+               assert!(!init_a.is_empty());
+
+               assert!(!peer_b.peers.read().unwrap().get(&fd_b).unwrap().lock().unwrap().handshake_complete());
+               assert_eq!(peer_b.read_event(&mut fd_b, &init_a).unwrap(), false);
+               peer_b.process_events();
+               assert!(peer_b.peers.read().unwrap().get(&fd_b).unwrap().lock().unwrap().handshake_complete());
+
+               // Make sure we're still connected.
+               assert_eq!(peer_b.peers.read().unwrap().len(), 1);
+
+               // B should send a ping on the first timer tick after `handshake_complete`.
+               assert!(fd_b.outbound_data.lock().unwrap().split_off(0).is_empty());
+               peer_b.timer_tick_occurred();
+               peer_b.process_events();
+               assert!(!fd_b.outbound_data.lock().unwrap().split_off(0).is_empty());
+
+               let mut send_warning = || {
+                       {
+                               let peers = peer_a.peers.read().unwrap();
+                               let mut peer_b = peers.get(&fd_a).unwrap().lock().unwrap();
+                               peer_a.enqueue_message(&mut peer_b, &msgs::WarningMessage {
+                                       channel_id: ChannelId([0; 32]),
+                                       data: "no disconnect plz".to_string(),
+                               });
+                       }
+                       peer_a.process_events();
+                       let msg = fd_a.outbound_data.lock().unwrap().split_off(0);
+                       assert!(!msg.is_empty());
+                       assert_eq!(peer_b.read_event(&mut fd_b, &msg).unwrap(), false);
+                       peer_b.process_events();
+               };
+
+               // Fire more ticks until we reach the pong timeout. We send any message except pong to
+               // pretend the connection is still alive.
+               send_warning();
+               for _ in 0..MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER {
+                       peer_b.timer_tick_occurred();
+                       send_warning();
+               }
+               assert_eq!(peer_b.peers.read().unwrap().len(), 1);
+
+               // One more tick should enforce the pong timeout.
+               peer_b.timer_tick_occurred();
+               assert_eq!(peer_b.peers.read().unwrap().len(), 0);
+       }
+
        #[test]
        fn test_filter_addresses(){
                // Tests the filter_addresses function.