Fix race between handshake_complete and timer_tick_occurred
authorWilmer Paulino <wilmer@wilmerpaulino.com>
Fri, 12 Apr 2024 17:23:45 +0000 (10:23 -0700)
committerWilmer Paulino <wilmer@wilmerpaulino.com>
Fri, 12 Apr 2024 17:57:34 +0000 (10:57 -0700)
The initial noise handshake on connection establishment must complete
within a single timer tick. This timeout is enforced via
`awaiting_pong_timer_tick_intervals` whenever a timer tick fires while
our handshake has yet to complete. Currently, on an inbound connection,
if a timer tick fires after we've sent act two of the noise handshake
along with our init message and before receiving the counterparty's init
message, we begin enforcing such timeout. Even if we immediately
continue to process the counterparty's init message to complete to
handshake, the timeout enforcement is not cleared. With the handshake
complete, `awaiting_pong_timer_tick_intervals` is now tracked to enforce
a pong timeout, except a ping was never actually sent. If a single timer
tick fires again without having received a message from the peer, or
enough timer ticks fire to trigger the
`MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER` logic, we'll end up
disconnecting the peer due to a timeout for a pong we'll never receive.

We fix this by always resetting `awaiting_pong_timer_tick_intervals`
upon processing our counterparty's init message.

lightning/src/ln/peer_handler.rs

index 9c27a23467ce7a8e7134225f352fc961a3998ace..da96693f079cf79d6a642bd6c8f772b8646795b0 100644 (file)
@@ -1475,7 +1475,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                let networks = self.message_handler.chan_handler.get_chain_hashes();
                                                                let resp = msgs::Init { features, networks, remote_network_address: filter_addresses(peer.their_socket_address.clone()) };
                                                                self.enqueue_message(peer, &resp);
-                                                               peer.awaiting_pong_timer_tick_intervals = 0;
                                                        },
                                                        NextNoiseStep::ActThree => {
                                                                let their_node_id = try_potential_handleerror!(peer,
@@ -1488,7 +1487,6 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                let networks = self.message_handler.chan_handler.get_chain_hashes();
                                                                let resp = msgs::Init { features, networks, remote_network_address: filter_addresses(peer.their_socket_address.clone()) };
                                                                self.enqueue_message(peer, &resp);
-                                                               peer.awaiting_pong_timer_tick_intervals = 0;
                                                        },
                                                        NextNoiseStep::NoiseComplete => {
                                                                if peer.pending_read_is_header {
@@ -1681,6 +1679,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                return Err(PeerHandleError { }.into());
                        }
 
+                       peer_lock.awaiting_pong_timer_tick_intervals = 0;
                        peer_lock.their_features = Some(msg.features);
                        return Ok(None);
                } else if peer_lock.their_features.is_none() {
@@ -2674,7 +2673,7 @@ mod tests {
        use crate::ln::ChannelId;
        use crate::ln::features::{InitFeatures, NodeFeatures};
        use crate::ln::peer_channel_encryptor::PeerChannelEncryptor;
-       use crate::ln::peer_handler::{CustomMessageHandler, PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler, filter_addresses};
+       use crate::ln::peer_handler::{CustomMessageHandler, PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler, filter_addresses, ErroringMessageHandler, MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER};
        use crate::ln::{msgs, wire};
        use crate::ln::msgs::{LightningError, SocketAddress};
        use crate::util::test_utils;
@@ -3216,6 +3215,105 @@ mod tests {
                assert!(peers[0].read_event(&mut fd_a, &b_data).is_err());
        }
 
+       #[test]
+       fn test_inbound_conn_handshake_complete_awaiting_pong() {
+               // Test that we do not disconnect an outbound peer after the noise handshake completes due
+               // to a pong timeout for a ping that was never sent if a timer tick fires after we send act
+               // two of the noise handshake along with our init message but before we receive their init
+               // message.
+               let logger = test_utils::TestLogger::new();
+               let node_signer_a = test_utils::TestNodeSigner::new(SecretKey::from_slice(&[42; 32]).unwrap());
+               let node_signer_b = test_utils::TestNodeSigner::new(SecretKey::from_slice(&[43; 32]).unwrap());
+               let peer_a = PeerManager::new(MessageHandler {
+                       chan_handler: ErroringMessageHandler::new(),
+                       route_handler: IgnoringMessageHandler {},
+                       onion_message_handler: IgnoringMessageHandler {},
+                       custom_message_handler: IgnoringMessageHandler {},
+               }, 0, &[0; 32], &logger, &node_signer_a);
+               let peer_b = PeerManager::new(MessageHandler {
+                       chan_handler: ErroringMessageHandler::new(),
+                       route_handler: IgnoringMessageHandler {},
+                       onion_message_handler: IgnoringMessageHandler {},
+                       custom_message_handler: IgnoringMessageHandler {},
+               }, 0, &[1; 32], &logger, &node_signer_b);
+
+               let a_id = node_signer_a.get_node_id(Recipient::Node).unwrap();
+               let mut fd_a = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
+               let mut fd_b = FileDescriptor {
+                       fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())),
+                       disconnect: Arc::new(AtomicBool::new(false)),
+               };
+
+               // Exchange messages with both peers until they both complete the init handshake.
+               let act_one = peer_b.new_outbound_connection(a_id, fd_b.clone(), None).unwrap();
+               peer_a.new_inbound_connection(fd_a.clone(), None).unwrap();
+
+               assert_eq!(peer_a.read_event(&mut fd_a, &act_one).unwrap(), false);
+               peer_a.process_events();
+
+               let act_two = fd_a.outbound_data.lock().unwrap().split_off(0);
+               assert_eq!(peer_b.read_event(&mut fd_b, &act_two).unwrap(), false);
+               peer_b.process_events();
+
+               // Calling this here triggers the race on inbound connections.
+               peer_b.timer_tick_occurred();
+
+               let act_three_with_init_b = fd_b.outbound_data.lock().unwrap().split_off(0);
+               assert!(!peer_a.peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().handshake_complete());
+               assert_eq!(peer_a.read_event(&mut fd_a, &act_three_with_init_b).unwrap(), false);
+               peer_a.process_events();
+               assert!(peer_a.peers.read().unwrap().get(&fd_a).unwrap().lock().unwrap().handshake_complete());
+
+               let init_a = fd_a.outbound_data.lock().unwrap().split_off(0);
+               assert!(!init_a.is_empty());
+
+               assert!(!peer_b.peers.read().unwrap().get(&fd_b).unwrap().lock().unwrap().handshake_complete());
+               assert_eq!(peer_b.read_event(&mut fd_b, &init_a).unwrap(), false);
+               peer_b.process_events();
+               assert!(peer_b.peers.read().unwrap().get(&fd_b).unwrap().lock().unwrap().handshake_complete());
+
+               // Make sure we're still connected.
+               assert_eq!(peer_b.peers.read().unwrap().len(), 1);
+
+               // B should send a ping on the first timer tick after `handshake_complete`.
+               assert!(fd_b.outbound_data.lock().unwrap().split_off(0).is_empty());
+               peer_b.timer_tick_occurred();
+               peer_b.process_events();
+               assert!(!fd_b.outbound_data.lock().unwrap().split_off(0).is_empty());
+
+               let mut send_warning = || {
+                       {
+                               let peers = peer_a.peers.read().unwrap();
+                               let mut peer_b = peers.get(&fd_a).unwrap().lock().unwrap();
+                               peer_a.enqueue_message(&mut peer_b, &msgs::WarningMessage {
+                                       channel_id: ChannelId([0; 32]),
+                                       data: "no disconnect plz".to_string(),
+                               });
+                       }
+                       peer_a.process_events();
+                       let msg = fd_a.outbound_data.lock().unwrap().split_off(0);
+                       assert!(!msg.is_empty());
+                       assert_eq!(peer_b.read_event(&mut fd_b, &msg).unwrap(), false);
+                       peer_b.process_events();
+               };
+
+               // Fire more ticks until we reach the pong timeout. We send any message except pong to
+               // pretend the connection is still alive.
+               send_warning();
+               for _ in 0..MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER {
+                       peer_b.timer_tick_occurred();
+                       send_warning();
+               }
+               assert_eq!(peer_b.peers.read().unwrap().len(), 1);
+
+               // One more tick should enforce the pong timeout.
+               peer_b.timer_tick_occurred();
+               assert_eq!(peer_b.peers.read().unwrap().len(), 0);
+       }
+
        #[test]
        fn test_filter_addresses(){
                // Tests the filter_addresses function.