]> git.bitcoin.ninja Git - ln-routing-replay/commitdiff
Test both the LDK historical and LDK live-bounds models main
authorMatt Corallo <git@bluematt.me>
Mon, 25 Nov 2024 21:29:37 +0000 (21:29 +0000)
committerMatt Corallo <git@bluematt.me>
Mon, 25 Nov 2024 21:42:36 +0000 (21:42 +0000)
src/main.rs

index d776c1167af41ec39780e4940d817f42016b7ce6..e0127afebbfd2013a2f3eed69d5f9a359fecd20f 100644 (file)
@@ -3,6 +3,20 @@ use internal::DevNullLogger;
 
 use lightning::routing::gossip::{NetworkGraph, NodeId, ReadOnlyNetworkGraph};
 
+/// We demonstrate breaking out the tracking into several categories, allowing us to track the
+/// accuracy of two different models across different types of results.
+/// Ultimately we end up with `CATEGORIES` outputs, with averages across the types computed in
+/// [`results_complete`].
+const CATEGORIES: usize = 1 << 3;
+/// Entries with this flag are those for hops that succeeded.
+const SUCCESS: usize = 1;
+/// Entries with this flag fell back to probability estimation without any historical probing
+/// results for this channel, using only the channel's capacity to estimate probability.
+const NO_DATA: usize = 2;
+/// Entries with this flag are for the live bounds model. Entries without this flag are for the LDK
+/// historical model.
+const LIVE: usize = 4;
+
 /// The simulation state.
 ///
 /// You're free to put whatever you want here.
@@ -10,29 +24,21 @@ pub struct State<'a> {
        /// As a demonstration, the default run calculates the probability the LDK historical model
        /// assigns to results.
        scorer: lightning::routing::scoring::ProbabilisticScorer<&'a NetworkGraph<DevNullLogger>, DevNullLogger>,
-       // We demonstrate calculating log-loss of the LDK historical model
-       success_loss_sum: f64,
-       success_result_count: u64,
-       failure_loss_sum: f64,
-       failure_result_count: u64,
-       no_data_success_loss_sum: f64,
-       no_data_success_result_count: u64,
-       no_data_failure_loss_sum: f64,
-       no_data_failure_result_count: u64,
+       /// We demonstrate calculating log-loss of the both the LDK historical model and the more naive
+       /// live bounds model.
+       ///
+       /// Each entry is defined as the total log-loss for the categories out outputs as defined by
+       /// the above flags.
+       log_loss_sum: [f64; CATEGORIES],
+       result_count: [u64; CATEGORIES],
 }
 
 /// Creates a new [`State`] before any probe results are processed.
 pub fn do_setup<'a>(graph: &'a NetworkGraph<DevNullLogger>) -> State {
        State {
                scorer: lightning::routing::scoring::ProbabilisticScorer::new(Default::default(), graph, internal::DevNullLogger),
-               success_loss_sum: 0.0,
-               success_result_count: 0,
-               failure_loss_sum: 0.0,
-               failure_result_count: 0,
-               no_data_success_loss_sum: 0.0,
-               no_data_success_result_count: 0,
-               no_data_failure_loss_sum: 0.0,
-               no_data_failure_result_count: 0,
+               log_loss_sum: [0.0; CATEGORIES],
+               result_count: [0; CATEGORIES],
        }
 }
 
@@ -49,48 +55,55 @@ pub fn process_probe_result(network_graph: ReadOnlyNetworkGraph, result: ProbeRe
        let cur_time = std::time::Duration::from_secs(result.timestamp);
        state.scorer.time_passed(cur_time);
 
-       // Evaluate the model
+       // For each hop in the path, we add new entries to the `log_loss_sum` and
+       // `result_count` state variables, updating entries with the given flags.
+       let mut update_data_with_result = |mut flags, mut probability: f64, success| {
+               if success {
+                       flags |= SUCCESS;
+               } else {
+                       flags &= !SUCCESS;
+               }
+               if !success { probability = 1.0 - probability; }
+               if probability < 0.01 {
+                       // While the model really needs to be tuned to avoid being so incredibly
+                       // overconfident, in the mean time we cheat a bit to avoid infinite results.
+                       probability = 0.01;
+               }
+               state.log_loss_sum[flags] -= probability.log2();
+               state.result_count[flags] += 1;
+       };
+
+       // At each hop, we add two new entries - one for the LDK historical model and one for the naive
+       // live bounds model.
+       let mut evaluate_hop = |hop: &DirectedChannel, success| {
+               let hist_model_probability =
+                       state.scorer.historical_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default(), true)
+                       .expect("We should have some estimated probability, even without history data");
+               let have_hist_results =
+                       state.scorer.historical_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default(), false)
+                       .is_some();
+               let flags = if have_hist_results { 0 } else { NO_DATA };
+               update_data_with_result(flags, hist_model_probability, success);
+
+               let live_model_probability =
+                       state.scorer.live_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default())
+                       .expect("We should have some estimated probability, even without past data");
+               let have_live_data = state.scorer.estimated_channel_liquidity_range(hop.short_channel_id, &hop.dst_node_id).is_some();
+               let flags = LIVE | if have_live_data { 0 } else { NO_DATA };
+               update_data_with_result(flags, live_model_probability, success);
+       };
+
+       // Evaluate the model by passing each hop which succeeded as well as the final failing hop to
+       // `evaluate_hop`.
        for hop in result.channels_with_sufficient_liquidity.iter() {
                // You can get additional information about the channel from the network_graph:
                let _chan = network_graph.channels().get(&hop.short_channel_id).unwrap();
-               let mut no_data = false;
-               let mut model_probability =
-                       state.scorer.historical_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default())
-                       .unwrap_or_else(|| {
-                               no_data = true;
-                               // If LDK doesn't have sufficient historical state it will fall back to (roughly) the live model.
-                               state.scorer
-                                       .live_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default())
-                                       .expect("We should have some estimated probability, even without history data")
-                       });
-               if model_probability < 0.01 { model_probability = 0.01; }
-               state.success_loss_sum -= model_probability.log2();
-               state.success_result_count += 1;
-               if no_data {
-                       state.no_data_success_loss_sum -= model_probability.log2();
-                       state.no_data_success_result_count += 1;
-               }
+               evaluate_hop(hop, true);
        }
        if let Some(hop) = &result.channel_that_rejected_payment {
                // You can get additional information about the channel from the network_graph:
                let _chan = network_graph.channels().get(&hop.short_channel_id).unwrap();
-               let mut no_data = false;
-               let mut model_probability =
-                       state.scorer.historical_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default())
-                       .unwrap_or_else(|| {
-                               no_data = true;
-                               // If LDK doesn't have sufficient historical state it will fall back to (roughly) the live model.
-                               state.scorer
-                                       .live_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default())
-                                       .expect("We should have some estimated probability, even without history data")
-                       });
-               if model_probability > 0.99 { model_probability = 0.99; }
-               state.failure_loss_sum -= (1.0 - model_probability).log2();
-               state.failure_result_count += 1;
-               if no_data {
-                       state.no_data_failure_loss_sum -= (1.0 - model_probability).log2();
-                       state.no_data_failure_result_count += 1;
-               }
+               evaluate_hop(hop, false);
        }
 
        // Update the model with the information we learned
@@ -125,27 +138,37 @@ pub fn results_complete(state: State) {
        // We break out log-loss for failure and success hops and print averages between the two
        // (rather than in aggregate) as there are substantially more succeeding hops than there are
        // failing hops.
-       let no_data_suc = state.no_data_success_loss_sum / (state.no_data_success_result_count as f64);
-       let no_data_fail = state.no_data_failure_loss_sum / (state.no_data_failure_result_count as f64);
-       println!("Avg no-data success log-loss            {}", no_data_suc);
-       println!("Avg no-data failure log-loss            {}", no_data_fail);
-       println!("Avg no-data success+failure log-loss    {}", (no_data_suc + no_data_fail) / 2.0);
-       println!();
-       let avg_hist_suc = (state.success_loss_sum - state.no_data_success_loss_sum) / ((state.success_result_count - state.no_data_success_result_count) as f64);
-       let avg_hist_fail = (state.failure_loss_sum - state.no_data_failure_loss_sum) / ((state.failure_result_count - state.no_data_failure_result_count) as f64);
-       println!("Avg historical data success log-loss    {}", avg_hist_suc);
-       println!("Avg historical data failure log-loss    {}", avg_hist_fail);
-       println!("Avg hist data suc+fail average log-loss {}", (avg_hist_suc + avg_hist_fail) / 2.0);
-       println!();
-       let avg_suc = state.success_loss_sum / (state.success_result_count as f64);
-       let avg_fail = state.failure_loss_sum / (state.failure_result_count as f64);
-       println!("Avg success log-loss                    {}", avg_suc);
-       println!("Avg failure log-loss                    {}", avg_fail);
-       println!("Avg success+failure average log-loss    {}", (avg_suc + avg_fail) / 2.0);
-       println!();
-       let loss_sum = state.success_loss_sum + state.failure_loss_sum;
-       let result_count = state.success_result_count + state.failure_result_count;
-       println!("Avg log-loss {}", loss_sum / (result_count as f64));
+       for category in 0..CATEGORIES / 4 {
+               let flags = category * 4;
+               let mut category_name = String::new();
+               if (flags & LIVE) != 0 {
+                       category_name += "Live Bounds Model";
+               } else {
+                       category_name += "Historical Model ";
+               }
+               for no_data in 0..2 {
+                       let flags = flags + no_data * NO_DATA;
+                       let fail_res = state.log_loss_sum[flags] / state.result_count[flags] as f64;
+                       let suc_res = state.log_loss_sum[flags|1] / state.result_count[flags|1] as f64;
+                       let mut category_name = category_name.clone();
+                       if (flags & NO_DATA) != 0 {
+                               category_name += " (w/ insufficient data)";
+                       } else {
+                               category_name += " (w/ some channel hist)";
+                       }
+                       println!("Avg {} success log-loss: {}", category_name, suc_res);
+                       println!("Avg {} failure log-loss: {}", category_name, fail_res);
+                       println!("Avg {} average log-loss: {}", category_name, (suc_res + fail_res) / 2.0);
+               }
+               let fail_res = (state.log_loss_sum[flags] + state.log_loss_sum[flags + NO_DATA])
+                       / (state.result_count[flags] + state.result_count[flags + NO_DATA]) as f64;
+               let suc_res = (state.log_loss_sum[flags + 1] + state.log_loss_sum[flags + NO_DATA + 1])
+                       / (state.result_count[flags + 1] + state.result_count[flags + NO_DATA + 1]) as f64;
+               println!("Avg {} success log-loss: {}", category_name, suc_res);
+               println!("Avg {} failure log-loss: {}", category_name, fail_res);
+               println!("Avg {} average log-loss: {}", category_name, (suc_res + fail_res) / 2.0);
+               println!();
+       }
 }
 
 /// A hop in a route, consisting of a channel and the source public key, as well as the amount