Remove duplicate generate_routes benchmark code
authorJeffrey Czyz <jkczyz@gmail.com>
Fri, 14 Jan 2022 18:28:30 +0000 (12:28 -0600)
committerJeffrey Czyz <jkczyz@gmail.com>
Tue, 25 Jan 2022 00:52:36 +0000 (18:52 -0600)
Refactor generate_routes and generate_mpp_routes into a single utility
for benchmarking. The utility is parameterized with features in order to
test both single path and multi-path routing. Additionally, it is
parameterized with a Score to be used with other scorers.

lightning/src/routing/router.rs

index 420bbaed0ca626af10c33efaba0fed01aee77790..397042f8e1248923ff4bf24edd844a4eba6ddc55 100644 (file)
@@ -4963,6 +4963,8 @@ pub(crate) mod test_utils {
 #[cfg(all(test, feature = "unstable", not(feature = "no-std")))]
 mod benches {
        use super::*;
+       use bitcoin::secp256k1::PublicKey;
+       use ln::features::InvoiceFeatures;
        use routing::scoring::Scorer;
        use util::logger::{Logger, Record};
 
@@ -4973,47 +4975,29 @@ mod benches {
                fn log(&self, _record: &Record) {}
        }
 
-       #[bench]
-       fn generate_routes(bench: &mut Bencher) {
+       fn read_network_graph() -> NetworkGraph {
                let mut d = test_utils::get_route_file().unwrap();
-               let graph = NetworkGraph::read(&mut d).unwrap();
-               let nodes = graph.read_only().nodes().clone();
-               let scorer = Scorer::with_fixed_penalty(0);
-
-               // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
-               let mut path_endpoints = Vec::new();
-               let mut seed: usize = 0xdeadbeef;
-               'load_endpoints: for _ in 0..100 {
-                       loop {
-                               seed *= 0xdeadbeef;
-                               let src = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               seed *= 0xdeadbeef;
-                               let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               let payment_params = PaymentParameters::from_node_id(dst);
-                               let amt = seed as u64 % 1_000_000;
-                               if get_route(&src, &payment_params, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok() {
-                                       path_endpoints.push((src, dst, amt));
-                                       continue 'load_endpoints;
-                               }
-                       }
-               }
+               NetworkGraph::read(&mut d).unwrap()
+       }
 
-               // ...then benchmark finding paths between the nodes we learned.
-               let mut idx = 0;
-               bench.iter(|| {
-                       let (src, dst, amt) = path_endpoints[idx % path_endpoints.len()];
-                       let payment_params = PaymentParameters::from_node_id(dst);
-                       assert!(get_route(&src, &payment_params, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok());
-                       idx += 1;
-               });
+       #[bench]
+       fn generate_routes_with_default_scorer(bench: &mut Bencher) {
+               let network_graph = read_network_graph();
+               let scorer = Scorer::default();
+               generate_routes(bench, &network_graph, scorer, InvoiceFeatures::empty());
        }
 
        #[bench]
-       fn generate_mpp_routes(bench: &mut Bencher) {
-               let mut d = test_utils::get_route_file().unwrap();
-               let graph = NetworkGraph::read(&mut d).unwrap();
+       fn generate_mpp_routes_with_default_scorer(bench: &mut Bencher) {
+               let network_graph = read_network_graph();
+               let scorer = Scorer::default();
+               generate_routes(bench, &network_graph, scorer, InvoiceFeatures::known());
+       }
+
+       fn generate_routes<S: Score>(
+               bench: &mut Bencher, graph: &NetworkGraph, scorer: S, features: InvoiceFeatures
+       ) {
                let nodes = graph.read_only().nodes().clone();
-               let scorer = Scorer::with_fixed_penalty(0);
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut path_endpoints = Vec::new();
@@ -5024,9 +5008,9 @@ mod benches {
                                let src = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed *= 0xdeadbeef;
                                let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
-                               let payment_params = PaymentParameters::from_node_id(dst).with_features(InvoiceFeatures::known());
+                               let params = PaymentParameters::from_node_id(dst).with_features(features.clone());
                                let amt = seed as u64 % 1_000_000;
-                               if get_route(&src, &payment_params, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok() {
+                               if get_route(&src, &params, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok() {
                                        path_endpoints.push((src, dst, amt));
                                        continue 'load_endpoints;
                                }
@@ -5037,8 +5021,8 @@ mod benches {
                let mut idx = 0;
                bench.iter(|| {
                        let (src, dst, amt) = path_endpoints[idx % path_endpoints.len()];
-                       let payment_params = PaymentParameters::from_node_id(dst).with_features(InvoiceFeatures::known());
-                       assert!(get_route(&src, &payment_params, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok());
+                       let params = PaymentParameters::from_node_id(dst).with_features(features.clone());
+                       assert!(get_route(&src, &params, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok());
                        idx += 1;
                });
        }