Add an additional test/bench for routing larger amounts, score more
[rust-lightning] / lightning / src / routing / router.rs
index fc65bc990005073a8384040d5452734474db40c8..420fc3d202d282f41c71f265d6dafb1e4f314df5 100644 (file)
@@ -5810,7 +5810,7 @@ mod tests {
                let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger);
                let features = super::InvoiceFeatures::empty();
 
-               super::bench_utils::generate_test_routes(&graph, &mut scorer, &params, features, random_init_seed() as usize, 2);
+               super::bench_utils::generate_test_routes(&graph, &mut scorer, &params, features, random_init_seed(), 0, 2);
        }
 
        #[test]
@@ -5831,7 +5831,28 @@ mod tests {
                let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger);
                let features = channelmanager::provided_invoice_features(&UserConfig::default());
 
-               super::bench_utils::generate_test_routes(&graph, &mut scorer, &params, features, random_init_seed() as usize, 2);
+               super::bench_utils::generate_test_routes(&graph, &mut scorer, &params, features, random_init_seed(), 0, 2);
+       }
+
+       #[test]
+       #[cfg(not(feature = "no-std"))]
+       fn generate_large_mpp_routes() {
+               use crate::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringFeeParameters};
+
+               let logger = ln_test_utils::TestLogger::new();
+               let graph = match super::bench_utils::read_network_graph(&logger) {
+                       Ok(f) => f,
+                       Err(e) => {
+                               eprintln!("{}", e);
+                               return;
+                       },
+               };
+
+               let params = ProbabilisticScoringFeeParameters::default();
+               let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger);
+               let features = channelmanager::provided_invoice_features(&UserConfig::default());
+
+               super::bench_utils::generate_test_routes(&graph, &mut scorer, &params, features, random_init_seed(), 1_000_000, 2);
        }
 
        #[test]
@@ -6085,11 +6106,11 @@ pub(crate) mod bench_utils {
                        short_channel_id: Some(1),
                        inbound_scid_alias: None,
                        outbound_scid_alias: None,
-                       channel_value_satoshis: 10_000_000,
+                       channel_value_satoshis: 10_000_000_000,
                        user_channel_id: 0,
-                       balance_msat: 10_000_000,
-                       outbound_capacity_msat: 10_000_000,
-                       next_outbound_htlc_limit_msat: 10_000_000,
+                       balance_msat: 10_000_000_000,
+                       outbound_capacity_msat: 10_000_000_000,
+                       next_outbound_htlc_limit_msat: 10_000_000_000,
                        inbound_capacity_msat: 0,
                        unspendable_punishment_reserve: None,
                        confirmations_required: None,
@@ -6107,7 +6128,8 @@ pub(crate) mod bench_utils {
        }
 
        pub(crate) fn generate_test_routes<S: Score>(graph: &NetworkGraph<&TestLogger>, scorer: &mut S,
-               score_params: &S::ScoreParams, features: InvoiceFeatures, mut seed: usize, route_count: usize,
+               score_params: &S::ScoreParams, features: InvoiceFeatures, mut seed: u64,
+               starting_amount: u64, route_count: usize,
        ) -> Vec<(ChannelDetails, PaymentParameters, u64)> {
                let payer = payer_pubkey();
                let keys_manager = KeysManager::new(&[0u8; 32], 42, 42);
@@ -6115,38 +6137,56 @@ pub(crate) mod bench_utils {
 
                let nodes = graph.read_only().nodes().clone();
                let mut route_endpoints = Vec::new();
-               let mut routes = Vec::new();
-
-               'load_endpoints: for _ in 0..route_count * 3 /2 {
+               // Fetch 1.5x more routes than we need as after we do some scorer updates we may end up
+               // with some routes we picked being un-routable.
+               for _ in 0..route_count * 3 / 2 {
                        loop {
-                               seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let src = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               let params = PaymentParameters::from_node_id(dst, 42).with_bolt11_features(features.clone()).unwrap();
+                               seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0;
+                               let src = PublicKey::from_slice(nodes.unordered_keys()
+                                       .skip((seed as usize) % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0;
+                               let dst = PublicKey::from_slice(nodes.unordered_keys()
+                                       .skip((seed as usize) % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let params = PaymentParameters::from_node_id(dst, 42)
+                                       .with_bolt11_features(features.clone()).unwrap();
                                let first_hop = first_hop(src);
-                               let amt = seed as u64 % 1_000_000;
-                               if let Ok(route) = get_route(&payer, &params, &graph.read_only(), Some(&[&first_hop]),
-                                       amt, &TestLogger::new(), &scorer, score_params, &random_seed_bytes,
-                               ) {
-                                       routes.push(route);
-                                       route_endpoints.push((first_hop, params, amt));
-                                       continue 'load_endpoints;
-                               }
-                       }
-               }
+                               let amt = starting_amount + seed % 1_000_000;
+                               let path_exists =
+                                       get_route(&payer, &params, &graph.read_only(), Some(&[&first_hop]),
+                                               amt, &TestLogger::new(), &scorer, score_params, &random_seed_bytes).is_ok();
+                               if path_exists {
+                                       // ...and seed the scorer with success and failure data...
+                                       seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0;
+                                       let mut score_amt = seed % 1_000_000_000;
+                                       loop {
+                                               // Generate fail/success paths for a wider range of potential amounts with
+                                               // MPP enabled to give us a chance to apply penalties for more potential
+                                               // routes.
+                                               let mpp_features = channelmanager::provided_invoice_features(&UserConfig::default());
+                                               let params = PaymentParameters::from_node_id(dst, 42)
+                                                       .with_bolt11_features(mpp_features).unwrap();
+
+                                               let route_res = get_route(&payer, &params, &graph.read_only(),
+                                                       Some(&[&first_hop]), score_amt, &TestLogger::new(), &scorer,
+                                                       score_params, &random_seed_bytes);
+                                               if let Ok(route) = route_res {
+                                                       for path in route.paths {
+                                                               if seed & 0x80 == 0 {
+                                                                       scorer.payment_path_successful(&path);
+                                                               } else {
+                                                                       let short_channel_id = path.hops[path.hops.len() / 2].short_channel_id;
+                                                                       scorer.payment_path_failed(&path, short_channel_id);
+                                                               }
+                                                               seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0;
+                                                       }
+                                                       break;
+                                               }
+                                               // If we couldn't find a path with a higer amount, reduce and try again.
+                                               score_amt /= 100;
+                                       }
 
-               // ...and seed the scorer with success and failure data...
-               for route in routes {
-                       let amount = route.get_total_amount();
-                       if amount < 250_000 {
-                               for path in route.paths {
-                                       scorer.payment_path_successful(&path);
-                               }
-                       } else if amount > 750_000 {
-                               for path in route.paths {
-                                       let short_channel_id = path.hops[path.hops.len() / 2].short_channel_id;
-                                       scorer.payment_path_failed(&path, short_channel_id);
+                                       route_endpoints.push((first_hop, params, amt));
+                                       break;
                                }
                        }
                }
@@ -6189,7 +6229,7 @@ mod benches {
                let logger = TestLogger::new();
                let network_graph = bench_utils::read_network_graph(&logger).unwrap();
                let scorer = FixedPenaltyScorer::with_penalty(0);
-               generate_routes(bench, &network_graph, scorer, &(), InvoiceFeatures::empty());
+               generate_routes(bench, &network_graph, scorer, &(), InvoiceFeatures::empty(), 0);
        }
 
        #[bench]
@@ -6197,7 +6237,7 @@ mod benches {
                let logger = TestLogger::new();
                let network_graph = bench_utils::read_network_graph(&logger).unwrap();
                let scorer = FixedPenaltyScorer::with_penalty(0);
-               generate_routes(bench, &network_graph, scorer, &(), channelmanager::provided_invoice_features(&UserConfig::default()));
+               generate_routes(bench, &network_graph, scorer, &(), channelmanager::provided_invoice_features(&UserConfig::default()), 0);
        }
 
        #[bench]
@@ -6206,7 +6246,7 @@ mod benches {
                let network_graph = bench_utils::read_network_graph(&logger).unwrap();
                let params = ProbabilisticScoringFeeParameters::default();
                let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
-               generate_routes(bench, &network_graph, scorer, &params, InvoiceFeatures::empty());
+               generate_routes(bench, &network_graph, scorer, &params, InvoiceFeatures::empty(), 0);
        }
 
        #[bench]
@@ -6215,19 +6255,28 @@ mod benches {
                let network_graph = bench_utils::read_network_graph(&logger).unwrap();
                let params = ProbabilisticScoringFeeParameters::default();
                let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
-               generate_routes(bench, &network_graph, scorer, &params, channelmanager::provided_invoice_features(&UserConfig::default()));
+               generate_routes(bench, &network_graph, scorer, &params, channelmanager::provided_invoice_features(&UserConfig::default()), 0);
+       }
+
+       #[bench]
+       fn generate_large_mpp_routes_with_probabilistic_scorer(bench: &mut Bencher) {
+               let logger = TestLogger::new();
+               let network_graph = bench_utils::read_network_graph(&logger).unwrap();
+               let params = ProbabilisticScoringFeeParameters::default();
+               let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
+               generate_routes(bench, &network_graph, scorer, &params, channelmanager::provided_invoice_features(&UserConfig::default()), 100_000_000);
        }
 
        fn generate_routes<S: Score>(
                bench: &mut Bencher, graph: &NetworkGraph<&TestLogger>, mut scorer: S,
-               score_params: &S::ScoreParams, features: InvoiceFeatures,
+               score_params: &S::ScoreParams, features: InvoiceFeatures, starting_amount: u64,
        ) {
                let payer = bench_utils::payer_pubkey();
                let keys_manager = KeysManager::new(&[0u8; 32], 42, 42);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
-               let route_endpoints = bench_utils::generate_test_routes(graph, &mut scorer, score_params, features, 0xdeadbeef, 100);
+               let route_endpoints = bench_utils::generate_test_routes(graph, &mut scorer, score_params, features, 0xdeadbeef, starting_amount, 50);
 
                // ...then benchmark finding paths between the nodes we learned.
                let mut idx = 0;