Merge pull request #2923 from tnull/2024-03-improve-best-block
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index 45e5de51a0ef770dfbcf03d54705233b67dbf904..adf2768b4152700ed82172e9a44ab75c6c3475a2 100644 (file)
@@ -487,16 +487,38 @@ impl<'a, 'b, 'c> Node<'a, 'b, 'c> {
        /// `release_commitment_secret` are affected by this setting.
        #[cfg(test)]
        pub fn set_channel_signer_available(&self, peer_id: &PublicKey, chan_id: &ChannelId, available: bool) {
+               use crate::sign::ChannelSigner;
+               log_debug!(self.logger, "Setting channel signer for {} as available={}", chan_id, available);
+
                let per_peer_state = self.node.per_peer_state.read().unwrap();
                let chan_lock = per_peer_state.get(peer_id).unwrap().lock().unwrap();
-               let signer = (|| {
-                       match chan_lock.channel_by_id.get(chan_id) {
-                               Some(phase) => phase.context().get_signer(),
-                               None => panic!("Couldn't find a channel with id {}", chan_id),
+
+               let mut channel_keys_id = None;
+               if let Some(chan) = chan_lock.channel_by_id.get(chan_id).map(|phase| phase.context()) {
+                       chan.get_signer().as_ecdsa().unwrap().set_available(available);
+                       channel_keys_id = Some(chan.channel_keys_id);
+               }
+
+               let mut monitor = None;
+               for (funding_txo, channel_id) in self.chain_monitor.chain_monitor.list_monitors() {
+                       if *chan_id == channel_id {
+                               monitor = self.chain_monitor.chain_monitor.get_monitor(funding_txo).ok();
                        }
-               })();
-               log_debug!(self.logger, "Setting channel signer for {} as available={}", chan_id, available);
-               signer.as_ecdsa().unwrap().set_available(available);
+               }
+               if let Some(monitor) = monitor {
+                       monitor.do_signer_call(|signer| {
+                               channel_keys_id = channel_keys_id.or(Some(signer.inner.channel_keys_id()));
+                               signer.set_available(available)
+                       });
+               }
+
+               if available {
+                       self.keys_manager.unavailable_signers.lock().unwrap()
+                               .remove(channel_keys_id.as_ref().unwrap());
+               } else {
+                       self.keys_manager.unavailable_signers.lock().unwrap()
+                               .insert(channel_keys_id.unwrap());
+               }
        }
 }
 
@@ -618,7 +640,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                        // Before using all the new monitors to check the watch outpoints, use the full set of
                        // them to ensure we can write and reload our ChannelManager.
                        {
-                               let mut channel_monitors = HashMap::new();
+                               let mut channel_monitors = new_hash_map();
                                for monitor in deserialized_monitors.iter_mut() {
                                        channel_monitors.insert(monitor.get_funding_txo().0, monitor);
                                }
@@ -1049,7 +1071,7 @@ pub fn _reload_node<'a, 'b, 'c>(node: &'a Node<'a, 'b, 'c>, default_config: User
 
        let mut node_read = &chanman_encoded[..];
        let (_, node_deserialized) = {
-               let mut channel_monitors = HashMap::new();
+               let mut channel_monitors = new_hash_map();
                for monitor in monitors_read.iter_mut() {
                        assert!(channel_monitors.insert(monitor.get_funding_txo().0, monitor).is_none());
                }
@@ -1203,7 +1225,7 @@ pub fn open_zero_conf_channel<'a, 'b, 'c, 'd>(initiator: &'a Node<'b, 'c, 'd>, r
        };
 
        let accept_channel = get_event_msg!(receiver, MessageSendEvent::SendAcceptChannel, initiator.node.get_our_node_id());
-       assert_eq!(accept_channel.minimum_depth, 0);
+       assert_eq!(accept_channel.common_fields.minimum_depth, 0);
        initiator.node.handle_accept_channel(&receiver.node.get_our_node_id(), &accept_channel);
 
        let (temporary_channel_id, tx, _) = create_funding_transaction(&initiator, &receiver.node.get_our_node_id(), 100_000, 42);
@@ -1257,7 +1279,7 @@ pub fn open_zero_conf_channel<'a, 'b, 'c, 'd>(initiator: &'a Node<'b, 'c, 'd>, r
 pub fn exchange_open_accept_chan<'a, 'b, 'c>(node_a: &Node<'a, 'b, 'c>, node_b: &Node<'a, 'b, 'c>, channel_value: u64, push_msat: u64) -> ChannelId {
        let create_chan_id = node_a.node.create_channel(node_b.node.get_our_node_id(), channel_value, push_msat, 42, None, None).unwrap();
        let open_channel_msg = get_event_msg!(node_a, MessageSendEvent::SendOpenChannel, node_b.node.get_our_node_id());
-       assert_eq!(open_channel_msg.temporary_channel_id, create_chan_id);
+       assert_eq!(open_channel_msg.common_fields.temporary_channel_id, create_chan_id);
        assert_eq!(node_a.node.list_channels().iter().find(|channel| channel.channel_id == create_chan_id).unwrap().user_channel_id, 42);
        node_b.node.handle_open_channel(&node_a.node.get_our_node_id(), &open_channel_msg);
        if node_b.node.get_current_default_configuration().manually_accept_inbound_channels {
@@ -1270,7 +1292,7 @@ pub fn exchange_open_accept_chan<'a, 'b, 'c>(node_a: &Node<'a, 'b, 'c>, node_b:
                };
        }
        let accept_channel_msg = get_event_msg!(node_b, MessageSendEvent::SendAcceptChannel, node_a.node.get_our_node_id());
-       assert_eq!(accept_channel_msg.temporary_channel_id, create_chan_id);
+       assert_eq!(accept_channel_msg.common_fields.temporary_channel_id, create_chan_id);
        node_a.node.handle_accept_channel(&node_b.node.get_our_node_id(), &accept_channel_msg);
        assert_ne!(node_b.node.list_channels().iter().find(|channel| channel.channel_id == create_chan_id).unwrap().user_channel_id, 0);
 
@@ -2203,14 +2225,19 @@ macro_rules! expect_payment_path_successful {
 
 pub fn expect_payment_forwarded<CM: AChannelManager, H: NodeHolder<CM=CM>>(
        event: Event, node: &H, prev_node: &H, next_node: &H, expected_fee: Option<u64>,
-       upstream_force_closed: bool, downstream_force_closed: bool
+       expected_extra_fees_msat: Option<u64>, upstream_force_closed: bool,
+       downstream_force_closed: bool
 ) {
        match event {
                Event::PaymentForwarded {
                        total_fee_earned_msat, prev_channel_id, claim_from_onchain_tx, next_channel_id,
-                       outbound_amount_forwarded_msat: _
+                       outbound_amount_forwarded_msat: _, skimmed_fee_msat
                } => {
                        assert_eq!(total_fee_earned_msat, expected_fee);
+
+                       // Check that the (knowingly) withheld amount is always less or equal to the expected
+                       // overpaid amount.
+                       assert!(skimmed_fee_msat == expected_extra_fees_msat);
                        if !upstream_force_closed {
                                // Is the event prev_channel_id in one of the channels between the two nodes?
                                assert!(node.node().list_channels().iter().any(|x| x.counterparty.node_id == prev_node.node().get_our_node_id() && x.channel_id == prev_channel_id.unwrap()));
@@ -2226,13 +2253,15 @@ pub fn expect_payment_forwarded<CM: AChannelManager, H: NodeHolder<CM=CM>>(
        }
 }
 
+#[macro_export]
 macro_rules! expect_payment_forwarded {
        ($node: expr, $prev_node: expr, $next_node: expr, $expected_fee: expr, $upstream_force_closed: expr, $downstream_force_closed: expr) => {
                let mut events = $node.node.get_and_clear_pending_events();
                assert_eq!(events.len(), 1);
                $crate::ln::functional_test_utils::expect_payment_forwarded(
-                       events.pop().unwrap(), &$node, &$prev_node, &$next_node, $expected_fee,
-                       $upstream_force_closed, $downstream_force_closed);
+                       events.pop().unwrap(), &$node, &$prev_node, &$next_node, $expected_fee, None,
+                       $upstream_force_closed, $downstream_force_closed
+               );
        }
 }
 
@@ -2696,11 +2725,17 @@ pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArg
                                                        channel.context().config().forwarding_fee_base_msat
                                                }
                                        };
+
+                                       let mut expected_extra_fee = None;
                                        if $idx == 1 {
                                                fee += expected_extra_fees[i];
                                                fee += expected_min_htlc_overpay[i];
+                                               expected_extra_fee = if expected_extra_fees[i] > 0 { Some(expected_extra_fees[i] as u64) } else { None };
                                        }
-                                       expect_payment_forwarded!(*$node, $next_node, $prev_node, Some(fee as u64), false, false);
+                                       let mut events = $node.node.get_and_clear_pending_events();
+                                       assert_eq!(events.len(), 1);
+                                       expect_payment_forwarded(events.pop().unwrap(), *$node, $next_node, $prev_node,
+                                               Some(fee as u64), expected_extra_fee, false, false);
                                        expected_total_fee_msat += fee as u64;
                                        check_added_monitors!($node, 1);
                                        let new_next_msgs = if $new_msgs {
@@ -2967,7 +3002,7 @@ pub fn create_node_cfgs_with_persisters<'a>(node_count: usize, chanmon_cfgs: &'a
                        tx_broadcaster: &chanmon_cfgs[i].tx_broadcaster,
                        fee_estimator: &chanmon_cfgs[i].fee_estimator,
                        router: test_utils::TestRouter::new(network_graph.clone(), &chanmon_cfgs[i].logger, &chanmon_cfgs[i].scorer),
-                       message_router: test_utils::TestMessageRouter::new(network_graph.clone()),
+                       message_router: test_utils::TestMessageRouter::new(network_graph.clone(), &chanmon_cfgs[i].keys_manager),
                        chain_monitor,
                        keys_manager: &chanmon_cfgs[i].keys_manager,
                        node_seed: seed,
@@ -3091,7 +3126,7 @@ pub enum HTLCType { NONE, TIMEOUT, SUCCESS }
 /// also fail.
 pub fn test_txn_broadcast<'a, 'b, 'c>(node: &Node<'a, 'b, 'c>, chan: &(msgs::ChannelUpdate, msgs::ChannelUpdate, ChannelId, Transaction), commitment_tx: Option<Transaction>, has_htlc_tx: HTLCType) -> Vec<Transaction>  {
        let mut node_txn = node.tx_broadcaster.txn_broadcasted.lock().unwrap();
-       let mut txn_seen = HashSet::new();
+       let mut txn_seen = new_hash_set();
        node_txn.retain(|tx| txn_seen.insert(tx.txid()));
        assert!(node_txn.len() >= if commitment_tx.is_some() { 0 } else { 1 } + if has_htlc_tx == HTLCType::NONE { 0 } else { 1 });
 
@@ -3156,7 +3191,7 @@ pub fn test_revoked_htlc_claim_txn_broadcast<'a, 'b, 'c>(node: &Node<'a, 'b, 'c>
 
 pub fn check_preimage_claim<'a, 'b, 'c>(node: &Node<'a, 'b, 'c>, prev_txn: &Vec<Transaction>) -> Vec<Transaction>  {
        let mut node_txn = node.tx_broadcaster.txn_broadcasted.lock().unwrap();
-       let mut txn_seen = HashSet::new();
+       let mut txn_seen = new_hash_set();
        node_txn.retain(|tx| txn_seen.insert(tx.txid()));
 
        let mut found_prev = false;
@@ -3256,7 +3291,7 @@ macro_rules! get_channel_value_stat {
 macro_rules! get_chan_reestablish_msgs {
        ($src_node: expr, $dst_node: expr) => {
                {
-                       let mut announcements = $crate::prelude::HashSet::new();
+                       let mut announcements = $crate::prelude::new_hash_set();
                        let mut res = Vec::with_capacity(1);
                        for msg in $src_node.node.get_and_clear_pending_msg_events() {
                                if let MessageSendEvent::SendChannelReestablish { ref node_id, ref msg } = msg {