Store channels per peer
[rust-lightning] / lightning / src / ln / payment_tests.rs
index 59f6909b837597cfec9c6a5bcf4d2dd17803505e..e69b628fc2dbd1ad0ce40c3931e1be1acf28826d 100644 (file)
 use crate::chain::{ChannelMonitorUpdateStatus, Confirm, Listen, Watch};
 use crate::chain::channelmonitor::{ANTI_REORG_DELAY, LATENCY_GRACE_PERIOD_BLOCKS};
 use crate::chain::transaction::OutPoint;
-use crate::chain::keysinterface::KeysInterface;
+use crate::chain::keysinterface::{EntropySource, KeysInterface};
 use crate::ln::channel::EXPIRE_PREV_CONFIG_TICKS;
-use crate::ln::channelmanager::{self, BREAKDOWN_TIMEOUT, ChannelManager, InterceptId, MPP_TIMEOUT_TICKS, MIN_CLTV_EXPIRY_DELTA, PaymentId, PaymentSendFailure, IDEMPOTENCY_TIMEOUT_TICKS};
+use crate::ln::channelmanager::{self, BREAKDOWN_TIMEOUT, ChannelManager, MPP_TIMEOUT_TICKS, MIN_CLTV_EXPIRY_DELTA, PaymentId, PaymentSendFailure, IDEMPOTENCY_TIMEOUT_TICKS};
 use crate::ln::msgs;
 use crate::ln::msgs::ChannelMessageHandler;
 use crate::routing::gossip::RoutingFees;
-use crate::routing::router::{find_route, get_route, PaymentParameters, RouteHint, RouteHintHop, RouteParameters};
+use crate::routing::router::{get_route, PaymentParameters, RouteHint, RouteHintHop, RouteParameters};
 use crate::util::events::{ClosureReason, Event, HTLCDestination, MessageSendEvent, MessageSendEventsProvider};
 use crate::util::test_utils;
 use crate::util::errors::APIError;
@@ -367,7 +367,7 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]);
        let persister: test_utils::TestPersister;
        let new_chain_monitor: test_utils::TestChainMonitor;
-       let nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let mut nodes = create_network(3, &node_cfgs, &node_chanmgrs);
 
        let chan_id = create_announced_chan_between_nodes(&nodes, 0, 1, channelmanager::provided_init_features(), channelmanager::provided_init_features()).2;
@@ -500,8 +500,10 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        // and not the original fee. We also update node[1]'s relevant config as
        // do_claim_payment_along_route expects us to never overpay.
        {
-               let mut channel_state = nodes[1].node.channel_state.lock().unwrap();
-               let mut channel = channel_state.by_id.get_mut(&chan_id_2).unwrap();
+               let per_peer_state = nodes[1].node.per_peer_state.read().unwrap();
+               let mut peer_state = per_peer_state.get(&nodes[2].node.get_our_node_id())
+                       .unwrap().lock().unwrap();
+               let mut channel = peer_state.channel_by_id.get_mut(&chan_id_2).unwrap();
                let mut new_config = channel.config();
                new_config.forwarding_fee_base_msat += 100_000;
                channel.update_config(&new_config);
@@ -545,13 +547,13 @@ fn do_test_completed_payment_not_retryable_on_reload(use_dust: bool) {
 
        let first_persister: test_utils::TestPersister;
        let first_new_chain_monitor: test_utils::TestChainMonitor;
-       let first_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let first_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let second_persister: test_utils::TestPersister;
        let second_new_chain_monitor: test_utils::TestChainMonitor;
-       let second_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let second_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let third_persister: test_utils::TestPersister;
        let third_new_chain_monitor: test_utils::TestChainMonitor;
-       let third_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let third_nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
 
        let mut nodes = create_network(3, &node_cfgs, &node_chanmgrs);
 
@@ -716,7 +718,7 @@ fn do_test_dup_htlc_onchain_fails_on_reload(persist_manager_post_event: bool, co
        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
        let persister: test_utils::TestPersister;
        let new_chain_monitor: test_utils::TestChainMonitor;
-       let nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let nodes_0_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
 
        let (_, _, chan_id, funding_tx) = create_announced_chan_between_nodes(&nodes, 0, 1, channelmanager::provided_init_features(), channelmanager::provided_init_features());
@@ -751,10 +753,8 @@ fn do_test_dup_htlc_onchain_fails_on_reload(persist_manager_post_event: bool, co
        check_added_monitors!(nodes[1], 1);
        check_closed_event!(nodes[1], 1, ClosureReason::CommitmentTxConfirmed);
        let claim_txn = nodes[1].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0);
-       assert_eq!(claim_txn.len(), 3);
+       assert_eq!(claim_txn.len(), 1);
        check_spends!(claim_txn[0], node_txn[1]);
-       check_spends!(claim_txn[1], funding_tx);
-       check_spends!(claim_txn[2], claim_txn[1]);
 
        header.prev_blockhash = nodes[0].best_block_hash();
        connect_block(&nodes[0], &Block { header, txdata: vec![node_txn[1].clone()]});
@@ -861,7 +861,7 @@ fn test_fulfill_restart_failure() {
        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
        let persister: test_utils::TestPersister;
        let new_chain_monitor: test_utils::TestChainMonitor;
-       let nodes_1_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
+       let nodes_1_deserialized: ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>;
        let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
 
        let chan_id = create_announced_chan_between_nodes(&nodes, 0, 1, channelmanager::provided_init_features(), channelmanager::provided_init_features()).2;
@@ -1243,6 +1243,13 @@ fn abandoned_send_payment_idempotent() {
        claim_payment(&nodes[0], &[&nodes[1]], second_payment_preimage);
 }
 
+#[derive(PartialEq)]
+enum InterceptTest {
+       Forward,
+       Fail,
+       Timeout,
+}
+
 #[test]
 fn test_trivial_inflight_htlc_tracking(){
        // In this test, we test three scenarios:
@@ -1262,10 +1269,12 @@ fn test_trivial_inflight_htlc_tracking(){
        {
                let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
 
-               let node_0_channel_lock = nodes[0].node.channel_state.lock().unwrap();
-               let node_1_channel_lock = nodes[1].node.channel_state.lock().unwrap();
-               let channel_1 = node_0_channel_lock.by_id.get(&chan_1_id).unwrap();
-               let channel_2 = node_1_channel_lock.by_id.get(&chan_2_id).unwrap();
+               let mut node_0_per_peer_lock;
+               let mut node_0_peer_state_lock;
+               let mut node_1_per_peer_lock;
+               let mut node_1_peer_state_lock;
+               let channel_1 =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
+               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
 
                let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
@@ -1287,10 +1296,12 @@ fn test_trivial_inflight_htlc_tracking(){
        {
                let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
 
-               let node_0_channel_lock = nodes[0].node.channel_state.lock().unwrap();
-               let node_1_channel_lock = nodes[1].node.channel_state.lock().unwrap();
-               let channel_1 = node_0_channel_lock.by_id.get(&chan_1_id).unwrap();
-               let channel_2 = node_1_channel_lock.by_id.get(&chan_2_id).unwrap();
+               let mut node_0_per_peer_lock;
+               let mut node_0_peer_state_lock;
+               let mut node_1_per_peer_lock;
+               let mut node_1_peer_state_lock;
+               let channel_1 =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
+               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
 
                let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
@@ -1313,10 +1324,12 @@ fn test_trivial_inflight_htlc_tracking(){
        {
                let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
 
-               let node_0_channel_lock = nodes[0].node.channel_state.lock().unwrap();
-               let node_1_channel_lock = nodes[1].node.channel_state.lock().unwrap();
-               let channel_1 = node_0_channel_lock.by_id.get(&chan_1_id).unwrap();
-               let channel_2 = node_1_channel_lock.by_id.get(&chan_2_id).unwrap();
+               let mut node_0_per_peer_lock;
+               let mut node_0_peer_state_lock;
+               let mut node_1_per_peer_lock;
+               let mut node_1_peer_state_lock;
+               let channel_1 =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id);
+               let channel_2 =  get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id);
 
                let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
@@ -1357,8 +1370,9 @@ fn test_holding_cell_inflight_htlcs() {
        let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs();
 
        {
-               let channel_lock = nodes[0].node.channel_state.lock().unwrap();
-               let channel = channel_lock.by_id.get(&channel_id).unwrap();
+               let mut node_0_per_peer_lock;
+               let mut node_0_peer_state_lock;
+               let channel =  get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, channel_id);
 
                let used_liquidity = inflight_htlcs.used_liquidity_msat(
                        &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) ,
@@ -1378,16 +1392,22 @@ fn intercepted_payment() {
        // Test that detecting an intercept scid on payment forward will signal LDK to generate an
        // intercept event, which the LSP can then use to either (a) open a JIT channel to forward the
        // payment or (b) fail the payment.
-       do_test_intercepted_payment(false);
-       do_test_intercepted_payment(true);
+       do_test_intercepted_payment(InterceptTest::Forward);
+       do_test_intercepted_payment(InterceptTest::Fail);
+       // Make sure that intercepted payments will be automatically failed back if too many blocks pass.
+       do_test_intercepted_payment(InterceptTest::Timeout);
 }
 
-fn do_test_intercepted_payment(fail_intercept: bool) {
+fn do_test_intercepted_payment(test: InterceptTest) {
        let chanmon_cfgs = create_chanmon_cfgs(3);
        let node_cfgs = create_node_cfgs(3, &chanmon_cfgs);
-       let mut chan_config = test_default_channel_config();
-       chan_config.manually_accept_inbound_channels = true;
-       let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, Some(chan_config)]);
+
+       let mut zero_conf_chan_config = test_default_channel_config();
+       zero_conf_chan_config.manually_accept_inbound_channels = true;
+       let mut intercept_forwards_config = test_default_channel_config();
+       intercept_forwards_config.accept_intercept_htlcs = true;
+       let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, Some(intercept_forwards_config), Some(zero_conf_chan_config)]);
+
        let nodes = create_network(3, &node_cfgs, &node_chanmgrs);
        let scorer = test_utils::TestScorer::with_penalty(0);
        let random_seed_bytes = chanmon_cfgs[0].keys_manager.get_secure_random_bytes();
@@ -1416,9 +1436,10 @@ fn do_test_intercepted_payment(fail_intercept: bool) {
                final_value_msat: amt_msat,
                final_cltv_expiry_delta: TEST_FINAL_CLTV,
        };
-       let route = find_route(
-               &nodes[0].node.get_our_node_id(), &route_params, &nodes[0].network_graph, None, nodes[0].logger,
-               &scorer, &random_seed_bytes
+       let route = get_route(
+               &nodes[0].node.get_our_node_id(), &route_params.payment_params,
+               &nodes[0].network_graph.read_only(), None, route_params.final_value_msat,
+               route_params.final_cltv_expiry_delta, nodes[0].logger, &scorer, &random_seed_bytes
        ).unwrap();
 
        let (payment_hash, payment_secret) = nodes[2].node.create_inbound_payment(Some(amt_msat), 60 * 60).unwrap();
@@ -1453,9 +1474,9 @@ fn do_test_intercepted_payment(fail_intercept: bool) {
 
        // Check for unknown channel id error.
        let unknown_chan_id_err = nodes[1].node.forward_intercepted_htlc(intercept_id, &[42; 32], nodes[2].node.get_our_node_id(), expected_outbound_amount_msat).unwrap_err();
-       assert_eq!(unknown_chan_id_err , APIError::APIMisuseError { err: format!("Channel with id {:?} not found", [42; 32]) });
+       assert_eq!(unknown_chan_id_err , APIError::ChannelUnavailable  { err: format!("Channel with id {} not found", log_bytes!([42; 32])) });
 
-       if fail_intercept {
+       if test == InterceptTest::Fail {
                // Ensure we can fail the intercepted payment back.
                nodes[1].node.fail_intercepted_htlc(intercept_id).unwrap();
                expect_pending_htlcs_forwardable_and_htlc_handling_failed_ignore!(nodes[1], vec![HTLCDestination::UnknownNextHop { requested_forward_scid: intercept_scid }]);
@@ -1473,15 +1494,16 @@ fn do_test_intercepted_payment(fail_intercept: bool) {
                        .blamed_chan_closed(true)
                        .expected_htlc_error_data(0x4000 | 10, &[]);
                expect_payment_failed_conditions(&nodes[0], payment_hash, false, fail_conditions);
-       } else {
+       } else if test == InterceptTest::Forward {
+               // Check that we'll fail as expected when sending to a channel that isn't in `ChannelReady` yet.
+               let temp_chan_id = nodes[1].node.create_channel(nodes[2].node.get_our_node_id(), 100_000, 0, 42, None).unwrap();
+               let unusable_chan_err = nodes[1].node.forward_intercepted_htlc(intercept_id, &temp_chan_id, nodes[2].node.get_our_node_id(), expected_outbound_amount_msat).unwrap_err();
+               assert_eq!(unusable_chan_err , APIError::ChannelUnavailable { err: format!("Channel with id {} not fully established", log_bytes!(temp_chan_id)) });
+               assert_eq!(nodes[1].node.get_and_clear_pending_msg_events().len(), 1);
+
                // Open the just-in-time channel so the payment can then be forwarded.
                let (_, channel_id) = open_zero_conf_channel(&nodes[1], &nodes[2], None);
 
-               // Check for unknown intercept id error.
-               let unknown_intercept_id = InterceptId([42; 32]);
-               let unknown_intercept_id_err = nodes[1].node.forward_intercepted_htlc(unknown_intercept_id, &channel_id, nodes[2].node.get_our_node_id(), expected_outbound_amount_msat).unwrap_err();
-               assert_eq!(unknown_intercept_id_err , APIError::APIMisuseError { err: format!("Payment with intercept id {:?} not found", unknown_intercept_id.0) });
-
                // Finally, forward the intercepted payment through and claim it.
                nodes[1].node.forward_intercepted_htlc(intercept_id, &channel_id, nodes[2].node.get_our_node_id(), expected_outbound_amount_msat).unwrap();
                expect_pending_htlcs_forwardable!(nodes[1]);
@@ -1501,7 +1523,7 @@ fn do_test_intercepted_payment(fail_intercept: bool) {
                expect_pending_htlcs_forwardable!(nodes[2]);
 
                let payment_preimage = nodes[2].node.get_payment_preimage(payment_hash, payment_secret).unwrap();
-               expect_payment_received!(&nodes[2], payment_hash, payment_secret, amt_msat, Some(payment_preimage), nodes[2].node.get_our_node_id());
+               expect_payment_claimable!(&nodes[2], payment_hash, payment_secret, amt_msat, Some(payment_preimage), nodes[2].node.get_our_node_id());
                do_claim_payment_along_route(&nodes[0], &vec!(&vec!(&nodes[1], &nodes[2])[..]), false, payment_preimage);
                let events = nodes[0].node.get_and_clear_pending_events();
                assert_eq!(events.len(), 2);
@@ -1519,5 +1541,35 @@ fn do_test_intercepted_payment(fail_intercept: bool) {
                        },
                        _ => panic!("Unexpected event")
                }
+       } else if test == InterceptTest::Timeout {
+               let mut block = Block {
+                       header: BlockHeader { version: 0x20000000, prev_blockhash: nodes[0].best_block_hash(), merkle_root: TxMerkleNode::all_zeros(), time: 42, bits: 42, nonce: 42 },
+                       txdata: vec![],
+               };
+               connect_block(&nodes[0], &block);
+               connect_block(&nodes[1], &block);
+               for _ in 0..TEST_FINAL_CLTV {
+                       block.header.prev_blockhash = block.block_hash();
+                       connect_block(&nodes[0], &block);
+                       connect_block(&nodes[1], &block);
+               }
+               expect_pending_htlcs_forwardable_and_htlc_handling_failed!(nodes[1], vec![HTLCDestination::InvalidForward { requested_forward_scid: intercept_scid }]);
+               check_added_monitors!(nodes[1], 1);
+               let htlc_timeout_updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
+               assert!(htlc_timeout_updates.update_add_htlcs.is_empty());
+               assert_eq!(htlc_timeout_updates.update_fail_htlcs.len(), 1);
+               assert!(htlc_timeout_updates.update_fail_malformed_htlcs.is_empty());
+               assert!(htlc_timeout_updates.update_fee.is_none());
+
+               nodes[0].node.handle_update_fail_htlc(&nodes[1].node.get_our_node_id(), &htlc_timeout_updates.update_fail_htlcs[0]);
+               commitment_signed_dance!(nodes[0], nodes[1], htlc_timeout_updates.commitment_signed, false);
+               expect_payment_failed!(nodes[0], payment_hash, false, 0x2000 | 2, []);
+
+               // Check for unknown intercept id error.
+               let (_, channel_id) = open_zero_conf_channel(&nodes[1], &nodes[2], None);
+               let unknown_intercept_id_err = nodes[1].node.forward_intercepted_htlc(intercept_id, &channel_id, nodes[2].node.get_our_node_id(), expected_outbound_amount_msat).unwrap_err();
+               assert_eq!(unknown_intercept_id_err , APIError::APIMisuseError { err: format!("Payment with intercept id {} not found", log_bytes!(intercept_id.0)) });
+               let unknown_intercept_id_err = nodes[1].node.fail_intercepted_htlc(intercept_id).unwrap_err();
+               assert_eq!(unknown_intercept_id_err , APIError::APIMisuseError { err: format!("Payment with intercept id {} not found", log_bytes!(intercept_id.0)) });
        }
 }