Remove one tab level when accessing a `peer_state`
[rust-lightning] / lightning / src / ln / payment_tests.rs
index faac5e53180ff9ed3927acf51d67c8e76d51731b..af31b4c012c1489a4ef9144abe866a5959423faf 100644 (file)
@@ -367,7 +367,7 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]);
        let persister: test_utils::TestPersister;
        let new_chain_monitor: test_utils::TestChainMonitor;
-       let nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let mut nodes = create_network(3, &node_cfgs, &node_chanmgrs);
 
        let chan_id = create_announced_chan_between_nodes(&nodes, 0, 1, channelmanager::provided_init_features(), channelmanager::provided_init_features()).2;
@@ -436,7 +436,7 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
                MessageSendEvent::HandleError { node_id, action: msgs::ErrorAction::SendErrorMessage { ref msg } } => {
                        assert_eq!(node_id, nodes[1].node.get_our_node_id());
                        nodes[1].node.handle_error(&nodes[0].node.get_our_node_id(), msg);
-                       check_closed_event!(nodes[1], 1, ClosureReason::CounterpartyForceClosed { peer_msg: "Failed to find corresponding channel".to_string() });
+                       check_closed_event!(nodes[1], 1, ClosureReason::CounterpartyForceClosed { peer_msg: format!("Got a message for a channel from the wrong node! No such channel for the passed counterparty_node_id {}", &nodes[1].node.get_our_node_id()) });
                        check_added_monitors!(nodes[1], 1);
                        assert_eq!(nodes[1].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0).len(), 1);
                },
@@ -500,8 +500,10 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        // and not the original fee. We also update node[1]'s relevant config as
        // do_claim_payment_along_route expects us to never overpay.
        {
-               let mut channel_state = nodes[1].node.channel_state.lock().unwrap();
-               let mut channel = channel_state.by_id.get_mut(&chan_id_2).unwrap();
+               let per_peer_state = nodes[1].node.per_peer_state.read().unwrap();
+               let mut peer_state = per_peer_state.get(&nodes[2].node.get_our_node_id())
+                       .unwrap().lock().unwrap();
+               let mut channel = peer_state.channel_by_id.get_mut(&chan_id_2).unwrap();
                let mut new_config = channel.config();
                new_config.forwarding_fee_base_msat += 100_000;
                channel.update_config(&new_config);
@@ -545,13 +547,13 @@ fn do_test_completed_payment_not_retryable_on_reload(use_dust: bool) {
 
        let first_persister: test_utils::TestPersister;
        let first_new_chain_monitor: test_utils::TestChainMonitor;
-       let first_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let first_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let second_persister: test_utils::TestPersister;
        let second_new_chain_monitor: test_utils::TestChainMonitor;
-       let second_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let second_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let third_persister: test_utils::TestPersister;
        let third_new_chain_monitor: test_utils::TestChainMonitor;
-       let third_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let third_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
 
        let mut nodes = create_network(3, &node_cfgs, &node_chanmgrs);
 
@@ -599,7 +601,7 @@ fn do_test_completed_payment_not_retryable_on_reload(use_dust: bool) {
                MessageSendEvent::HandleError { node_id, action: msgs::ErrorAction::SendErrorMessage { ref msg } } => {
                        assert_eq!(node_id, nodes[1].node.get_our_node_id());
                        nodes[1].node.handle_error(&nodes[0].node.get_our_node_id(), msg);
-                       check_closed_event!(nodes[1], 1, ClosureReason::CounterpartyForceClosed { peer_msg: "Failed to find corresponding channel".to_string() });
+                       check_closed_event!(nodes[1], 1, ClosureReason::CounterpartyForceClosed { peer_msg: format!("Got a message for a channel from the wrong node! No such channel for the passed counterparty_node_id {}", &nodes[1].node.get_our_node_id()) });
                        check_added_monitors!(nodes[1], 1);
                        bs_commitment_tx = nodes[1].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0);
                },
@@ -716,7 +718,7 @@ fn do_test_dup_htlc_onchain_fails_on_reload(persist_manager_post_event: bool, co
        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
        let persister: test_utils::TestPersister;
        let new_chain_monitor: test_utils::TestChainMonitor;
-       let nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
 
        let (_, _, chan_id, funding_tx) = create_announced_chan_between_nodes(&nodes, 0, 1, channelmanager::provided_init_features(), channelmanager::provided_init_features());
@@ -751,10 +753,8 @@ fn do_test_dup_htlc_onchain_fails_on_reload(persist_manager_post_event: bool, co
        check_added_monitors!(nodes[1], 1);
        check_closed_event!(nodes[1], 1, ClosureReason::CommitmentTxConfirmed);
        let claim_txn = nodes[1].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0);
-       assert_eq!(claim_txn.len(), 3);
+       assert_eq!(claim_txn.len(), 1);
        check_spends!(claim_txn[0], node_txn[1]);
-       check_spends!(claim_txn[1], funding_tx);
-       check_spends!(claim_txn[2], claim_txn[1]);
 
        header.prev_blockhash = nodes[0].best_block_hash();
        connect_block(&nodes[0], &Block { header, txdata: vec![node_txn[1].clone()]});
@@ -861,7 +861,7 @@ fn test_fulfill_restart_failure() {
        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
        let persister: test_utils::TestPersister;
        let new_chain_monitor: test_utils::TestChainMonitor;
-       let nodes_1_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let nodes_1_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
 
        let chan_id = create_announced_chan_between_nodes(&nodes, 0, 1, channelmanager::provided_init_features(), channelmanager::provided_init_features()).2;
@@ -1269,10 +1269,12 @@ fn test_trivial_inflight_htlc_tracking(){
        {
                let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
 
-               let node_0_channel_lock = nodes[0].node.channel_state.lock().unwrap();
-               let node_1_channel_lock = nodes[1].node.channel_state.lock().unwrap();
-               let channel_1 = node_0_channel_lock.by_id.get(&chan_1_id).unwrap();
-               let channel_2 = node_1_channel_lock.by_id.get(&chan_2_id).unwrap();
+               let mut node_0_per_peer_lock;
+               let mut node_0_peer_state_lock;
+               let mut node_1_per_peer_lock;
+               let mut node_1_peer_state_lock;
+               let channel_1 =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
+               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
 
                let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
@@ -1294,10 +1296,12 @@ fn test_trivial_inflight_htlc_tracking(){
        {
                let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
 
-               let node_0_channel_lock = nodes[0].node.channel_state.lock().unwrap();
-               let node_1_channel_lock = nodes[1].node.channel_state.lock().unwrap();
-               let channel_1 = node_0_channel_lock.by_id.get(&chan_1_id).unwrap();
-               let channel_2 = node_1_channel_lock.by_id.get(&chan_2_id).unwrap();
+               let mut node_0_per_peer_lock;
+               let mut node_0_peer_state_lock;
+               let mut node_1_per_peer_lock;
+               let mut node_1_peer_state_lock;
+               let channel_1 =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
+               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
 
                let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
@@ -1320,10 +1324,12 @@ fn test_trivial_inflight_htlc_tracking(){
        {
                let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
 
-               let node_0_channel_lock = nodes[0].node.channel_state.lock().unwrap();
-               let node_1_channel_lock = nodes[1].node.channel_state.lock().unwrap();
-               let channel_1 = node_0_channel_lock.by_id.get(&chan_1_id).unwrap();
-               let channel_2 = node_1_channel_lock.by_id.get(&chan_2_id).unwrap();
+               let mut node_0_per_peer_lock;
+               let mut node_0_peer_state_lock;
+               let mut node_1_per_peer_lock;
+               let mut node_1_peer_state_lock;
+               let channel_1 =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
+               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
 
                let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
@@ -1364,8 +1370,9 @@ fn test_holding_cell_inflight_htlcs() {
        let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
 
        {
-               let channel_lock = nodes[0].node.channel_state.lock().unwrap();
-               let channel = channel_lock.by_id.get(&channel_id).unwrap();
+               let mut node_0_per_peer_lock;
+               let mut node_0_peer_state_lock;
+               let channel =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, channel_id);
 
                let used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
@@ -1467,7 +1474,7 @@ fn do_test_intercepted_payment(test: InterceptTest) {
 
        // Check for unknown channel id error.
        let unknown_chan_id_err = nodes[1].node.forward_intercepted_htlc(intercept_id, &[42; 32], nodes[2].node.get_our_node_id(), expected_outbound_amount_msat).unwrap_err();
-       assert_eq!(unknown_chan_id_err , APIError::ChannelUnavailable  { err: format!("Channel with id {} not found", log_bytes!([42; 32])) });
+       assert_eq!(unknown_chan_id_err , APIError::ChannelUnavailable  { err: format!("Channel with id {} not found for the passed counterparty node_id {}", log_bytes!([42; 32]), nodes[2].node.get_our_node_id()) });
 
        if test == InterceptTest::Fail {
                // Ensure we can fail the intercepted payment back.