Send channel_reestablish out-of-band to ensure ordered deliver
[rust-lightning] / src / ln / channelmanager.rs
index 7d11fe987f2c67d07a8a3fbc2cbe43bb3df1a7e0..c5356e291789215d98ebbdae6cfa676dcf6d01f7 100644 (file)
@@ -2676,9 +2676,10 @@ impl ChannelMessageHandler for ChannelManager {
                }
        }
 
-       fn peer_connected(&self, their_node_id: &PublicKey) -> Vec<msgs::ChannelReestablish> {
-               let mut res = Vec::new();
-               let mut channel_state = self.channel_state.lock().unwrap();
+       fn peer_connected(&self, their_node_id: &PublicKey) {
+               let mut channel_state_lock = self.channel_state.lock().unwrap();
+               let channel_state = channel_state_lock.borrow_parts();
+               let pending_msg_events = channel_state.pending_msg_events;
                channel_state.by_id.retain(|_, chan| {
                        if chan.get_their_node_id() == *their_node_id {
                                if !chan.have_received_message() {
@@ -2688,13 +2689,15 @@ impl ChannelMessageHandler for ChannelManager {
                                        // drop it.
                                        false
                                } else {
-                                       res.push(chan.get_channel_reestablish());
+                                       pending_msg_events.push(events::MessageSendEvent::SendChannelReestablish {
+                                               node_id: chan.get_their_node_id(),
+                                               msg: chan.get_channel_reestablish(),
+                                       });
                                        true
                                }
                        } else { true }
                });
                //TODO: Also re-broadcast announcement_signatures
-               res
        }
 
        fn handle_error(&self, their_node_id: &PublicKey, msg: &msgs::ErrorMessage) {
@@ -5197,6 +5200,23 @@ mod tests {
                assert_eq!(channel_state.short_to_id.len(), 0);
        }
 
+       macro_rules! get_chan_reestablish_msgs {
+               ($src_node: expr, $dst_node: expr) => {
+                       {
+                               let mut res = Vec::with_capacity(1);
+                               for msg in $src_node.node.get_and_clear_pending_msg_events() {
+                                       if let MessageSendEvent::SendChannelReestablish { ref node_id, ref msg } = msg {
+                                               assert_eq!(*node_id, $dst_node.node.get_our_node_id());
+                                               res.push(msg.clone());
+                                       } else {
+                                               panic!("Unexpected event")
+                                       }
+                               }
+                               res
+                       }
+               }
+       }
+
        macro_rules! handle_chan_reestablish_msgs {
                ($src_node: expr, $dst_node: expr) => {
                        {
@@ -5255,8 +5275,10 @@ mod tests {
        /// pending_htlc_adds includes both the holding cell and in-flight update_add_htlcs, whereas
        /// for claims/fails they are separated out.
        fn reconnect_nodes(node_a: &Node, node_b: &Node, pre_all_htlcs: bool, pending_htlc_adds: (i64, i64), pending_htlc_claims: (usize, usize), pending_cell_htlc_claims: (usize, usize), pending_cell_htlc_fails: (usize, usize), pending_raa: (bool, bool)) {
-               let reestablish_1 = node_a.node.peer_connected(&node_b.node.get_our_node_id());
-               let reestablish_2 = node_b.node.peer_connected(&node_a.node.get_our_node_id());
+               node_a.node.peer_connected(&node_b.node.get_our_node_id());
+               let reestablish_1 = get_chan_reestablish_msgs!(node_a, node_b);
+               node_b.node.peer_connected(&node_a.node.get_our_node_id());
+               let reestablish_2 = get_chan_reestablish_msgs!(node_b, node_a);
 
                let mut resp_1 = Vec::new();
                for msg in reestablish_1 {
@@ -5754,9 +5776,11 @@ mod tests {
                nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
                nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
 
-               let reestablish_1 = nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id());
+               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id());
+               let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
                assert_eq!(reestablish_1.len(), 1);
-               let reestablish_2 = nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id());
+               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id());
+               let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
                assert_eq!(reestablish_2.len(), 1);
 
                nodes[0].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &reestablish_2[0]).unwrap();
@@ -6042,9 +6066,11 @@ mod tests {
                        nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
                        nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
 
-                       let reestablish_1 = nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id());
+                       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id());
+                       let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
                        assert_eq!(reestablish_1.len(), 1);
-                       let reestablish_2 = nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id());
+                       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id());
+                       let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
                        assert_eq!(reestablish_2.len(), 1);
 
                        nodes[0].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &reestablish_2[0]).unwrap();
@@ -6062,9 +6088,11 @@ mod tests {
                        assert!(nodes[0].node.get_and_clear_pending_events().is_empty());
                        assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty());
 
-                       let reestablish_1 = nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id());
+                       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id());
+                       let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
                        assert_eq!(reestablish_1.len(), 1);
-                       let reestablish_2 = nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id());
+                       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id());
+                       let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
                        assert_eq!(reestablish_2.len(), 1);
 
                        nodes[0].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &reestablish_2[0]).unwrap();