]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Unify route benchmarking with route tests
authorMatt Corallo <git@bluematt.me>
Thu, 11 May 2023 05:34:00 +0000 (05:34 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 11 May 2023 05:42:21 +0000 (05:42 +0000)
There's a few route tests which do the same thing as the benchmarks
as they're also a good test. However, they didn't share code, which
is somewhat wasteful, so we fix that here.

lightning/src/routing/router.rs

index b33e021ab4fc30357a1124885c45b076419aed61..fc65bc990005073a8384040d5452734474db40c8 100644 (file)
@@ -5791,44 +5791,26 @@ mod tests {
                println!("Using seed of {}", seed);
                seed
        }
-       #[cfg(not(feature = "no-std"))]
-       use crate::util::ser::ReadableArgs;
 
        #[test]
        #[cfg(not(feature = "no-std"))]
        fn generate_routes() {
                use crate::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringFeeParameters};
 
-               let mut d = match super::bench_utils::get_route_file() {
+               let logger = ln_test_utils::TestLogger::new();
+               let graph = match super::bench_utils::read_network_graph(&logger) {
                        Ok(f) => f,
                        Err(e) => {
                                eprintln!("{}", e);
                                return;
                        },
                };
-               let logger = ln_test_utils::TestLogger::new();
-               let graph = NetworkGraph::read(&mut d, &logger).unwrap();
-               let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
-               let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
-               // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
-               let mut seed = random_init_seed() as usize;
-               let nodes = graph.read_only().nodes().clone();
-               'load_endpoints: for _ in 0..10 {
-                       loop {
-                               seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               let payment_params = PaymentParameters::from_node_id(dst, 42);
-                               let amt = seed as u64 % 200_000_000;
-                               let params = ProbabilisticScoringFeeParameters::default();
-                               let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger);
-                               if get_route(src, &payment_params, &graph.read_only(), None, amt, &logger, &scorer, &params, &random_seed_bytes).is_ok() {
-                                       continue 'load_endpoints;
-                               }
-                       }
-               }
+               let params = ProbabilisticScoringFeeParameters::default();
+               let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger);
+               let features = super::InvoiceFeatures::empty();
+
+               super::bench_utils::generate_test_routes(&graph, &mut scorer, &params, features, random_init_seed() as usize, 2);
        }
 
        #[test]
@@ -5836,37 +5818,20 @@ mod tests {
        fn generate_routes_mpp() {
                use crate::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringFeeParameters};
 
-               let mut d = match super::bench_utils::get_route_file() {
+               let logger = ln_test_utils::TestLogger::new();
+               let graph = match super::bench_utils::read_network_graph(&logger) {
                        Ok(f) => f,
                        Err(e) => {
                                eprintln!("{}", e);
                                return;
                        },
                };
-               let logger = ln_test_utils::TestLogger::new();
-               let graph = NetworkGraph::read(&mut d, &logger).unwrap();
-               let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
-               let random_seed_bytes = keys_manager.get_secure_random_bytes();
-               let config = UserConfig::default();
 
-               // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
-               let mut seed = random_init_seed() as usize;
-               let nodes = graph.read_only().nodes().clone();
-               'load_endpoints: for _ in 0..10 {
-                       loop {
-                               seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               let payment_params = PaymentParameters::from_node_id(dst, 42).with_bolt11_features(channelmanager::provided_invoice_features(&config)).unwrap();
-                               let amt = seed as u64 % 200_000_000;
-                               let params = ProbabilisticScoringFeeParameters::default();
-                               let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger);
-                               if get_route(src, &payment_params, &graph.read_only(), None, amt, &logger, &scorer, &params, &random_seed_bytes).is_ok() {
-                                       continue 'load_endpoints;
-                               }
-                       }
-               }
+               let params = ProbabilisticScoringFeeParameters::default();
+               let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger);
+               let features = channelmanager::provided_invoice_features(&UserConfig::default());
+
+               super::bench_utils::generate_test_routes(&graph, &mut scorer, &params, features, random_init_seed() as usize, 2);
        }
 
        #[test]
@@ -6054,7 +6019,21 @@ mod tests {
 
 #[cfg(all(test, not(feature = "no-std")))]
 pub(crate) mod bench_utils {
+       use super::*;
        use std::fs::File;
+
+       use bitcoin::hashes::Hash;
+       use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
+
+       use crate::chain::transaction::OutPoint;
+       use crate::sign::{EntropySource, KeysManager};
+       use crate::ln::channelmanager::{self, ChannelCounterparty, ChannelDetails};
+       use crate::ln::features::InvoiceFeatures;
+       use crate::routing::gossip::NetworkGraph;
+       use crate::util::config::UserConfig;
+       use crate::util::ser::ReadableArgs;
+       use crate::util::test_utils::TestLogger;
+
        /// Tries to open a network graph file, or panics with a URL to fetch it.
        pub(crate) fn get_route_file() -> Result<std::fs::File, &'static str> {
                let res = File::open("net_graph-2023-01-18.bin") // By default we're run in RL/lightning
@@ -6077,42 +6056,18 @@ pub(crate) mod bench_utils {
                #[cfg(not(require_route_graph_test))]
                return res;
        }
-}
-
-#[cfg(all(test, feature = "_bench_unstable", not(feature = "no-std")))]
-mod benches {
-       use super::*;
-       use bitcoin::hashes::Hash;
-       use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
-       use crate::chain::transaction::OutPoint;
-       use crate::sign::{EntropySource, KeysManager};
-       use crate::ln::channelmanager::{self, ChannelCounterparty, ChannelDetails};
-       use crate::ln::features::InvoiceFeatures;
-       use crate::routing::gossip::NetworkGraph;
-       use crate::routing::scoring::{FixedPenaltyScorer, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters};
-       use crate::util::config::UserConfig;
-       use crate::util::logger::{Logger, Record};
-       use crate::util::ser::ReadableArgs;
-
-       use test::Bencher;
-
-       struct DummyLogger {}
-       impl Logger for DummyLogger {
-               fn log(&self, _record: &Record) {}
-       }
 
-       fn read_network_graph(logger: &DummyLogger) -> NetworkGraph<&DummyLogger> {
-               let mut d = bench_utils::get_route_file().unwrap();
-               NetworkGraph::read(&mut d, logger).unwrap()
+       pub(crate) fn read_network_graph(logger: &TestLogger) -> Result<NetworkGraph<&TestLogger>, &'static str> {
+               get_route_file().map(|mut f| NetworkGraph::read(&mut f, logger).unwrap())
        }
 
-       fn payer_pubkey() -> PublicKey {
+       pub(crate) fn payer_pubkey() -> PublicKey {
                let secp_ctx = Secp256k1::new();
                PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap())
        }
 
        #[inline]
-       fn first_hop(node_id: PublicKey) -> ChannelDetails {
+       pub(crate) fn first_hop(node_id: PublicKey) -> ChannelDetails {
                ChannelDetails {
                        channel_id: [0; 32],
                        counterparty: ChannelCounterparty {
@@ -6151,63 +6106,29 @@ mod benches {
                }
        }
 
-       #[bench]
-       fn generate_routes_with_zero_penalty_scorer(bench: &mut Bencher) {
-               let logger = DummyLogger {};
-               let network_graph = read_network_graph(&logger);
-               let scorer = FixedPenaltyScorer::with_penalty(0);
-               generate_routes(bench, &network_graph, scorer, &(), InvoiceFeatures::empty());
-       }
-
-       #[bench]
-       fn generate_mpp_routes_with_zero_penalty_scorer(bench: &mut Bencher) {
-               let logger = DummyLogger {};
-               let network_graph = read_network_graph(&logger);
-               let scorer = FixedPenaltyScorer::with_penalty(0);
-               generate_routes(bench, &network_graph, scorer, &(), channelmanager::provided_invoice_features(&UserConfig::default()));
-       }
-
-       #[bench]
-       fn generate_routes_with_probabilistic_scorer(bench: &mut Bencher) {
-               let logger = DummyLogger {};
-               let network_graph = read_network_graph(&logger);
-               let params = ProbabilisticScoringFeeParameters::default();
-               let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
-               generate_routes(bench, &network_graph, scorer, &params, InvoiceFeatures::empty());
-       }
-
-       #[bench]
-       fn generate_mpp_routes_with_probabilistic_scorer(bench: &mut Bencher) {
-               let logger = DummyLogger {};
-               let network_graph = read_network_graph(&logger);
-               let params = ProbabilisticScoringFeeParameters::default();
-               let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
-               generate_routes(bench, &network_graph, scorer, &params, channelmanager::provided_invoice_features(&UserConfig::default()));
-       }
-
-       fn generate_routes<S: Score>(
-               bench: &mut Bencher, graph: &NetworkGraph<&DummyLogger>, mut scorer: S, score_params: &S::ScoreParams,
-               features: InvoiceFeatures
-       ) {
-               let nodes = graph.read_only().nodes().clone();
+       pub(crate) fn generate_test_routes<S: Score>(graph: &NetworkGraph<&TestLogger>, scorer: &mut S,
+               score_params: &S::ScoreParams, features: InvoiceFeatures, mut seed: usize, route_count: usize,
+       ) -> Vec<(ChannelDetails, PaymentParameters, u64)> {
                let payer = payer_pubkey();
                let keys_manager = KeysManager::new(&[0u8; 32], 42, 42);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
 
-               // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
-               let mut routes = Vec::new();
+               let nodes = graph.read_only().nodes().clone();
                let mut route_endpoints = Vec::new();
-               let mut seed: usize = 0xdeadbeef;
-               'load_endpoints: for _ in 0..150 {
+               let mut routes = Vec::new();
+
+               'load_endpoints: for _ in 0..route_count * 3 /2 {
                        loop {
-                               seed *= 0xdeadbeef;
+                               seed = seed.overflowing_mul(0xdeadbeef).0;
                                let src = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               seed *= 0xdeadbeef;
+                               seed = seed.overflowing_mul(0xdeadbeef).0;
                                let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                let params = PaymentParameters::from_node_id(dst, 42).with_bolt11_features(features.clone()).unwrap();
                                let first_hop = first_hop(src);
                                let amt = seed as u64 % 1_000_000;
-                               if let Ok(route) = get_route(&payer, &params, &graph.read_only(), Some(&[&first_hop]), amt, &DummyLogger{}, &scorer, score_params, &random_seed_bytes) {
+                               if let Ok(route) = get_route(&payer, &params, &graph.read_only(), Some(&[&first_hop]),
+                                       amt, &TestLogger::new(), &scorer, score_params, &random_seed_bytes,
+                               ) {
                                        routes.push(route);
                                        route_endpoints.push((first_hop, params, amt));
                                        continue 'load_endpoints;
@@ -6230,20 +6151,90 @@ mod benches {
                        }
                }
 
-               // Because we've changed channel scores, its possible we'll take different routes to the
+               // Because we've changed channel scores, it's possible we'll take different routes to the
                // selected destinations, possibly causing us to fail because, eg, the newly-selected path
                // requires a too-high CLTV delta.
                route_endpoints.retain(|(first_hop, params, amt)| {
-                       get_route(&payer, params, &graph.read_only(), Some(&[first_hop]), *amt, &DummyLogger{}, &scorer, score_params, &random_seed_bytes).is_ok()
+                       get_route(&payer, params, &graph.read_only(), Some(&[first_hop]), *amt,
+                               &TestLogger::new(), &scorer, score_params, &random_seed_bytes).is_ok()
                });
-               route_endpoints.truncate(100);
-               assert_eq!(route_endpoints.len(), 100);
+               route_endpoints.truncate(route_count);
+               assert_eq!(route_endpoints.len(), route_count);
+               route_endpoints
+       }
+}
+
+#[cfg(all(test, feature = "_bench_unstable", not(feature = "no-std")))]
+mod benches {
+       use super::*;
+       use crate::sign::{EntropySource, KeysManager};
+       use crate::ln::channelmanager;
+       use crate::ln::features::InvoiceFeatures;
+       use crate::routing::gossip::NetworkGraph;
+       use crate::routing::scoring::{FixedPenaltyScorer, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters};
+       use crate::util::config::UserConfig;
+       use crate::util::logger::{Logger, Record};
+       use crate::util::test_utils::TestLogger;
+
+       use test::Bencher;
+
+       struct DummyLogger {}
+       impl Logger for DummyLogger {
+               fn log(&self, _record: &Record) {}
+       }
+
+
+       #[bench]
+       fn generate_routes_with_zero_penalty_scorer(bench: &mut Bencher) {
+               let logger = TestLogger::new();
+               let network_graph = bench_utils::read_network_graph(&logger).unwrap();
+               let scorer = FixedPenaltyScorer::with_penalty(0);
+               generate_routes(bench, &network_graph, scorer, &(), InvoiceFeatures::empty());
+       }
+
+       #[bench]
+       fn generate_mpp_routes_with_zero_penalty_scorer(bench: &mut Bencher) {
+               let logger = TestLogger::new();
+               let network_graph = bench_utils::read_network_graph(&logger).unwrap();
+               let scorer = FixedPenaltyScorer::with_penalty(0);
+               generate_routes(bench, &network_graph, scorer, &(), channelmanager::provided_invoice_features(&UserConfig::default()));
+       }
+
+       #[bench]
+       fn generate_routes_with_probabilistic_scorer(bench: &mut Bencher) {
+               let logger = TestLogger::new();
+               let network_graph = bench_utils::read_network_graph(&logger).unwrap();
+               let params = ProbabilisticScoringFeeParameters::default();
+               let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
+               generate_routes(bench, &network_graph, scorer, &params, InvoiceFeatures::empty());
+       }
+
+       #[bench]
+       fn generate_mpp_routes_with_probabilistic_scorer(bench: &mut Bencher) {
+               let logger = TestLogger::new();
+               let network_graph = bench_utils::read_network_graph(&logger).unwrap();
+               let params = ProbabilisticScoringFeeParameters::default();
+               let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
+               generate_routes(bench, &network_graph, scorer, &params, channelmanager::provided_invoice_features(&UserConfig::default()));
+       }
+
+       fn generate_routes<S: Score>(
+               bench: &mut Bencher, graph: &NetworkGraph<&TestLogger>, mut scorer: S,
+               score_params: &S::ScoreParams, features: InvoiceFeatures,
+       ) {
+               let payer = bench_utils::payer_pubkey();
+               let keys_manager = KeysManager::new(&[0u8; 32], 42, 42);
+               let random_seed_bytes = keys_manager.get_secure_random_bytes();
+
+               // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
+               let route_endpoints = bench_utils::generate_test_routes(graph, &mut scorer, score_params, features, 0xdeadbeef, 100);
 
                // ...then benchmark finding paths between the nodes we learned.
                let mut idx = 0;
                bench.iter(|| {
                        let (first_hop, params, amt) = &route_endpoints[idx % route_endpoints.len()];
-                       assert!(get_route(&payer, params, &graph.read_only(), Some(&[first_hop]), *amt, &DummyLogger{}, &scorer, score_params, &random_seed_bytes).is_ok());
+                       assert!(get_route(&payer, params, &graph.read_only(), Some(&[first_hop]), *amt,
+                               &DummyLogger{}, &scorer, score_params, &random_seed_bytes).is_ok());
                        idx += 1;
                });
        }