Merge pull request #1897 from TheBlueMatt/2022-11-monitor-updates-always-async
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Wed, 22 Feb 2023 19:12:31 +0000 (19:12 +0000)
committerGitHub <noreply@github.com>
Wed, 22 Feb 2023 19:12:31 +0000 (19:12 +0000)
Always process `ChannelMonitorUpdate`s asynchronously

24 files changed:
fuzz/src/chanmon_consistency.rs
fuzz/src/full_stack.rs
fuzz/src/indexedmap.rs
lightning-background-processor/src/lib.rs
lightning-invoice/src/utils.rs
lightning-net-tokio/src/lib.rs
lightning/src/ln/chanmon_update_fail_tests.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs
lightning/src/ln/functional_tests.rs
lightning/src/ln/msgs.rs
lightning/src/ln/onion_route_tests.rs
lightning/src/ln/outbound_payment.rs
lightning/src/ln/payment_tests.rs
lightning/src/ln/peer_handler.rs
lightning/src/ln/priv_short_conf_tests.rs
lightning/src/ln/reload_tests.rs
lightning/src/ln/reorg_tests.rs
lightning/src/ln/shutdown_tests.rs
lightning/src/onion_message/functional_tests.rs
lightning/src/onion_message/messenger.rs
lightning/src/routing/gossip.rs
lightning/src/util/indexed_map.rs
lightning/src/util/test_utils.rs

index 41b10fb5a32d051f8d341fba2709b4033a60be6b..4386302f845b0d95279bbba99c975dede237df8a 100644 (file)
@@ -474,8 +474,8 @@ pub fn do_test<Out: Output>(data: &[u8], underlying_out: Out) {
        let mut channel_txn = Vec::new();
        macro_rules! make_channel {
                ($source: expr, $dest: expr, $chan_id: expr) => { {
-                       $source.peer_connected(&$dest.get_our_node_id(), &Init { features: $dest.init_features(), remote_network_address: None }).unwrap();
-                       $dest.peer_connected(&$source.get_our_node_id(), &Init { features: $source.init_features(), remote_network_address: None }).unwrap();
+                       $source.peer_connected(&$dest.get_our_node_id(), &Init { features: $dest.init_features(), remote_network_address: None }, true).unwrap();
+                       $dest.peer_connected(&$source.get_our_node_id(), &Init { features: $source.init_features(), remote_network_address: None }, false).unwrap();
 
                        $source.create_channel($dest.get_our_node_id(), 100_000, 42, 0, None).unwrap();
                        let open_channel = {
@@ -979,31 +979,31 @@ pub fn do_test<Out: Output>(data: &[u8], underlying_out: Out) {
 
                        0x0c => {
                                if !chan_a_disconnected {
-                                       nodes[0].peer_disconnected(&nodes[1].get_our_node_id(), false);
-                                       nodes[1].peer_disconnected(&nodes[0].get_our_node_id(), false);
+                                       nodes[0].peer_disconnected(&nodes[1].get_our_node_id());
+                                       nodes[1].peer_disconnected(&nodes[0].get_our_node_id());
                                        chan_a_disconnected = true;
                                        drain_msg_events_on_disconnect!(0);
                                }
                        },
                        0x0d => {
                                if !chan_b_disconnected {
-                                       nodes[1].peer_disconnected(&nodes[2].get_our_node_id(), false);
-                                       nodes[2].peer_disconnected(&nodes[1].get_our_node_id(), false);
+                                       nodes[1].peer_disconnected(&nodes[2].get_our_node_id());
+                                       nodes[2].peer_disconnected(&nodes[1].get_our_node_id());
                                        chan_b_disconnected = true;
                                        drain_msg_events_on_disconnect!(2);
                                }
                        },
                        0x0e => {
                                if chan_a_disconnected {
-                                       nodes[0].peer_connected(&nodes[1].get_our_node_id(), &Init { features: nodes[1].init_features(), remote_network_address: None }).unwrap();
-                                       nodes[1].peer_connected(&nodes[0].get_our_node_id(), &Init { features: nodes[0].init_features(), remote_network_address: None }).unwrap();
+                                       nodes[0].peer_connected(&nodes[1].get_our_node_id(), &Init { features: nodes[1].init_features(), remote_network_address: None }, true).unwrap();
+                                       nodes[1].peer_connected(&nodes[0].get_our_node_id(), &Init { features: nodes[0].init_features(), remote_network_address: None }, false).unwrap();
                                        chan_a_disconnected = false;
                                }
                        },
                        0x0f => {
                                if chan_b_disconnected {
-                                       nodes[1].peer_connected(&nodes[2].get_our_node_id(), &Init { features: nodes[2].init_features(), remote_network_address: None }).unwrap();
-                                       nodes[2].peer_connected(&nodes[1].get_our_node_id(), &Init { features: nodes[1].init_features(), remote_network_address: None }).unwrap();
+                                       nodes[1].peer_connected(&nodes[2].get_our_node_id(), &Init { features: nodes[2].init_features(), remote_network_address: None }, true).unwrap();
+                                       nodes[2].peer_connected(&nodes[1].get_our_node_id(), &Init { features: nodes[1].init_features(), remote_network_address: None }, false).unwrap();
                                        chan_b_disconnected = false;
                                }
                        },
@@ -1040,7 +1040,7 @@ pub fn do_test<Out: Output>(data: &[u8], underlying_out: Out) {
 
                        0x2c => {
                                if !chan_a_disconnected {
-                                       nodes[1].peer_disconnected(&nodes[0].get_our_node_id(), false);
+                                       nodes[1].peer_disconnected(&nodes[0].get_our_node_id());
                                        chan_a_disconnected = true;
                                        drain_msg_events_on_disconnect!(0);
                                }
@@ -1054,14 +1054,14 @@ pub fn do_test<Out: Output>(data: &[u8], underlying_out: Out) {
                        },
                        0x2d => {
                                if !chan_a_disconnected {
-                                       nodes[0].peer_disconnected(&nodes[1].get_our_node_id(), false);
+                                       nodes[0].peer_disconnected(&nodes[1].get_our_node_id());
                                        chan_a_disconnected = true;
                                        nodes[0].get_and_clear_pending_msg_events();
                                        ab_events.clear();
                                        ba_events.clear();
                                }
                                if !chan_b_disconnected {
-                                       nodes[2].peer_disconnected(&nodes[1].get_our_node_id(), false);
+                                       nodes[2].peer_disconnected(&nodes[1].get_our_node_id());
                                        chan_b_disconnected = true;
                                        nodes[2].get_and_clear_pending_msg_events();
                                        bc_events.clear();
@@ -1073,7 +1073,7 @@ pub fn do_test<Out: Output>(data: &[u8], underlying_out: Out) {
                        },
                        0x2e => {
                                if !chan_b_disconnected {
-                                       nodes[1].peer_disconnected(&nodes[2].get_our_node_id(), false);
+                                       nodes[1].peer_disconnected(&nodes[2].get_our_node_id());
                                        chan_b_disconnected = true;
                                        drain_msg_events_on_disconnect!(2);
                                }
@@ -1198,13 +1198,13 @@ pub fn do_test<Out: Output>(data: &[u8], underlying_out: Out) {
 
                                // Next, make sure peers are all connected to each other
                                if chan_a_disconnected {
-                                       nodes[0].peer_connected(&nodes[1].get_our_node_id(), &Init { features: nodes[1].init_features(), remote_network_address: None }).unwrap();
-                                       nodes[1].peer_connected(&nodes[0].get_our_node_id(), &Init { features: nodes[0].init_features(), remote_network_address: None }).unwrap();
+                                       nodes[0].peer_connected(&nodes[1].get_our_node_id(), &Init { features: nodes[1].init_features(), remote_network_address: None }, true).unwrap();
+                                       nodes[1].peer_connected(&nodes[0].get_our_node_id(), &Init { features: nodes[0].init_features(), remote_network_address: None }, false).unwrap();
                                        chan_a_disconnected = false;
                                }
                                if chan_b_disconnected {
-                                       nodes[1].peer_connected(&nodes[2].get_our_node_id(), &Init { features: nodes[2].init_features(), remote_network_address: None }).unwrap();
-                                       nodes[2].peer_connected(&nodes[1].get_our_node_id(), &Init { features: nodes[1].init_features(), remote_network_address: None }).unwrap();
+                                       nodes[1].peer_connected(&nodes[2].get_our_node_id(), &Init { features: nodes[2].init_features(), remote_network_address: None }, true).unwrap();
+                                       nodes[2].peer_connected(&nodes[1].get_our_node_id(), &Init { features: nodes[1].init_features(), remote_network_address: None }, false).unwrap();
                                        chan_b_disconnected = false;
                                }
 
index 9b3b76c2b3a7d4e60f9d7354406411b2b42a294f..ca42466880a110e3e0e24873efe85f377a4974ad 100644 (file)
@@ -634,11 +634,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                        if let Err(e) = channelmanager.funding_transaction_generated(&funding_generation.0, &funding_generation.1, tx.clone()) {
                                                // It's possible the channel has been closed in the mean time, but any other
                                                // failure may be a bug.
-                                               if let APIError::ChannelUnavailable { err } = e {
-                                                       if !err.starts_with("Can't find a peer matching the passed counterparty node_id ") {
-                                                               assert_eq!(err, "No such channel");
-                                                       }
-                                               } else { panic!(); }
+                                               if let APIError::ChannelUnavailable { .. } = e { } else { panic!(); }
                                        }
                                        pending_funding_signatures.insert(funding_output, tx);
                                }
index 795d6175bb5dc2dd8aee2b0cdce76ccc6a95c12b..7cbb8957ad247f807ff91f8851af44f0b8a546a8 100644 (file)
@@ -13,14 +13,27 @@ use hashbrown::HashSet;
 
 use crate::utils::test_logger;
 
-fn check_eq(btree: &BTreeMap<u8, u8>, indexed: &IndexedMap<u8, u8>) {
+use std::ops::{RangeBounds, Bound};
+
+struct ExclLowerInclUpper(u8, u8);
+impl RangeBounds<u8> for ExclLowerInclUpper {
+       fn start_bound(&self) -> Bound<&u8> { Bound::Excluded(&self.0) }
+       fn end_bound(&self) -> Bound<&u8> { Bound::Included(&self.1) }
+}
+struct ExclLowerExclUpper(u8, u8);
+impl RangeBounds<u8> for ExclLowerExclUpper {
+       fn start_bound(&self) -> Bound<&u8> { Bound::Excluded(&self.0) }
+       fn end_bound(&self) -> Bound<&u8> { Bound::Excluded(&self.1) }
+}
+
+fn check_eq(btree: &BTreeMap<u8, u8>, mut indexed: IndexedMap<u8, u8>) {
        assert_eq!(btree.len(), indexed.len());
        assert_eq!(btree.is_empty(), indexed.is_empty());
 
        let mut btree_clone = btree.clone();
        assert!(btree_clone == *btree);
        let mut indexed_clone = indexed.clone();
-       assert!(indexed_clone == *indexed);
+       assert!(indexed_clone == indexed);
 
        for k in 0..=255 {
                assert_eq!(btree.contains_key(&k), indexed.contains_key(&k));
@@ -43,16 +56,27 @@ fn check_eq(btree: &BTreeMap<u8, u8>, indexed: &IndexedMap<u8, u8>) {
        }
 
        const STRIDE: u8 = 16;
-       for k in 0..=255/STRIDE {
-               let lower_bound = k * STRIDE;
-               let upper_bound = lower_bound + (STRIDE - 1);
-               let mut btree_iter = btree.range(lower_bound..=upper_bound);
-               let mut indexed_iter = indexed.range(lower_bound..=upper_bound);
-               loop {
-                       let b_v = btree_iter.next();
-                       let i_v = indexed_iter.next();
-                       assert_eq!(b_v, i_v);
-                       if b_v.is_none() { break; }
+       for range_type in 0..4 {
+               for k in 0..=255/STRIDE {
+                       let lower_bound = k * STRIDE;
+                       let upper_bound = lower_bound + (STRIDE - 1);
+                       macro_rules! range { ($map: expr) => {
+                               match range_type {
+                                       0 => $map.range(lower_bound..upper_bound),
+                                       1 => $map.range(lower_bound..=upper_bound),
+                                       2 => $map.range(ExclLowerInclUpper(lower_bound, upper_bound)),
+                                       3 => $map.range(ExclLowerExclUpper(lower_bound, upper_bound)),
+                                       _ => unreachable!(),
+                               }
+                       } }
+                       let mut btree_iter = range!(btree);
+                       let mut indexed_iter = range!(indexed);
+                       loop {
+                               let b_v = btree_iter.next();
+                               let i_v = indexed_iter.next();
+                               assert_eq!(b_v, i_v);
+                               if b_v.is_none() { break; }
+                       }
                }
        }
 
@@ -91,7 +115,7 @@ pub fn do_test(data: &[u8]) {
                let prev_value_i = indexed.insert(tuple[0], tuple[1]);
                assert_eq!(prev_value_b, prev_value_i);
        }
-       check_eq(&btree, &indexed);
+       check_eq(&btree, indexed.clone());
 
        // Now, modify the maps in all the ways we have to do so, checking that the maps remain
        // equivalent as we go.
@@ -99,7 +123,7 @@ pub fn do_test(data: &[u8]) {
                *v = *k;
                *btree.get_mut(k).unwrap() = *k;
        }
-       check_eq(&btree, &indexed);
+       check_eq(&btree, indexed.clone());
 
        for k in 0..=255 {
                match btree.entry(k) {
@@ -124,7 +148,7 @@ pub fn do_test(data: &[u8]) {
                        },
                }
        }
-       check_eq(&btree, &indexed);
+       check_eq(&btree, indexed);
 }
 
 pub fn indexedmap_test<Out: test_logger::Output>(data: &[u8], _out: Out) {
index a21f00f867cb36c1dc4d5c596f29dc2af551c5de..97d0462eb5010980850c64b53dc37a07c512f476 100644 (file)
@@ -963,8 +963,8 @@ mod tests {
 
                for i in 0..num_nodes {
                        for j in (i+1)..num_nodes {
-                               nodes[i].node.peer_connected(&nodes[j].node.get_our_node_id(), &Init { features: nodes[j].node.init_features(), remote_network_address: None }).unwrap();
-                               nodes[j].node.peer_connected(&nodes[i].node.get_our_node_id(), &Init { features: nodes[i].node.init_features(), remote_network_address: None }).unwrap();
+                               nodes[i].node.peer_connected(&nodes[j].node.get_our_node_id(), &Init { features: nodes[j].node.init_features(), remote_network_address: None }, true).unwrap();
+                               nodes[j].node.peer_connected(&nodes[i].node.get_our_node_id(), &Init { features: nodes[i].node.init_features(), remote_network_address: None }, false).unwrap();
                        }
                }
 
index f6cc87fa72af2d0645b3b1538199b28f4265d5a3..868f0b297b1162da0551c8923b4ced23ab880ccb 100644 (file)
@@ -842,13 +842,13 @@ mod test {
 
                // With only one sufficient-value peer connected we should only get its hint
                scid_aliases.remove(&chan_b.0.short_channel_id_alias.unwrap());
-               nodes[0].node.peer_disconnected(&nodes[2].node.get_our_node_id(), false);
+               nodes[0].node.peer_disconnected(&nodes[2].node.get_our_node_id());
                match_invoice_routes(Some(1_000_000_000), &nodes[0], scid_aliases.clone());
 
                // If we don't have any sufficient-value peers connected we should get all hints with
                // sufficient value, even though there is a connected insufficient-value peer.
                scid_aliases.insert(chan_b.0.short_channel_id_alias.unwrap());
-               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
                match_invoice_routes(Some(1_000_000_000), &nodes[0], scid_aliases);
        }
 
index b259f77eff8f9ace67fec3b2a666122b0ae3c12a..81e6157c656f6e39e76525e33ca09e1b55870325 100644 (file)
@@ -617,7 +617,7 @@ mod tests {
                fn handle_channel_update(&self, _msg: &ChannelUpdate) -> Result<bool, LightningError> { Ok(false) }
                fn get_next_channel_announcement(&self, _starting_point: u64) -> Option<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> { None }
                fn get_next_node_announcement(&self, _starting_point: Option<&NodeId>) -> Option<NodeAnnouncement> { None }
-               fn peer_connected(&self, _their_node_id: &PublicKey, _init_msg: &Init) -> Result<(), ()> { Ok(()) }
+               fn peer_connected(&self, _their_node_id: &PublicKey, _init_msg: &Init, _inbound: bool) -> Result<(), ()> { Ok(()) }
                fn handle_reply_channel_range(&self, _their_node_id: &PublicKey, _msg: ReplyChannelRange) -> Result<(), LightningError> { Ok(()) }
                fn handle_reply_short_channel_ids_end(&self, _their_node_id: &PublicKey, _msg: ReplyShortChannelIdsEnd) -> Result<(), LightningError> { Ok(()) }
                fn handle_query_channel_range(&self, _their_node_id: &PublicKey, _msg: QueryChannelRange) -> Result<(), LightningError> { Ok(()) }
@@ -643,13 +643,13 @@ mod tests {
                fn handle_update_fee(&self, _their_node_id: &PublicKey, _msg: &UpdateFee) {}
                fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &AnnouncementSignatures) {}
                fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &ChannelUpdate) {}
-               fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
+               fn peer_disconnected(&self, their_node_id: &PublicKey) {
                        if *their_node_id == self.expected_pubkey {
                                self.disconnected_flag.store(true, Ordering::SeqCst);
                                self.pubkey_disconnected.clone().try_send(()).unwrap();
                        }
                }
-               fn peer_connected(&self, their_node_id: &PublicKey, _init_msg: &Init) -> Result<(), ()> {
+               fn peer_connected(&self, their_node_id: &PublicKey, _init_msg: &Init, _inbound: bool) -> Result<(), ()> {
                        if *their_node_id == self.expected_pubkey {
                                self.pubkey_connected.clone().try_send(()).unwrap();
                        }
index b1d43ca07a0d3ce6b62ba3d7e985db13758574f7..1532884fa605c5ba363720aa6701b602c83d92a2 100644 (file)
@@ -181,8 +181,8 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) {
        assert_eq!(nodes[0].node.list_channels().len(), 1);
 
        if disconnect {
-               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
                reconnect_nodes(&nodes[0], &nodes[1], (true, true), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
        }
 
@@ -234,8 +234,8 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) {
        assert_eq!(nodes[0].node.list_channels().len(), 1);
 
        if disconnect {
-               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
                reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
        }
 
@@ -337,8 +337,8 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) {
        };
 
        if disconnect_count & !disconnect_flags > 0 {
-               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
        }
 
        // Now fix monitor updating...
@@ -348,13 +348,13 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) {
        check_added_monitors!(nodes[0], 0);
 
        macro_rules! disconnect_reconnect_peers { () => { {
-               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
-               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
                let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
                assert_eq!(reestablish_1.len(), 1);
-               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
                let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
                assert_eq!(reestablish_2.len(), 1);
 
@@ -373,10 +373,10 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) {
                assert!(nodes[0].node.get_and_clear_pending_events().is_empty());
                assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty());
 
-               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
                let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
                assert_eq!(reestablish_1.len(), 1);
-               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
                let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
                assert_eq!(reestablish_2.len(), 1);
 
@@ -1110,8 +1110,8 @@ fn test_monitor_update_fail_reestablish() {
 
        let (payment_preimage, payment_hash, _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 1_000_000);
 
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
 
        nodes[2].node.claim_funds(payment_preimage);
        check_added_monitors!(nodes[2], 1);
@@ -1130,8 +1130,8 @@ fn test_monitor_update_fail_reestablish() {
        commitment_signed_dance!(nodes[1], nodes[2], updates.commitment_signed, false);
 
        chanmon_cfgs[1].persister.set_update_ret(ChannelMonitorUpdateStatus::InProgress);
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
 
        let as_reestablish = get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap();
        let bs_reestablish = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
@@ -1146,11 +1146,11 @@ fn test_monitor_update_fail_reestablish() {
        nodes[1].node.get_and_clear_pending_msg_events(); // Free the holding cell
        check_added_monitors!(nodes[1], 1);
 
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
 
        assert_eq!(get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap(), as_reestablish);
        assert_eq!(get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap(), bs_reestablish);
@@ -1315,15 +1315,15 @@ fn claim_while_disconnected_monitor_update_fail() {
        // Forward a payment for B to claim
        let (payment_preimage_1, payment_hash_1, _) = route_payment(&nodes[0], &[&nodes[1]], 1_000_000);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        nodes[1].node.claim_funds(payment_preimage_1);
        check_added_monitors!(nodes[1], 1);
        expect_payment_claimed!(nodes[1], payment_hash_1, 1_000_000);
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
 
        let as_reconnect = get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap();
        let bs_reconnect = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
@@ -1451,11 +1451,11 @@ fn monitor_failed_no_reestablish_response() {
 
        // Now disconnect and immediately reconnect, delivering the channel_reestablish while nodes[1]
        // is still failing to update monitors.
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
 
        let as_reconnect = get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap();
        let bs_reconnect = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
@@ -1879,8 +1879,8 @@ fn do_during_funding_monitor_fail(confirm_a_first: bool, restore_b_before_conf:
        }
 
        // Make sure nodes[1] isn't stupid enough to re-send the ChannelReady on reconnect
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
        reconnect_nodes(&nodes[0], &nodes[1], (false, confirm_a_first), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
        assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty());
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
@@ -2047,12 +2047,12 @@ fn test_pending_update_fee_ack_on_reconnect() {
        let bs_first_raa = get_event_msg!(nodes[1], MessageSendEvent::SendRevokeAndACK, nodes[0].node.get_our_node_id());
        // bs_first_raa is not delivered until it is re-generated after reconnect
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        let as_connect_msg = get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let bs_connect_msg = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
 
        nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &as_connect_msg);
@@ -2175,12 +2175,12 @@ fn do_update_fee_resend_test(deliver_update: bool, parallel_updates: bool) {
                assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty());
        }
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        let as_connect_msg = get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let bs_connect_msg = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
 
        nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &as_connect_msg);
@@ -2308,15 +2308,15 @@ fn do_channel_holding_cell_serialize(disconnect: bool, reload_a: bool) {
                        let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
                        reload_node!(nodes[0], &nodes[0].node.encode(), &[&chan_0_monitor_serialized], persister, new_chain_monitor, nodes_0_deserialized);
                } else {
-                       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+                       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
                }
-               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
                // Now reconnect the two
-               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
                let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
                assert_eq!(reestablish_1.len(), 1);
-               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
                let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
                assert_eq!(reestablish_2.len(), 1);
 
@@ -2499,8 +2499,8 @@ fn do_test_reconnect_dup_htlc_claims(htlc_status: HTLCStatusAtDupClaim, second_f
                assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
        }
 
-       nodes[1].node.peer_disconnected(&nodes[2].node.get_our_node_id(), false);
-       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[2].node.get_our_node_id());
+       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id());
 
        if second_fails {
                reconnect_nodes(&nodes[1], &nodes[2], (false, false), (0, 0), (0, 0), (1, 0), (0, 0), (0, 0), (false, false));
index bdfd29530db05426ade0390fa5079dfd3ede5790..61de296c1442f3096fcb8b88efbd341157c4de6d 100644 (file)
@@ -605,6 +605,15 @@ pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = C
 /// offline for a full minute. In order to track this, you must call
 /// timer_tick_occurred roughly once per minute, though it doesn't have to be perfect.
 ///
+/// To avoid trivial DoS issues, ChannelManager limits the number of inbound connections and
+/// inbound channels without confirmed funding transactions. This may result in nodes which we do
+/// not have a channel with being unable to connect to us or open new channels with us if we have
+/// many peers with unfunded channels.
+///
+/// Because it is an indication of trust, inbound channels which we've accepted as 0conf are
+/// exempted from the count of unfunded channels. Similarly, outbound channels and connections are
+/// never limited. Please ensure you limit the count of such channels yourself.
+///
 /// Rather than using a plain ChannelManager, it is preferable to use either a SimpleArcChannelManager
 /// a SimpleRefChannelManager, for conciseness. See their documentation for more details, but
 /// essentially you should default to using a SimpleRefChannelManager, and use a
@@ -955,6 +964,19 @@ pub(crate) const MPP_TIMEOUT_TICKS: u8 = 3;
 /// [`OutboundPayments::remove_stale_resolved_payments`].
 pub(crate) const IDEMPOTENCY_TIMEOUT_TICKS: u8 = 7;
 
+/// The maximum number of unfunded channels we can have per-peer before we start rejecting new
+/// (inbound) ones. The number of peers with unfunded channels is limited separately in
+/// [`MAX_UNFUNDED_CHANNEL_PEERS`].
+const MAX_UNFUNDED_CHANS_PER_PEER: usize = 4;
+
+/// The maximum number of peers from which we will allow pending unfunded channels. Once we reach
+/// this many peers we reject new (inbound) channels from peers with which we don't have a channel.
+const MAX_UNFUNDED_CHANNEL_PEERS: usize = 50;
+
+/// The maximum number of peers which we do not have a (funded) channel with. Once we reach this
+/// many peers we reject new (inbound) connections.
+const MAX_NO_CHANNEL_PEERS: usize = 250;
+
 /// Information needed for constructing an invoice route hint for this channel.
 #[derive(Clone, Debug, PartialEq)]
 pub struct CounterpartyForwardingInfo {
@@ -2676,7 +2698,7 @@ where
                                        (chan, funding_msg)
                                },
                                Err(_) => { return Err(APIError::ChannelUnavailable {
-                                       err: "Error deriving keys or signing initial commitment transactions - either our RNG or our counterparty's RNG is broken or the Signer refused to sign".to_owned()
+                                       err: "Signer refused to sign the initial commitment transaction".to_owned()
                                }) },
                        }
                };
@@ -3747,16 +3769,19 @@ where
                // being fully configured. See the docs for `ChannelManagerReadArgs` for more.
                match source {
                        HTLCSource::OutboundRoute { ref path, ref session_priv, ref payment_id, ref payment_params, .. } => {
-                               self.pending_outbound_payments.fail_htlc(source, payment_hash, onion_error, path, session_priv, payment_id, payment_params, self.probing_cookie_secret, &self.secp_ctx, &self.pending_events, &self.logger);
+                               if self.pending_outbound_payments.fail_htlc(source, payment_hash, onion_error, path,
+                                       session_priv, payment_id, payment_params, self.probing_cookie_secret, &self.secp_ctx,
+                                       &self.pending_events, &self.logger)
+                               { self.push_pending_forwards_ev(); }
                        },
                        HTLCSource::PreviousHopData(HTLCPreviousHopData { ref short_channel_id, ref htlc_id, ref incoming_packet_shared_secret, ref phantom_shared_secret, ref outpoint }) => {
                                log_trace!(self.logger, "Failing HTLC with payment_hash {} backwards from us with {:?}", log_bytes!(payment_hash.0), onion_error);
                                let err_packet = onion_error.get_encrypted_failure_packet(incoming_packet_shared_secret, phantom_shared_secret);
 
-                               let mut forward_event = None;
+                               let mut push_forward_ev = false;
                                let mut forward_htlcs = self.forward_htlcs.lock().unwrap();
                                if forward_htlcs.is_empty() {
-                                       forward_event = Some(Duration::from_millis(MIN_HTLC_RELAY_HOLDING_CELL_MILLIS));
+                                       push_forward_ev = true;
                                }
                                match forward_htlcs.entry(*short_channel_id) {
                                        hash_map::Entry::Occupied(mut entry) => {
@@ -3767,12 +3792,8 @@ where
                                        }
                                }
                                mem::drop(forward_htlcs);
+                               if push_forward_ev { self.push_pending_forwards_ev(); }
                                let mut pending_events = self.pending_events.lock().unwrap();
-                               if let Some(time) = forward_event {
-                                       pending_events.push(events::Event::PendingHTLCsForwardable {
-                                               time_forwardable: time
-                                       });
-                               }
                                pending_events.push(events::Event::HTLCHandlingFailed {
                                        prev_channel_id: outpoint.to_channel_id(),
                                        failed_next_destination: destination,
@@ -4221,11 +4242,13 @@ where
        fn do_accept_inbound_channel(&self, temporary_channel_id: &[u8; 32], counterparty_node_id: &PublicKey, accept_0conf: bool, user_channel_id: u128) -> Result<(), APIError> {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
+               let peers_without_funded_channels = self.peers_without_funded_channels(|peer| !peer.channel_by_id.is_empty());
                let per_peer_state = self.per_peer_state.read().unwrap();
                let peer_state_mutex = per_peer_state.get(counterparty_node_id)
                        .ok_or_else(|| APIError::ChannelUnavailable { err: format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id) })?;
                let mut peer_state_lock = peer_state_mutex.lock().unwrap();
                let peer_state = &mut *peer_state_lock;
+               let is_only_peer_channel = peer_state.channel_by_id.len() == 1;
                match peer_state.channel_by_id.entry(temporary_channel_id.clone()) {
                        hash_map::Entry::Occupied(mut channel) => {
                                if !channel.get().inbound_is_awaiting_accept() {
@@ -4243,6 +4266,21 @@ where
                                        peer_state.pending_msg_events.push(send_msg_err_event);
                                        let _ = remove_channel!(self, channel);
                                        return Err(APIError::APIMisuseError { err: "Please use accept_inbound_channel_from_trusted_peer_0conf to accept channels with zero confirmations.".to_owned() });
+                               } else {
+                                       // If this peer already has some channels, a new channel won't increase our number of peers
+                                       // with unfunded channels, so as long as we aren't over the maximum number of unfunded
+                                       // channels per-peer we can accept channels from a peer with existing ones.
+                                       if is_only_peer_channel && peers_without_funded_channels >= MAX_UNFUNDED_CHANNEL_PEERS {
+                                               let send_msg_err_event = events::MessageSendEvent::HandleError {
+                                                       node_id: channel.get().get_counterparty_node_id(),
+                                                       action: msgs::ErrorAction::SendErrorMessage{
+                                                               msg: msgs::ErrorMessage { channel_id: temporary_channel_id.clone(), data: "Have too many peers with unfunded channels, not accepting new ones".to_owned(), }
+                                                       }
+                                               };
+                                               peer_state.pending_msg_events.push(send_msg_err_event);
+                                               let _ = remove_channel!(self, channel);
+                                               return Err(APIError::APIMisuseError { err: "Too many peers with unfunded channels, refusing to accept new ones".to_owned() });
+                                       }
                                }
 
                                peer_state.pending_msg_events.push(events::MessageSendEvent::SendAcceptChannel {
@@ -4257,6 +4295,43 @@ where
                Ok(())
        }
 
+       /// Gets the number of peers which match the given filter and do not have any funded, outbound,
+       /// or 0-conf channels.
+       ///
+       /// The filter is called for each peer and provided with the number of unfunded, inbound, and
+       /// non-0-conf channels we have with the peer.
+       fn peers_without_funded_channels<Filter>(&self, maybe_count_peer: Filter) -> usize
+       where Filter: Fn(&PeerState<<SP::Target as SignerProvider>::Signer>) -> bool {
+               let mut peers_without_funded_channels = 0;
+               let best_block_height = self.best_block.read().unwrap().height();
+               {
+                       let peer_state_lock = self.per_peer_state.read().unwrap();
+                       for (_, peer_mtx) in peer_state_lock.iter() {
+                               let peer = peer_mtx.lock().unwrap();
+                               if !maybe_count_peer(&*peer) { continue; }
+                               let num_unfunded_channels = Self::unfunded_channel_count(&peer, best_block_height);
+                               if num_unfunded_channels == peer.channel_by_id.len() {
+                                       peers_without_funded_channels += 1;
+                               }
+                       }
+               }
+               return peers_without_funded_channels;
+       }
+
+       fn unfunded_channel_count(
+               peer: &PeerState<<SP::Target as SignerProvider>::Signer>, best_block_height: u32
+       ) -> usize {
+               let mut num_unfunded_channels = 0;
+               for (_, chan) in peer.channel_by_id.iter() {
+                       if !chan.is_outbound() && chan.minimum_depth().unwrap_or(1) != 0 &&
+                               chan.get_funding_tx_confirmations(best_block_height) == 0
+                       {
+                               num_unfunded_channels += 1;
+                       }
+               }
+               num_unfunded_channels
+       }
+
        fn internal_open_channel(&self, counterparty_node_id: &PublicKey, msg: &msgs::OpenChannel) -> Result<(), MsgHandleErrInternal> {
                if msg.chain_hash != self.genesis_hash {
                        return Err(MsgHandleErrInternal::send_err_msg_no_close("Unknown genesis block hash".to_owned(), msg.temporary_channel_id.clone()));
@@ -4269,8 +4344,13 @@ where
                let mut random_bytes = [0u8; 16];
                random_bytes.copy_from_slice(&self.entropy_source.get_secure_random_bytes()[..16]);
                let user_channel_id = u128::from_be_bytes(random_bytes);
-
                let outbound_scid_alias = self.create_and_insert_outbound_scid_alias();
+
+               // Get the number of peers with channels, but without funded ones. We don't care too much
+               // about peers that never open a channel, so we filter by peers that have at least one
+               // channel, and then limit the number of those with unfunded channels.
+               let channeled_peers_without_funding = self.peers_without_funded_channels(|node| !node.channel_by_id.is_empty());
+
                let per_peer_state = self.per_peer_state.read().unwrap();
                let peer_state_mutex = per_peer_state.get(counterparty_node_id)
                    .ok_or_else(|| {
@@ -4279,9 +4359,29 @@ where
                        })?;
                let mut peer_state_lock = peer_state_mutex.lock().unwrap();
                let peer_state = &mut *peer_state_lock;
+
+               // If this peer already has some channels, a new channel won't increase our number of peers
+               // with unfunded channels, so as long as we aren't over the maximum number of unfunded
+               // channels per-peer we can accept channels from a peer with existing ones.
+               if peer_state.channel_by_id.is_empty() &&
+                       channeled_peers_without_funding >= MAX_UNFUNDED_CHANNEL_PEERS &&
+                       !self.default_configuration.manually_accept_inbound_channels
+               {
+                       return Err(MsgHandleErrInternal::send_err_msg_no_close(
+                               "Have too many peers with unfunded channels, not accepting new ones".to_owned(),
+                               msg.temporary_channel_id.clone()));
+               }
+
+               let best_block_height = self.best_block.read().unwrap().height();
+               if Self::unfunded_channel_count(peer_state, best_block_height) >= MAX_UNFUNDED_CHANS_PER_PEER {
+                       return Err(MsgHandleErrInternal::send_err_msg_no_close(
+                               format!("Refusing more than {} unfunded channels.", MAX_UNFUNDED_CHANS_PER_PEER),
+                               msg.temporary_channel_id.clone()));
+               }
+
                let mut channel = match Channel::new_from_req(&self.fee_estimator, &self.entropy_source, &self.signer_provider,
-                       counterparty_node_id.clone(), &self.channel_type_features(), &peer_state.latest_features, msg, user_channel_id, &self.default_configuration,
-                       self.best_block.read().unwrap().height(), &self.logger, outbound_scid_alias)
+                       counterparty_node_id.clone(), &self.channel_type_features(), &peer_state.latest_features, msg, user_channel_id,
+                       &self.default_configuration, best_block_height, &self.logger, outbound_scid_alias)
                {
                        Err(e) => {
                                self.outbound_scid_aliases.lock().unwrap().remove(&outbound_scid_alias);
@@ -4730,7 +4830,7 @@ where
        #[inline]
        fn forward_htlcs(&self, per_source_pending_forwards: &mut [(u64, OutPoint, u128, Vec<(PendingHTLCInfo, u64)>)]) {
                for &mut (prev_short_channel_id, prev_funding_outpoint, prev_user_channel_id, ref mut pending_forwards) in per_source_pending_forwards {
-                       let mut forward_event = None;
+                       let mut push_forward_event = false;
                        let mut new_intercept_events = Vec::new();
                        let mut failed_intercept_forwards = Vec::new();
                        if !pending_forwards.is_empty() {
@@ -4788,7 +4888,7 @@ where
                                                                // We don't want to generate a PendingHTLCsForwardable event if only intercepted
                                                                // payments are being processed.
                                                                if forward_htlcs_empty {
-                                                                       forward_event = Some(Duration::from_millis(MIN_HTLC_RELAY_HOLDING_CELL_MILLIS));
+                                                                       push_forward_event = true;
                                                                }
                                                                entry.insert(vec!(HTLCForwardInfo::AddHTLC(PendingAddHTLCInfo {
                                                                        prev_short_channel_id, prev_funding_outpoint, prev_htlc_id, prev_user_channel_id, forward_info })));
@@ -4806,16 +4906,21 @@ where
                                let mut events = self.pending_events.lock().unwrap();
                                events.append(&mut new_intercept_events);
                        }
+                       if push_forward_event { self.push_pending_forwards_ev() }
+               }
+       }
 
-                       match forward_event {
-                               Some(time) => {
-                                       let mut pending_events = self.pending_events.lock().unwrap();
-                                       pending_events.push(events::Event::PendingHTLCsForwardable {
-                                               time_forwardable: time
-                                       });
-                               }
-                               None => {},
-                       }
+       // We only want to push a PendingHTLCsForwardable event if no others are queued.
+       fn push_pending_forwards_ev(&self) {
+               let mut pending_events = self.pending_events.lock().unwrap();
+               let forward_ev_exists = pending_events.iter()
+                       .find(|ev| if let events::Event::PendingHTLCsForwardable { .. } = ev { true } else { false })
+                       .is_some();
+               if !forward_ev_exists {
+                       pending_events.push(events::Event::PendingHTLCsForwardable {
+                               time_forwardable:
+                                       Duration::from_millis(MIN_HTLC_RELAY_HOLDING_CELL_MILLIS),
+                       });
                }
        }
 
@@ -6079,13 +6184,13 @@ where
                let _ = handle_error!(self, self.internal_channel_reestablish(counterparty_node_id, msg), *counterparty_node_id);
        }
 
-       fn peer_disconnected(&self, counterparty_node_id: &PublicKey, no_connection_possible: bool) {
+       fn peer_disconnected(&self, counterparty_node_id: &PublicKey) {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
                let mut failed_channels = Vec::new();
                let mut per_peer_state = self.per_peer_state.write().unwrap();
                let remove_peer = {
-                       log_debug!(self.logger, "Marking channels with {} disconnected and generating channel_updates. We believe we {} make future connections to this peer.",
-                               log_pubkey!(counterparty_node_id), if no_connection_possible { "cannot" } else { "can" });
+                       log_debug!(self.logger, "Marking channels with {} disconnected and generating channel_updates.",
+                               log_pubkey!(counterparty_node_id));
                        if let Some(peer_state_mutex) = per_peer_state.get(counterparty_node_id) {
                                let mut peer_state_lock = peer_state_mutex.lock().unwrap();
                                let peer_state = &mut *peer_state_lock;
@@ -6127,7 +6232,7 @@ where
                                debug_assert!(peer_state.is_connected, "A disconnected peer cannot disconnect");
                                peer_state.is_connected = false;
                                peer_state.ok_to_remove(true)
-                       } else { true }
+                       } else { debug_assert!(false, "Unconnected peer disconnected"); true }
                };
                if remove_peer {
                        per_peer_state.remove(counterparty_node_id);
@@ -6139,20 +6244,28 @@ where
                }
        }
 
-       fn peer_connected(&self, counterparty_node_id: &PublicKey, init_msg: &msgs::Init) -> Result<(), ()> {
+       fn peer_connected(&self, counterparty_node_id: &PublicKey, init_msg: &msgs::Init, inbound: bool) -> Result<(), ()> {
                if !init_msg.features.supports_static_remote_key() {
-                       log_debug!(self.logger, "Peer {} does not support static remote key, disconnecting with no_connection_possible", log_pubkey!(counterparty_node_id));
+                       log_debug!(self.logger, "Peer {} does not support static remote key, disconnecting", log_pubkey!(counterparty_node_id));
                        return Err(());
                }
 
-               log_debug!(self.logger, "Generating channel_reestablish events for {}", log_pubkey!(counterparty_node_id));
-
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
+               // If we have too many peers connected which don't have funded channels, disconnect the
+               // peer immediately (as long as it doesn't have funded channels). If we have a bunch of
+               // unfunded channels taking up space in memory for disconnected peers, we still let new
+               // peers connect, but we'll reject new channels from them.
+               let connected_peers_without_funded_channels = self.peers_without_funded_channels(|node| node.is_connected);
+               let inbound_peer_limited = inbound && connected_peers_without_funded_channels >= MAX_NO_CHANNEL_PEERS;
+
                {
                        let mut peer_state_lock = self.per_peer_state.write().unwrap();
                        match peer_state_lock.entry(counterparty_node_id.clone()) {
                                hash_map::Entry::Vacant(e) => {
+                                       if inbound_peer_limited {
+                                               return Err(());
+                                       }
                                        e.insert(Mutex::new(PeerState {
                                                channel_by_id: HashMap::new(),
                                                latest_features: init_msg.features.clone(),
@@ -6164,14 +6277,24 @@ where
                                hash_map::Entry::Occupied(e) => {
                                        let mut peer_state = e.get().lock().unwrap();
                                        peer_state.latest_features = init_msg.features.clone();
+
+                                       let best_block_height = self.best_block.read().unwrap().height();
+                                       if inbound_peer_limited &&
+                                               Self::unfunded_channel_count(&*peer_state, best_block_height) ==
+                                               peer_state.channel_by_id.len()
+                                       {
+                                               return Err(());
+                                       }
+
                                        debug_assert!(!peer_state.is_connected, "A peer shouldn't be connected twice");
                                        peer_state.is_connected = true;
                                },
                        }
                }
 
-               let per_peer_state = self.per_peer_state.read().unwrap();
+               log_debug!(self.logger, "Generating channel_reestablish events for {}", log_pubkey!(counterparty_node_id));
 
+               let per_peer_state = self.per_peer_state.read().unwrap();
                for (_cp_id, peer_state_mutex) in per_peer_state.iter() {
                        let mut peer_state_lock = peer_state_mutex.lock().unwrap();
                        let peer_state = &mut *peer_state_lock;
@@ -7376,7 +7499,8 @@ where
                        }
                }
 
-               if !forward_htlcs.is_empty() {
+               let pending_outbounds = OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), retry_lock: Mutex::new(()) };
+               if !forward_htlcs.is_empty() || pending_outbounds.needs_abandon() {
                        // If we have pending HTLCs to forward, assume we either dropped a
                        // `PendingHTLCsForwardable` or the user received it but never processed it as they
                        // shut down before the timer hit. Either way, set the time_forwardable to a small
@@ -7553,7 +7677,7 @@ where
 
                        inbound_payment_key: expanded_inbound_key,
                        pending_inbound_payments: Mutex::new(pending_inbound_payments),
-                       pending_outbound_payments: OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), retry_lock: Mutex::new(()), },
+                       pending_outbound_payments: pending_outbounds,
                        pending_intercepted_htlcs: Mutex::new(pending_intercepted_htlcs.unwrap()),
 
                        forward_htlcs: Mutex::new(forward_htlcs),
@@ -8040,8 +8164,8 @@ mod tests {
 
                let chan = create_announced_chan_between_nodes(&nodes, 0, 1);
 
-               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
                nodes[0].node.force_close_broadcasting_latest_txn(&chan.2, &nodes[1].node.get_our_node_id()).unwrap();
                check_closed_broadcast!(nodes[0], true);
@@ -8261,6 +8385,213 @@ mod tests {
                check_unkown_peer_error(nodes[0].node.update_channel_config(&unkown_public_key, &[channel_id], &ChannelConfig::default()), unkown_public_key);
        }
 
+       #[test]
+       fn test_connection_limiting() {
+               // Test that we limit un-channel'd peers and un-funded channels properly.
+               let chanmon_cfgs = create_chanmon_cfgs(2);
+               let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+               let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
+               let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+               // Note that create_network connects the nodes together for us
+
+               nodes[0].node.create_channel(nodes[1].node.get_our_node_id(), 100_000, 0, 42, None).unwrap();
+               let mut open_channel_msg = get_event_msg!(nodes[0], MessageSendEvent::SendOpenChannel, nodes[1].node.get_our_node_id());
+
+               let mut funding_tx = None;
+               for idx in 0..super::MAX_UNFUNDED_CHANS_PER_PEER {
+                       nodes[1].node.handle_open_channel(&nodes[0].node.get_our_node_id(), &open_channel_msg);
+                       let accept_channel = get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, nodes[0].node.get_our_node_id());
+
+                       if idx == 0 {
+                               nodes[0].node.handle_accept_channel(&nodes[1].node.get_our_node_id(), &accept_channel);
+                               let (temporary_channel_id, tx, _) = create_funding_transaction(&nodes[0], &nodes[1].node.get_our_node_id(), 100_000, 42);
+                               funding_tx = Some(tx.clone());
+                               nodes[0].node.funding_transaction_generated(&temporary_channel_id, &nodes[1].node.get_our_node_id(), tx).unwrap();
+                               let funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id());
+
+                               nodes[1].node.handle_funding_created(&nodes[0].node.get_our_node_id(), &funding_created_msg);
+                               check_added_monitors!(nodes[1], 1);
+                               let funding_signed = get_event_msg!(nodes[1], MessageSendEvent::SendFundingSigned, nodes[0].node.get_our_node_id());
+
+                               nodes[0].node.handle_funding_signed(&nodes[1].node.get_our_node_id(), &funding_signed);
+                               check_added_monitors!(nodes[0], 1);
+                       }
+                       open_channel_msg.temporary_channel_id = nodes[0].keys_manager.get_secure_random_bytes();
+               }
+
+               // A MAX_UNFUNDED_CHANS_PER_PEER + 1 channel will be summarily rejected
+               open_channel_msg.temporary_channel_id = nodes[0].keys_manager.get_secure_random_bytes();
+               nodes[1].node.handle_open_channel(&nodes[0].node.get_our_node_id(), &open_channel_msg);
+               assert_eq!(get_err_msg!(nodes[1], nodes[0].node.get_our_node_id()).channel_id,
+                       open_channel_msg.temporary_channel_id);
+
+               // Further, because all of our channels with nodes[0] are inbound, and none of them funded,
+               // it doesn't count as a "protected" peer, i.e. it counts towards the MAX_NO_CHANNEL_PEERS
+               // limit.
+               let mut peer_pks = Vec::with_capacity(super::MAX_NO_CHANNEL_PEERS);
+               for _ in 1..super::MAX_NO_CHANNEL_PEERS {
+                       let random_pk = PublicKey::from_secret_key(&nodes[0].node.secp_ctx,
+                               &SecretKey::from_slice(&nodes[1].keys_manager.get_secure_random_bytes()).unwrap());
+                       peer_pks.push(random_pk);
+                       nodes[1].node.peer_connected(&random_pk, &msgs::Init {
+                               features: nodes[0].node.init_features(), remote_network_address: None }, true).unwrap();
+               }
+               let last_random_pk = PublicKey::from_secret_key(&nodes[0].node.secp_ctx,
+                       &SecretKey::from_slice(&nodes[1].keys_manager.get_secure_random_bytes()).unwrap());
+               nodes[1].node.peer_connected(&last_random_pk, &msgs::Init {
+                       features: nodes[0].node.init_features(), remote_network_address: None }, true).unwrap_err();
+
+               // Also importantly, because nodes[0] isn't "protected", we will refuse a reconnection from
+               // them if we have too many un-channel'd peers.
+               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
+               let chan_closed_events = nodes[1].node.get_and_clear_pending_events();
+               assert_eq!(chan_closed_events.len(), super::MAX_UNFUNDED_CHANS_PER_PEER - 1);
+               for ev in chan_closed_events {
+                       if let Event::ChannelClosed { .. } = ev { } else { panic!(); }
+               }
+               nodes[1].node.peer_connected(&last_random_pk, &msgs::Init {
+                       features: nodes[0].node.init_features(), remote_network_address: None }, true).unwrap();
+               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init {
+                       features: nodes[0].node.init_features(), remote_network_address: None }, true).unwrap_err();
+
+               // but of course if the connection is outbound its allowed...
+               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init {
+                       features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
+               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
+
+               // Now nodes[0] is disconnected but still has a pending, un-funded channel lying around.
+               // Even though we accept one more connection from new peers, we won't actually let them
+               // open channels.
+               assert!(peer_pks.len() > super::MAX_UNFUNDED_CHANNEL_PEERS - 1);
+               for i in 0..super::MAX_UNFUNDED_CHANNEL_PEERS - 1 {
+                       nodes[1].node.handle_open_channel(&peer_pks[i], &open_channel_msg);
+                       get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, peer_pks[i]);
+                       open_channel_msg.temporary_channel_id = nodes[0].keys_manager.get_secure_random_bytes();
+               }
+               nodes[1].node.handle_open_channel(&last_random_pk, &open_channel_msg);
+               assert_eq!(get_err_msg!(nodes[1], last_random_pk).channel_id,
+                       open_channel_msg.temporary_channel_id);
+
+               // Of course, however, outbound channels are always allowed
+               nodes[1].node.create_channel(last_random_pk, 100_000, 0, 42, None).unwrap();
+               get_event_msg!(nodes[1], MessageSendEvent::SendOpenChannel, last_random_pk);
+
+               // If we fund the first channel, nodes[0] has a live on-chain channel with us, it is now
+               // "protected" and can connect again.
+               mine_transaction(&nodes[1], funding_tx.as_ref().unwrap());
+               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init {
+                       features: nodes[0].node.init_features(), remote_network_address: None }, true).unwrap();
+               get_event_msg!(nodes[1], MessageSendEvent::SendChannelReestablish, nodes[0].node.get_our_node_id());
+
+               // Further, because the first channel was funded, we can open another channel with
+               // last_random_pk.
+               nodes[1].node.handle_open_channel(&last_random_pk, &open_channel_msg);
+               get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, last_random_pk);
+       }
+
+       #[test]
+       fn test_outbound_chans_unlimited() {
+               // Test that we never refuse an outbound channel even if a peer is unfuned-channel-limited
+               let chanmon_cfgs = create_chanmon_cfgs(2);
+               let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+               let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
+               let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+               // Note that create_network connects the nodes together for us
+
+               nodes[0].node.create_channel(nodes[1].node.get_our_node_id(), 100_000, 0, 42, None).unwrap();
+               let mut open_channel_msg = get_event_msg!(nodes[0], MessageSendEvent::SendOpenChannel, nodes[1].node.get_our_node_id());
+
+               for _ in 0..super::MAX_UNFUNDED_CHANS_PER_PEER {
+                       nodes[1].node.handle_open_channel(&nodes[0].node.get_our_node_id(), &open_channel_msg);
+                       get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, nodes[0].node.get_our_node_id());
+                       open_channel_msg.temporary_channel_id = nodes[0].keys_manager.get_secure_random_bytes();
+               }
+
+               // Once we have MAX_UNFUNDED_CHANS_PER_PEER unfunded channels, new inbound channels will be
+               // rejected.
+               nodes[1].node.handle_open_channel(&nodes[0].node.get_our_node_id(), &open_channel_msg);
+               assert_eq!(get_err_msg!(nodes[1], nodes[0].node.get_our_node_id()).channel_id,
+                       open_channel_msg.temporary_channel_id);
+
+               // but we can still open an outbound channel.
+               nodes[1].node.create_channel(nodes[0].node.get_our_node_id(), 100_000, 0, 42, None).unwrap();
+               get_event_msg!(nodes[1], MessageSendEvent::SendOpenChannel, nodes[0].node.get_our_node_id());
+
+               // but even with such an outbound channel, additional inbound channels will still fail.
+               nodes[1].node.handle_open_channel(&nodes[0].node.get_our_node_id(), &open_channel_msg);
+               assert_eq!(get_err_msg!(nodes[1], nodes[0].node.get_our_node_id()).channel_id,
+                       open_channel_msg.temporary_channel_id);
+       }
+
+       #[test]
+       fn test_0conf_limiting() {
+               // Tests that we properly limit inbound channels when we have the manual-channel-acceptance
+               // flag set and (sometimes) accept channels as 0conf.
+               let chanmon_cfgs = create_chanmon_cfgs(2);
+               let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+               let mut settings = test_default_channel_config();
+               settings.manually_accept_inbound_channels = true;
+               let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, Some(settings)]);
+               let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+               // Note that create_network connects the nodes together for us
+
+               nodes[0].node.create_channel(nodes[1].node.get_our_node_id(), 100_000, 0, 42, None).unwrap();
+               let mut open_channel_msg = get_event_msg!(nodes[0], MessageSendEvent::SendOpenChannel, nodes[1].node.get_our_node_id());
+
+               // First, get us up to MAX_UNFUNDED_CHANNEL_PEERS so we can test at the edge
+               for _ in 0..super::MAX_UNFUNDED_CHANNEL_PEERS - 1 {
+                       let random_pk = PublicKey::from_secret_key(&nodes[0].node.secp_ctx,
+                               &SecretKey::from_slice(&nodes[1].keys_manager.get_secure_random_bytes()).unwrap());
+                       nodes[1].node.peer_connected(&random_pk, &msgs::Init {
+                               features: nodes[0].node.init_features(), remote_network_address: None }, true).unwrap();
+
+                       nodes[1].node.handle_open_channel(&random_pk, &open_channel_msg);
+                       let events = nodes[1].node.get_and_clear_pending_events();
+                       match events[0] {
+                               Event::OpenChannelRequest { temporary_channel_id, .. } => {
+                                       nodes[1].node.accept_inbound_channel(&temporary_channel_id, &random_pk, 23).unwrap();
+                               }
+                               _ => panic!("Unexpected event"),
+                       }
+                       get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, random_pk);
+                       open_channel_msg.temporary_channel_id = nodes[0].keys_manager.get_secure_random_bytes();
+               }
+
+               // If we try to accept a channel from another peer non-0conf it will fail.
+               let last_random_pk = PublicKey::from_secret_key(&nodes[0].node.secp_ctx,
+                       &SecretKey::from_slice(&nodes[1].keys_manager.get_secure_random_bytes()).unwrap());
+               nodes[1].node.peer_connected(&last_random_pk, &msgs::Init {
+                       features: nodes[0].node.init_features(), remote_network_address: None }, true).unwrap();
+               nodes[1].node.handle_open_channel(&last_random_pk, &open_channel_msg);
+               let events = nodes[1].node.get_and_clear_pending_events();
+               match events[0] {
+                       Event::OpenChannelRequest { temporary_channel_id, .. } => {
+                               match nodes[1].node.accept_inbound_channel(&temporary_channel_id, &last_random_pk, 23) {
+                                       Err(APIError::APIMisuseError { err }) =>
+                                               assert_eq!(err, "Too many peers with unfunded channels, refusing to accept new ones"),
+                                       _ => panic!(),
+                               }
+                       }
+                       _ => panic!("Unexpected event"),
+               }
+               assert_eq!(get_err_msg!(nodes[1], last_random_pk).channel_id,
+                       open_channel_msg.temporary_channel_id);
+
+               // ...however if we accept the same channel 0conf it should work just fine.
+               nodes[1].node.handle_open_channel(&last_random_pk, &open_channel_msg);
+               let events = nodes[1].node.get_and_clear_pending_events();
+               match events[0] {
+                       Event::OpenChannelRequest { temporary_channel_id, .. } => {
+                               nodes[1].node.accept_inbound_channel_from_trusted_peer_0conf(&temporary_channel_id, &last_random_pk, 23).unwrap();
+                       }
+                       _ => panic!("Unexpected event"),
+               }
+               get_event_msg!(nodes[1], MessageSendEvent::SendAcceptChannel, last_random_pk);
+       }
+
        #[cfg(anchors)]
        #[test]
        fn test_anchors_zero_fee_htlc_tx_fallback() {
@@ -8371,8 +8702,8 @@ pub mod bench {
                });
                let node_b_holder = NodeHolder { node: &node_b };
 
-               node_a.peer_connected(&node_b.get_our_node_id(), &Init { features: node_b.init_features(), remote_network_address: None }).unwrap();
-               node_b.peer_connected(&node_a.get_our_node_id(), &Init { features: node_a.init_features(), remote_network_address: None }).unwrap();
+               node_a.peer_connected(&node_b.get_our_node_id(), &Init { features: node_b.init_features(), remote_network_address: None }, true).unwrap();
+               node_b.peer_connected(&node_a.get_our_node_id(), &Init { features: node_a.init_features(), remote_network_address: None }, false).unwrap();
                node_a.create_channel(node_b.get_our_node_id(), 8_000_000, 100_000_000, 42, None).unwrap();
                node_b.handle_open_channel(&node_a.get_our_node_id(), &get_event_msg!(node_a_holder, MessageSendEvent::SendOpenChannel, node_b.get_our_node_id()));
                node_a.handle_accept_channel(&node_b.get_our_node_id(), &get_event_msg!(node_b_holder, MessageSendEvent::SendAcceptChannel, node_a.get_our_node_id()));
index 0c21588fb717af4aee2ef93a52f9da16abaaaaa3..fd867bc6ab7e07287c0572ae1cdb577e9e290ab8 100644 (file)
@@ -2374,8 +2374,8 @@ pub fn create_network<'a, 'b: 'a, 'c: 'b>(node_count: usize, cfgs: &'b Vec<NodeC
 
        for i in 0..node_count {
                for j in (i+1)..node_count {
-                       nodes[i].node.peer_connected(&nodes[j].node.get_our_node_id(), &msgs::Init { features: nodes[j].override_init_features.borrow().clone().unwrap_or_else(|| nodes[j].node.init_features()), remote_network_address: None }).unwrap();
-                       nodes[j].node.peer_connected(&nodes[i].node.get_our_node_id(), &msgs::Init { features: nodes[i].override_init_features.borrow().clone().unwrap_or_else(|| nodes[i].node.init_features()), remote_network_address: None }).unwrap();
+                       nodes[i].node.peer_connected(&nodes[j].node.get_our_node_id(), &msgs::Init { features: nodes[j].override_init_features.borrow().clone().unwrap_or_else(|| nodes[j].node.init_features()), remote_network_address: None }, true).unwrap();
+                       nodes[j].node.peer_connected(&nodes[i].node.get_our_node_id(), &msgs::Init { features: nodes[i].override_init_features.borrow().clone().unwrap_or_else(|| nodes[i].node.init_features()), remote_network_address: None }, false).unwrap();
                }
        }
 
@@ -2658,9 +2658,9 @@ macro_rules! handle_chan_reestablish_msgs {
 /// pending_htlc_adds includes both the holding cell and in-flight update_add_htlcs, whereas
 /// for claims/fails they are separated out.
 pub fn reconnect_nodes<'a, 'b, 'c>(node_a: &Node<'a, 'b, 'c>, node_b: &Node<'a, 'b, 'c>, send_channel_ready: (bool, bool), pending_htlc_adds: (i64, i64), pending_htlc_claims: (usize, usize), pending_htlc_fails: (usize, usize), pending_cell_htlc_claims: (usize, usize), pending_cell_htlc_fails: (usize, usize), pending_raa: (bool, bool))  {
-       node_a.node.peer_connected(&node_b.node.get_our_node_id(), &msgs::Init { features: node_b.node.init_features(), remote_network_address: None }).unwrap();
+       node_a.node.peer_connected(&node_b.node.get_our_node_id(), &msgs::Init { features: node_b.node.init_features(), remote_network_address: None }, true).unwrap();
        let reestablish_1 = get_chan_reestablish_msgs!(node_a, node_b);
-       node_b.node.peer_connected(&node_a.node.get_our_node_id(), &msgs::Init { features: node_a.node.init_features(), remote_network_address: None }).unwrap();
+       node_b.node.peer_connected(&node_a.node.get_our_node_id(), &msgs::Init { features: node_a.node.init_features(), remote_network_address: None }, false).unwrap();
        let reestablish_2 = get_chan_reestablish_msgs!(node_b, node_a);
 
        if send_channel_ready.0 {
index aa258b5713988ec56119f6caaa2920d271c0d4f8..d6d2972d0736e3ec642fccee7363e44329560584 100644 (file)
@@ -3524,8 +3524,8 @@ fn test_dup_events_on_peer_disconnect() {
        nodes[0].node.handle_update_fulfill_htlc(&nodes[1].node.get_our_node_id(), &claim_msgs.update_fulfill_htlcs[0]);
        expect_payment_sent_without_paths!(nodes[0], payment_preimage);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (1, 0), (0, 0), (0, 0), (0, 0), (false, false));
        expect_payment_path_successful!(nodes[0]);
@@ -3565,8 +3565,8 @@ fn test_peer_disconnected_before_funding_broadcasted() {
 
        // Ensure that the channel is closed with `ClosureReason::DisconnectedPeer` when the peers are
        // disconnected before the funding transaction was broadcasted.
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        check_closed_event!(nodes[0], 1, ClosureReason::DisconnectedPeer);
        check_closed_event!(nodes[1], 1, ClosureReason::DisconnectedPeer);
@@ -3582,8 +3582,8 @@ fn test_simple_peer_disconnect() {
        create_announced_chan_between_nodes(&nodes, 0, 1);
        create_announced_chan_between_nodes(&nodes, 1, 2);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
        reconnect_nodes(&nodes[0], &nodes[1], (true, true), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
        let payment_preimage_1 = route_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], 1000000).0;
@@ -3591,8 +3591,8 @@ fn test_simple_peer_disconnect() {
        fail_payment(&nodes[0], &vec!(&nodes[1], &nodes[2]), payment_hash_2);
        claim_payment(&nodes[0], &vec!(&nodes[1], &nodes[2]), payment_preimage_1);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
        reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
        let (payment_preimage_3, payment_hash_3, _) = route_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], 1000000);
@@ -3600,8 +3600,8 @@ fn test_simple_peer_disconnect() {
        let payment_hash_5 = route_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], 1000000).1;
        let payment_hash_6 = route_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], 1000000).1;
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        claim_payment_along_route(&nodes[0], &[&[&nodes[1], &nodes[2]]], true, payment_preimage_3);
        fail_payment_along_route(&nodes[0], &[&[&nodes[1], &nodes[2]]], true, payment_hash_5);
@@ -3701,8 +3701,8 @@ fn do_test_drop_messages_peer_disconnect(messages_delivered: u8, simulate_broken
                }
        }
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
        if messages_delivered < 3 {
                if simulate_broken_lnd {
                        // lnd has a long-standing bug where they send a channel_ready prior to a
@@ -3751,8 +3751,8 @@ fn do_test_drop_messages_peer_disconnect(messages_delivered: u8, simulate_broken
                };
        }
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
        reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
        nodes[1].node.process_pending_htlc_forwards();
@@ -3834,8 +3834,8 @@ fn do_test_drop_messages_peer_disconnect(messages_delivered: u8, simulate_broken
                }
        }
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
        if messages_delivered < 2 {
                reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (1, 0), (0, 0), (0, 0), (0, 0), (false, false));
                if messages_delivered < 1 {
@@ -3861,8 +3861,8 @@ fn do_test_drop_messages_peer_disconnect(messages_delivered: u8, simulate_broken
                expect_payment_path_successful!(nodes[0]);
        }
        if messages_delivered <= 5 {
-               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+               nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+               nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
        }
        reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
@@ -3976,13 +3976,13 @@ fn test_drop_messages_peer_disconnect_dual_htlc() {
                _ => panic!("Unexpected event"),
        }
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
        assert_eq!(reestablish_1.len(), 1);
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
        assert_eq!(reestablish_2.len(), 1);
 
@@ -6292,12 +6292,12 @@ fn test_update_add_htlc_bolt2_receiver_check_repeated_id_ignore() {
        nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &updates.update_add_htlcs[0]);
 
        //Disconnect and Reconnect
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
        assert_eq!(reestablish_1.len(), 1);
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
        assert_eq!(reestablish_2.len(), 1);
        nodes[0].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &reestablish_2[0]);
@@ -7032,8 +7032,8 @@ fn test_announce_disable_channels() {
        create_announced_chan_between_nodes(&nodes, 0, 1);
 
        // Disconnect peers
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        nodes[0].node.timer_tick_occurred(); // Enabled -> DisabledStaged
        nodes[0].node.timer_tick_occurred(); // DisabledStaged -> Disabled
@@ -7053,10 +7053,10 @@ fn test_announce_disable_channels() {
                }
        }
        // Reconnect peers
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
        assert_eq!(reestablish_1.len(), 3);
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
        assert_eq!(reestablish_2.len(), 3);
 
@@ -8832,13 +8832,11 @@ fn test_error_chans_closed() {
                _ => panic!("Unexpected event"),
        }
        // Note that at this point users of a standard PeerHandler will end up calling
-       // peer_disconnected with no_connection_possible set to false, duplicating the
-       // close-all-channels logic. That's OK, we don't want to end up not force-closing channels for
-       // users with their own peer handling logic. We duplicate the call here, however.
+       // peer_disconnected.
        assert_eq!(nodes[0].node.list_usable_channels().len(), 1);
        assert!(nodes[0].node.list_usable_channels()[0].channel_id == chan_3.2);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), true);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
        assert_eq!(nodes[0].node.list_usable_channels().len(), 1);
        assert!(nodes[0].node.list_usable_channels()[0].channel_id == chan_3.2);
 }
@@ -8954,8 +8952,8 @@ fn do_test_tx_confirmed_skipping_blocks_immediate_broadcast(test_height_before_t
        create_announced_chan_between_nodes(&nodes, 0, 1);
        let (chan_announce, _, channel_id, _) = create_announced_chan_between_nodes(&nodes, 1, 2);
        let (_, payment_hash, _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 1_000_000);
-       nodes[1].node.peer_disconnected(&nodes[2].node.get_our_node_id(), false);
-       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[2].node.get_our_node_id());
+       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id());
 
        nodes[1].node.force_close_broadcasting_latest_txn(&channel_id, &nodes[2].node.get_our_node_id()).unwrap();
        check_closed_broadcast!(nodes[1], true);
index e8b40c71d89e7f1825335bbe4a0abb0ed817970d..6e49a46f08a02a25065652e28cf3895afc031fe8 100644 (file)
@@ -993,21 +993,15 @@ pub trait ChannelMessageHandler : MessageSendEventsProvider {
        fn handle_announcement_signatures(&self, their_node_id: &PublicKey, msg: &AnnouncementSignatures);
 
        // Connection loss/reestablish:
-       /// Indicates a connection to the peer failed/an existing connection was lost. If no connection
-       /// is believed to be possible in the future (eg they're sending us messages we don't
-       /// understand or indicate they require unknown feature bits), `no_connection_possible` is set
-       /// and any outstanding channels should be failed.
-       ///
-       /// Note that in some rare cases this may be called without a corresponding
-       /// [`Self::peer_connected`].
-       fn peer_disconnected(&self, their_node_id: &PublicKey, no_connection_possible: bool);
+       /// Indicates a connection to the peer failed/an existing connection was lost.
+       fn peer_disconnected(&self, their_node_id: &PublicKey);
 
        /// Handle a peer reconnecting, possibly generating `channel_reestablish` message(s).
        ///
        /// May return an `Err(())` if the features the peer supports are not sufficient to communicate
        /// with us. Implementors should be somewhat conservative about doing so, however, as other
        /// message handlers may still wish to communicate with this peer.
-       fn peer_connected(&self, their_node_id: &PublicKey, msg: &Init) -> Result<(), ()>;
+       fn peer_connected(&self, their_node_id: &PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>;
        /// Handle an incoming `channel_reestablish` message from the given peer.
        fn handle_channel_reestablish(&self, their_node_id: &PublicKey, msg: &ChannelReestablish);
 
@@ -1065,7 +1059,7 @@ pub trait RoutingMessageHandler : MessageSendEventsProvider {
        /// May return an `Err(())` if the features the peer supports are not sufficient to communicate
        /// with us. Implementors should be somewhat conservative about doing so, however, as other
        /// message handlers may still wish to communicate with this peer.
-       fn peer_connected(&self, their_node_id: &PublicKey, init: &Init) -> Result<(), ()>;
+       fn peer_connected(&self, their_node_id: &PublicKey, init: &Init, inbound: bool) -> Result<(), ()>;
        /// Handles the reply of a query we initiated to learn about channels
        /// for a given range of blocks. We can expect to receive one or more
        /// replies to a single query.
@@ -1112,13 +1106,10 @@ pub trait OnionMessageHandler : OnionMessageProvider {
        /// May return an `Err(())` if the features the peer supports are not sufficient to communicate
        /// with us. Implementors should be somewhat conservative about doing so, however, as other
        /// message handlers may still wish to communicate with this peer.
-       fn peer_connected(&self, their_node_id: &PublicKey, init: &Init) -> Result<(), ()>;
+       fn peer_connected(&self, their_node_id: &PublicKey, init: &Init, inbound: bool) -> Result<(), ()>;
        /// Indicates a connection to the peer failed/an existing connection was lost. Allows handlers to
        /// drop and refuse to forward onion messages to this peer.
-       ///
-       /// Note that in some rare cases this may be called without a corresponding
-       /// [`Self::peer_connected`].
-       fn peer_disconnected(&self, their_node_id: &PublicKey, no_connection_possible: bool);
+       fn peer_disconnected(&self, their_node_id: &PublicKey);
 
        // Handler information:
        /// Gets the node feature flags which this handler itself supports. All available handlers are
index bb69bbb32d51fd7ee95108c3c89ad06902fc58b7..da5aadb83065af7d6de274c30a021e4b9b05fc23 100644 (file)
@@ -576,8 +576,8 @@ fn test_onion_failure() {
        let short_channel_id = channels[1].0.contents.short_channel_id;
        run_onion_failure_test("channel_disabled", 0, &nodes, &route, &payment_hash, &payment_secret, |_| {}, || {
                // disconnect event to the channel between nodes[1] ~ nodes[2]
-               nodes[1].node.peer_disconnected(&nodes[2].node.get_our_node_id(), false);
-               nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+               nodes[1].node.peer_disconnected(&nodes[2].node.get_our_node_id());
+               nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id());
        }, true, Some(UPDATE|20), Some(NetworkUpdate::ChannelUpdateMessage{msg: ChannelUpdate::dummy(short_channel_id)}), Some(short_channel_id));
        reconnect_nodes(&nodes[1], &nodes[2], (false, false), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
index 6ebd7bc547fe42b1db2242b89ca249ca514c8ae3..f252d88ac364e87947ed42c31bef96b5c276375f 100644 (file)
@@ -15,7 +15,7 @@ use bitcoin::secp256k1::{self, Secp256k1, SecretKey};
 
 use crate::chain::keysinterface::{EntropySource, NodeSigner, Recipient};
 use crate::ln::{PaymentHash, PaymentPreimage, PaymentSecret};
-use crate::ln::channelmanager::{ChannelDetails, HTLCSource, IDEMPOTENCY_TIMEOUT_TICKS, MIN_HTLC_RELAY_HOLDING_CELL_MILLIS, PaymentId};
+use crate::ln::channelmanager::{ChannelDetails, HTLCSource, IDEMPOTENCY_TIMEOUT_TICKS, PaymentId};
 use crate::ln::channelmanager::MIN_FINAL_CLTV_EXPIRY_DELTA as LDK_DEFAULT_MIN_FINAL_CLTV_EXPIRY_DELTA;
 use crate::ln::msgs::DecodeError;
 use crate::ln::onion_utils::HTLCFailReason;
@@ -30,7 +30,6 @@ use crate::util::time::tests::SinceEpoch;
 use core::cmp;
 use core::fmt::{self, Display, Formatter};
 use core::ops::Deref;
-use core::time::Duration;
 
 use crate::prelude::*;
 use crate::sync::Mutex;
@@ -546,6 +545,12 @@ impl OutboundPayments {
                });
        }
 
+       pub(super) fn needs_abandon(&self) -> bool {
+               let outbounds = self.pending_outbound_payments.lock().unwrap();
+               outbounds.iter().any(|(_, pmt)|
+                       !pmt.is_auto_retryable_now() && pmt.remaining_parts() == 0 && !pmt.is_fulfilled())
+       }
+
        /// Will return `Ok(())` iff at least one HTLC is sent for the payment.
        fn pay_internal<R: Deref, NS: Deref, ES: Deref, IH, SP, L: Deref>(
                &self, payment_id: PaymentId,
@@ -1006,12 +1011,13 @@ impl OutboundPayments {
                });
        }
 
+       // Returns a bool indicating whether a PendingHTLCsForwardable event should be generated.
        pub(super) fn fail_htlc<L: Deref>(
                &self, source: &HTLCSource, payment_hash: &PaymentHash, onion_error: &HTLCFailReason,
                path: &Vec<RouteHop>, session_priv: &SecretKey, payment_id: &PaymentId,
                payment_params: &Option<PaymentParameters>, probing_cookie_secret: [u8; 32],
                secp_ctx: &Secp256k1<secp256k1::All>, pending_events: &Mutex<Vec<events::Event>>, logger: &L
-       ) where L::Target: Logger {
+       ) -> bool where L::Target: Logger {
                #[cfg(test)]
                let (network_update, short_channel_id, payment_retryable, onion_error_code, onion_error_data) = onion_error.decode_onion_failure(secp_ctx, logger, &source);
                #[cfg(not(test))]
@@ -1021,18 +1027,33 @@ impl OutboundPayments {
                let mut session_priv_bytes = [0; 32];
                session_priv_bytes.copy_from_slice(&session_priv[..]);
                let mut outbounds = self.pending_outbound_payments.lock().unwrap();
+
+               // If any payments already need retry, there's no need to generate a redundant
+               // `PendingHTLCsForwardable`.
+               let already_awaiting_retry = outbounds.iter().any(|(_, pmt)| {
+                       let mut awaiting_retry = false;
+                       if pmt.is_auto_retryable_now() {
+                               if let PendingOutboundPayment::Retryable { pending_amt_msat, total_msat, .. } = pmt {
+                                       if pending_amt_msat < total_msat {
+                                               awaiting_retry = true;
+                                       }
+                               }
+                       }
+                       awaiting_retry
+               });
+
                let mut all_paths_failed = false;
                let mut full_failure_ev = None;
-               let mut pending_retry_ev = None;
+               let mut pending_retry_ev = false;
                let mut retry = None;
                let attempts_remaining = if let hash_map::Entry::Occupied(mut payment) = outbounds.entry(*payment_id) {
                        if !payment.get_mut().remove(&session_priv_bytes, Some(&path)) {
                                log_trace!(logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
-                               return
+                               return false
                        }
                        if payment.get().is_fulfilled() {
                                log_trace!(logger, "Received failure of HTLC with payment_hash {} after payment completion", log_bytes!(payment_hash.0));
-                               return
+                               return false
                        }
                        let mut is_retryable_now = payment.get().is_auto_retryable_now();
                        if let Some(scid) = short_channel_id {
@@ -1084,7 +1105,7 @@ impl OutboundPayments {
                        is_retryable_now
                } else {
                        log_trace!(logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
-                       return
+                       return false
                };
                core::mem::drop(outbounds);
                log_trace!(logger, "Failing outbound payment HTLC with payment_hash {}", log_bytes!(payment_hash.0));
@@ -1114,11 +1135,9 @@ impl OutboundPayments {
                                }
                                // If we miss abandoning the payment above, we *must* generate an event here or else the
                                // payment will sit in our outbounds forever.
-                               if attempts_remaining {
+                               if attempts_remaining && !already_awaiting_retry {
                                        debug_assert!(full_failure_ev.is_none());
-                                       pending_retry_ev = Some(events::Event::PendingHTLCsForwardable {
-                                               time_forwardable: Duration::from_millis(MIN_HTLC_RELAY_HOLDING_CELL_MILLIS),
-                                       });
+                                       pending_retry_ev = true;
                                }
                                events::Event::PaymentPathFailed {
                                        payment_id: Some(*payment_id),
@@ -1139,7 +1158,7 @@ impl OutboundPayments {
                let mut pending_events = pending_events.lock().unwrap();
                pending_events.push(path_failure);
                if let Some(ev) = full_failure_ev { pending_events.push(ev); }
-               if let Some(ev) = pending_retry_ev { pending_events.push(ev); }
+               pending_retry_ev
        }
 
        pub(super) fn abandon_payment(
index ef9a9d3755bbe870b501e85fef7c66f50a66fa5d..96a04f417d031978c1a7f6116cf04e133c225d11 100644 (file)
@@ -253,8 +253,8 @@ fn no_pending_leak_on_initial_send_failure() {
 
        let (route, payment_hash, _, payment_secret) = get_route_and_payment_hash!(nodes[0], nodes[1], 100_000);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        unwrap_send_err!(nodes[0].node.send_payment(&route, payment_hash, &Some(payment_secret), PaymentId(payment_hash.0)),
                true, APIError::ChannelUnavailable { ref err },
@@ -310,8 +310,8 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        // We relay the payment to nodes[1] while its disconnected from nodes[2], causing the payment
        // to be returned immediately to nodes[0], without having nodes[2] fail the inbound payment
        // which would prevent retry.
-       nodes[1].node.peer_disconnected(&nodes[2].node.get_our_node_id(), false);
-       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[2].node.get_our_node_id());
+       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id());
 
        nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &payment_event.msgs[0]);
        commitment_signed_dance!(nodes[1], nodes[0], payment_event.commitment_msg, false, true);
@@ -340,13 +340,13 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        assert_eq!(as_broadcasted_txn.len(), 1);
        assert_eq!(as_broadcasted_txn[0], as_commitment_tx);
 
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty());
 
        // Now nodes[1] should send a channel reestablish, which nodes[0] will respond to with an
        // error, as the channel has hit the chain.
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let bs_reestablish = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
        nodes[0].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &bs_reestablish);
        let as_err = nodes[0].node.get_and_clear_pending_msg_events();
@@ -496,7 +496,7 @@ fn do_test_completed_payment_not_retryable_on_reload(use_dust: bool) {
        let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
 
        reload_node!(nodes[0], test_default_channel_config(), nodes_0_serialized, &[&chan_0_monitor_serialized], first_persister, first_new_chain_monitor, first_nodes_0_deserialized);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        // On reload, the ChannelManager should realize it is stale compared to the ChannelMonitor and
        // force-close the channel.
@@ -505,12 +505,12 @@ fn do_test_completed_payment_not_retryable_on_reload(use_dust: bool) {
        assert!(nodes[0].node.has_pending_payments());
        assert_eq!(nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0).len(), 1);
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty());
 
        // Now nodes[1] should send a channel reestablish, which nodes[0] will respond to with an
        // error, as the channel has hit the chain.
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let bs_reestablish = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
        nodes[0].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &bs_reestablish);
        let as_err = nodes[0].node.get_and_clear_pending_msg_events();
@@ -591,7 +591,7 @@ fn do_test_completed_payment_not_retryable_on_reload(use_dust: bool) {
        assert!(!nodes[0].node.get_and_clear_pending_msg_events().is_empty());
 
        reload_node!(nodes[0], test_default_channel_config(), nodes_0_serialized, &[&chan_0_monitor_serialized, &chan_1_monitor_serialized], second_persister, second_new_chain_monitor, second_nodes_0_deserialized);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        reconnect_nodes(&nodes[0], &nodes[1], (true, true), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
@@ -615,7 +615,7 @@ fn do_test_completed_payment_not_retryable_on_reload(use_dust: bool) {
        // Check that after reload we can send the payment again (though we shouldn't, since it was
        // claimed previously).
        reload_node!(nodes[0], test_default_channel_config(), nodes_0_serialized, &[&chan_0_monitor_serialized, &chan_1_monitor_serialized], third_persister, third_new_chain_monitor, third_nodes_0_deserialized);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
@@ -660,8 +660,8 @@ fn do_test_dup_htlc_onchain_fails_on_reload(persist_manager_post_event: bool, co
        check_added_monitors!(nodes[0], 1);
        check_closed_event!(nodes[0], 1, ClosureReason::HolderForceClosed);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        // Connect blocks until the CLTV timeout is up so that we get an HTLC-Timeout transaction
        connect_blocks(&nodes[0], TEST_FINAL_CLTV + LATENCY_GRACE_PERIOD_BLOCKS + 1);
@@ -812,7 +812,7 @@ fn test_fulfill_restart_failure() {
        // Now reload nodes[1]...
        reload_node!(nodes[1], &chan_manager_serialized, &[&chan_0_monitor_serialized], persister, new_chain_monitor, nodes_1_deserialized);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
        reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
        nodes[1].node.fail_htlc_backwards(&payment_hash);
@@ -1721,8 +1721,9 @@ fn do_automatic_retries(test: AutoRetry) {
                let chan_1_monitor_serialized = get_monitor!(nodes[0], channel_id_1).encode();
                reload_node!(nodes[0], node_encoded, &[&chan_1_monitor_serialized], persister, new_chain_monitor, node_0_deserialized);
 
+               let mut events = nodes[0].node.get_and_clear_pending_events();
+               expect_pending_htlcs_forwardable_from_events!(nodes[0], events, true);
                // Make sure we don't retry again.
-               nodes[0].node.process_pending_htlc_forwards();
                let mut msg_events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(msg_events.len(), 0);
 
@@ -2348,7 +2349,7 @@ fn no_extra_retries_on_back_to_back_fail() {
        // Because we now retry payments as a batch, we simply return a single-path route in the
        // second, batched, request, have that fail, ensure the payment was abandoned.
        let mut events = nodes[0].node.get_and_clear_pending_events();
-       assert_eq!(events.len(), 4);
+       assert_eq!(events.len(), 3);
        match events[0] {
                Event::PaymentPathFailed { payment_hash: ev_payment_hash, payment_failed_permanently, ..  } => {
                        assert_eq!(payment_hash, ev_payment_hash);
@@ -2367,10 +2368,6 @@ fn no_extra_retries_on_back_to_back_fail() {
                },
                _ => panic!("Unexpected event"),
        }
-       match events[3] {
-               Event::PendingHTLCsForwardable { .. } => {},
-               _ => panic!("Unexpected event"),
-       }
 
        nodes[0].node.process_pending_htlc_forwards();
        let retry_htlc_updates = SendEvent::from_node(&nodes[0]);
index 8c3ab968af20f92e3a544d7d1fea024837bd26e6..e9eaf33e8840b021afe7fbcfc60afec3e0fa5b3d 100644 (file)
@@ -79,7 +79,7 @@ impl RoutingMessageHandler for IgnoringMessageHandler {
        fn get_next_channel_announcement(&self, _starting_point: u64) ->
                Option<(msgs::ChannelAnnouncement, Option<msgs::ChannelUpdate>, Option<msgs::ChannelUpdate>)> { None }
        fn get_next_node_announcement(&self, _starting_point: Option<&NodeId>) -> Option<msgs::NodeAnnouncement> { None }
-       fn peer_connected(&self, _their_node_id: &PublicKey, _init: &msgs::Init) -> Result<(), ()> { Ok(()) }
+       fn peer_connected(&self, _their_node_id: &PublicKey, _init: &msgs::Init, _inbound: bool) -> Result<(), ()> { Ok(()) }
        fn handle_reply_channel_range(&self, _their_node_id: &PublicKey, _msg: msgs::ReplyChannelRange) -> Result<(), LightningError> { Ok(()) }
        fn handle_reply_short_channel_ids_end(&self, _their_node_id: &PublicKey, _msg: msgs::ReplyShortChannelIdsEnd) -> Result<(), LightningError> { Ok(()) }
        fn handle_query_channel_range(&self, _their_node_id: &PublicKey, _msg: msgs::QueryChannelRange) -> Result<(), LightningError> { Ok(()) }
@@ -95,8 +95,8 @@ impl OnionMessageProvider for IgnoringMessageHandler {
 }
 impl OnionMessageHandler for IgnoringMessageHandler {
        fn handle_onion_message(&self, _their_node_id: &PublicKey, _msg: &msgs::OnionMessage) {}
-       fn peer_connected(&self, _their_node_id: &PublicKey, _init: &msgs::Init) -> Result<(), ()> { Ok(()) }
-       fn peer_disconnected(&self, _their_node_id: &PublicKey, _no_connection_possible: bool) {}
+       fn peer_connected(&self, _their_node_id: &PublicKey, _init: &msgs::Init, _inbound: bool) -> Result<(), ()> { Ok(()) }
+       fn peer_disconnected(&self, _their_node_id: &PublicKey) {}
        fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() }
        fn provided_init_features(&self, _their_node_id: &PublicKey) -> InitFeatures {
                InitFeatures::empty()
@@ -230,8 +230,8 @@ impl ChannelMessageHandler for ErroringMessageHandler {
        }
        // msgs::ChannelUpdate does not contain the channel_id field, so we just drop them.
        fn handle_channel_update(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelUpdate) {}
-       fn peer_disconnected(&self, _their_node_id: &PublicKey, _no_connection_possible: bool) {}
-       fn peer_connected(&self, _their_node_id: &PublicKey, _init: &msgs::Init) -> Result<(), ()> { Ok(()) }
+       fn peer_disconnected(&self, _their_node_id: &PublicKey) {}
+       fn peer_connected(&self, _their_node_id: &PublicKey, _init: &msgs::Init, _inbound: bool) -> Result<(), ()> { Ok(()) }
        fn handle_error(&self, _their_node_id: &PublicKey, _msg: &msgs::ErrorMessage) {}
        fn provided_node_features(&self) -> NodeFeatures { NodeFeatures::empty() }
        fn provided_init_features(&self, _their_node_id: &PublicKey) -> InitFeatures {
@@ -322,16 +322,7 @@ pub trait SocketDescriptor : cmp::Eq + hash::Hash + Clone {
 /// generate no further read_event/write_buffer_space_avail/socket_disconnected calls for the
 /// descriptor.
 #[derive(Clone)]
-pub struct PeerHandleError {
-       /// Used to indicate that we probably can't make any future connections to this peer (e.g.
-       /// because we required features that our peer was missing, or vice versa).
-       ///
-       /// While LDK's [`ChannelManager`] will not do it automatically, you likely wish to force-close
-       /// any channels with this peer or check for new versions of LDK.
-       ///
-       /// [`ChannelManager`]: crate::ln::channelmanager::ChannelManager
-       pub no_connection_possible: bool,
-}
+pub struct PeerHandleError { }
 impl fmt::Debug for PeerHandleError {
        fn fmt(&self, formatter: &mut fmt::Formatter) -> Result<(), fmt::Error> {
                formatter.write_str("Peer Sent Invalid Data")
@@ -400,6 +391,12 @@ struct Peer {
        /// We cache a `NodeId` here to avoid serializing peers' keys every time we forward gossip
        /// messages in `PeerManager`. Use `Peer::set_their_node_id` to modify this field.
        their_node_id: Option<(PublicKey, NodeId)>,
+       /// The features provided in the peer's [`msgs::Init`] message.
+       ///
+       /// This is set only after we've processed the [`msgs::Init`] message and called relevant
+       /// `peer_connected` handler methods. Thus, this field is set *iff* we've finished our
+       /// handshake and can talk to this peer normally (though use [`Peer::handshake_complete`] to
+       /// check this.
        their_features: Option<InitFeatures>,
        their_net_address: Option<NetAddress>,
 
@@ -428,9 +425,18 @@ struct Peer {
        /// `channel_announcement` at all - we set this unconditionally but unset it every time we
        /// check if we're gossip-processing-backlogged).
        received_channel_announce_since_backlogged: bool,
+
+       inbound_connection: bool,
 }
 
 impl Peer {
+       /// True after we've processed the [`msgs::Init`] message and called relevant `peer_connected`
+       /// handler methods. Thus, this implies we've finished our handshake and can talk to this peer
+       /// normally.
+       fn handshake_complete(&self) -> bool {
+               self.their_features.is_some()
+       }
+
        /// Returns true if the channel announcements/updates for the given channel should be
        /// forwarded to this peer.
        /// If we are sending our routing table to this peer and we have not yet sent channel
@@ -438,6 +444,7 @@ impl Peer {
        /// point and we shouldn't send it yet to avoid sending duplicate updates. If we've already
        /// sent the old versions, we should send the update, and so return true here.
        fn should_forward_channel_announcement(&self, channel_id: u64) -> bool {
+               if !self.handshake_complete() { return false; }
                if self.their_features.as_ref().unwrap().supports_gossip_queries() &&
                        !self.sent_gossip_timestamp_filter {
                                return false;
@@ -451,6 +458,7 @@ impl Peer {
 
        /// Similar to the above, but for node announcements indexed by node_id.
        fn should_forward_node_announcement(&self, node_id: NodeId) -> bool {
+               if !self.handshake_complete() { return false; }
                if self.their_features.as_ref().unwrap().supports_gossip_queries() &&
                        !self.sent_gossip_timestamp_filter {
                                return false;
@@ -477,19 +485,20 @@ impl Peer {
        fn should_buffer_gossip_backfill(&self) -> bool {
                self.pending_outbound_buffer.is_empty() && self.gossip_broadcast_buffer.is_empty()
                        && self.msgs_sent_since_pong < BUFFER_DRAIN_MSGS_PER_TICK
+                       && self.handshake_complete()
        }
 
        /// Determines if we should push an onion message onto a peer's outbound buffer. This is checked
        /// every time the peer's buffer may have been drained.
        fn should_buffer_onion_message(&self) -> bool {
-               self.pending_outbound_buffer.is_empty()
+               self.pending_outbound_buffer.is_empty() && self.handshake_complete()
                        && self.msgs_sent_since_pong < BUFFER_DRAIN_MSGS_PER_TICK
        }
 
        /// Determines if we should push additional gossip broadcast messages onto a peer's outbound
        /// buffer. This is checked every time the peer's buffer may have been drained.
        fn should_buffer_gossip_broadcast(&self) -> bool {
-               self.pending_outbound_buffer.is_empty()
+               self.pending_outbound_buffer.is_empty() && self.handshake_complete()
                        && self.msgs_sent_since_pong < BUFFER_DRAIN_MSGS_PER_TICK
        }
 
@@ -771,8 +780,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                let peers = self.peers.read().unwrap();
                peers.values().filter_map(|peer_mutex| {
                        let p = peer_mutex.lock().unwrap();
-                       if !p.channel_encryptor.is_ready_for_encryption() || p.their_features.is_none() ||
-                               p.their_node_id.is_none() {
+                       if !p.handshake_complete() {
                                return None;
                        }
                        Some((p.their_node_id.unwrap().0, p.their_net_address.clone()))
@@ -830,6 +838,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                        sent_gossip_timestamp_filter: false,
 
                        received_channel_announce_since_backlogged: false,
+                       inbound_connection: false,
                })).is_some() {
                        panic!("PeerManager driver duplicated descriptors!");
                };
@@ -879,6 +888,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                        sent_gossip_timestamp_filter: false,
 
                        received_channel_announce_since_backlogged: false,
+                       inbound_connection: true,
                })).is_some() {
                        panic!("PeerManager driver duplicated descriptors!");
                };
@@ -1001,7 +1011,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                // This is most likely a simple race condition where the user found that the socket
                                // was writeable, then we told the user to `disconnect_socket()`, then they called
                                // this method. Return an error to make sure we get disconnected.
-                               return Err(PeerHandleError { no_connection_possible: false });
+                               return Err(PeerHandleError { });
                        },
                        Some(peer_mutex) => {
                                let mut peer = peer_mutex.lock().unwrap();
@@ -1034,7 +1044,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                        Ok(res) => Ok(res),
                        Err(e) => {
                                log_trace!(self.logger, "Peer sent invalid data or we decided to disconnect due to a protocol error");
-                               self.disconnect_event_internal(peer_descriptor, e.no_connection_possible);
+                               self.disconnect_event_internal(peer_descriptor);
                                Err(e)
                        }
                }
@@ -1067,7 +1077,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                // This is most likely a simple race condition where the user read some bytes
                                // from the socket, then we told the user to `disconnect_socket()`, then they
                                // called this method. Return an error to make sure we get disconnected.
-                               return Err(PeerHandleError { no_connection_possible: false });
+                               return Err(PeerHandleError { });
                        },
                        Some(peer_mutex) => {
                                let mut read_pos = 0;
@@ -1081,7 +1091,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                                msgs::ErrorAction::DisconnectPeer { msg: _ } => {
                                                                                        //TODO: Try to push msg
                                                                                        log_debug!(self.logger, "Error handling message{}; disconnecting peer with: {}", OptionalFromDebugger(&peer_node_id), e.err);
-                                                                                       return Err(PeerHandleError{ no_connection_possible: false });
+                                                                                       return Err(PeerHandleError { });
                                                                                },
                                                                                msgs::ErrorAction::IgnoreAndLog(level) => {
                                                                                        log_given_level!(self.logger, level, "Error handling message{}; ignoring: {}", OptionalFromDebugger(&peer_node_id), e.err);
@@ -1134,7 +1144,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                        hash_map::Entry::Occupied(_) => {
                                                                                log_trace!(self.logger, "Got second connection with {}, closing", log_pubkey!(peer.their_node_id.unwrap().0));
                                                                                peer.their_node_id = None; // Unset so that we don't generate a peer_disconnected event
-                                                                               return Err(PeerHandleError{ no_connection_possible: false })
+                                                                               return Err(PeerHandleError { })
                                                                        },
                                                                        hash_map::Entry::Vacant(entry) => {
                                                                                log_debug!(self.logger, "Finished noise handshake for connection with {}", log_pubkey!(peer.their_node_id.unwrap().0));
@@ -1191,7 +1201,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                        if peer.pending_read_buffer.capacity() > 8192 { peer.pending_read_buffer = Vec::new(); }
                                                                        peer.pending_read_buffer.resize(msg_len as usize + 16, 0);
                                                                        if msg_len < 2 { // Need at least the message type tag
-                                                                               return Err(PeerHandleError{ no_connection_possible: false });
+                                                                               return Err(PeerHandleError { });
                                                                        }
                                                                        peer.pending_read_is_header = false;
                                                                } else {
@@ -1234,19 +1244,19 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                                                                (msgs::DecodeError::UnknownRequiredFeature, ty) => {
                                                                                                        log_gossip!(self.logger, "Received a message with an unknown required feature flag or TLV, you may want to update!");
                                                                                                        self.enqueue_message(peer, &msgs::WarningMessage { channel_id: [0; 32], data: format!("Received an unknown required feature/TLV in message type {:?}", ty) });
-                                                                                                       return Err(PeerHandleError { no_connection_possible: false });
+                                                                                                       return Err(PeerHandleError { });
                                                                                                }
-                                                                                               (msgs::DecodeError::UnknownVersion, _) => return Err(PeerHandleError { no_connection_possible: false }),
+                                                                                               (msgs::DecodeError::UnknownVersion, _) => return Err(PeerHandleError { }),
                                                                                                (msgs::DecodeError::InvalidValue, _) => {
                                                                                                        log_debug!(self.logger, "Got an invalid value while deserializing message");
-                                                                                                       return Err(PeerHandleError { no_connection_possible: false });
+                                                                                                       return Err(PeerHandleError { });
                                                                                                }
                                                                                                (msgs::DecodeError::ShortRead, _) => {
                                                                                                        log_debug!(self.logger, "Deserialization failed due to shortness of message");
-                                                                                                       return Err(PeerHandleError { no_connection_possible: false });
+                                                                                                       return Err(PeerHandleError { });
                                                                                                }
-                                                                                               (msgs::DecodeError::BadLengthDescriptor, _) => return Err(PeerHandleError { no_connection_possible: false }),
-                                                                                               (msgs::DecodeError::Io(_), _) => return Err(PeerHandleError { no_connection_possible: false }),
+                                                                                               (msgs::DecodeError::BadLengthDescriptor, _) => return Err(PeerHandleError { }),
+                                                                                               (msgs::DecodeError::Io(_), _) => return Err(PeerHandleError { }),
                                                                                        }
                                                                                }
                                                                        };
@@ -1298,10 +1308,10 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                if let wire::Message::Init(msg) = message {
                        if msg.features.requires_unknown_bits() {
                                log_debug!(self.logger, "Peer features required unknown version bits");
-                               return Err(PeerHandleError{ no_connection_possible: true }.into());
+                               return Err(PeerHandleError { }.into());
                        }
                        if peer_lock.their_features.is_some() {
-                               return Err(PeerHandleError{ no_connection_possible: false }.into());
+                               return Err(PeerHandleError { }.into());
                        }
 
                        log_info!(self.logger, "Received peer Init message from {}: {}", log_pubkey!(their_node_id), msg.features);
@@ -1311,24 +1321,24 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                peer_lock.sync_status = InitSyncTracker::ChannelsSyncing(0);
                        }
 
-                       if let Err(()) = self.message_handler.route_handler.peer_connected(&their_node_id, &msg) {
+                       if let Err(()) = self.message_handler.route_handler.peer_connected(&their_node_id, &msg, peer_lock.inbound_connection) {
                                log_debug!(self.logger, "Route Handler decided we couldn't communicate with peer {}", log_pubkey!(their_node_id));
-                               return Err(PeerHandleError{ no_connection_possible: true }.into());
+                               return Err(PeerHandleError { }.into());
                        }
-                       if let Err(()) = self.message_handler.chan_handler.peer_connected(&their_node_id, &msg) {
+                       if let Err(()) = self.message_handler.chan_handler.peer_connected(&their_node_id, &msg, peer_lock.inbound_connection) {
                                log_debug!(self.logger, "Channel Handler decided we couldn't communicate with peer {}", log_pubkey!(their_node_id));
-                               return Err(PeerHandleError{ no_connection_possible: true }.into());
+                               return Err(PeerHandleError { }.into());
                        }
-                       if let Err(()) = self.message_handler.onion_message_handler.peer_connected(&their_node_id, &msg) {
+                       if let Err(()) = self.message_handler.onion_message_handler.peer_connected(&their_node_id, &msg, peer_lock.inbound_connection) {
                                log_debug!(self.logger, "Onion Message Handler decided we couldn't communicate with peer {}", log_pubkey!(their_node_id));
-                               return Err(PeerHandleError{ no_connection_possible: true }.into());
+                               return Err(PeerHandleError { }.into());
                        }
 
                        peer_lock.their_features = Some(msg.features);
                        return Ok(None);
                } else if peer_lock.their_features.is_none() {
                        log_debug!(self.logger, "Peer {} sent non-Init first message", log_pubkey!(their_node_id));
-                       return Err(PeerHandleError{ no_connection_possible: false }.into());
+                       return Err(PeerHandleError { }.into());
                }
 
                if let wire::Message::GossipTimestampFilter(_msg) = message {
@@ -1380,7 +1390,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                }
                                self.message_handler.chan_handler.handle_error(&their_node_id, &msg);
                                if msg.channel_id == [0; 32] {
-                                       return Err(PeerHandleError{ no_connection_possible: true }.into());
+                                       return Err(PeerHandleError { }.into());
                                }
                        },
                        wire::Message::Warning(msg) => {
@@ -1510,8 +1520,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                        // Unknown messages:
                        wire::Message::Unknown(type_id) if message.is_even() => {
                                log_debug!(self.logger, "Received unknown even message of type {}, disconnecting peer!", type_id);
-                               // Fail the channel if message is an even, unknown type as per BOLT #1.
-                               return Err(PeerHandleError{ no_connection_possible: true }.into());
+                               return Err(PeerHandleError { }.into());
                        },
                        wire::Message::Unknown(type_id) => {
                                log_trace!(self.logger, "Received unknown odd message of type {}, ignoring", type_id);
@@ -1531,10 +1540,12 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
 
                                for (_, peer_mutex) in peers.iter() {
                                        let mut peer = peer_mutex.lock().unwrap();
-                                       if !peer.channel_encryptor.is_ready_for_encryption() || peer.their_features.is_none() ||
+                                       if !peer.handshake_complete() ||
                                                        !peer.should_forward_channel_announcement(msg.contents.short_channel_id) {
                                                continue
                                        }
+                                       debug_assert!(peer.their_node_id.is_some());
+                                       debug_assert!(peer.channel_encryptor.is_ready_for_encryption());
                                        if peer.buffer_full_drop_gossip_broadcast() {
                                                log_gossip!(self.logger, "Skipping broadcast message to {:?} as its outbound buffer is full", peer.their_node_id);
                                                continue;
@@ -1556,10 +1567,12 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
 
                                for (_, peer_mutex) in peers.iter() {
                                        let mut peer = peer_mutex.lock().unwrap();
-                                       if !peer.channel_encryptor.is_ready_for_encryption() || peer.their_features.is_none() ||
+                                       if !peer.handshake_complete() ||
                                                        !peer.should_forward_node_announcement(msg.contents.node_id) {
                                                continue
                                        }
+                                       debug_assert!(peer.their_node_id.is_some());
+                                       debug_assert!(peer.channel_encryptor.is_ready_for_encryption());
                                        if peer.buffer_full_drop_gossip_broadcast() {
                                                log_gossip!(self.logger, "Skipping broadcast message to {:?} as its outbound buffer is full", peer.their_node_id);
                                                continue;
@@ -1581,10 +1594,12 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
 
                                for (_, peer_mutex) in peers.iter() {
                                        let mut peer = peer_mutex.lock().unwrap();
-                                       if !peer.channel_encryptor.is_ready_for_encryption() || peer.their_features.is_none() ||
+                                       if !peer.handshake_complete() ||
                                                        !peer.should_forward_channel_announcement(msg.contents.short_channel_id)  {
                                                continue
                                        }
+                                       debug_assert!(peer.their_node_id.is_some());
+                                       debug_assert!(peer.channel_encryptor.is_ready_for_encryption());
                                        if peer.buffer_full_drop_gossip_broadcast() {
                                                log_gossip!(self.logger, "Skipping broadcast message to {:?} as its outbound buffer is full", peer.their_node_id);
                                                continue;
@@ -1664,7 +1679,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                                        Some(descriptor) => match peers.get(&descriptor) {
                                                                Some(peer_mutex) => {
                                                                        let peer_lock = peer_mutex.lock().unwrap();
-                                                                       if peer_lock.their_features.is_none() {
+                                                                       if !peer_lock.handshake_complete() {
                                                                                continue;
                                                                        }
                                                                        peer_lock
@@ -1884,24 +1899,21 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                // thread can be holding the peer lock if we have the global write
                                // lock).
 
-                               if let Some(mut descriptor) = self.node_id_to_descriptor.lock().unwrap().remove(&node_id) {
+                               let descriptor_opt = self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
+                               if let Some(mut descriptor) = descriptor_opt {
                                        if let Some(peer_mutex) = peers.remove(&descriptor) {
+                                               let mut peer = peer_mutex.lock().unwrap();
                                                if let Some(msg) = msg {
                                                        log_trace!(self.logger, "Handling DisconnectPeer HandleError event in peer_handler for node {} with message {}",
                                                                        log_pubkey!(node_id),
                                                                        msg.data);
-                                                       let mut peer = peer_mutex.lock().unwrap();
                                                        self.enqueue_message(&mut *peer, &msg);
                                                        // This isn't guaranteed to work, but if there is enough free
                                                        // room in the send buffer, put the error message there...
                                                        self.do_attempt_write_data(&mut descriptor, &mut *peer, false);
-                                               } else {
-                                                       log_trace!(self.logger, "Handling DisconnectPeer HandleError event in peer_handler for node {} with no message", log_pubkey!(node_id));
                                                }
+                                               self.do_disconnect(descriptor, &*peer, "DisconnectPeer HandleError");
                                        }
-                                       descriptor.disconnect_socket();
-                                       self.message_handler.chan_handler.peer_disconnected(&node_id, false);
-                                       self.message_handler.onion_message_handler.peer_disconnected(&node_id, false);
                                }
                        }
                }
@@ -1909,10 +1921,26 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
 
        /// Indicates that the given socket descriptor's connection is now closed.
        pub fn socket_disconnected(&self, descriptor: &Descriptor) {
-               self.disconnect_event_internal(descriptor, false);
+               self.disconnect_event_internal(descriptor);
        }
 
-       fn disconnect_event_internal(&self, descriptor: &Descriptor, no_connection_possible: bool) {
+       fn do_disconnect(&self, mut descriptor: Descriptor, peer: &Peer, reason: &'static str) {
+               if !peer.handshake_complete() {
+                       log_trace!(self.logger, "Disconnecting peer which hasn't completed handshake due to {}", reason);
+                       descriptor.disconnect_socket();
+                       return;
+               }
+
+               debug_assert!(peer.their_node_id.is_some());
+               if let Some((node_id, _)) = peer.their_node_id {
+                       log_trace!(self.logger, "Disconnecting peer with id {} due to {}", node_id, reason);
+                       self.message_handler.chan_handler.peer_disconnected(&node_id);
+                       self.message_handler.onion_message_handler.peer_disconnected(&node_id);
+               }
+               descriptor.disconnect_socket();
+       }
+
+       fn disconnect_event_internal(&self, descriptor: &Descriptor) {
                let mut peers = self.peers.write().unwrap();
                let peer_option = peers.remove(descriptor);
                match peer_option {
@@ -1923,13 +1951,13 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                        },
                        Some(peer_lock) => {
                                let peer = peer_lock.lock().unwrap();
+                               if !peer.handshake_complete() { return; }
+                               debug_assert!(peer.their_node_id.is_some());
                                if let Some((node_id, _)) = peer.their_node_id {
-                                       log_trace!(self.logger,
-                                               "Handling disconnection of peer {}, with {}future connection to the peer possible.",
-                                               log_pubkey!(node_id), if no_connection_possible { "no " } else { "" });
+                                       log_trace!(self.logger, "Handling disconnection of peer {}", log_pubkey!(node_id));
                                        self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
-                                       self.message_handler.chan_handler.peer_disconnected(&node_id, no_connection_possible);
-                                       self.message_handler.onion_message_handler.peer_disconnected(&node_id, no_connection_possible);
+                                       self.message_handler.chan_handler.peer_disconnected(&node_id);
+                                       self.message_handler.onion_message_handler.peer_disconnected(&node_id);
                                }
                        }
                };
@@ -1937,21 +1965,17 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
 
        /// Disconnect a peer given its node id.
        ///
-       /// Set `no_connection_possible` to true to prevent any further connection with this peer,
-       /// force-closing any channels we have with it.
-       ///
        /// If a peer is connected, this will call [`disconnect_socket`] on the descriptor for the
        /// peer. Thus, be very careful about reentrancy issues.
        ///
        /// [`disconnect_socket`]: SocketDescriptor::disconnect_socket
-       pub fn disconnect_by_node_id(&self, node_id: PublicKey, no_connection_possible: bool) {
+       pub fn disconnect_by_node_id(&self, node_id: PublicKey) {
                let mut peers_lock = self.peers.write().unwrap();
-               if let Some(mut descriptor) = self.node_id_to_descriptor.lock().unwrap().remove(&node_id) {
-                       log_trace!(self.logger, "Disconnecting peer with id {} due to client request", node_id);
-                       peers_lock.remove(&descriptor);
-                       self.message_handler.chan_handler.peer_disconnected(&node_id, no_connection_possible);
-                       self.message_handler.onion_message_handler.peer_disconnected(&node_id, no_connection_possible);
-                       descriptor.disconnect_socket();
+               if let Some(descriptor) = self.node_id_to_descriptor.lock().unwrap().remove(&node_id) {
+                       let peer_opt = peers_lock.remove(&descriptor);
+                       if let Some(peer_mutex) = peer_opt {
+                               self.do_disconnect(descriptor, &*peer_mutex.lock().unwrap(), "client request");
+                       } else { debug_assert!(false, "node_id_to_descriptor thought we had a peer"); }
                }
        }
 
@@ -1962,13 +1986,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                let mut peers_lock = self.peers.write().unwrap();
                self.node_id_to_descriptor.lock().unwrap().clear();
                let peers = &mut *peers_lock;
-               for (mut descriptor, peer) in peers.drain() {
-                       if let Some((node_id, _)) = peer.lock().unwrap().their_node_id {
-                               log_trace!(self.logger, "Disconnecting peer with id {} due to client request to disconnect all peers", node_id);
-                               self.message_handler.chan_handler.peer_disconnected(&node_id, false);
-                               self.message_handler.onion_message_handler.peer_disconnected(&node_id, false);
-                       }
-                       descriptor.disconnect_socket();
+               for (descriptor, peer_mutex) in peers.drain() {
+                       self.do_disconnect(descriptor, &*peer_mutex.lock().unwrap(), "client request to disconnect all peers");
                }
        }
 
@@ -2009,7 +2028,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                let mut peer = peer_mutex.lock().unwrap();
                                if flush_read_disabled { peer.received_channel_announce_since_backlogged = false; }
 
-                               if !peer.channel_encryptor.is_ready_for_encryption() || peer.their_node_id.is_none() {
+                               if !peer.handshake_complete() {
                                        // The peer needs to complete its handshake before we can exchange messages. We
                                        // give peers one timer tick to complete handshake, reusing
                                        // `awaiting_pong_timer_tick_intervals` to track number of timer ticks taken
@@ -2021,6 +2040,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                                        }
                                        continue;
                                }
+                               debug_assert!(peer.channel_encryptor.is_ready_for_encryption());
+                               debug_assert!(peer.their_node_id.is_some());
 
                                loop { // Used as a `goto` to skip writing a Ping message.
                                        if peer.awaiting_pong_timer_tick_intervals == -1 {
@@ -2059,21 +2080,16 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, OM: Deref, L: Deref, CM
                if !descriptors_needing_disconnect.is_empty() {
                        {
                                let mut peers_lock = self.peers.write().unwrap();
-                               for descriptor in descriptors_needing_disconnect.iter() {
-                                       if let Some(peer) = peers_lock.remove(descriptor) {
-                                               if let Some((node_id, _)) = peer.lock().unwrap().their_node_id {
-                                                       log_trace!(self.logger, "Disconnecting peer with id {} due to ping timeout", node_id);
+                               for descriptor in descriptors_needing_disconnect {
+                                       if let Some(peer_mutex) = peers_lock.remove(&descriptor) {
+                                               let peer = peer_mutex.lock().unwrap();
+                                               if let Some((node_id, _)) = peer.their_node_id {
                                                        self.node_id_to_descriptor.lock().unwrap().remove(&node_id);
-                                                       self.message_handler.chan_handler.peer_disconnected(&node_id, false);
-                                                       self.message_handler.onion_message_handler.peer_disconnected(&node_id, false);
                                                }
+                                               self.do_disconnect(descriptor, &*peer, "ping timeout");
                                        }
                                }
                        }
-
-                       for mut descriptor in descriptors_needing_disconnect.drain(..) {
-                               descriptor.disconnect_socket();
-                       }
                }
        }
 
@@ -2161,6 +2177,7 @@ fn is_gossip_msg(type_id: u16) -> bool {
 #[cfg(test)]
 mod tests {
        use crate::chain::keysinterface::{NodeSigner, Recipient};
+       use crate::ln::peer_channel_encryptor::PeerChannelEncryptor;
        use crate::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler, filter_addresses};
        use crate::ln::{msgs, wire};
        use crate::ln::msgs::NetAddress;
@@ -2269,19 +2286,15 @@ mod tests {
                // Simple test which builds a network of PeerManager, connects and brings them to NoiseState::Finished and
                // push a DisconnectPeer event to remove the node flagged by id
                let cfgs = create_peermgr_cfgs(2);
-               let chan_handler = test_utils::TestChannelMessageHandler::new();
-               let mut peers = create_network(2, &cfgs);
+               let peers = create_network(2, &cfgs);
                establish_connection(&peers[0], &peers[1]);
                assert_eq!(peers[0].peers.read().unwrap().len(), 1);
 
                let their_id = peers[1].node_signer.get_node_id(Recipient::Node).unwrap();
-
-               chan_handler.pending_events.lock().unwrap().push(events::MessageSendEvent::HandleError {
+               cfgs[0].chan_handler.pending_events.lock().unwrap().push(events::MessageSendEvent::HandleError {
                        node_id: their_id,
                        action: msgs::ErrorAction::DisconnectPeer { msg: None },
                });
-               assert_eq!(chan_handler.pending_events.lock().unwrap().len(), 1);
-               peers[0].message_handler.chan_handler = &chan_handler;
 
                peers[0].process_events();
                assert_eq!(peers[0].peers.read().unwrap().len(), 0);
@@ -2315,6 +2328,35 @@ mod tests {
                assert_eq!(peers[1].read_event(&mut fd_b, &a_data).unwrap(), false);
        }
 
+       #[test]
+       fn test_non_init_first_msg() {
+               // Simple test of the first message received over a connection being something other than
+               // Init. This results in an immediate disconnection, which previously included a spurious
+               // peer_disconnected event handed to event handlers (which would panic in
+               // `TestChannelMessageHandler` here).
+               let cfgs = create_peermgr_cfgs(2);
+               let peers = create_network(2, &cfgs);
+
+               let mut fd_dup = FileDescriptor { fd: 3, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let addr_dup = NetAddress::IPv4{addr: [127, 0, 0, 1], port: 1003};
+               let id_a = cfgs[0].node_signer.get_node_id(Recipient::Node).unwrap();
+               peers[0].new_inbound_connection(fd_dup.clone(), Some(addr_dup.clone())).unwrap();
+
+               let mut dup_encryptor = PeerChannelEncryptor::new_outbound(id_a, SecretKey::from_slice(&[42; 32]).unwrap());
+               let initial_data = dup_encryptor.get_act_one(&peers[1].secp_ctx);
+               assert_eq!(peers[0].read_event(&mut fd_dup, &initial_data).unwrap(), false);
+               peers[0].process_events();
+
+               let a_data = fd_dup.outbound_data.lock().unwrap().split_off(0);
+               let (act_three, _) =
+                       dup_encryptor.process_act_two(&a_data[..], &&cfgs[1].node_signer).unwrap();
+               assert_eq!(peers[0].read_event(&mut fd_dup, &act_three).unwrap(), false);
+
+               let not_init_msg = msgs::Ping { ponglen: 4, byteslen: 0 };
+               let msg_bytes = dup_encryptor.encrypt_message(&not_init_msg);
+               assert!(peers[0].read_event(&mut fd_dup, &msg_bytes).is_err());
+       }
+
        #[test]
        fn test_disconnect_all_peer() {
                // Simple test which builds a network of PeerManager, connects and brings them to NoiseState::Finished and
index adc2b59abbf4b3180ab6650e5433305b1fb112ef..7636f5c63641edc5e18cb5d7f69d3375a9fa235b 100644 (file)
@@ -90,8 +90,8 @@ fn test_priv_forwarding_rejection() {
        // Now disconnect nodes[1] from its peers and restart with accept_forwards_to_priv_channels set
        // to true. Sadly there is currently no way to change it at runtime.
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id());
 
        let nodes_1_serialized = nodes[1].node.encode();
        let monitor_a_serialized = get_monitor!(nodes[1], chan_id_1).encode();
@@ -100,8 +100,8 @@ fn test_priv_forwarding_rejection() {
        no_announce_cfg.accept_forwards_to_priv_channels = true;
        reload_node!(nodes[1], no_announce_cfg, &nodes_1_serialized, &[&monitor_a_serialized, &monitor_b_serialized], persister, new_chain_monitor, nodes_1_deserialized);
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let as_reestablish = get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap();
        let bs_reestablish = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
        nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &as_reestablish);
@@ -109,8 +109,8 @@ fn test_priv_forwarding_rejection() {
        get_event_msg!(nodes[0], MessageSendEvent::SendChannelUpdate, nodes[1].node.get_our_node_id());
        get_event_msg!(nodes[1], MessageSendEvent::SendChannelUpdate, nodes[0].node.get_our_node_id());
 
-       nodes[1].node.peer_connected(&nodes[2].node.get_our_node_id(), &msgs::Init { features: nodes[2].node.init_features(), remote_network_address: None }).unwrap();
-       nodes[2].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[2].node.get_our_node_id(), &msgs::Init { features: nodes[2].node.init_features(), remote_network_address: None }, true).unwrap();
+       nodes[2].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, false).unwrap();
        let bs_reestablish = get_chan_reestablish_msgs!(nodes[1], nodes[2]).pop().unwrap();
        let cs_reestablish = get_chan_reestablish_msgs!(nodes[2], nodes[1]).pop().unwrap();
        nodes[2].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &bs_reestablish);
index 8e5056dc68e45750e4bc354df973b86d403720ae..14deefd3f93828181f6cfad952d48913c8e25586 100644 (file)
@@ -44,8 +44,8 @@ fn test_funding_peer_disconnect() {
        let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
        let tx = create_chan_between_nodes_with_value_init(&nodes[0], &nodes[1], 100000, 10001);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        confirm_transaction(&nodes[0], &tx);
        let events_1 = nodes[0].node.get_and_clear_pending_msg_events();
@@ -53,16 +53,16 @@ fn test_funding_peer_disconnect() {
 
        reconnect_nodes(&nodes[0], &nodes[1], (false, true), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        confirm_transaction(&nodes[1], &tx);
        let events_2 = nodes[1].node.get_and_clear_pending_msg_events();
        assert!(events_2.is_empty());
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        let as_reestablish = get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let bs_reestablish = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
 
        // nodes[0] hasn't yet received a channel_ready, so it only sends that on reconnect.
@@ -169,7 +169,7 @@ fn test_funding_peer_disconnect() {
 
        // Check that after deserialization and reconnection we can still generate an identical
        // channel_announcement from the cached signatures.
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
 
@@ -190,15 +190,15 @@ fn test_no_txn_manager_serialize_deserialize() {
 
        let tx = create_chan_between_nodes_with_value_init(&nodes[0], &nodes[1], 100000, 10001);
 
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        let chan_0_monitor_serialized =
                get_monitor!(nodes[0], OutPoint { txid: tx.txid(), index: 0 }.to_channel_id()).encode();
        reload_node!(nodes[0], nodes[0].node.encode(), &[&chan_0_monitor_serialized], persister, new_chain_monitor, nodes_0_deserialized);
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
 
        nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &reestablish_1[0]);
@@ -267,7 +267,7 @@ fn test_manager_serialize_deserialize_events() {
        let chan_0_monitor_serialized = get_monitor!(nodes[0], bs_funding_signed.channel_id).encode();
        reload_node!(nodes[0], nodes[0].node.encode(), &[&chan_0_monitor_serialized], persister, new_chain_monitor, nodes_0_deserialized);
 
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        // After deserializing, make sure the funding_transaction is still held by the channel manager
        let events_4 = nodes[0].node.get_and_clear_pending_events();
@@ -278,9 +278,9 @@ fn test_manager_serialize_deserialize_events() {
        // Make sure the channel is functioning as though the de/serialization never happened
        assert_eq!(nodes[0].node.list_channels().len(), 1);
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let reestablish_2 = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
 
        nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &reestablish_1[0]);
@@ -313,7 +313,7 @@ fn test_simple_manager_serialize_deserialize() {
        let (our_payment_preimage, _, _) = route_payment(&nodes[0], &[&nodes[1]], 1000000);
        let (_, our_payment_hash, _) = route_payment(&nodes[0], &[&nodes[1]], 1000000);
 
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        let chan_0_monitor_serialized = get_monitor!(nodes[0], chan_id).encode();
        reload_node!(nodes[0], nodes[0].node.encode(), &[&chan_0_monitor_serialized], persister, new_chain_monitor, nodes_0_deserialized);
@@ -353,9 +353,9 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() {
        let nodes_0_serialized = nodes[0].node.encode();
 
        route_payment(&nodes[0], &[&nodes[3]], 1000000);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
-       nodes[2].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
-       nodes[3].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
+       nodes[2].node.peer_disconnected(&nodes[0].node.get_our_node_id());
+       nodes[3].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        // Now the ChannelMonitor (which is now out-of-sync with ChannelManager for channel w/
        // nodes[3])
@@ -443,9 +443,9 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() {
        //... and we can even still claim the payment!
        claim_payment(&nodes[2], &[&nodes[0], &nodes[1]], our_payment_preimage);
 
-       nodes[3].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[3].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, true).unwrap();
        let reestablish = get_chan_reestablish_msgs!(nodes[3], nodes[0]).pop().unwrap();
-       nodes[0].node.peer_connected(&nodes[3].node.get_our_node_id(), &msgs::Init { features: nodes[3].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[3].node.get_our_node_id(), &msgs::Init { features: nodes[3].node.init_features(), remote_network_address: None }, false).unwrap();
        nodes[0].node.handle_channel_reestablish(&nodes[3].node.get_our_node_id(), &reestablish);
        let mut found_err = false;
        for msg_event in nodes[0].node.get_and_clear_pending_msg_events() {
@@ -488,14 +488,14 @@ fn do_test_data_loss_protect(reconnect_panicing: bool) {
        send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000);
        send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
        reload_node!(nodes[0], previous_node_state, &[&previous_chain_monitor_state], persister, new_chain_monitor, nodes_0_deserialized);
 
        if reconnect_panicing {
-               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
-               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
+               nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
 
                let reestablish_1 = get_chan_reestablish_msgs!(nodes[0], nodes[1]);
 
@@ -543,8 +543,8 @@ fn do_test_data_loss_protect(reconnect_panicing: bool) {
        // after the warning message sent by B, we should not able to
        // use the channel, or reconnect with success to the channel.
        assert!(nodes[0].node.list_usable_channels().is_empty());
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let retry_reestablish = get_chan_reestablish_msgs!(nodes[1], nodes[0]);
 
        nodes[0].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &retry_reestablish[0]);
@@ -627,8 +627,8 @@ fn test_forwardable_regen() {
        assert!(nodes[1].node.get_and_clear_pending_events().is_empty());
 
        // Now restart nodes[1] and make sure it regenerates a single PendingHTLCsForwardable
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[2].node.peer_disconnected(&nodes[1].node.get_our_node_id());
 
        let chan_0_monitor_serialized = get_monitor!(nodes[1], chan_id_1).encode();
        let chan_1_monitor_serialized = get_monitor!(nodes[1], chan_id_2).encode();
@@ -753,8 +753,8 @@ fn do_test_partial_claim_before_restart(persist_both_monitors: bool) {
        assert!(get_monitor!(nodes[3], chan_id_persisted).get_stored_preimages().contains_key(&payment_hash));
        assert!(get_monitor!(nodes[3], chan_id_not_persisted).get_stored_preimages().contains_key(&payment_hash));
 
-       nodes[1].node.peer_disconnected(&nodes[3].node.get_our_node_id(), false);
-       nodes[2].node.peer_disconnected(&nodes[3].node.get_our_node_id(), false);
+       nodes[1].node.peer_disconnected(&nodes[3].node.get_our_node_id());
+       nodes[2].node.peer_disconnected(&nodes[3].node.get_our_node_id());
 
        // During deserialization, we should have closed one channel and broadcast its latest
        // commitment transaction. We should also still have the original PaymentClaimable event we
@@ -779,9 +779,9 @@ fn do_test_partial_claim_before_restart(persist_both_monitors: bool) {
        if !persist_both_monitors {
                // If one of the two channels is still live, reveal the payment preimage over it.
 
-               nodes[3].node.peer_connected(&nodes[2].node.get_our_node_id(), &msgs::Init { features: nodes[2].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[3].node.peer_connected(&nodes[2].node.get_our_node_id(), &msgs::Init { features: nodes[2].node.init_features(), remote_network_address: None }, true).unwrap();
                let reestablish_1 = get_chan_reestablish_msgs!(nodes[3], nodes[2]);
-               nodes[2].node.peer_connected(&nodes[3].node.get_our_node_id(), &msgs::Init { features: nodes[3].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[2].node.peer_connected(&nodes[3].node.get_our_node_id(), &msgs::Init { features: nodes[3].node.init_features(), remote_network_address: None }, false).unwrap();
                let reestablish_2 = get_chan_reestablish_msgs!(nodes[2], nodes[3]);
 
                nodes[2].node.handle_channel_reestablish(&nodes[3].node.get_our_node_id(), &reestablish_1[0]);
@@ -923,7 +923,7 @@ fn do_forwarded_payment_no_manager_persistence(use_cs_commitment: bool, claim_ht
        let bs_commitment_tx = nodes[1].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0);
        assert_eq!(bs_commitment_tx.len(), 1);
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), true);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
        reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
        if use_cs_commitment {
@@ -1038,7 +1038,7 @@ fn removed_payment_no_manager_persistence() {
        // Now that the ChannelManager has force-closed the channel which had the HTLC removed, it is
        // now forgotten everywhere. The ChannelManager should have, as a side-effect of reload,
        // learned that the HTLC is gone from the ChannelMonitor and added it to the to-fail-back set.
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), true);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
        reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
 
        expect_pending_htlcs_forwardable_and_htlc_handling_failed!(nodes[1], [HTLCDestination::NextHopChannel { node_id: Some(nodes[2].node.get_our_node_id()), channel_id: chan_id_2 }]);
index ce1f3b64750bafbd7343ff66f47e8b16677d30a4..69e0c9afa9e49c056c7d081ef8e42913d7492d55 100644 (file)
@@ -378,7 +378,7 @@ fn do_test_unconf_chan(reload_node: bool, reorg_after_reload: bool, use_funding_
                // If we dropped the channel before reloading the node, nodes[1] was also dropped from
                // nodes[0] storage, and hence not connected again on startup. We therefore need to
                // reconnect to the node before attempting to create a new channel.
-               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+               nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        }
        create_announced_chan_between_nodes(&nodes, 0, 1);
        send_payment(&nodes[0], &[&nodes[1]], 8000000);
index 5a5f054846f8d4e66d720aaa3bc77a29b1b7da6e..75eacb74bc2fe7352a99288e28fcef82eac4c81d 100644 (file)
@@ -243,12 +243,12 @@ fn do_test_shutdown_rebroadcast(recv_count: u8) {
                }
        }
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, true).unwrap();
        let node_0_reestablish = get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap();
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, false).unwrap();
        let node_1_reestablish = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
 
        nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &node_0_reestablish);
@@ -305,12 +305,12 @@ fn do_test_shutdown_rebroadcast(recv_count: u8) {
                assert!(node_0_2nd_closing_signed.is_some());
        }
 
-       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id(), false);
-       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false);
+       nodes[0].node.peer_disconnected(&nodes[1].node.get_our_node_id());
+       nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id());
 
-       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: nodes[0].node.init_features(), remote_network_address: None }, true).unwrap();
        let node_1_2nd_reestablish = get_chan_reestablish_msgs!(nodes[1], nodes[0]).pop().unwrap();
-       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }).unwrap();
+       nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: nodes[1].node.init_features(), remote_network_address: None }, false).unwrap();
        if recv_count == 0 {
                // If all closing_signeds weren't delivered we can just resume where we left off...
                let node_0_2nd_reestablish = get_chan_reestablish_msgs!(nodes[0], nodes[1]).pop().unwrap();
index 1309baf60eac7ed85a5493ec8ca102653cad6eb6..aee7f562defc5915517d696cf032353deee4e3d6 100644 (file)
@@ -85,8 +85,8 @@ fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
                let mut features = InitFeatures::empty();
                features.set_onion_messages_optional();
                let init_msg = msgs::Init { features, remote_network_address: None };
-               nodes[i].messenger.peer_connected(&nodes[i + 1].get_node_pk(), &init_msg.clone()).unwrap();
-               nodes[i + 1].messenger.peer_connected(&nodes[i].get_node_pk(), &init_msg.clone()).unwrap();
+               nodes[i].messenger.peer_connected(&nodes[i + 1].get_node_pk(), &init_msg.clone(), true).unwrap();
+               nodes[i + 1].messenger.peer_connected(&nodes[i].get_node_pk(), &init_msg.clone(), false).unwrap();
        }
        nodes
 }
index f7656e7949aa9df945010eadee9b0ba30eefb861..7814ccb6508c77a10abfa32ee189914493a62278 100644 (file)
@@ -417,7 +417,7 @@ impl<ES: Deref, NS: Deref, L: Deref, CMH: Deref> OnionMessageHandler for OnionMe
                };
        }
 
-       fn peer_connected(&self, their_node_id: &PublicKey, init: &msgs::Init) -> Result<(), ()> {
+       fn peer_connected(&self, their_node_id: &PublicKey, init: &msgs::Init, _inbound: bool) -> Result<(), ()> {
                if init.features.supports_onion_messages() {
                        let mut peers = self.pending_messages.lock().unwrap();
                        peers.insert(their_node_id.clone(), VecDeque::new());
@@ -425,7 +425,7 @@ impl<ES: Deref, NS: Deref, L: Deref, CMH: Deref> OnionMessageHandler for OnionMe
                Ok(())
        }
 
-       fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
+       fn peer_disconnected(&self, their_node_id: &PublicKey) {
                let mut pending_msgs = self.pending_messages.lock().unwrap();
                pending_msgs.remove(their_node_id);
        }
index 6c2d70bd6f309ed429af45876666673fc6421a35..cf51b6ab52827943626ede2cbc5d0e87e1f95366 100644 (file)
@@ -390,7 +390,7 @@ where U::Target: UtxoLookup, L::Target: Logger
        }
 
        fn get_next_channel_announcement(&self, starting_point: u64) -> Option<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> {
-               let channels = self.network_graph.channels.read().unwrap();
+               let mut channels = self.network_graph.channels.write().unwrap();
                for (_, ref chan) in channels.range(starting_point..) {
                        if chan.announcement_message.is_some() {
                                let chan_announcement = chan.announcement_message.clone().unwrap();
@@ -412,7 +412,7 @@ where U::Target: UtxoLookup, L::Target: Logger
        }
 
        fn get_next_node_announcement(&self, starting_point: Option<&NodeId>) -> Option<NodeAnnouncement> {
-               let nodes = self.network_graph.nodes.read().unwrap();
+               let mut nodes = self.network_graph.nodes.write().unwrap();
                let iter = if let Some(node_id) = starting_point {
                                nodes.range((Bound::Excluded(node_id), Bound::Unbounded))
                        } else {
@@ -437,7 +437,7 @@ where U::Target: UtxoLookup, L::Target: Logger
        /// to request gossip messages for each channel. The sync is considered complete
        /// when the final reply_scids_end message is received, though we are not
        /// tracking this directly.
-       fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &Init) -> Result<(), ()> {
+       fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &Init, _inbound: bool) -> Result<(), ()> {
                // We will only perform a sync with peers that support gossip_queries.
                if !init_msg.features.supports_gossip_queries() {
                        // Don't disconnect peers for not supporting gossip queries. We may wish to have
@@ -572,7 +572,7 @@ where U::Target: UtxoLookup, L::Target: Logger
                // (has at least one update). A peer may still want to know the channel
                // exists even if its not yet routable.
                let mut batches: Vec<Vec<u64>> = vec![Vec::with_capacity(MAX_SCIDS_PER_REPLY)];
-               let channels = self.network_graph.channels.read().unwrap();
+               let mut channels = self.network_graph.channels.write().unwrap();
                for (_, ref chan) in channels.range(inclusive_start_scid.unwrap()..exclusive_end_scid.unwrap()) {
                        if let Some(chan_announcement) = &chan.announcement_message {
                                // Construct a new batch if last one is full
@@ -2791,7 +2791,7 @@ pub(crate) mod tests {
                // It should ignore if gossip_queries feature is not enabled
                {
                        let init_msg = Init { features: InitFeatures::empty(), remote_network_address: None };
-                       gossip_sync.peer_connected(&node_id_1, &init_msg).unwrap();
+                       gossip_sync.peer_connected(&node_id_1, &init_msg, true).unwrap();
                        let events = gossip_sync.get_and_clear_pending_msg_events();
                        assert_eq!(events.len(), 0);
                }
@@ -2801,7 +2801,7 @@ pub(crate) mod tests {
                        let mut features = InitFeatures::empty();
                        features.set_gossip_queries_optional();
                        let init_msg = Init { features, remote_network_address: None };
-                       gossip_sync.peer_connected(&node_id_1, &init_msg).unwrap();
+                       gossip_sync.peer_connected(&node_id_1, &init_msg, true).unwrap();
                        let events = gossip_sync.get_and_clear_pending_msg_events();
                        assert_eq!(events.len(), 1);
                        match &events[0] {
index cccbfe7bc7a43434a12b8c7608dff2d935e5c1dc..3d45172517db3e706dcbf4884dd75923c872c8a5 100644 (file)
@@ -1,10 +1,11 @@
 //! This module has a map which can be iterated in a deterministic order. See the [`IndexedMap`].
 
 use crate::prelude::{HashMap, hash_map};
-use alloc::collections::{BTreeSet, btree_set};
+use alloc::vec::Vec;
+use alloc::slice::Iter;
 use core::hash::Hash;
 use core::cmp::Ord;
-use core::ops::RangeBounds;
+use core::ops::{Bound, RangeBounds};
 
 /// A map which can be iterated in a deterministic order.
 ///
@@ -21,11 +22,10 @@ use core::ops::RangeBounds;
 /// keys in the order defined by [`Ord`].
 ///
 /// [`BTreeMap`]: alloc::collections::BTreeMap
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, Eq)]
 pub struct IndexedMap<K: Hash + Ord, V> {
        map: HashMap<K, V>,
-       // TODO: Explore swapping this for a sorted vec (that is only sorted on first range() call)
-       keys: BTreeSet<K>,
+       keys: Vec<K>,
 }
 
 impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
@@ -33,7 +33,7 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
        pub fn new() -> Self {
                Self {
                        map: HashMap::new(),
-                       keys: BTreeSet::new(),
+                       keys: Vec::new(),
                }
        }
 
@@ -58,7 +58,8 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
        pub fn remove(&mut self, key: &K) -> Option<V> {
                let ret = self.map.remove(key);
                if let Some(_) = ret {
-                       assert!(self.keys.remove(key), "map and keys must be consistent");
+                       let idx = self.keys.iter().position(|k| k == key).expect("map and keys must be consistent");
+                       self.keys.remove(idx);
                }
                ret
        }
@@ -68,7 +69,7 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
        pub fn insert(&mut self, key: K, value: V) -> Option<V> {
                let ret = self.map.insert(key.clone(), value);
                if ret.is_none() {
-                       assert!(self.keys.insert(key), "map and keys must be consistent");
+                       self.keys.push(key);
                }
                ret
        }
@@ -109,9 +110,21 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
        }
 
        /// Returns an iterator which iterates over the `key`/`value` pairs in a given range.
-       pub fn range<R: RangeBounds<K>>(&self, range: R) -> Range<K, V> {
+       pub fn range<R: RangeBounds<K>>(&mut self, range: R) -> Range<K, V> {
+               self.keys.sort_unstable();
+               let start = match range.start_bound() {
+                       Bound::Unbounded => 0,
+                       Bound::Included(key) => self.keys.binary_search(key).unwrap_or_else(|index| index),
+                       Bound::Excluded(key) => self.keys.binary_search(key).and_then(|index| Ok(index + 1)).unwrap_or_else(|index| index),
+               };
+               let end = match range.end_bound() {
+                       Bound::Unbounded => self.keys.len(),
+                       Bound::Included(key) => self.keys.binary_search(key).and_then(|index| Ok(index + 1)).unwrap_or_else(|index| index),
+                       Bound::Excluded(key) => self.keys.binary_search(key).unwrap_or_else(|index| index),
+               };
+
                Range {
-                       inner_range: self.keys.range(range),
+                       inner_range: self.keys[start..end].iter(),
                        map: &self.map,
                }
        }
@@ -127,9 +140,15 @@ impl<K: Clone + Hash + Ord, V> IndexedMap<K, V> {
        }
 }
 
+impl<K: Hash + Ord + PartialEq, V: PartialEq> PartialEq for IndexedMap<K, V> {
+       fn eq(&self, other: &Self) -> bool {
+               self.map == other.map
+       }
+}
+
 /// An iterator over a range of values in an [`IndexedMap`]
 pub struct Range<'a, K: Hash + Ord, V> {
-       inner_range: btree_set::Range<'a, K>,
+       inner_range: Iter<'a, K>,
        map: &'a HashMap<K, V>,
 }
 impl<'a, K: Hash + Ord, V: 'a> Iterator for Range<'a, K, V> {
@@ -148,7 +167,7 @@ pub struct VacantEntry<'a, K: Hash + Ord, V> {
        #[cfg(not(feature = "hashbrown"))]
        underlying_entry: hash_map::VacantEntry<'a, K, V>,
        key: K,
-       keys: &'a mut BTreeSet<K>,
+       keys: &'a mut Vec<K>,
 }
 
 /// An [`Entry`] for an existing key-value pair
@@ -157,7 +176,7 @@ pub struct OccupiedEntry<'a, K: Hash + Ord, V> {
        underlying_entry: hash_map::OccupiedEntry<'a, K, V, hash_map::DefaultHashBuilder>,
        #[cfg(not(feature = "hashbrown"))]
        underlying_entry: hash_map::OccupiedEntry<'a, K, V>,
-       keys: &'a mut BTreeSet<K>,
+       keys: &'a mut Vec<K>,
 }
 
 /// A mutable reference to a position in the map. This can be used to reference, add, or update the
@@ -172,7 +191,7 @@ pub enum Entry<'a, K: Hash + Ord, V> {
 impl<'a, K: Hash + Ord, V> VacantEntry<'a, K, V> {
        /// Insert a value into the position described by this entry.
        pub fn insert(self, value: V) -> &'a mut V {
-               assert!(self.keys.insert(self.key), "map and keys must be consistent");
+               self.keys.push(self.key);
                self.underlying_entry.insert(value)
        }
 }
@@ -181,7 +200,8 @@ impl<'a, K: Hash + Ord, V> OccupiedEntry<'a, K, V> {
        /// Remove the value at the position described by this entry.
        pub fn remove_entry(self) -> (K, V) {
                let res = self.underlying_entry.remove_entry();
-               assert!(self.keys.remove(&res.0), "map and keys must be consistent");
+               let idx = self.keys.iter().position(|k| k == &res.0).expect("map and keys must be consistent");
+               self.keys.remove(idx);
                res
        }
 
index b2b5cacfe97e1e25d0f1e3dc7a6463b771fdfcea..344f5b4c62d01736afd85654064735c68f0456bd 100644 (file)
@@ -331,6 +331,7 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster {
 pub struct TestChannelMessageHandler {
        pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
        expected_recv_msgs: Mutex<Option<Vec<wire::Message<()>>>>,
+       connected_peers: Mutex<HashSet<PublicKey>>,
 }
 
 impl TestChannelMessageHandler {
@@ -338,6 +339,7 @@ impl TestChannelMessageHandler {
                TestChannelMessageHandler {
                        pending_events: Mutex::new(Vec::new()),
                        expected_recv_msgs: Mutex::new(None),
+                       connected_peers: Mutex::new(HashSet::new()),
                }
        }
 
@@ -422,8 +424,11 @@ impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
        fn handle_channel_reestablish(&self, _their_node_id: &PublicKey, msg: &msgs::ChannelReestablish) {
                self.received_msg(wire::Message::ChannelReestablish(msg.clone()));
        }
-       fn peer_disconnected(&self, _their_node_id: &PublicKey, _no_connection_possible: bool) {}
-       fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) -> Result<(), ()> {
+       fn peer_disconnected(&self, their_node_id: &PublicKey) {
+               assert!(self.connected_peers.lock().unwrap().remove(their_node_id));
+       }
+       fn peer_connected(&self, their_node_id: &PublicKey, _msg: &msgs::Init, _inbound: bool) -> Result<(), ()> {
+               assert!(self.connected_peers.lock().unwrap().insert(their_node_id.clone()));
                // Don't bother with `received_msg` for Init as its auto-generated and we don't want to
                // bother re-generating the expected Init message in all tests.
                Ok(())
@@ -539,7 +544,7 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
                None
        }
 
-       fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &msgs::Init) -> Result<(), ()> {
+       fn peer_connected(&self, their_node_id: &PublicKey, init_msg: &msgs::Init, _inbound: bool) -> Result<(), ()> {
                if !init_msg.features.supports_gossip_queries() {
                        return Ok(());
                }