From 72be6850bd3093eea6cc3c15ee9d2f9ebdddc81b Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Mon, 25 Nov 2024 21:29:37 +0000 Subject: [PATCH] Test both the LDK historical and LDK live-bounds models --- src/main.rs | 169 +++++++++++++++++++++++++++++----------------------- 1 file changed, 96 insertions(+), 73 deletions(-) diff --git a/src/main.rs b/src/main.rs index d776c11..e0127af 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,20 @@ use internal::DevNullLogger; use lightning::routing::gossip::{NetworkGraph, NodeId, ReadOnlyNetworkGraph}; +/// We demonstrate breaking out the tracking into several categories, allowing us to track the +/// accuracy of two different models across different types of results. +/// Ultimately we end up with `CATEGORIES` outputs, with averages across the types computed in +/// [`results_complete`]. +const CATEGORIES: usize = 1 << 3; +/// Entries with this flag are those for hops that succeeded. +const SUCCESS: usize = 1; +/// Entries with this flag fell back to probability estimation without any historical probing +/// results for this channel, using only the channel's capacity to estimate probability. +const NO_DATA: usize = 2; +/// Entries with this flag are for the live bounds model. Entries without this flag are for the LDK +/// historical model. +const LIVE: usize = 4; + /// The simulation state. /// /// You're free to put whatever you want here. @@ -10,29 +24,21 @@ pub struct State<'a> { /// As a demonstration, the default run calculates the probability the LDK historical model /// assigns to results. scorer: lightning::routing::scoring::ProbabilisticScorer<&'a NetworkGraph, DevNullLogger>, - // We demonstrate calculating log-loss of the LDK historical model - success_loss_sum: f64, - success_result_count: u64, - failure_loss_sum: f64, - failure_result_count: u64, - no_data_success_loss_sum: f64, - no_data_success_result_count: u64, - no_data_failure_loss_sum: f64, - no_data_failure_result_count: u64, + /// We demonstrate calculating log-loss of the both the LDK historical model and the more naive + /// live bounds model. + /// + /// Each entry is defined as the total log-loss for the categories out outputs as defined by + /// the above flags. + log_loss_sum: [f64; CATEGORIES], + result_count: [u64; CATEGORIES], } /// Creates a new [`State`] before any probe results are processed. pub fn do_setup<'a>(graph: &'a NetworkGraph) -> State { State { scorer: lightning::routing::scoring::ProbabilisticScorer::new(Default::default(), graph, internal::DevNullLogger), - success_loss_sum: 0.0, - success_result_count: 0, - failure_loss_sum: 0.0, - failure_result_count: 0, - no_data_success_loss_sum: 0.0, - no_data_success_result_count: 0, - no_data_failure_loss_sum: 0.0, - no_data_failure_result_count: 0, + log_loss_sum: [0.0; CATEGORIES], + result_count: [0; CATEGORIES], } } @@ -49,48 +55,55 @@ pub fn process_probe_result(network_graph: ReadOnlyNetworkGraph, result: ProbeRe let cur_time = std::time::Duration::from_secs(result.timestamp); state.scorer.time_passed(cur_time); - // Evaluate the model + // For each hop in the path, we add new entries to the `log_loss_sum` and + // `result_count` state variables, updating entries with the given flags. + let mut update_data_with_result = |mut flags, mut probability: f64, success| { + if success { + flags |= SUCCESS; + } else { + flags &= !SUCCESS; + } + if !success { probability = 1.0 - probability; } + if probability < 0.01 { + // While the model really needs to be tuned to avoid being so incredibly + // overconfident, in the mean time we cheat a bit to avoid infinite results. + probability = 0.01; + } + state.log_loss_sum[flags] -= probability.log2(); + state.result_count[flags] += 1; + }; + + // At each hop, we add two new entries - one for the LDK historical model and one for the naive + // live bounds model. + let mut evaluate_hop = |hop: &DirectedChannel, success| { + let hist_model_probability = + state.scorer.historical_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default(), true) + .expect("We should have some estimated probability, even without history data"); + let have_hist_results = + state.scorer.historical_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default(), false) + .is_some(); + let flags = if have_hist_results { 0 } else { NO_DATA }; + update_data_with_result(flags, hist_model_probability, success); + + let live_model_probability = + state.scorer.live_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default()) + .expect("We should have some estimated probability, even without past data"); + let have_live_data = state.scorer.estimated_channel_liquidity_range(hop.short_channel_id, &hop.dst_node_id).is_some(); + let flags = LIVE | if have_live_data { 0 } else { NO_DATA }; + update_data_with_result(flags, live_model_probability, success); + }; + + // Evaluate the model by passing each hop which succeeded as well as the final failing hop to + // `evaluate_hop`. for hop in result.channels_with_sufficient_liquidity.iter() { // You can get additional information about the channel from the network_graph: let _chan = network_graph.channels().get(&hop.short_channel_id).unwrap(); - let mut no_data = false; - let mut model_probability = - state.scorer.historical_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default()) - .unwrap_or_else(|| { - no_data = true; - // If LDK doesn't have sufficient historical state it will fall back to (roughly) the live model. - state.scorer - .live_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default()) - .expect("We should have some estimated probability, even without history data") - }); - if model_probability < 0.01 { model_probability = 0.01; } - state.success_loss_sum -= model_probability.log2(); - state.success_result_count += 1; - if no_data { - state.no_data_success_loss_sum -= model_probability.log2(); - state.no_data_success_result_count += 1; - } + evaluate_hop(hop, true); } if let Some(hop) = &result.channel_that_rejected_payment { // You can get additional information about the channel from the network_graph: let _chan = network_graph.channels().get(&hop.short_channel_id).unwrap(); - let mut no_data = false; - let mut model_probability = - state.scorer.historical_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default()) - .unwrap_or_else(|| { - no_data = true; - // If LDK doesn't have sufficient historical state it will fall back to (roughly) the live model. - state.scorer - .live_estimated_payment_success_probability(hop.short_channel_id, &hop.dst_node_id, hop.amount_msat, &Default::default()) - .expect("We should have some estimated probability, even without history data") - }); - if model_probability > 0.99 { model_probability = 0.99; } - state.failure_loss_sum -= (1.0 - model_probability).log2(); - state.failure_result_count += 1; - if no_data { - state.no_data_failure_loss_sum -= (1.0 - model_probability).log2(); - state.no_data_failure_result_count += 1; - } + evaluate_hop(hop, false); } // Update the model with the information we learned @@ -125,27 +138,37 @@ pub fn results_complete(state: State) { // We break out log-loss for failure and success hops and print averages between the two // (rather than in aggregate) as there are substantially more succeeding hops than there are // failing hops. - let no_data_suc = state.no_data_success_loss_sum / (state.no_data_success_result_count as f64); - let no_data_fail = state.no_data_failure_loss_sum / (state.no_data_failure_result_count as f64); - println!("Avg no-data success log-loss {}", no_data_suc); - println!("Avg no-data failure log-loss {}", no_data_fail); - println!("Avg no-data success+failure log-loss {}", (no_data_suc + no_data_fail) / 2.0); - println!(); - let avg_hist_suc = (state.success_loss_sum - state.no_data_success_loss_sum) / ((state.success_result_count - state.no_data_success_result_count) as f64); - let avg_hist_fail = (state.failure_loss_sum - state.no_data_failure_loss_sum) / ((state.failure_result_count - state.no_data_failure_result_count) as f64); - println!("Avg historical data success log-loss {}", avg_hist_suc); - println!("Avg historical data failure log-loss {}", avg_hist_fail); - println!("Avg hist data suc+fail average log-loss {}", (avg_hist_suc + avg_hist_fail) / 2.0); - println!(); - let avg_suc = state.success_loss_sum / (state.success_result_count as f64); - let avg_fail = state.failure_loss_sum / (state.failure_result_count as f64); - println!("Avg success log-loss {}", avg_suc); - println!("Avg failure log-loss {}", avg_fail); - println!("Avg success+failure average log-loss {}", (avg_suc + avg_fail) / 2.0); - println!(); - let loss_sum = state.success_loss_sum + state.failure_loss_sum; - let result_count = state.success_result_count + state.failure_result_count; - println!("Avg log-loss {}", loss_sum / (result_count as f64)); + for category in 0..CATEGORIES / 4 { + let flags = category * 4; + let mut category_name = String::new(); + if (flags & LIVE) != 0 { + category_name += "Live Bounds Model"; + } else { + category_name += "Historical Model "; + } + for no_data in 0..2 { + let flags = flags + no_data * NO_DATA; + let fail_res = state.log_loss_sum[flags] / state.result_count[flags] as f64; + let suc_res = state.log_loss_sum[flags|1] / state.result_count[flags|1] as f64; + let mut category_name = category_name.clone(); + if (flags & NO_DATA) != 0 { + category_name += " (w/ insufficient data)"; + } else { + category_name += " (w/ some channel hist)"; + } + println!("Avg {} success log-loss: {}", category_name, suc_res); + println!("Avg {} failure log-loss: {}", category_name, fail_res); + println!("Avg {} average log-loss: {}", category_name, (suc_res + fail_res) / 2.0); + } + let fail_res = (state.log_loss_sum[flags] + state.log_loss_sum[flags + NO_DATA]) + / (state.result_count[flags] + state.result_count[flags + NO_DATA]) as f64; + let suc_res = (state.log_loss_sum[flags + 1] + state.log_loss_sum[flags + NO_DATA + 1]) + / (state.result_count[flags + 1] + state.result_count[flags + NO_DATA + 1]) as f64; + println!("Avg {} success log-loss: {}", category_name, suc_res); + println!("Avg {} failure log-loss: {}", category_name, fail_res); + println!("Avg {} average log-loss: {}", category_name, (suc_res + fail_res) / 2.0); + println!(); + } } /// A hop in a route, consisting of a channel and the source public key, as well as the amount -- 2.39.5