X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Frouting%2Frouter.rs;h=c8468798eecb4ba955fea2256b66d470f16170cf;hb=6a54098dd125798d921898647243b1e89290724e;hp=46010475cebfd56550ef06ede0cba4824265684a;hpb=0ecb4b093ac1134cbf06275cc48dd79ef7524774;p=rust-lightning diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 46010475..c8468798 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -18,7 +18,7 @@ use crate::ln::PaymentHash; use crate::ln::channelmanager::{ChannelDetails, PaymentId}; use crate::ln::features::{Bolt12InvoiceFeatures, ChannelFeatures, InvoiceFeatures, NodeFeatures}; use crate::ln::msgs::{DecodeError, ErrorAction, LightningError, MAX_VALUE_MSAT}; -use crate::offers::invoice::BlindedPayInfo; +use crate::offers::invoice::{BlindedPayInfo, Invoice as Bolt12Invoice}; use crate::routing::gossip::{DirectedChannelInfo, EffectiveCapacity, ReadOnlyNetworkGraph, NetworkGraph, NodeId, RoutingFees}; use crate::routing::scoring::{ChannelUsage, LockableScore, Score}; use crate::util::ser::{Writeable, Readable, ReadableArgs, Writer}; @@ -27,39 +27,43 @@ use crate::util::chacha20::ChaCha20; use crate::io; use crate::prelude::*; -use crate::sync::Mutex; +use crate::sync::{Mutex, MutexGuard}; use alloc::collections::BinaryHeap; use core::{cmp, fmt}; use core::ops::Deref; /// A [`Router`] implemented using [`find_route`]. -pub struct DefaultRouter>, L: Deref, S: Deref> where +pub struct DefaultRouter>, L: Deref, S: Deref, SP: Sized, Sc: Score> where L::Target: Logger, - S::Target: for <'a> LockableScore<'a>, + S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>, { network_graph: G, logger: L, random_seed_bytes: Mutex<[u8; 32]>, - scorer: S + scorer: S, + score_params: SP } -impl>, L: Deref, S: Deref> DefaultRouter where +impl>, L: Deref, S: Deref, SP: Sized, Sc: Score> DefaultRouter where L::Target: Logger, - S::Target: for <'a> LockableScore<'a>, + S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>, { /// Creates a new router. - pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S) -> Self { + pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S, score_params: SP) -> Self { let random_seed_bytes = Mutex::new(random_seed_bytes); - Self { network_graph, logger, random_seed_bytes, scorer } + Self { network_graph, logger, random_seed_bytes, scorer, score_params } } } -impl>, L: Deref, S: Deref> Router for DefaultRouter where +impl< G: Deref>, L: Deref, S: Deref, SP: Sized, Sc: Score> Router for DefaultRouter where L::Target: Logger, - S::Target: for <'a> LockableScore<'a>, + S::Target: for <'a> LockableScore<'a, Locked = MutexGuard<'a, Sc>>, { fn find_route( - &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>, + &self, + payer: &PublicKey, + params: &RouteParameters, + first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: &InFlightHtlcs ) -> Result { let random_seed_bytes = { @@ -67,10 +71,10 @@ impl>, L: Deref, S: Deref> Router for DefaultR *locked_random_seed_bytes = Sha256::hash(&*locked_random_seed_bytes).into_inner(); *locked_random_seed_bytes }; - find_route( payer, params, &self.network_graph, first_hops, &*self.logger, &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock(), inflight_htlcs), + &self.score_params, &random_seed_bytes ) } @@ -122,7 +126,8 @@ impl<'a, S: Score> Writeable for ScorerAccountingForInFlightHtlcs<'a, S> { } impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> { - fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage) -> u64 { + type ScoreParams = S::ScoreParams; + fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage, score_params: &Self::ScoreParams) -> u64 { if let Some(used_liquidity) = self.inflight_htlcs.used_liquidity_msat( source, target, short_channel_id ) { @@ -131,9 +136,9 @@ impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> { ..usage }; - self.scorer.channel_penalty_msat(short_channel_id, source, target, usage) + self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params) } else { - self.scorer.channel_penalty_msat(short_channel_id, source, target, usage) + self.scorer.channel_penalty_msat(short_channel_id, source, target, usage, score_params) } } @@ -616,12 +621,53 @@ impl PaymentParameters { /// /// The `final_cltv_expiry_delta` should match the expected final CLTV delta the recipient has /// provided. - pub fn for_keysend(payee_pubkey: PublicKey, final_cltv_expiry_delta: u32) -> Self { - Self::from_node_id(payee_pubkey, final_cltv_expiry_delta).with_bolt11_features(InvoiceFeatures::for_keysend()).expect("PaymentParameters::from_node_id should always initialize the payee as unblinded") + /// + /// Note that MPP keysend is not widely supported yet. The `allow_mpp` lets you choose + /// whether your router will be allowed to find a multi-part route for this payment. If you + /// set `allow_mpp` to true, you should ensure a payment secret is set on send, likely via + /// [`RecipientOnionFields::secret_only`]. + /// + /// [`RecipientOnionFields::secret_only`]: crate::ln::channelmanager::RecipientOnionFields::secret_only + pub fn for_keysend(payee_pubkey: PublicKey, final_cltv_expiry_delta: u32, allow_mpp: bool) -> Self { + Self::from_node_id(payee_pubkey, final_cltv_expiry_delta) + .with_bolt11_features(InvoiceFeatures::for_keysend(allow_mpp)) + .expect("PaymentParameters::from_node_id should always initialize the payee as unblinded") + } + + /// Creates parameters for paying to a blinded payee from the provided invoice. Sets + /// [`Payee::Blinded::route_hints`], [`Payee::Blinded::features`], and + /// [`PaymentParameters::expiry_time`]. + pub fn from_bolt12_invoice(invoice: &Bolt12Invoice) -> Self { + Self::blinded(invoice.payment_paths().to_vec()) + .with_bolt12_features(invoice.features().clone()).unwrap() + .with_expiry_time(invoice.created_at().as_secs().saturating_add(invoice.relative_expiry().as_secs())) + } + + fn blinded(blinded_route_hints: Vec<(BlindedPayInfo, BlindedPath)>) -> Self { + Self { + payee: Payee::Blinded { route_hints: blinded_route_hints, features: None }, + expiry_time: None, + max_total_cltv_expiry_delta: DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, + max_path_count: DEFAULT_MAX_PATH_COUNT, + max_channel_saturation_power_of_half: 2, + previously_failed_channels: Vec::new(), + } } - /// Includes the payee's features. Errors if the parameters were initialized with blinded payment - /// paths. + /// Includes the payee's features. Errors if the parameters were not initialized with + /// [`PaymentParameters::from_bolt12_invoice`]. + /// + /// This is not exported to bindings users since bindings don't support move semantics + pub fn with_bolt12_features(self, features: Bolt12InvoiceFeatures) -> Result { + match self.payee { + Payee::Clear { .. } => Err(()), + Payee::Blinded { route_hints, .. } => + Ok(Self { payee: Payee::Blinded { route_hints, features: Some(features) }, ..self }) + } + } + + /// Includes the payee's features. Errors if the parameters were initialized with + /// [`PaymentParameters::from_bolt12_invoice`]. /// /// This is not exported to bindings users since bindings don't support move semantics pub fn with_bolt11_features(self, features: InvoiceFeatures) -> Result { @@ -637,7 +683,7 @@ impl PaymentParameters { } /// Includes hints for routing to the payee. Errors if the parameters were initialized with - /// blinded payment paths. + /// [`PaymentParameters::from_bolt12_invoice`]. /// /// This is not exported to bindings users since bindings don't support move semantics pub fn with_route_hints(self, route_hints: Vec) -> Result { @@ -673,7 +719,8 @@ impl PaymentParameters { Self { max_path_count, ..self } } - /// Includes a limit for the maximum number of payment paths that may be used. + /// Includes a limit for the maximum share of a channel's total capacity that can be sent over, as + /// a power of 1/2. See [`PaymentParameters::max_channel_saturation_power_of_half`]. /// /// This is not exported to bindings users since bindings don't support move semantics pub fn with_max_channel_saturation_power_of_half(self, max_channel_saturation_power_of_half: u8) -> Self { @@ -789,7 +836,7 @@ impl ReadableArgs for Features { } /// A list of hops along a payment path terminating with a channel to the recipient. -#[derive(Clone, Debug, Hash, Eq, PartialEq)] +#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] pub struct RouteHint(pub Vec); impl Writeable for RouteHint { @@ -814,7 +861,7 @@ impl Readable for RouteHint { } /// A channel descriptor for a hop along a payment path. -#[derive(Clone, Debug, Hash, Eq, PartialEq)] +#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)] pub struct RouteHintHop { /// The node_id of the non-target end of the route pub src_node_id: PublicKey, @@ -924,7 +971,7 @@ impl<'a> CandidateRouteHop<'a> { fn htlc_minimum_msat(&self) -> u64 { match self { - CandidateRouteHop::FirstHop { .. } => 0, + CandidateRouteHop::FirstHop { details } => details.next_outbound_htlc_minimum_msat, CandidateRouteHop::PublicHop { info, .. } => info.direction().htlc_minimum_msat, CandidateRouteHop::PrivateHop { hint } => hint.htlc_minimum_msat.unwrap_or(0), } @@ -946,7 +993,10 @@ impl<'a> CandidateRouteHop<'a> { liquidity_msat: details.next_outbound_htlc_limit_msat, }, CandidateRouteHop::PublicHop { info, .. } => info.effective_capacity(), - CandidateRouteHop::PrivateHop { .. } => EffectiveCapacity::Infinite, + CandidateRouteHop::PrivateHop { hint: RouteHintHop { htlc_maximum_msat: Some(max), .. }} => + EffectiveCapacity::HintMaxHTLC { amount_msat: *max }, + CandidateRouteHop::PrivateHop { hint: RouteHintHop { htlc_maximum_msat: None, .. }} => + EffectiveCapacity::Infinite, } } } @@ -958,8 +1008,11 @@ fn max_htlc_from_capacity(capacity: EffectiveCapacity, max_channel_saturation_po EffectiveCapacity::ExactLiquidity { liquidity_msat } => liquidity_msat, EffectiveCapacity::Infinite => u64::max_value(), EffectiveCapacity::Unknown => EffectiveCapacity::Unknown.as_msat(), - EffectiveCapacity::MaximumHTLC { amount_msat } => + EffectiveCapacity::AdvertisedMaxHTLC { amount_msat } => amount_msat.checked_shr(saturation_shift).unwrap_or(0), + // Treat htlc_maximum_msat from a route hint as an exact liquidity amount, since the invoice is + // expected to have been generated from up-to-date capacity information. + EffectiveCapacity::HintMaxHTLC { amount_msat } => amount_msat, EffectiveCapacity::Total { capacity_msat, htlc_maximum_msat } => cmp::min(capacity_msat.checked_shr(saturation_shift).unwrap_or(0), htlc_maximum_msat), } @@ -1010,7 +1063,7 @@ struct PathBuildingHop<'a> { /// decrease as well. Thus, we have to explicitly track which nodes have been processed and /// avoid processing them again. was_processed: bool, - #[cfg(all(not(feature = "_bench_unstable"), any(test, fuzzing)))] + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] // In tests, we apply further sanity checks on cases where we skip nodes we already processed // to ensure it is specifically in cases where the fee has gone down because of a decrease in // value_contribution_msat, which requires tracking it here. See comments below where it is @@ -1031,7 +1084,7 @@ impl<'a> core::fmt::Debug for PathBuildingHop<'a> { .field("path_penalty_msat", &self.path_penalty_msat) .field("path_htlc_minimum_msat", &self.path_htlc_minimum_msat) .field("cltv_expiry_delta", &self.candidate.cltv_expiry_delta()); - #[cfg(all(not(feature = "_bench_unstable"), any(test, fuzzing)))] + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] let debug_struct = debug_struct .field("value_contribution_msat", &self.value_contribution_msat); debug_struct.finish() @@ -1187,6 +1240,41 @@ impl fmt::Display for LoggedPayeePubkey { } } +#[inline] +fn sort_first_hop_channels( + channels: &mut Vec<&ChannelDetails>, used_channel_liquidities: &HashMap<(u64, bool), u64>, + recommended_value_msat: u64, our_node_pubkey: &PublicKey +) { + // Sort the first_hops channels to the same node(s) in priority order of which channel we'd + // most like to use. + // + // First, if channels are below `recommended_value_msat`, sort them in descending order, + // preferring larger channels to avoid splitting the payment into more MPP parts than is + // required. + // + // Second, because simply always sorting in descending order would always use our largest + // available outbound capacity, needlessly fragmenting our available channel capacities, + // sort channels above `recommended_value_msat` in ascending order, preferring channels + // which have enough, but not too much, capacity for the payment. + // + // Available outbound balances factor in liquidity already reserved for previously found paths. + channels.sort_unstable_by(|chan_a, chan_b| { + let chan_a_outbound_limit_msat = chan_a.next_outbound_htlc_limit_msat + .saturating_sub(*used_channel_liquidities.get(&(chan_a.get_outbound_payment_scid().unwrap(), + our_node_pubkey < &chan_a.counterparty.node_id)).unwrap_or(&0)); + let chan_b_outbound_limit_msat = chan_b.next_outbound_htlc_limit_msat + .saturating_sub(*used_channel_liquidities.get(&(chan_b.get_outbound_payment_scid().unwrap(), + our_node_pubkey < &chan_b.counterparty.node_id)).unwrap_or(&0)); + if chan_b_outbound_limit_msat < recommended_value_msat || chan_a_outbound_limit_msat < recommended_value_msat { + // Sort in descending order + chan_b_outbound_limit_msat.cmp(&chan_a_outbound_limit_msat) + } else { + // Sort in ascending order + chan_a_outbound_limit_msat.cmp(&chan_b_outbound_limit_msat) + } + }); +} + /// Finds a route from us (payer) to the given target node (payee). /// /// If the payee provided features in their invoice, they should be provided via `params.payee`. @@ -1219,12 +1307,12 @@ impl fmt::Display for LoggedPayeePubkey { pub fn find_route( our_node_pubkey: &PublicKey, route_params: &RouteParameters, network_graph: &NetworkGraph, first_hops: Option<&[&ChannelDetails]>, logger: L, - scorer: &S, random_seed_bytes: &[u8; 32] + scorer: &S, score_params: &S::ScoreParams, random_seed_bytes: &[u8; 32] ) -> Result where L::Target: Logger, GL::Target: Logger { let graph_lock = network_graph.read_only(); let mut route = get_route(our_node_pubkey, &route_params.payment_params, &graph_lock, first_hops, - route_params.final_value_msat, logger, scorer, + route_params.final_value_msat, logger, scorer, score_params, random_seed_bytes)?; add_random_cltv_offset(&mut route, &route_params.payment_params, &graph_lock, random_seed_bytes); Ok(route) @@ -1232,7 +1320,7 @@ where L::Target: Logger, GL::Target: Logger { pub(crate) fn get_route( our_node_pubkey: &PublicKey, payment_params: &PaymentParameters, network_graph: &ReadOnlyNetworkGraph, - first_hops: Option<&[&ChannelDetails]>, final_value_msat: u64, logger: L, scorer: &S, + first_hops: Option<&[&ChannelDetails]>, final_value_msat: u64, logger: L, scorer: &S, score_params: &S::ScoreParams, _random_seed_bytes: &[u8; 32] ) -> Result where L::Target: Logger { @@ -1240,7 +1328,7 @@ where L::Target: Logger { // unblinded payee id as an option. We also need a non-optional "payee id" for path construction, // so use a dummy id for this in the blinded case. let payee_node_id_opt = payment_params.payee.node_id().map(|pk| NodeId::from_pubkey(&pk)); - const DUMMY_BLINDED_PAYEE_ID: [u8; 33] = [42u8; 33]; + const DUMMY_BLINDED_PAYEE_ID: [u8; 33] = [2; 33]; let maybe_dummy_payee_pk = payment_params.payee.node_id().unwrap_or_else(|| PublicKey::from_slice(&DUMMY_BLINDED_PAYEE_ID).unwrap()); let maybe_dummy_payee_node_id = NodeId::from_pubkey(&maybe_dummy_payee_pk); let our_node_id = NodeId::from_pubkey(&our_node_pubkey); @@ -1432,26 +1520,8 @@ where L::Target: Logger { let mut already_collected_value_msat = 0; for (_, channels) in first_hop_targets.iter_mut() { - // Sort the first_hops channels to the same node(s) in priority order of which channel we'd - // most like to use. - // - // First, if channels are below `recommended_value_msat`, sort them in descending order, - // preferring larger channels to avoid splitting the payment into more MPP parts than is - // required. - // - // Second, because simply always sorting in descending order would always use our largest - // available outbound capacity, needlessly fragmenting our available channel capacities, - // sort channels above `recommended_value_msat` in ascending order, preferring channels - // which have enough, but not too much, capacity for the payment. - channels.sort_unstable_by(|chan_a, chan_b| { - if chan_b.next_outbound_htlc_limit_msat < recommended_value_msat || chan_a.next_outbound_htlc_limit_msat < recommended_value_msat { - // Sort in descending order - chan_b.next_outbound_htlc_limit_msat.cmp(&chan_a.next_outbound_htlc_limit_msat) - } else { - // Sort in ascending order - chan_a.next_outbound_htlc_limit_msat.cmp(&chan_b.next_outbound_htlc_limit_msat) - } - }); + sort_first_hop_channels(channels, &used_channel_liquidities, recommended_value_msat, + our_node_pubkey); } log_trace!(logger, "Building path from {} to payer {} for value {} msat.", @@ -1465,8 +1535,9 @@ where L::Target: Logger { ( $candidate: expr, $src_node_id: expr, $dest_node_id: expr, $next_hops_fee_msat: expr, $next_hops_value_contribution: expr, $next_hops_path_htlc_minimum_msat: expr, $next_hops_path_penalty_msat: expr, $next_hops_cltv_delta: expr, $next_hops_path_length: expr ) => { { - // We "return" whether we updated the path at the end, via this: - let mut did_add_update_path_to_src_node = false; + // We "return" whether we updated the path at the end, and how much we can route via + // this channel, via this: + let mut did_add_update_path_to_src_node = None; // Channels to self should not be used. This is more of belt-and-suspenders, because in // practice these cases should be caught earlier: // - for regular channels at channel announcement (TODO) @@ -1565,14 +1636,14 @@ where L::Target: Logger { path_htlc_minimum_msat, path_penalty_msat: u64::max_value(), was_processed: false, - #[cfg(all(not(feature = "_bench_unstable"), any(test, fuzzing)))] + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] value_contribution_msat, } }); #[allow(unused_mut)] // We only use the mut in cfg(test) let mut should_process = !old_entry.was_processed; - #[cfg(all(not(feature = "_bench_unstable"), any(test, fuzzing)))] + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] { // In test/fuzzing builds, we do extra checks to make sure the skipping // of already-seen nodes only happens in cases we expect (see below). @@ -1598,7 +1669,7 @@ where L::Target: Logger { effective_capacity, }; let channel_penalty_msat = scorer.channel_penalty_msat( - short_channel_id, &$src_node_id, &$dest_node_id, channel_usage + short_channel_id, &$src_node_id, &$dest_node_id, channel_usage, score_params ); let path_penalty_msat = $next_hops_path_penalty_msat .saturating_add(channel_penalty_msat); @@ -1643,13 +1714,13 @@ where L::Target: Logger { old_entry.fee_msat = 0; // This value will be later filled with hop_use_fee_msat of the following channel old_entry.path_htlc_minimum_msat = path_htlc_minimum_msat; old_entry.path_penalty_msat = path_penalty_msat; - #[cfg(all(not(feature = "_bench_unstable"), any(test, fuzzing)))] + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] { old_entry.value_contribution_msat = value_contribution_msat; } - did_add_update_path_to_src_node = true; + did_add_update_path_to_src_node = Some(value_contribution_msat); } else if old_entry.was_processed && new_cost < old_cost { - #[cfg(all(not(feature = "_bench_unstable"), any(test, fuzzing)))] + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] { // If we're skipping processing a node which was previously // processed even though we found another path to it with a @@ -1768,7 +1839,7 @@ where L::Target: Logger { for details in first_channels { let candidate = CandidateRouteHop::FirstHop { details }; let added = add_entry!(candidate, our_node_id, payee, 0, path_value_msat, - 0, 0u64, 0, 0); + 0, 0u64, 0, 0).is_some(); log_trace!(logger, "{} direct route to payee via SCID {}", if added { "Added" } else { "Skipped" }, candidate.short_channel_id()); } @@ -1815,6 +1886,7 @@ where L::Target: Logger { let mut aggregate_next_hops_path_penalty_msat: u64 = 0; let mut aggregate_next_hops_cltv_delta: u32 = 0; let mut aggregate_next_hops_path_length: u8 = 0; + let mut aggregate_path_contribution_msat = path_value_msat; for (idx, (hop, prev_hop_id)) in hop_iter.zip(prev_hop_iter).enumerate() { let source = NodeId::from_pubkey(&hop.src_node_id); @@ -1828,10 +1900,13 @@ where L::Target: Logger { }) .unwrap_or_else(|| CandidateRouteHop::PrivateHop { hint: hop }); - if !add_entry!(candidate, source, target, aggregate_next_hops_fee_msat, - path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, - aggregate_next_hops_path_penalty_msat, - aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length) { + if let Some(hop_used_msat) = add_entry!(candidate, source, target, + aggregate_next_hops_fee_msat, aggregate_path_contribution_msat, + aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, + aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length) + { + aggregate_path_contribution_msat = hop_used_msat; + } else { // If this hop was not used then there is no use checking the preceding // hops in the RouteHint. We can break by just searching for a direct // channel between last checked hop and first_hop_targets. @@ -1846,7 +1921,7 @@ where L::Target: Logger { effective_capacity: candidate.effective_capacity(), }; let channel_penalty_msat = scorer.channel_penalty_msat( - hop.short_channel_id, &source, &target, channel_usage + hop.short_channel_id, &source, &target, channel_usage, score_params ); aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat .saturating_add(channel_penalty_msat); @@ -1858,14 +1933,15 @@ where L::Target: Logger { .saturating_add(1); // Searching for a direct channel between last checked hop and first_hop_targets - if let Some(first_channels) = first_hop_targets.get(&NodeId::from_pubkey(&prev_hop_id)) { + if let Some(first_channels) = first_hop_targets.get_mut(&NodeId::from_pubkey(&prev_hop_id)) { + sort_first_hop_channels(first_channels, &used_channel_liquidities, + recommended_value_msat, our_node_pubkey); for details in first_channels { - let candidate = CandidateRouteHop::FirstHop { details }; - add_entry!(candidate, our_node_id, NodeId::from_pubkey(&prev_hop_id), - aggregate_next_hops_fee_msat, path_value_msat, - aggregate_next_hops_path_htlc_minimum_msat, - aggregate_next_hops_path_penalty_msat, aggregate_next_hops_cltv_delta, - aggregate_next_hops_path_length); + let first_hop_candidate = CandidateRouteHop::FirstHop { details }; + add_entry!(first_hop_candidate, our_node_id, NodeId::from_pubkey(&prev_hop_id), + aggregate_next_hops_fee_msat, aggregate_path_contribution_msat, + aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, + aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length); } } @@ -1898,12 +1974,15 @@ where L::Target: Logger { // Note that we *must* check if the last hop was added as `add_entry` // always assumes that the third argument is a node to which we have a // path. - if let Some(first_channels) = first_hop_targets.get(&NodeId::from_pubkey(&hop.src_node_id)) { + if let Some(first_channels) = first_hop_targets.get_mut(&NodeId::from_pubkey(&hop.src_node_id)) { + sort_first_hop_channels(first_channels, &used_channel_liquidities, + recommended_value_msat, our_node_pubkey); for details in first_channels { - let candidate = CandidateRouteHop::FirstHop { details }; - add_entry!(candidate, our_node_id, + let first_hop_candidate = CandidateRouteHop::FirstHop { details }; + add_entry!(first_hop_candidate, our_node_id, NodeId::from_pubkey(&hop.src_node_id), - aggregate_next_hops_fee_msat, path_value_msat, + aggregate_next_hops_fee_msat, + aggregate_path_contribution_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, aggregate_next_hops_cltv_delta, @@ -2343,8 +2422,9 @@ fn build_route_from_hops_internal( } impl Score for HopScorer { + type ScoreParams = (); fn channel_penalty_msat(&self, _short_channel_id: u64, source: &NodeId, target: &NodeId, - _usage: ChannelUsage) -> u64 + _usage: ChannelUsage, _score_params: &Self::ScoreParams) -> u64 { let mut cur_id = self.our_node_id; for i in 0..self.hop_ids.len() { @@ -2389,7 +2469,7 @@ fn build_route_from_hops_internal( let scorer = HopScorer { our_node_id, hop_ids }; get_route(our_node_pubkey, payment_params, network_graph, None, final_value_msat, - logger, &scorer, random_seed_bytes) + logger, &scorer, &(), random_seed_bytes) } #[cfg(test)] @@ -2400,7 +2480,7 @@ mod tests { use crate::routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features, BlindedTail, InFlightHtlcs, Path, PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees, DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE}; - use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, Score, ProbabilisticScorer, ProbabilisticScoringParameters}; + use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, Score, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters}; use crate::routing::test_utils::{add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel}; use crate::chain::transaction::OutPoint; use crate::sign::EntropySource; @@ -2454,6 +2534,7 @@ mod tests { balance_msat: 0, outbound_capacity_msat, next_outbound_htlc_limit_msat: outbound_capacity_msat, + next_outbound_htlc_minimum_msat: 0, inbound_capacity_msat: 42, unspendable_punishment_reserve: None, confirmations_required: None, @@ -2479,11 +2560,11 @@ mod tests { // Simple route to 2 via 1 - if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 0, Arc::clone(&logger), &scorer, &random_seed_bytes) { + if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 0, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Cannot send a payment of 0 msat"); } else { panic!(); } - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 2); assert_eq!(route.paths[0].hops[0].pubkey, nodes[1]); @@ -2515,11 +2596,11 @@ mod tests { let our_chans = vec![get_channel_details(Some(2), our_id, InitFeatures::from_le_bytes(vec![0b11]), 100000)]; if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = - get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &random_seed_bytes) { + get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "First hop cannot have our_node_pubkey as a destination."); } else { panic!(); } - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 2); } @@ -2627,7 +2708,7 @@ mod tests { }); // Not possible to send 199_999_999, because the minimum on channel=2 is 200_000_000. - if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 199_999_999, Arc::clone(&logger), &scorer, &random_seed_bytes) { + if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 199_999_999, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a path to the given destination"); } else { panic!(); } @@ -2646,7 +2727,7 @@ mod tests { }); // A payment above the minimum should pass - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 199_999_999, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 199_999_999, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 2); } @@ -2728,7 +2809,7 @@ mod tests { excess_data: Vec::new() }); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 60_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 60_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); // Overpay fees to hit htlc_minimum_msat. let overpaid_fees = route.paths[0].hops[0].fee_msat + route.paths[1].hops[0].fee_msat; // TODO: this could be better balanced to overpay 10k and not 15k. @@ -2773,14 +2854,14 @@ mod tests { excess_data: Vec::new() }); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 60_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 60_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); // Fine to overpay for htlc_minimum_msat if it allows us to save fee. assert_eq!(route.paths.len(), 1); assert_eq!(route.paths[0].hops[0].short_channel_id, 12); let fees = route.paths[0].hops[0].fee_msat; assert_eq!(fees, 5_000); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 50_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 50_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); // Not fine to overpay for htlc_minimum_msat if it requires paying more than fee on // the other channel. assert_eq!(route.paths.len(), 1); @@ -2825,13 +2906,13 @@ mod tests { }); // If all the channels require some features we don't understand, route should fail - if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes) { + if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a path to the given destination"); } else { panic!(); } // If we specify a channel to node7, that overrides our local channel view and that gets used let our_chans = vec![get_channel_details(Some(42), nodes[7].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)]; - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 2); assert_eq!(route.paths[0].hops[0].pubkey, nodes[7]); @@ -2866,13 +2947,13 @@ mod tests { add_or_update_node(&gossip_sync, &secp_ctx, &privkeys[7], unknown_features.clone(), 1); // If all nodes require some features we don't understand, route should fail - if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes) { + if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a path to the given destination"); } else { panic!(); } // If we specify a channel to node7, that overrides our local channel view and that gets used let our_chans = vec![get_channel_details(Some(42), nodes[7].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)]; - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 2); assert_eq!(route.paths[0].hops[0].pubkey, nodes[7]); @@ -2904,7 +2985,7 @@ mod tests { // Route to 1 via 2 and 3 because our channel to 1 is disabled let payment_params = PaymentParameters::from_node_id(nodes[0], 42); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 3); assert_eq!(route.paths[0].hops[0].pubkey, nodes[1]); @@ -2931,7 +3012,7 @@ mod tests { // If we specify a channel to node7, that overrides our local channel view and that gets used let payment_params = PaymentParameters::from_node_id(nodes[2], 42); let our_chans = vec![get_channel_details(Some(42), nodes[7].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)]; - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 2); assert_eq!(route.paths[0].hops[0].pubkey, nodes[7]); @@ -3054,13 +3135,13 @@ mod tests { invalid_last_hops.push(invalid_last_hop); { let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(invalid_last_hops).unwrap(); - if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes) { + if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Route hint cannot have the payee as the source."); } else { panic!(); } } let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(last_hops_multi_private_channels(&nodes)).unwrap(); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 5); assert_eq!(route.paths[0].hops[0].pubkey, nodes[1]); @@ -3136,7 +3217,7 @@ mod tests { // Test handling of an empty RouteHint passed in Invoice. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 5); assert_eq!(route.paths[0].hops[0].pubkey, nodes[1]); @@ -3242,7 +3323,7 @@ mod tests { excess_data: Vec::new() }); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 4); assert_eq!(route.paths[0].hops[0].pubkey, nodes[1]); @@ -3314,7 +3395,7 @@ mod tests { excess_data: Vec::new() }); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &[42u8; 32]).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &[42u8; 32]).unwrap(); assert_eq!(route.paths[0].hops.len(), 4); assert_eq!(route.paths[0].hops[0].pubkey, nodes[1]); @@ -3396,7 +3477,7 @@ mod tests { // This test shows that public routes can be present in the invoice // which would be handled in the same manner. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 5); assert_eq!(route.paths[0].hops[0].pubkey, nodes[1]); @@ -3449,7 +3530,7 @@ mod tests { let our_chans = vec![get_channel_details(Some(42), nodes[3].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)]; let mut last_hops = last_hops(&nodes); let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(last_hops.clone()).unwrap(); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 2); assert_eq!(route.paths[0].hops[0].pubkey, nodes[3]); @@ -3470,7 +3551,7 @@ mod tests { // Revert to via 6 as the fee on 8 goes up let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(last_hops).unwrap(); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 4); assert_eq!(route.paths[0].hops[0].pubkey, nodes[1]); @@ -3504,7 +3585,7 @@ mod tests { assert_eq!(route.paths[0].hops[3].channel_features.le_flags(), &Vec::::new()); // We can't learn any flags from invoices, sadly // ...but still use 8 for larger payments as 6 has a variable feerate - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 2000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 2000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths[0].hops.len(), 5); assert_eq!(route.paths[0].hops[0].pubkey, nodes[1]); @@ -3570,7 +3651,7 @@ mod tests { let logger = ln_test_utils::TestLogger::new(); let network_graph = NetworkGraph::new(Network::Testnet, &logger); let route = get_route(&source_node_id, &payment_params, &network_graph.read_only(), - Some(&our_chans.iter().collect::>()), route_val, &logger, &scorer, &random_seed_bytes); + Some(&our_chans.iter().collect::>()), route_val, &logger, &scorer, &(), &random_seed_bytes); route } @@ -3692,14 +3773,14 @@ mod tests { { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &payment_params, &network_graph.read_only(), None, 250_000_001, Arc::clone(&logger), &scorer, &random_seed_bytes) { + &our_id, &payment_params, &network_graph.read_only(), None, 250_000_001, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } { // Now, attempt to route an exact amount we have should be fine. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 250_000_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 250_000_000, Arc::clone(&logger), &scorer, &(),&random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); let path = route.paths.last().unwrap(); assert_eq!(path.hops.len(), 2); @@ -3728,14 +3809,14 @@ mod tests { { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 200_000_001, Arc::clone(&logger), &scorer, &random_seed_bytes) { + &our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 200_000_001, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } { // Now, attempt to route an exact amount we have should be fine. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 200_000_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&our_chans.iter().collect::>()), 200_000_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); let path = route.paths.last().unwrap(); assert_eq!(path.hops.len(), 2); @@ -3775,14 +3856,14 @@ mod tests { { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &payment_params, &network_graph.read_only(), None, 15_001, Arc::clone(&logger), &scorer, &random_seed_bytes) { + &our_id, &payment_params, &network_graph.read_only(), None, 15_001, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } { // Now, attempt to route an exact amount we have should be fine. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 15_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 15_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); let path = route.paths.last().unwrap(); assert_eq!(path.hops.len(), 2); @@ -3846,14 +3927,14 @@ mod tests { { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &payment_params, &network_graph.read_only(), None, 15_001, Arc::clone(&logger), &scorer, &random_seed_bytes) { + &our_id, &payment_params, &network_graph.read_only(), None, 15_001, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } { // Now, attempt to route an exact amount we have should be fine. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 15_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 15_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); let path = route.paths.last().unwrap(); assert_eq!(path.hops.len(), 2); @@ -3878,14 +3959,14 @@ mod tests { { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &payment_params, &network_graph.read_only(), None, 10_001, Arc::clone(&logger), &scorer, &random_seed_bytes) { + &our_id, &payment_params, &network_graph.read_only(), None, 10_001, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } { // Now, attempt to route an exact amount we have should be fine. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 10_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 10_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); let path = route.paths.last().unwrap(); assert_eq!(path.hops.len(), 2); @@ -3990,14 +4071,14 @@ mod tests { { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &payment_params, &network_graph.read_only(), None, 60_000, Arc::clone(&logger), &scorer, &random_seed_bytes) { + &our_id, &payment_params, &network_graph.read_only(), None, 60_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } { // Now, attempt to route 49 sats (just a bit below the capacity). - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 49_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 49_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); let mut total_amount_paid_msat = 0; for path in &route.paths { @@ -4010,7 +4091,7 @@ mod tests { { // Attempt to route an exact amount is also fine - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 50_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 50_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); let mut total_amount_paid_msat = 0; for path in &route.paths { @@ -4058,7 +4139,7 @@ mod tests { }); { - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 50_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 50_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); let mut total_amount_paid_msat = 0; for path in &route.paths { @@ -4173,7 +4254,7 @@ mod tests { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( &our_id, &payment_params, &network_graph.read_only(), None, 300_000, - Arc::clone(&logger), &scorer, &random_seed_bytes) { + Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } @@ -4183,7 +4264,7 @@ mod tests { let zero_payment_params = payment_params.clone().with_max_path_count(0); if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( &our_id, &zero_payment_params, &network_graph.read_only(), None, 100, - Arc::clone(&logger), &scorer, &random_seed_bytes) { + Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Can't find a route with no paths allowed."); } else { panic!(); } } @@ -4195,7 +4276,7 @@ mod tests { let fail_payment_params = payment_params.clone().with_max_path_count(3); if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( &our_id, &fail_payment_params, &network_graph.read_only(), None, 250_000, - Arc::clone(&logger), &scorer, &random_seed_bytes) { + Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } @@ -4204,7 +4285,7 @@ mod tests { // Now, attempt to route 250 sats (just a bit below the capacity). // Our algorithm should provide us with these 3 paths. let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, - 250_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + 250_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 3); let mut total_amount_paid_msat = 0; for path in &route.paths { @@ -4218,7 +4299,7 @@ mod tests { { // Attempt to route an exact amount is also fine let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, - 290_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + 290_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 3); let mut total_amount_paid_msat = 0; for path in &route.paths { @@ -4374,7 +4455,7 @@ mod tests { { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &payment_params, &network_graph.read_only(), None, 350_000, Arc::clone(&logger), &scorer, &random_seed_bytes) { + &our_id, &payment_params, &network_graph.read_only(), None, 350_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } @@ -4382,7 +4463,7 @@ mod tests { { // Now, attempt to route 300 sats (exact amount we can route). // Our algorithm should provide us with these 3 paths, 100 sats each. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 300_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 300_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 3); let mut total_amount_paid_msat = 0; @@ -4543,7 +4624,7 @@ mod tests { { // Now, attempt to route 180 sats. // Our algorithm should provide us with these 2 paths. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 180_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 180_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 2); let mut total_value_transferred_msat = 0; @@ -4714,14 +4795,14 @@ mod tests { { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &payment_params, &network_graph.read_only(), None, 210_000, Arc::clone(&logger), &scorer, &random_seed_bytes) { + &our_id, &payment_params, &network_graph.read_only(), None, 210_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } { // Now, attempt to route 200 sats (exact amount we can route). - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 200_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 200_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 2); let mut total_amount_paid_msat = 0; @@ -4821,7 +4902,7 @@ mod tests { // Get a route for 100 sats and check that we found the MPP route no problem and didn't // overpay at all. - let mut route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let mut route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 2); route.paths.sort_by_key(|path| path.hops[0].short_channel_id); // Paths are manually ordered ordered by SCID, so: @@ -4940,7 +5021,7 @@ mod tests { { // Attempt to route more than available results in a failure. if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &payment_params, &network_graph.read_only(), None, 150_000, Arc::clone(&logger), &scorer, &random_seed_bytes) { + &our_id, &payment_params, &network_graph.read_only(), None, 150_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { assert_eq!(err, "Failed to find a sufficient route to the given destination"); } else { panic!(); } } @@ -4948,7 +5029,7 @@ mod tests { { // Now, attempt to route 125 sats (just a bit below the capacity of 3 channels). // Our algorithm should provide us with these 3 paths. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 125_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 125_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 3); let mut total_amount_paid_msat = 0; for path in &route.paths { @@ -4961,7 +5042,7 @@ mod tests { { // Attempt to route without the last small cheap channel - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 90_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 90_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 2); let mut total_amount_paid_msat = 0; for path in &route.paths { @@ -5100,7 +5181,7 @@ mod tests { { // Now ensure the route flows simply over nodes 1 and 4 to 6. - let route = get_route(&our_id, &payment_params, &network.read_only(), None, 10_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network.read_only(), None, 10_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); assert_eq!(route.paths[0].hops.len(), 3); @@ -5171,7 +5252,7 @@ mod tests { { // Now, attempt to route 90 sats, which is exactly 90 sats at the last hop, plus the // 200% fee charged channel 13 in the 1-to-2 direction. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 90_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 90_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); assert_eq!(route.paths[0].hops.len(), 2); @@ -5237,7 +5318,7 @@ mod tests { // Now, attempt to route 90 sats, hitting the htlc_minimum on channel 4, but // overshooting the htlc_maximum on channel 2. Thus, we should pick the (absurdly // expensive) channels 12-13 path. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 90_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 90_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); assert_eq!(route.paths[0].hops.len(), 2); @@ -5279,7 +5360,7 @@ mod tests { let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&[ &get_channel_details(Some(3), nodes[0], channelmanager::provided_init_features(&config), 200_000), &get_channel_details(Some(2), nodes[0], channelmanager::provided_init_features(&config), 10_000), - ]), 100_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + ]), 100_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); assert_eq!(route.paths[0].hops.len(), 1); @@ -5291,7 +5372,7 @@ mod tests { let route = get_route(&our_id, &payment_params, &network_graph.read_only(), Some(&[ &get_channel_details(Some(3), nodes[0], channelmanager::provided_init_features(&config), 50_000), &get_channel_details(Some(2), nodes[0], channelmanager::provided_init_features(&config), 50_000), - ]), 100_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + ]), 100_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 2); assert_eq!(route.paths[0].hops.len(), 1); assert_eq!(route.paths[1].hops.len(), 1); @@ -5323,7 +5404,7 @@ mod tests { &get_channel_details(Some(8), nodes[0], channelmanager::provided_init_features(&config), 50_000), &get_channel_details(Some(9), nodes[0], channelmanager::provided_init_features(&config), 50_000), &get_channel_details(Some(4), nodes[0], channelmanager::provided_init_features(&config), 1_000_000), - ]), 100_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + ]), 100_000, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); assert_eq!(route.paths[0].hops.len(), 1); @@ -5345,7 +5426,7 @@ mod tests { let random_seed_bytes = keys_manager.get_secure_random_bytes(); let route = get_route( &our_id, &payment_params, &network_graph.read_only(), None, 100, - Arc::clone(&logger), &scorer, &random_seed_bytes + Arc::clone(&logger), &scorer, &(), &random_seed_bytes ).unwrap(); let path = route.paths[0].hops.iter().map(|hop| hop.short_channel_id).collect::>(); @@ -5358,7 +5439,7 @@ mod tests { let scorer = FixedPenaltyScorer::with_penalty(100); let route = get_route( &our_id, &payment_params, &network_graph.read_only(), None, 100, - Arc::clone(&logger), &scorer, &random_seed_bytes + Arc::clone(&logger), &scorer, &(), &random_seed_bytes ).unwrap(); let path = route.paths[0].hops.iter().map(|hop| hop.short_channel_id).collect::>(); @@ -5376,7 +5457,8 @@ mod tests { fn write(&self, _w: &mut W) -> Result<(), crate::io::Error> { unimplemented!() } } impl Score for BadChannelScorer { - fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage) -> u64 { + type ScoreParams = (); + fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 { if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 } } @@ -5396,7 +5478,8 @@ mod tests { } impl Score for BadNodeScorer { - fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage) -> u64 { + type ScoreParams = (); + fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage, _score_params:&Self::ScoreParams) -> u64 { if *target == self.node_id { u64::max_value() } else { 0 } } @@ -5419,7 +5502,7 @@ mod tests { let random_seed_bytes = keys_manager.get_secure_random_bytes(); let route = get_route( &our_id, &payment_params, &network_graph, None, 100, - Arc::clone(&logger), &scorer, &random_seed_bytes + Arc::clone(&logger), &scorer, &(), &random_seed_bytes ).unwrap(); let path = route.paths[0].hops.iter().map(|hop| hop.short_channel_id).collect::>(); @@ -5431,7 +5514,7 @@ mod tests { let scorer = BadChannelScorer { short_channel_id: 6 }; let route = get_route( &our_id, &payment_params, &network_graph, None, 100, - Arc::clone(&logger), &scorer, &random_seed_bytes + Arc::clone(&logger), &scorer, &(), &random_seed_bytes ).unwrap(); let path = route.paths[0].hops.iter().map(|hop| hop.short_channel_id).collect::>(); @@ -5443,7 +5526,7 @@ mod tests { let scorer = BadNodeScorer { node_id: NodeId::from_pubkey(&nodes[2]) }; match get_route( &our_id, &payment_params, &network_graph, None, 100, - Arc::clone(&logger), &scorer, &random_seed_bytes + Arc::clone(&logger), &scorer, &(), &random_seed_bytes ) { Err(LightningError { err, .. } ) => { assert_eq!(err, "Failed to find a path to the given destination"); @@ -5537,7 +5620,7 @@ mod tests { .with_max_total_cltv_expiry_delta(feasible_max_total_cltv_delta); let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet); let random_seed_bytes = keys_manager.get_secure_random_bytes(); - let route = get_route(&our_id, &feasible_payment_params, &network_graph, None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &feasible_payment_params, &network_graph, None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); let path = route.paths[0].hops.iter().map(|hop| hop.short_channel_id).collect::>(); assert_ne!(path.len(), 0); @@ -5545,7 +5628,7 @@ mod tests { let fail_max_total_cltv_delta = 23; let fail_payment_params = PaymentParameters::from_node_id(nodes[6], 0).with_route_hints(last_hops(&nodes)).unwrap() .with_max_total_cltv_expiry_delta(fail_max_total_cltv_delta); - match get_route(&our_id, &fail_payment_params, &network_graph, None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes) + match get_route(&our_id, &fail_payment_params, &network_graph, None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { Err(LightningError { err, .. } ) => { assert_eq!(err, "Failed to find a path to the given destination"); @@ -5570,9 +5653,9 @@ mod tests { // We should be able to find a route initially, and then after we fail a few random // channels eventually we won't be able to any longer. - assert!(get_route(&our_id, &payment_params, &network_graph, None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).is_ok()); + assert!(get_route(&our_id, &payment_params, &network_graph, None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).is_ok()); loop { - if let Ok(route) = get_route(&our_id, &payment_params, &network_graph, None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes) { + if let Ok(route) = get_route(&our_id, &payment_params, &network_graph, None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { for chan in route.paths[0].hops.iter() { assert!(!payment_params.previously_failed_channels.contains(&chan.short_channel_id)); } @@ -5596,14 +5679,14 @@ mod tests { // First check we can actually create a long route on this graph. let feasible_payment_params = PaymentParameters::from_node_id(nodes[18], 0); let route = get_route(&our_id, &feasible_payment_params, &network_graph, None, 100, - Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); let path = route.paths[0].hops.iter().map(|hop| hop.short_channel_id).collect::>(); assert!(path.len() == MAX_PATH_LENGTH_ESTIMATE.into()); // But we can't create a path surpassing the MAX_PATH_LENGTH_ESTIMATE limit. let fail_payment_params = PaymentParameters::from_node_id(nodes[19], 0); match get_route(&our_id, &fail_payment_params, &network_graph, None, 100, - Arc::clone(&logger), &scorer, &random_seed_bytes) + Arc::clone(&logger), &scorer, &(), &random_seed_bytes) { Err(LightningError { err, .. } ) => { assert_eq!(err, "Failed to find a path to the given destination"); @@ -5622,7 +5705,7 @@ mod tests { let payment_params = PaymentParameters::from_node_id(nodes[6], 42).with_route_hints(last_hops(&nodes)).unwrap(); let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet); let random_seed_bytes = keys_manager.get_secure_random_bytes(); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 1); let cltv_expiry_deltas_before = route.paths[0].hops.iter().map(|h| h.cltv_expiry_delta).collect::>(); @@ -5657,7 +5740,7 @@ mod tests { let random_seed_bytes = keys_manager.get_secure_random_bytes(); let mut route = get_route(&our_id, &payment_params, &network_graph, None, 100, - Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); add_random_cltv_offset(&mut route, &payment_params, &network_graph, &random_seed_bytes); let mut path_plausibility = vec![]; @@ -5734,8 +5817,8 @@ mod tests { fn avoids_saturating_channels() { let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph(); let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - - let scorer = ProbabilisticScorer::new(Default::default(), &*network_graph, Arc::clone(&logger)); + let decay_params = ProbabilisticScoringDecayParameters::default(); + let scorer = ProbabilisticScorer::new(decay_params, &*network_graph, Arc::clone(&logger)); // Set the fee on channel 13 to 100% to match channel 4 giving us two equivalent paths (us // -> node 7 -> node2 and us -> node 1 -> node 2) which we should balance over. @@ -5769,7 +5852,7 @@ mod tests { let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet); let random_seed_bytes = keys_manager.get_secure_random_bytes(); // 100,000 sats is less than the available liquidity on each channel, set above. - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100_000_000, Arc::clone(&logger), &scorer, &random_seed_bytes).unwrap(); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100_000_000, Arc::clone(&logger), &scorer, &ProbabilisticScoringFeeParameters::default(), &random_seed_bytes).unwrap(); assert_eq!(route.paths.len(), 2); assert!((route.paths[0].hops[1].short_channel_id == 4 && route.paths[1].hops[1].short_channel_id == 13) || (route.paths[1].hops[1].short_channel_id == 4 && route.paths[0].hops[1].short_channel_id == 13)); @@ -5783,82 +5866,68 @@ mod tests { println!("Using seed of {}", seed); seed } - #[cfg(not(feature = "no-std"))] - use crate::util::ser::ReadableArgs; #[test] #[cfg(not(feature = "no-std"))] fn generate_routes() { - use crate::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringParameters}; + use crate::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringFeeParameters}; - let mut d = match super::bench_utils::get_route_file() { + let logger = ln_test_utils::TestLogger::new(); + let graph = match super::bench_utils::read_network_graph(&logger) { Ok(f) => f, Err(e) => { eprintln!("{}", e); return; }, }; - let logger = ln_test_utils::TestLogger::new(); - let graph = NetworkGraph::read(&mut d, &logger).unwrap(); - let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet); - let random_seed_bytes = keys_manager.get_secure_random_bytes(); - // First, get 100 (source, destination) pairs for which route-getting actually succeeds... - let mut seed = random_init_seed() as usize; - let nodes = graph.read_only().nodes().clone(); - 'load_endpoints: for _ in 0..10 { - loop { - seed = seed.overflowing_mul(0xdeadbeef).0; - let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); - seed = seed.overflowing_mul(0xdeadbeef).0; - let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); - let payment_params = PaymentParameters::from_node_id(dst, 42); - let amt = seed as u64 % 200_000_000; - let params = ProbabilisticScoringParameters::default(); - let scorer = ProbabilisticScorer::new(params, &graph, &logger); - if get_route(src, &payment_params, &graph.read_only(), None, amt, &logger, &scorer, &random_seed_bytes).is_ok() { - continue 'load_endpoints; - } - } - } + let params = ProbabilisticScoringFeeParameters::default(); + let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger); + let features = super::InvoiceFeatures::empty(); + + super::bench_utils::generate_test_routes(&graph, &mut scorer, ¶ms, features, random_init_seed(), 0, 2); } #[test] #[cfg(not(feature = "no-std"))] fn generate_routes_mpp() { - use crate::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringParameters}; + use crate::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringFeeParameters}; - let mut d = match super::bench_utils::get_route_file() { + let logger = ln_test_utils::TestLogger::new(); + let graph = match super::bench_utils::read_network_graph(&logger) { Ok(f) => f, Err(e) => { eprintln!("{}", e); return; }, }; + + let params = ProbabilisticScoringFeeParameters::default(); + let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger); + let features = channelmanager::provided_invoice_features(&UserConfig::default()); + + super::bench_utils::generate_test_routes(&graph, &mut scorer, ¶ms, features, random_init_seed(), 0, 2); + } + + #[test] + #[cfg(not(feature = "no-std"))] + fn generate_large_mpp_routes() { + use crate::routing::scoring::{ProbabilisticScorer, ProbabilisticScoringFeeParameters}; + let logger = ln_test_utils::TestLogger::new(); - let graph = NetworkGraph::read(&mut d, &logger).unwrap(); - let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet); - let random_seed_bytes = keys_manager.get_secure_random_bytes(); - let config = UserConfig::default(); + let graph = match super::bench_utils::read_network_graph(&logger) { + Ok(f) => f, + Err(e) => { + eprintln!("{}", e); + return; + }, + }; - // First, get 100 (source, destination) pairs for which route-getting actually succeeds... - let mut seed = random_init_seed() as usize; - let nodes = graph.read_only().nodes().clone(); - 'load_endpoints: for _ in 0..10 { - loop { - seed = seed.overflowing_mul(0xdeadbeef).0; - let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); - seed = seed.overflowing_mul(0xdeadbeef).0; - let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); - let payment_params = PaymentParameters::from_node_id(dst, 42).with_bolt11_features(channelmanager::provided_invoice_features(&config)).unwrap(); - let amt = seed as u64 % 200_000_000; - let params = ProbabilisticScoringParameters::default(); - let scorer = ProbabilisticScorer::new(params, &graph, &logger); - if get_route(src, &payment_params, &graph.read_only(), None, amt, &logger, &scorer, &random_seed_bytes).is_ok() { - continue 'load_endpoints; - } - } - } + let params = ProbabilisticScoringFeeParameters::default(); + let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &graph, &logger); + let features = channelmanager::provided_invoice_features(&UserConfig::default()); + + super::bench_utils::generate_test_routes(&graph, &mut scorer, ¶ms, features, random_init_seed(), 1_000_000, 2); } #[test] @@ -5869,8 +5938,8 @@ mod tests { let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet); let random_seed_bytes = keys_manager.get_secure_random_bytes(); - let scorer_params = ProbabilisticScoringParameters::default(); - let mut scorer = ProbabilisticScorer::new(scorer_params, Arc::clone(&network_graph), Arc::clone(&logger)); + let mut scorer_params = ProbabilisticScoringFeeParameters::default(); + let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), Arc::clone(&network_graph), Arc::clone(&logger)); // First check set manual penalties are returned by the scorer. let usage = ChannelUsage { @@ -5878,26 +5947,147 @@ mod tests { inflight_htlc_msat: 0, effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_024_000, htlc_maximum_msat: 1_000 }, }; - scorer.set_manual_penalty(&NodeId::from_pubkey(&nodes[3]), 123); - scorer.set_manual_penalty(&NodeId::from_pubkey(&nodes[4]), 456); - assert_eq!(scorer.channel_penalty_msat(42, &NodeId::from_pubkey(&nodes[3]), &NodeId::from_pubkey(&nodes[4]), usage), 456); + scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[3]), 123); + scorer_params.set_manual_penalty(&NodeId::from_pubkey(&nodes[4]), 456); + assert_eq!(scorer.channel_penalty_msat(42, &NodeId::from_pubkey(&nodes[3]), &NodeId::from_pubkey(&nodes[4]), usage, &scorer_params), 456); // Then check we can get a normal route let payment_params = PaymentParameters::from_node_id(nodes[10], 42); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &scorer_params,&random_seed_bytes); assert!(route.is_ok()); // Then check that we can't get a route if we ban an intermediate node. - scorer.add_banned(&NodeId::from_pubkey(&nodes[3])); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes); + scorer_params.add_banned(&NodeId::from_pubkey(&nodes[3])); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &scorer_params,&random_seed_bytes); assert!(route.is_err()); // Finally make sure we can route again, when we remove the ban. - scorer.remove_banned(&NodeId::from_pubkey(&nodes[3])); - let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &random_seed_bytes); + scorer_params.remove_banned(&NodeId::from_pubkey(&nodes[3])); + let route = get_route(&our_id, &payment_params, &network_graph.read_only(), None, 100, Arc::clone(&logger), &scorer, &scorer_params,&random_seed_bytes); assert!(route.is_ok()); } + #[test] + fn abide_by_route_hint_max_htlc() { + // Check that we abide by any htlc_maximum_msat provided in the route hints of the payment + // params in the final route. + let (secp_ctx, network_graph, _, _, logger) = build_graph(); + let netgraph = network_graph.read_only(); + let (_, our_id, _, nodes) = get_nodes(&secp_ctx); + let scorer = ln_test_utils::TestScorer::new(); + let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet); + let random_seed_bytes = keys_manager.get_secure_random_bytes(); + let config = UserConfig::default(); + + let max_htlc_msat = 50_000; + let route_hint_1 = RouteHint(vec![RouteHintHop { + src_node_id: nodes[2], + short_channel_id: 42, + fees: RoutingFees { + base_msat: 100, + proportional_millionths: 0, + }, + cltv_expiry_delta: 10, + htlc_minimum_msat: None, + htlc_maximum_msat: Some(max_htlc_msat), + }]); + let dest_node_id = ln_test_utils::pubkey(42); + let payment_params = PaymentParameters::from_node_id(dest_node_id, 42) + .with_route_hints(vec![route_hint_1.clone()]).unwrap() + .with_bolt11_features(channelmanager::provided_invoice_features(&config)).unwrap(); + + // Make sure we'll error if our route hints don't have enough liquidity according to their + // htlc_maximum_msat. + if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, + &payment_params, &netgraph, None, max_htlc_msat + 1, Arc::clone(&logger), &scorer, &(), + &random_seed_bytes) + { + assert_eq!(err, "Failed to find a sufficient route to the given destination"); + } else { panic!(); } + + // Make sure we'll split an MPP payment across route hints if their htlc_maximum_msat warrants. + let mut route_hint_2 = route_hint_1.clone(); + route_hint_2.0[0].short_channel_id = 43; + let payment_params = PaymentParameters::from_node_id(dest_node_id, 42) + .with_route_hints(vec![route_hint_1, route_hint_2]).unwrap() + .with_bolt11_features(channelmanager::provided_invoice_features(&config)).unwrap(); + let route = get_route(&our_id, &payment_params, &netgraph, None, max_htlc_msat + 1, + Arc::clone(&logger), &scorer, &(), &random_seed_bytes).unwrap(); + assert_eq!(route.paths.len(), 2); + assert!(route.paths[0].hops.last().unwrap().fee_msat <= max_htlc_msat); + assert!(route.paths[1].hops.last().unwrap().fee_msat <= max_htlc_msat); + } + + #[test] + fn direct_channel_to_hints_with_max_htlc() { + // Check that if we have a first hop channel peer that's connected to multiple provided route + // hints, that we properly split the payment between the route hints if needed. + let logger = Arc::new(ln_test_utils::TestLogger::new()); + let network_graph = Arc::new(NetworkGraph::new(Network::Testnet, Arc::clone(&logger))); + let scorer = ln_test_utils::TestScorer::new(); + let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet); + let random_seed_bytes = keys_manager.get_secure_random_bytes(); + let config = UserConfig::default(); + + let our_node_id = ln_test_utils::pubkey(42); + let intermed_node_id = ln_test_utils::pubkey(43); + let first_hop = vec![get_channel_details(Some(42), intermed_node_id, InitFeatures::from_le_bytes(vec![0b11]), 10_000_000)]; + + let amt_msat = 900_000; + let max_htlc_msat = 500_000; + let route_hint_1 = RouteHint(vec![RouteHintHop { + src_node_id: intermed_node_id, + short_channel_id: 44, + fees: RoutingFees { + base_msat: 100, + proportional_millionths: 0, + }, + cltv_expiry_delta: 10, + htlc_minimum_msat: None, + htlc_maximum_msat: Some(max_htlc_msat), + }, RouteHintHop { + src_node_id: intermed_node_id, + short_channel_id: 45, + fees: RoutingFees { + base_msat: 100, + proportional_millionths: 0, + }, + cltv_expiry_delta: 10, + htlc_minimum_msat: None, + // Check that later route hint max htlcs don't override earlier ones + htlc_maximum_msat: Some(max_htlc_msat - 50), + }]); + let mut route_hint_2 = route_hint_1.clone(); + route_hint_2.0[0].short_channel_id = 46; + route_hint_2.0[1].short_channel_id = 47; + let dest_node_id = ln_test_utils::pubkey(44); + let payment_params = PaymentParameters::from_node_id(dest_node_id, 42) + .with_route_hints(vec![route_hint_1, route_hint_2]).unwrap() + .with_bolt11_features(channelmanager::provided_invoice_features(&config)).unwrap(); + + let route = get_route(&our_node_id, &payment_params, &network_graph.read_only(), + Some(&first_hop.iter().collect::>()), amt_msat, Arc::clone(&logger), &scorer, &(), + &random_seed_bytes).unwrap(); + assert_eq!(route.paths.len(), 2); + assert!(route.paths[0].hops.last().unwrap().fee_msat <= max_htlc_msat); + assert!(route.paths[1].hops.last().unwrap().fee_msat <= max_htlc_msat); + assert_eq!(route.get_total_amount(), amt_msat); + + // Re-run but with two first hop channels connected to the same route hint peers that must be + // split between. + let first_hops = vec![ + get_channel_details(Some(42), intermed_node_id, InitFeatures::from_le_bytes(vec![0b11]), amt_msat - 10), + get_channel_details(Some(43), intermed_node_id, InitFeatures::from_le_bytes(vec![0b11]), amt_msat - 10), + ]; + let route = get_route(&our_node_id, &payment_params, &network_graph.read_only(), + Some(&first_hops.iter().collect::>()), amt_msat, Arc::clone(&logger), &scorer, &(), + &random_seed_bytes).unwrap(); + assert_eq!(route.paths.len(), 2); + assert!(route.paths[0].hops.last().unwrap().fee_msat <= max_htlc_msat); + assert!(route.paths[1].hops.last().unwrap().fee_msat <= max_htlc_msat); + assert_eq!(route.get_total_amount(), amt_msat); + } + #[test] fn blinded_route_ser() { let blinded_path_1 = BlindedPath { @@ -6044,9 +6234,23 @@ mod tests { } } -#[cfg(all(test, not(feature = "no-std")))] +#[cfg(all(any(test, ldk_bench), not(feature = "no-std")))] pub(crate) mod bench_utils { + use super::*; use std::fs::File; + + use bitcoin::hashes::Hash; + use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; + + use crate::chain::transaction::OutPoint; + use crate::sign::{EntropySource, KeysManager}; + use crate::ln::channelmanager::{self, ChannelCounterparty, ChannelDetails}; + use crate::ln::features::InvoiceFeatures; + use crate::routing::gossip::NetworkGraph; + use crate::util::config::UserConfig; + use crate::util::ser::ReadableArgs; + use crate::util::test_utils::TestLogger; + /// Tries to open a network graph file, or panics with a URL to fetch it. pub(crate) fn get_route_file() -> Result { let res = File::open("net_graph-2023-01-18.bin") // By default we're run in RL/lightning @@ -6060,7 +6264,18 @@ pub(crate) mod bench_utils { path.pop(); // target path.push("lightning"); path.push("net_graph-2023-01-18.bin"); - eprintln!("{}", path.to_str().unwrap()); + File::open(path) + }) + .or_else(|_| { // Fall back to guessing based on the binary location for a subcrate + // path is likely something like .../rust-lightning/bench/target/debug/deps/bench.. + let mut path = std::env::current_exe().unwrap(); + path.pop(); // bench... + path.pop(); // deps + path.pop(); // debug + path.pop(); // target + path.pop(); // bench + path.push("lightning"); + path.push("net_graph-2023-01-18.bin"); File::open(path) }) .map_err(|_| "Please fetch https://bitcoin.ninja/ldk-net_graph-v0.0.113-2023-01-18.bin and place it at lightning/net_graph-2023-01-18.bin"); @@ -6069,42 +6284,18 @@ pub(crate) mod bench_utils { #[cfg(not(require_route_graph_test))] return res; } -} - -#[cfg(all(test, feature = "_bench_unstable", not(feature = "no-std")))] -mod benches { - use super::*; - use bitcoin::hashes::Hash; - use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; - use crate::chain::transaction::OutPoint; - use crate::sign::{EntropySource, KeysManager}; - use crate::ln::channelmanager::{self, ChannelCounterparty, ChannelDetails}; - use crate::ln::features::InvoiceFeatures; - use crate::routing::gossip::NetworkGraph; - use crate::routing::scoring::{FixedPenaltyScorer, ProbabilisticScorer, ProbabilisticScoringParameters}; - use crate::util::config::UserConfig; - use crate::util::logger::{Logger, Record}; - use crate::util::ser::ReadableArgs; - - use test::Bencher; - - struct DummyLogger {} - impl Logger for DummyLogger { - fn log(&self, _record: &Record) {} - } - fn read_network_graph(logger: &DummyLogger) -> NetworkGraph<&DummyLogger> { - let mut d = bench_utils::get_route_file().unwrap(); - NetworkGraph::read(&mut d, logger).unwrap() + pub(crate) fn read_network_graph(logger: &TestLogger) -> Result, &'static str> { + get_route_file().map(|mut f| NetworkGraph::read(&mut f, logger).unwrap()) } - fn payer_pubkey() -> PublicKey { + pub(crate) fn payer_pubkey() -> PublicKey { let secp_ctx = Secp256k1::new(); PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap()) } #[inline] - fn first_hop(node_id: PublicKey) -> ChannelDetails { + pub(crate) fn first_hop(node_id: PublicKey) -> ChannelDetails { ChannelDetails { channel_id: [0; 32], counterparty: ChannelCounterparty { @@ -6122,11 +6313,12 @@ mod benches { short_channel_id: Some(1), inbound_scid_alias: None, outbound_scid_alias: None, - channel_value_satoshis: 10_000_000, + channel_value_satoshis: 10_000_000_000, user_channel_id: 0, - balance_msat: 10_000_000, - outbound_capacity_msat: 10_000_000, - next_outbound_htlc_limit_msat: 10_000_000, + balance_msat: 10_000_000_000, + outbound_capacity_msat: 10_000_000_000, + next_outbound_htlc_minimum_msat: 0, + next_outbound_htlc_limit_msat: 10_000_000_000, inbound_capacity_msat: 0, unspendable_punishment_reserve: None, confirmations_required: None, @@ -6143,100 +6335,167 @@ mod benches { } } - #[bench] - fn generate_routes_with_zero_penalty_scorer(bench: &mut Bencher) { - let logger = DummyLogger {}; - let network_graph = read_network_graph(&logger); - let scorer = FixedPenaltyScorer::with_penalty(0); - generate_routes(bench, &network_graph, scorer, InvoiceFeatures::empty()); - } - - #[bench] - fn generate_mpp_routes_with_zero_penalty_scorer(bench: &mut Bencher) { - let logger = DummyLogger {}; - let network_graph = read_network_graph(&logger); - let scorer = FixedPenaltyScorer::with_penalty(0); - generate_routes(bench, &network_graph, scorer, channelmanager::provided_invoice_features(&UserConfig::default())); - } - - #[bench] - fn generate_routes_with_probabilistic_scorer(bench: &mut Bencher) { - let logger = DummyLogger {}; - let network_graph = read_network_graph(&logger); - let params = ProbabilisticScoringParameters::default(); - let scorer = ProbabilisticScorer::new(params, &network_graph, &logger); - generate_routes(bench, &network_graph, scorer, InvoiceFeatures::empty()); - } - - #[bench] - fn generate_mpp_routes_with_probabilistic_scorer(bench: &mut Bencher) { - let logger = DummyLogger {}; - let network_graph = read_network_graph(&logger); - let params = ProbabilisticScoringParameters::default(); - let scorer = ProbabilisticScorer::new(params, &network_graph, &logger); - generate_routes(bench, &network_graph, scorer, channelmanager::provided_invoice_features(&UserConfig::default())); - } - - fn generate_routes( - bench: &mut Bencher, graph: &NetworkGraph<&DummyLogger>, mut scorer: S, - features: InvoiceFeatures - ) { - let nodes = graph.read_only().nodes().clone(); + pub(crate) fn generate_test_routes(graph: &NetworkGraph<&TestLogger>, scorer: &mut S, + score_params: &S::ScoreParams, features: InvoiceFeatures, mut seed: u64, + starting_amount: u64, route_count: usize, + ) -> Vec<(ChannelDetails, PaymentParameters, u64)> { let payer = payer_pubkey(); let keys_manager = KeysManager::new(&[0u8; 32], 42, 42); let random_seed_bytes = keys_manager.get_secure_random_bytes(); - // First, get 100 (source, destination) pairs for which route-getting actually succeeds... - let mut routes = Vec::new(); + let nodes = graph.read_only().nodes().clone(); let mut route_endpoints = Vec::new(); - let mut seed: usize = 0xdeadbeef; - 'load_endpoints: for _ in 0..150 { + // Fetch 1.5x more routes than we need as after we do some scorer updates we may end up + // with some routes we picked being un-routable. + for _ in 0..route_count * 3 / 2 { loop { - seed *= 0xdeadbeef; - let src = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); - seed *= 0xdeadbeef; - let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); - let params = PaymentParameters::from_node_id(dst, 42).with_bolt11_features(features.clone()).unwrap(); + seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0; + let src = PublicKey::from_slice(nodes.unordered_keys() + .skip((seed as usize) % nodes.len()).next().unwrap().as_slice()).unwrap(); + seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0; + let dst = PublicKey::from_slice(nodes.unordered_keys() + .skip((seed as usize) % nodes.len()).next().unwrap().as_slice()).unwrap(); + let params = PaymentParameters::from_node_id(dst, 42) + .with_bolt11_features(features.clone()).unwrap(); let first_hop = first_hop(src); - let amt = seed as u64 % 1_000_000; - if let Ok(route) = get_route(&payer, ¶ms, &graph.read_only(), Some(&[&first_hop]), amt, &DummyLogger{}, &scorer, &random_seed_bytes) { - routes.push(route); - route_endpoints.push((first_hop, params, amt)); - continue 'load_endpoints; - } - } - } + let amt = starting_amount + seed % 1_000_000; + let path_exists = + get_route(&payer, ¶ms, &graph.read_only(), Some(&[&first_hop]), + amt, &TestLogger::new(), &scorer, score_params, &random_seed_bytes).is_ok(); + if path_exists { + // ...and seed the scorer with success and failure data... + seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0; + let mut score_amt = seed % 1_000_000_000; + loop { + // Generate fail/success paths for a wider range of potential amounts with + // MPP enabled to give us a chance to apply penalties for more potential + // routes. + let mpp_features = channelmanager::provided_invoice_features(&UserConfig::default()); + let params = PaymentParameters::from_node_id(dst, 42) + .with_bolt11_features(mpp_features).unwrap(); + + let route_res = get_route(&payer, ¶ms, &graph.read_only(), + Some(&[&first_hop]), score_amt, &TestLogger::new(), &scorer, + score_params, &random_seed_bytes); + if let Ok(route) = route_res { + for path in route.paths { + if seed & 0x80 == 0 { + scorer.payment_path_successful(&path); + } else { + let short_channel_id = path.hops[path.hops.len() / 2].short_channel_id; + scorer.payment_path_failed(&path, short_channel_id); + } + seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0; + } + break; + } + // If we couldn't find a path with a higer amount, reduce and try again. + score_amt /= 100; + } - // ...and seed the scorer with success and failure data... - for route in routes { - let amount = route.get_total_amount(); - if amount < 250_000 { - for path in route.paths { - scorer.payment_path_successful(&path); - } - } else if amount > 750_000 { - for path in route.paths { - let short_channel_id = path.hops[path.hops.len() / 2].short_channel_id; - scorer.payment_path_failed(&path, short_channel_id); + route_endpoints.push((first_hop, params, amt)); + break; } } } - // Because we've changed channel scores, its possible we'll take different routes to the + // Because we've changed channel scores, it's possible we'll take different routes to the // selected destinations, possibly causing us to fail because, eg, the newly-selected path // requires a too-high CLTV delta. route_endpoints.retain(|(first_hop, params, amt)| { - get_route(&payer, params, &graph.read_only(), Some(&[first_hop]), *amt, &DummyLogger{}, &scorer, &random_seed_bytes).is_ok() + get_route(&payer, params, &graph.read_only(), Some(&[first_hop]), *amt, + &TestLogger::new(), &scorer, score_params, &random_seed_bytes).is_ok() }); - route_endpoints.truncate(100); - assert_eq!(route_endpoints.len(), 100); + route_endpoints.truncate(route_count); + assert_eq!(route_endpoints.len(), route_count); + route_endpoints + } +} + +#[cfg(ldk_bench)] +pub mod benches { + use super::*; + use crate::sign::{EntropySource, KeysManager}; + use crate::ln::channelmanager; + use crate::ln::features::InvoiceFeatures; + use crate::routing::gossip::NetworkGraph; + use crate::routing::scoring::{FixedPenaltyScorer, ProbabilisticScorer, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters}; + use crate::util::config::UserConfig; + use crate::util::logger::{Logger, Record}; + use crate::util::test_utils::TestLogger; + + use criterion::Criterion; + + struct DummyLogger {} + impl Logger for DummyLogger { + fn log(&self, _record: &Record) {} + } + + pub fn generate_routes_with_zero_penalty_scorer(bench: &mut Criterion) { + let logger = TestLogger::new(); + let network_graph = bench_utils::read_network_graph(&logger).unwrap(); + let scorer = FixedPenaltyScorer::with_penalty(0); + generate_routes(bench, &network_graph, scorer, &(), InvoiceFeatures::empty(), 0, + "generate_routes_with_zero_penalty_scorer"); + } + + pub fn generate_mpp_routes_with_zero_penalty_scorer(bench: &mut Criterion) { + let logger = TestLogger::new(); + let network_graph = bench_utils::read_network_graph(&logger).unwrap(); + let scorer = FixedPenaltyScorer::with_penalty(0); + generate_routes(bench, &network_graph, scorer, &(), + channelmanager::provided_invoice_features(&UserConfig::default()), 0, + "generate_mpp_routes_with_zero_penalty_scorer"); + } + + pub fn generate_routes_with_probabilistic_scorer(bench: &mut Criterion) { + let logger = TestLogger::new(); + let network_graph = bench_utils::read_network_graph(&logger).unwrap(); + let params = ProbabilisticScoringFeeParameters::default(); + let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger); + generate_routes(bench, &network_graph, scorer, ¶ms, InvoiceFeatures::empty(), 0, + "generate_routes_with_probabilistic_scorer"); + } + + pub fn generate_mpp_routes_with_probabilistic_scorer(bench: &mut Criterion) { + let logger = TestLogger::new(); + let network_graph = bench_utils::read_network_graph(&logger).unwrap(); + let params = ProbabilisticScoringFeeParameters::default(); + let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger); + generate_routes(bench, &network_graph, scorer, ¶ms, + channelmanager::provided_invoice_features(&UserConfig::default()), 0, + "generate_mpp_routes_with_probabilistic_scorer"); + } + + pub fn generate_large_mpp_routes_with_probabilistic_scorer(bench: &mut Criterion) { + let logger = TestLogger::new(); + let network_graph = bench_utils::read_network_graph(&logger).unwrap(); + let params = ProbabilisticScoringFeeParameters::default(); + let scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger); + generate_routes(bench, &network_graph, scorer, ¶ms, + channelmanager::provided_invoice_features(&UserConfig::default()), 100_000_000, + "generate_large_mpp_routes_with_probabilistic_scorer"); + } + + fn generate_routes( + bench: &mut Criterion, graph: &NetworkGraph<&TestLogger>, mut scorer: S, + score_params: &S::ScoreParams, features: InvoiceFeatures, starting_amount: u64, + bench_name: &'static str, + ) { + let payer = bench_utils::payer_pubkey(); + let keys_manager = KeysManager::new(&[0u8; 32], 42, 42); + let random_seed_bytes = keys_manager.get_secure_random_bytes(); + + // First, get 100 (source, destination) pairs for which route-getting actually succeeds... + let route_endpoints = bench_utils::generate_test_routes(graph, &mut scorer, score_params, features, 0xdeadbeef, starting_amount, 50); // ...then benchmark finding paths between the nodes we learned. let mut idx = 0; - bench.iter(|| { + bench.bench_function(bench_name, |b| b.iter(|| { let (first_hop, params, amt) = &route_endpoints[idx % route_endpoints.len()]; - assert!(get_route(&payer, params, &graph.read_only(), Some(&[first_hop]), *amt, &DummyLogger{}, &scorer, &random_seed_bytes).is_ok()); + assert!(get_route(&payer, params, &graph.read_only(), Some(&[first_hop]), *amt, + &DummyLogger{}, &scorer, score_params, &random_seed_bytes).is_ok()); idx += 1; - }); + })); } }