]> git.bitcoin.ninja Git - rust-lightning/commitdiff
test_utils: parameterize TestRouter by TestScorer
authorValentine Wallace <vwallace@protonmail.com>
Tue, 7 Feb 2023 19:04:55 +0000 (14:04 -0500)
committerValentine Wallace <vwallace@protonmail.com>
Tue, 14 Feb 2023 19:20:48 +0000 (14:20 -0500)
This allows us set scoring expectations and ensure in-flight htlcs are factored
into scoring

12 files changed:
lightning-invoice/src/utils.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs
lightning/src/ln/functional_tests.rs
lightning/src/ln/onion_route_tests.rs
lightning/src/ln/outbound_payment.rs
lightning/src/ln/payment_tests.rs
lightning/src/ln/shutdown_tests.rs
lightning/src/routing/gossip.rs
lightning/src/routing/router.rs
lightning/src/routing/scoring.rs
lightning/src/util/test_utils.rs

index 83f72e461455d3e3d06001121b7fa2bf8c6414f5..f6cc87fa72af2d0645b3b1538199b28f4265d5a3 100644 (file)
@@ -691,7 +691,7 @@ mod test {
                let first_hops = nodes[0].node.list_usable_channels();
                let network_graph = &node_cfgs[0].network_graph;
                let logger = test_utils::TestLogger::new();
-               let scorer = test_utils::TestScorer::with_penalty(0);
+               let scorer = test_utils::TestScorer::new();
                let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
                let route = find_route(
                        &nodes[0].node.get_our_node_id(), &route_params, &network_graph,
@@ -1055,7 +1055,7 @@ mod test {
                let first_hops = nodes[0].node.list_usable_channels();
                let network_graph = &node_cfgs[0].network_graph;
                let logger = test_utils::TestLogger::new();
-               let scorer = test_utils::TestScorer::with_penalty(0);
+               let scorer = test_utils::TestScorer::new();
                let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
                let route = find_route(
                        &nodes[0].node.get_our_node_id(), &params, &network_graph,
index 6ab5149f4c548267aad161d2dfc660c98078990b..cfb462b0120cac784c2393e4e27450c14747a6bc 100644 (file)
@@ -7966,7 +7966,7 @@ mod tests {
                let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
                let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
                create_announced_chan_between_nodes(&nodes, 0, 1);
-               let scorer = test_utils::TestScorer::with_penalty(0);
+               let scorer = test_utils::TestScorer::new();
                let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
 
                // To start (1), send a regular payment but don't claim it.
@@ -8074,7 +8074,7 @@ mod tests {
                };
                let network_graph = nodes[0].network_graph.clone();
                let first_hops = nodes[0].node.list_usable_channels();
-               let scorer = test_utils::TestScorer::with_penalty(0);
+               let scorer = test_utils::TestScorer::new();
                let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
                let route = find_route(
                        &payer_pubkey, &route_params, &network_graph, Some(&first_hops.iter().collect::<Vec<_>>()),
@@ -8119,7 +8119,7 @@ mod tests {
                };
                let network_graph = nodes[0].network_graph.clone();
                let first_hops = nodes[0].node.list_usable_channels();
-               let scorer = test_utils::TestScorer::with_penalty(0);
+               let scorer = test_utils::TestScorer::new();
                let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
                let route = find_route(
                        &payer_pubkey, &route_params, &network_graph, Some(&first_hops.iter().collect::<Vec<_>>()),
@@ -8589,7 +8589,8 @@ pub mod bench {
                let tx_broadcaster = test_utils::TestBroadcaster{txn_broadcasted: Mutex::new(Vec::new()), blocks: Arc::new(Mutex::new(Vec::new()))};
                let fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: Mutex::new(253) };
                let logger_a = test_utils::TestLogger::with_id("node a".to_owned());
-               let router = test_utils::TestRouter::new(Arc::new(NetworkGraph::new(genesis_hash, &logger_a)));
+               let scorer = Mutex::new(test_utils::TestScorer::new());
+               let router = test_utils::TestRouter::new(Arc::new(NetworkGraph::new(genesis_hash, &logger_a)), &scorer);
 
                let mut config: UserConfig = Default::default();
                config.channel_handshake_config.minimum_depth = 1;
@@ -8680,7 +8681,7 @@ pub mod bench {
                                let usable_channels = $node_a.list_usable_channels();
                                let payment_params = PaymentParameters::from_node_id($node_b.get_our_node_id(), TEST_FINAL_CLTV)
                                        .with_features($node_b.invoice_features());
-                               let scorer = test_utils::TestScorer::with_penalty(0);
+                               let scorer = test_utils::TestScorer::new();
                                let seed = [3u8; 32];
                                let keys_manager = KeysManager::new(&seed, 42, 42);
                                let random_seed_bytes = keys_manager.get_secure_random_bytes();
index 7bf51df5c7a28bb553dd67fd44de672e0f98142d..f7fadc1ecee3179deb799d05334a51b0dde941a4 100644 (file)
@@ -305,6 +305,7 @@ pub struct TestChanMonCfg {
        pub persister: test_utils::TestPersister,
        pub logger: test_utils::TestLogger,
        pub keys_manager: test_utils::TestKeysInterface,
+       pub scorer: Mutex<test_utils::TestScorer>,
 }
 
 pub struct NodeCfg<'a> {
@@ -427,6 +428,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                        channel_monitors.insert(monitor.get_funding_txo().0, monitor);
                                }
 
+                               let scorer = Mutex::new(test_utils::TestScorer::new());
                                let mut w = test_utils::TestVecWriter(Vec::new());
                                self.node.write(&mut w).unwrap();
                                <(BlockHash, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>)>::read(&mut io::Cursor::new(w.0), ChannelManagerReadArgs {
@@ -435,7 +437,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                        node_signer: self.keys_manager,
                                        signer_provider: self.keys_manager,
                                        fee_estimator: &test_utils::TestFeeEstimator { sat_per_kw: Mutex::new(253) },
-                                       router: &test_utils::TestRouter::new(Arc::new(network_graph)),
+                                       router: &test_utils::TestRouter::new(Arc::new(network_graph), &scorer),
                                        chain_monitor: self.chain_monitor,
                                        tx_broadcaster: &broadcaster,
                                        logger: &self.logger,
@@ -1570,7 +1572,7 @@ macro_rules! get_payment_preimage_hash {
 macro_rules! get_route {
        ($send_node: expr, $payment_params: expr, $recv_value: expr, $cltv: expr) => {{
                use $crate::chain::keysinterface::EntropySource;
-               let scorer = $crate::util::test_utils::TestScorer::with_penalty(0);
+               let scorer = $crate::util::test_utils::TestScorer::new();
                let keys_manager = $crate::util::test_utils::TestKeysInterface::new(&[0u8; 32], bitcoin::network::constants::Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                $crate::routing::router::get_route(
@@ -2124,7 +2126,7 @@ pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_rou
        let payment_params = PaymentParameters::from_node_id(expected_route.last().unwrap().node.get_our_node_id(), TEST_FINAL_CLTV)
                .with_features(expected_route.last().unwrap().node.invoice_features());
        let network_graph = origin_node.network_graph.read_only();
-       let scorer = test_utils::TestScorer::with_penalty(0);
+       let scorer = test_utils::TestScorer::new();
        let seed = [0u8; 32];
        let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
        let random_seed_bytes = keys_manager.get_secure_random_bytes();
@@ -2289,8 +2291,9 @@ pub fn create_chanmon_cfgs(node_count: usize) -> Vec<TestChanMonCfg> {
                let persister = test_utils::TestPersister::new();
                let seed = [i as u8; 32];
                let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
+               let scorer = Mutex::new(test_utils::TestScorer::new());
 
-               chan_mon_cfgs.push(TestChanMonCfg { tx_broadcaster, fee_estimator, chain_source, logger, persister, keys_manager });
+               chan_mon_cfgs.push(TestChanMonCfg { tx_broadcaster, fee_estimator, chain_source, logger, persister, keys_manager, scorer });
        }
 
        chan_mon_cfgs
@@ -2308,7 +2311,7 @@ pub fn create_node_cfgs<'a>(node_count: usize, chanmon_cfgs: &'a Vec<TestChanMon
                        logger: &chanmon_cfgs[i].logger,
                        tx_broadcaster: &chanmon_cfgs[i].tx_broadcaster,
                        fee_estimator: &chanmon_cfgs[i].fee_estimator,
-                       router: test_utils::TestRouter::new(network_graph.clone()),
+                       router: test_utils::TestRouter::new(network_graph.clone(), &chanmon_cfgs[i].scorer),
                        chain_monitor,
                        keys_manager: &chanmon_cfgs[i].keys_manager,
                        node_seed: seed,
index bf0c3396cedb81e746d8bcdbc648ccd9139a218a..c8473048892ac58ecf74839faa4233e2343a064d 100644 (file)
@@ -5257,7 +5257,8 @@ fn test_key_derivation_params() {
        let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
        let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &chanmon_cfgs[0].persister, &keys_manager);
        let network_graph = Arc::new(NetworkGraph::new(chanmon_cfgs[0].chain_source.genesis_hash, &chanmon_cfgs[0].logger));
-       let router = test_utils::TestRouter::new(network_graph.clone());
+       let scorer = Mutex::new(test_utils::TestScorer::new());
+       let router = test_utils::TestRouter::new(network_graph.clone(), &scorer);
        let node = NodeCfg { chain_source: &chanmon_cfgs[0].chain_source, logger: &chanmon_cfgs[0].logger, tx_broadcaster: &chanmon_cfgs[0].tx_broadcaster, fee_estimator: &chanmon_cfgs[0].fee_estimator, router, chain_monitor, keys_manager: &keys_manager, network_graph, node_seed: seed, override_init_features: alloc::rc::Rc::new(core::cell::RefCell::new(None)) };
        let mut node_cfgs = create_node_cfgs(3, &chanmon_cfgs);
        node_cfgs.remove(0);
@@ -6925,7 +6926,7 @@ fn test_check_htlc_underpaying() {
        // Create some initial channels
        create_announced_chan_between_nodes(&nodes, 0, 1);
 
-       let scorer = test_utils::TestScorer::with_penalty(0);
+       let scorer = test_utils::TestScorer::new();
        let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
        let payment_params = PaymentParameters::from_node_id(nodes[1].node.get_our_node_id(), TEST_FINAL_CLTV).with_features(nodes[1].node.invoice_features());
        let route = get_route(&nodes[0].node.get_our_node_id(), &payment_params, &nodes[0].network_graph.read_only(), None, 10_000, TEST_FINAL_CLTV, nodes[0].logger, &scorer, &random_seed_bytes).unwrap();
@@ -7175,7 +7176,7 @@ fn test_bump_penalty_txn_on_revoked_htlcs() {
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1000000, 59000000);
        // Lock HTLC in both directions (using a slightly lower CLTV delay to provide timely RBF bumps)
        let payment_params = PaymentParameters::from_node_id(nodes[1].node.get_our_node_id(), 50).with_features(nodes[1].node.invoice_features());
-       let scorer = test_utils::TestScorer::with_penalty(0);
+       let scorer = test_utils::TestScorer::new();
        let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
        let route = get_route(&nodes[0].node.get_our_node_id(), &payment_params, &nodes[0].network_graph.read_only(), None,
                3_000_000, 50, nodes[0].logger, &scorer, &random_seed_bytes).unwrap();
@@ -9186,7 +9187,7 @@ fn test_keysend_payments_to_public_node() {
                final_value_msat: 10000,
                final_cltv_expiry_delta: 40,
        };
-       let scorer = test_utils::TestScorer::with_penalty(0);
+       let scorer = test_utils::TestScorer::new();
        let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
        let route = find_route(&payer_pubkey, &route_params, &network_graph, None, nodes[0].logger, &scorer, &random_seed_bytes).unwrap();
 
@@ -9221,7 +9222,7 @@ fn test_keysend_payments_to_private_node() {
        };
        let network_graph = nodes[0].network_graph.clone();
        let first_hops = nodes[0].node.list_usable_channels();
-       let scorer = test_utils::TestScorer::with_penalty(0);
+       let scorer = test_utils::TestScorer::new();
        let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
        let route = find_route(
                &payer_pubkey, &route_params, &network_graph, Some(&first_hops.iter().collect::<Vec<_>>()),
index 1ed10a4a61dc784b291d8b514b84a290c4ff6fa7..fbc8e1c9e4ec22cfe805154e04f2af7546ff89f9 100644 (file)
@@ -928,7 +928,7 @@ macro_rules! get_phantom_route {
                                                htlc_maximum_msat: None,
                                        }
                ])]);
-               let scorer = test_utils::TestScorer::with_penalty(0);
+               let scorer = test_utils::TestScorer::new();
                let network_graph = $nodes[0].network_graph.read_only();
                (get_route(
                        &$nodes[0].node.get_our_node_id(), &payment_params, &network_graph,
index a9ced49f647a75e13a1a2a535c38fa86ef29a1c9..b1b610bf8c9703445bea47aff84c4214bcd0852f 100644 (file)
@@ -1202,7 +1202,7 @@ mod tests {
        use crate::ln::outbound_payment::{OutboundPayments, Retry};
        use crate::routing::gossip::NetworkGraph;
        use crate::routing::router::{InFlightHtlcs, PaymentParameters, Route, RouteParameters};
-       use crate::sync::Arc;
+       use crate::sync::{Arc, Mutex};
        use crate::util::errors::APIError;
        use crate::util::test_utils;
 
@@ -1218,7 +1218,8 @@ mod tests {
                let logger = test_utils::TestLogger::new();
                let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
                let network_graph = Arc::new(NetworkGraph::new(genesis_hash, &logger));
-               let router = test_utils::TestRouter::new(network_graph);
+               let scorer = Mutex::new(test_utils::TestScorer::new());
+               let router = test_utils::TestRouter::new(network_graph, &scorer);
                let secp_ctx = Secp256k1::new();
                let keys_manager = test_utils::TestKeysInterface::new(&[0; 32], Network::Testnet);
 
@@ -1257,7 +1258,8 @@ mod tests {
                let logger = test_utils::TestLogger::new();
                let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
                let network_graph = Arc::new(NetworkGraph::new(genesis_hash, &logger));
-               let router = test_utils::TestRouter::new(network_graph);
+               let scorer = Mutex::new(test_utils::TestScorer::new());
+               let router = test_utils::TestRouter::new(network_graph, &scorer);
                let secp_ctx = Secp256k1::new();
                let keys_manager = test_utils::TestKeysInterface::new(&[0; 32], Network::Testnet);
 
index c6b170364d6f8eb1dceba3b3e9ef2ff7adf4ad4d..c82534ec8fc27567a594ee570630af9e888aa765 100644 (file)
@@ -921,7 +921,7 @@ fn get_ldk_payment_preimage() {
 
        let payment_params = PaymentParameters::from_node_id(nodes[1].node.get_our_node_id(), TEST_FINAL_CLTV)
                .with_features(nodes[1].node.invoice_features());
-       let scorer = test_utils::TestScorer::with_penalty(0);
+       let scorer = test_utils::TestScorer::new();
        let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
        let random_seed_bytes = keys_manager.get_secure_random_bytes();
        let route = get_route(
@@ -1442,7 +1442,7 @@ fn do_test_intercepted_payment(test: InterceptTest) {
        let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, Some(intercept_forwards_config), Some(zero_conf_chan_config)]);
 
        let nodes = create_network(3, &node_cfgs, &node_chanmgrs);
-       let scorer = test_utils::TestScorer::with_penalty(0);
+       let scorer = test_utils::TestScorer::new();
        let random_seed_bytes = chanmon_cfgs[0].keys_manager.get_secure_random_bytes();
 
        let _ = create_announced_chan_between_nodes(&nodes, 0, 1).2;
index c1fe23a62a65406180e3dd5356a6fef06340411e..5a5f054846f8d4e66d720aaa3bc77a29b1b7da6e 100644 (file)
@@ -76,7 +76,7 @@ fn updates_shutdown_wait() {
        let chan_1 = create_announced_chan_between_nodes(&nodes, 0, 1);
        let chan_2 = create_announced_chan_between_nodes(&nodes, 1, 2);
        let logger = test_utils::TestLogger::new();
-       let scorer = test_utils::TestScorer::with_penalty(0);
+       let scorer = test_utils::TestScorer::new();
        let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
        let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
index a499532e6a94a8c46016afb3d367ceab4461f633..6c2d70bd6f309ed429af45876666673fc6421a35 100644 (file)
@@ -968,7 +968,7 @@ impl<'a> fmt::Debug for DirectedChannelInfo<'a> {
 ///
 /// While this may be smaller than the actual channel capacity, amounts greater than
 /// [`Self::as_msat`] should not be routed through the channel.
-#[derive(Clone, Copy, Debug)]
+#[derive(Clone, Copy, Debug, PartialEq)]
 pub enum EffectiveCapacity {
        /// The available liquidity in the channel known from being a channel counterparty, and thus a
        /// direct hop.
index 0e51b2bf86b9d366386b736c5dacc28e82740592..79838e185bc7f31ff38741a858089853c02ea263 100644 (file)
@@ -2116,7 +2116,7 @@ mod tests {
        use crate::routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
                PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
                DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE};
-       use crate::routing::scoring::{ChannelUsage, Score, ProbabilisticScorer, ProbabilisticScoringParameters};
+       use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, Score, ProbabilisticScorer, ProbabilisticScoringParameters};
        use crate::routing::test_utils::{add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel};
        use crate::chain::transaction::OutPoint;
        use crate::chain::keysinterface::EntropySource;
@@ -2186,7 +2186,7 @@ mod tests {
                let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let payment_params = PaymentParameters::from_node_id(nodes[2], 42);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -2219,7 +2219,7 @@ mod tests {
                let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let payment_params = PaymentParameters::from_node_id(nodes[2], 42);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -2241,7 +2241,7 @@ mod tests {
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let payment_params = PaymentParameters::from_node_id(nodes[2], 42);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -2369,7 +2369,7 @@ mod tests {
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let config = UserConfig::default();
                let payment_params = PaymentParameters::from_node_id(nodes[2], 42).with_features(channelmanager::provided_invoice_features(&config));
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -2507,7 +2507,7 @@ mod tests {
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let payment_params = PaymentParameters::from_node_id(nodes[2], 42);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -2567,7 +2567,7 @@ mod tests {
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let payment_params = PaymentParameters::from_node_id(nodes[2], 42);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -2611,7 +2611,7 @@ mod tests {
        fn our_chans_test() {
                let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -2742,7 +2742,7 @@ mod tests {
        fn partial_route_hint_test() {
                let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -2843,7 +2843,7 @@ mod tests {
                let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(empty_last_hop(&nodes));
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -2923,7 +2923,7 @@ mod tests {
                let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let last_hops = multi_hop_last_hops_hint([nodes[2], nodes[3]]);
                let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(last_hops.clone());
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                // Test through channels 2, 3, 0xff00, 0xff01.
@@ -2997,7 +2997,7 @@ mod tests {
 
                let last_hops = multi_hop_last_hops_hint([nodes[2], non_announced_pubkey]);
                let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(last_hops.clone());
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                // Test through channels 2, 3, 0xff00, 0xff01.
                // Test shows that multiple hop hints are considered.
 
@@ -3103,7 +3103,7 @@ mod tests {
                let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(last_hops_with_public_channel(&nodes));
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                // This test shows that public routes can be present in the invoice
@@ -3154,7 +3154,7 @@ mod tests {
        fn our_chans_last_hop_connect_test() {
                let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -3277,7 +3277,7 @@ mod tests {
                }]);
                let payment_params = PaymentParameters::from_node_id(target_node_id, 42).with_route_hints(vec![last_hops]);
                let our_chans = vec![get_channel_details(Some(42), middle_node_id, InitFeatures::from_le_bytes(vec![0b11]), outbound_capacity_msat)];
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
@@ -3338,7 +3338,7 @@ mod tests {
 
                let (secp_ctx, network_graph, mut gossip_sync, chain_monitor, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let config = UserConfig::default();
@@ -3614,7 +3614,7 @@ mod tests {
                // one of the latter hops is limited.
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let config = UserConfig::default();
@@ -3740,7 +3740,7 @@ mod tests {
        fn ignore_fee_first_hop_test() {
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let payment_params = PaymentParameters::from_node_id(nodes[2], 42);
@@ -3788,7 +3788,7 @@ mod tests {
        fn simple_mpp_route_test() {
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let config = UserConfig::default();
@@ -3948,7 +3948,7 @@ mod tests {
        fn long_mpp_route_test() {
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let config = UserConfig::default();
@@ -4113,7 +4113,7 @@ mod tests {
        fn mpp_cheaper_route_test() {
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let config = UserConfig::default();
@@ -4283,7 +4283,7 @@ mod tests {
                // if the fee is not properly accounted for, the behavior is different.
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let config = UserConfig::default();
@@ -4465,7 +4465,7 @@ mod tests {
                // This bug appeared in production in some specific channel configurations.
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let config = UserConfig::default();
@@ -4557,7 +4557,7 @@ mod tests {
                // path finding we realize that we found more capacity than we need.
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let config = UserConfig::default();
@@ -4719,7 +4719,7 @@ mod tests {
                let network = Arc::new(NetworkGraph::new(genesis_hash, Arc::clone(&logger)));
                let gossip_sync = P2PGossipSync::new(Arc::clone(&network), None, Arc::clone(&logger));
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let payment_params = PaymentParameters::from_node_id(nodes[6], 42);
@@ -4850,7 +4850,7 @@ mod tests {
                // we calculated fees on a higher value, resulting in us ignoring such paths.
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let payment_params = PaymentParameters::from_node_id(nodes[2], 42);
@@ -4914,7 +4914,7 @@ mod tests {
                // resulting in us thinking there is no possible path, even if other paths exist.
                let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let config = UserConfig::default();
@@ -4985,7 +4985,7 @@ mod tests {
                let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
                let logger = Arc::new(ln_test_utils::TestLogger::new());
                let network_graph = NetworkGraph::new(genesis_hash, Arc::clone(&logger));
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let config = UserConfig::default();
                let payment_params = PaymentParameters::from_node_id(nodes[0], 42).with_features(channelmanager::provided_invoice_features(&config));
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
@@ -5056,7 +5056,7 @@ mod tests {
                let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(last_hops(&nodes));
 
                // Without penalizing each hop 100 msats, a longer path with lower fees is chosen.
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let route = get_route(
@@ -5071,7 +5071,7 @@ mod tests {
 
                // Applying a 100 msat penalty to each hop results in taking channels 7 and 10 to nodes[6]
                // from nodes[2] rather than channel 6, 11, and 8, even though the longer path is cheaper.
-               let scorer = ln_test_utils::TestScorer::with_penalty(100);
+               let scorer = FixedPenaltyScorer::with_penalty(100);
                let route = get_route(
                        &our_id, &payment_params, &network_graph.read_only(), None, 100, 42,
                        Arc::clone(&logger), &scorer, &random_seed_bytes
@@ -5130,7 +5130,7 @@ mod tests {
                let network_graph = network.read_only();
 
                // A path to nodes[6] exists when no penalties are applied to any channel.
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                let route = get_route(
@@ -5245,7 +5245,7 @@ mod tests {
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let network_graph = network.read_only();
 
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
 
                // Make sure that generally there is at least one route available
                let feasible_max_total_cltv_delta = 1008;
@@ -5278,7 +5278,7 @@ mod tests {
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let network_graph = network.read_only();
 
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let mut payment_params = PaymentParameters::from_node_id(nodes[6], 0).with_route_hints(last_hops(&nodes))
                        .with_max_path_count(1);
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
@@ -5305,7 +5305,7 @@ mod tests {
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let network_graph = network.read_only();
 
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
@@ -5333,7 +5333,7 @@ mod tests {
                let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
 
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
 
                let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(last_hops(&nodes));
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
@@ -5367,7 +5367,7 @@ mod tests {
                let network_graph = network.read_only();
                let network_nodes = network_graph.nodes();
                let network_channels = network_graph.channels();
-               let scorer = ln_test_utils::TestScorer::with_penalty(0);
+               let scorer = ln_test_utils::TestScorer::new();
                let payment_params = PaymentParameters::from_node_id(nodes[3], 0);
                let keys_manager = ln_test_utils::TestKeysInterface::new(&[4u8; 32], Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
index f16dd2d5b401e521c7299cfa3e2b647cb9d105ab..7504e0840d2683125cf75144fd445b070c6f08da 100644 (file)
@@ -260,7 +260,7 @@ impl<'a, S: Writeable> Writeable for MutexGuard<'a, S> {
 }
 
 /// Proposed use of a channel passed as a parameter to [`Score::channel_penalty_msat`].
-#[derive(Clone, Copy, Debug)]
+#[derive(Clone, Copy, Debug, PartialEq)]
 pub struct ChannelUsage {
        /// The amount to send through the channel, denominated in millisatoshis.
        pub amount_msat: u64,
index 83647a733853029b2501606bc41dc8cec266f1c0..2752212009832732c34fbc202a63ad40bbba50fd 100644 (file)
@@ -22,10 +22,10 @@ use crate::ln::features::{ChannelFeatures, InitFeatures, NodeFeatures};
 use crate::ln::{msgs, wire};
 use crate::ln::msgs::LightningError;
 use crate::ln::script::ShutdownScript;
-use crate::routing::gossip::{NetworkGraph, NodeId};
+use crate::routing::gossip::{EffectiveCapacity, NetworkGraph, NodeId};
 use crate::routing::utxo::{UtxoLookup, UtxoLookupError, UtxoResult};
 use crate::routing::router::{find_route, InFlightHtlcs, Route, RouteHop, RouteParameters, Router, ScorerAccountingForInFlightHtlcs};
-use crate::routing::scoring::FixedPenaltyScorer;
+use crate::routing::scoring::{ChannelUsage, Score};
 use crate::util::config::UserConfig;
 use crate::util::enforcing_trait_impls::{EnforcingSigner, EnforcementState};
 use crate::util::events;
@@ -48,6 +48,7 @@ use regex;
 
 use crate::io;
 use crate::prelude::*;
+use core::cell::RefCell;
 use core::time::Duration;
 use crate::sync::{Mutex, Arc};
 use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
@@ -79,11 +80,12 @@ impl chaininterface::FeeEstimator for TestFeeEstimator {
 pub struct TestRouter<'a> {
        pub network_graph: Arc<NetworkGraph<&'a TestLogger>>,
        pub next_routes: Mutex<VecDeque<(RouteParameters, Result<Route, LightningError>)>>,
+       pub scorer: &'a Mutex<TestScorer>,
 }
 
 impl<'a> TestRouter<'a> {
-       pub fn new(network_graph: Arc<NetworkGraph<&'a TestLogger>>) -> Self {
-               Self { network_graph, next_routes: Mutex::new(VecDeque::new()), }
+       pub fn new(network_graph: Arc<NetworkGraph<&'a TestLogger>>, scorer: &'a Mutex<TestScorer>) -> Self {
+               Self { network_graph, next_routes: Mutex::new(VecDeque::new()), scorer }
        }
 
        pub fn expect_find_route(&self, query: RouteParameters, result: Result<Route, LightningError>) {
@@ -99,12 +101,37 @@ impl<'a> Router for TestRouter<'a> {
        ) -> Result<Route, msgs::LightningError> {
                if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
                        assert_eq!(find_route_query, *params);
+                       if let Ok(ref route) = find_route_res {
+                               let locked_scorer = self.scorer.lock().unwrap();
+                               let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
+                               for path in &route.paths {
+                                       let mut aggregate_msat = 0u64;
+                                       for (idx, hop) in path.iter().rev().enumerate() {
+                                               aggregate_msat += hop.fee_msat;
+                                               let usage = ChannelUsage {
+                                                       amount_msat: aggregate_msat,
+                                                       inflight_htlc_msat: 0,
+                                                       effective_capacity: EffectiveCapacity::Unknown,
+                                               };
+
+                                               // Since the path is reversed, the last element in our iteration is the first
+                                               // hop.
+                                               if idx == path.len() - 1 {
+                                                       scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(payer), &NodeId::from_pubkey(&hop.pubkey), usage);
+                                               } else {
+                                                       let curr_hop_path_idx = path.len() - 1 - idx;
+                                                       scorer.channel_penalty_msat(hop.short_channel_id, &NodeId::from_pubkey(&path[curr_hop_path_idx - 1].pubkey), &NodeId::from_pubkey(&hop.pubkey), usage);
+                                               }
+                                       }
+                               }
+                       }
                        return find_route_res;
                }
                let logger = TestLogger::new();
+               let scorer = self.scorer.lock().unwrap();
                find_route(
                        payer, params, &self.network_graph, first_hops, &logger,
-                       &ScorerAccountingForInFlightHtlcs::new(TestScorer::with_penalty(0), &inflight_htlcs),
+                       &ScorerAccountingForInFlightHtlcs::new(scorer, &inflight_htlcs),
                        &[42; 32]
                )
        }
@@ -889,5 +916,65 @@ impl Drop for TestChainSource {
        }
 }
 
-/// A scorer useful in testing, when the passage of time isn't a concern.
-pub type TestScorer = FixedPenaltyScorer;
+pub struct TestScorer {
+       /// Stores a tuple of (scid, ChannelUsage)
+       scorer_expectations: RefCell<Option<VecDeque<(u64, ChannelUsage)>>>,
+}
+
+impl TestScorer {
+       pub fn new() -> Self {
+               Self {
+                       scorer_expectations: RefCell::new(None),
+               }
+       }
+
+       pub fn expect_usage(&self, scid: u64, expectation: ChannelUsage) {
+               self.scorer_expectations.borrow_mut().get_or_insert_with(|| VecDeque::new()).push_back((scid, expectation));
+       }
+}
+
+#[cfg(c_bindings)]
+impl crate::util::ser::Writeable for TestScorer {
+       fn write<W: crate::util::ser::Writer>(&self, _: &mut W) -> Result<(), crate::io::Error> { unreachable!(); }
+}
+
+impl Score for TestScorer {
+       fn channel_penalty_msat(
+               &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId, usage: ChannelUsage
+       ) -> u64 {
+               if let Some(scorer_expectations) = self.scorer_expectations.borrow_mut().as_mut() {
+                       match scorer_expectations.pop_front() {
+                               Some((scid, expectation)) => {
+                                       assert_eq!(expectation, usage);
+                                       assert_eq!(scid, short_channel_id);
+                               },
+                               None => {},
+                       }
+               }
+               0
+       }
+
+       fn payment_path_failed(&mut self, _actual_path: &[&RouteHop], _actual_short_channel_id: u64) {}
+
+       fn payment_path_successful(&mut self, _actual_path: &[&RouteHop]) {}
+
+       fn probe_failed(&mut self, _actual_path: &[&RouteHop], _: u64) {}
+
+       fn probe_successful(&mut self, _actual_path: &[&RouteHop]) {}
+}
+
+impl Drop for TestScorer {
+       fn drop(&mut self) {
+               #[cfg(feature = "std")] {
+                       if std::thread::panicking() {
+                               return;
+                       }
+               }
+
+               if let Some(scorer_expectations) = self.scorer_expectations.borrow().as_ref() {
+                       if !scorer_expectations.is_empty() {
+                               panic!("Unsatisfied scorer expectations: {:?}", scorer_expectations)
+                       }
+               }
+       }
+}