From: Matt Corallo Date: Fri, 12 Nov 2021 03:52:58 +0000 (+0000) Subject: Provide `Score` the HTLC amount and channel capacity X-Git-Tag: v0.0.104~34^2~2 X-Git-Url: http://git.bitcoin.ninja/?a=commitdiff_plain;h=8dc7cfab3af7e26cddd8c65a913d5ee021907442;p=rust-lightning Provide `Score` the HTLC amount and channel capacity This should allow `Score` implementations to make substantially better decisions, including of the form "willing to pay X to avoid routing over this channel which may have a high failure rate". --- diff --git a/lightning-invoice/src/payment.rs b/lightning-invoice/src/payment.rs index 85786a90a..4099afbaa 100644 --- a/lightning-invoice/src/payment.rs +++ b/lightning-invoice/src/payment.rs @@ -73,7 +73,7 @@ //! # struct FakeScorer {}; //! # impl routing::Score for FakeScorer { //! # fn channel_penalty_msat( -//! # &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId +//! # &self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option, _source: &NodeId, _target: &NodeId //! # ) -> u64 { 0 } //! # fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {} //! # } @@ -1227,7 +1227,7 @@ mod tests { impl routing::Score for TestScorer { fn channel_penalty_msat( - &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId + &self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option, _source: &NodeId, _target: &NodeId ) -> u64 { 0 } fn payment_path_failed(&mut self, _path: &[&RouteHop], short_channel_id: u64) { diff --git a/lightning/src/routing/mod.rs b/lightning/src/routing/mod.rs index 3a48ffe93..91478bafc 100644 --- a/lightning/src/routing/mod.rs +++ b/lightning/src/routing/mod.rs @@ -24,9 +24,19 @@ use sync::{Mutex, MutexGuard}; /// /// Scoring is in terms of fees willing to be paid in order to avoid routing through a channel. pub trait Score { - /// Returns the fee in msats willing to be paid to avoid routing through the given channel - /// in the direction from `source` to `target`. - fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64; + /// Returns the fee in msats willing to be paid to avoid routing `send_amt_msat` through the + /// given channel in the direction from `source` to `target`. + /// + /// The channel's capacity (less any other MPP parts which are also being considered for use in + /// the same payment) is given by `channel_capacity_msat`. It may be guessed from various + /// sources or assumed from no data at all. + /// + /// For hints provided in the invoice, we assume the channel has sufficient capacity to accept + /// the invoice's full amount, and provide a `channel_capacity_msat` of `None`. In all other + /// cases it is set to `Some`, even if we're guessing at the channel value. + /// + /// Your code should be overflow-safe through a `channel_capacity_msat` of 21 million BTC. + fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, channel_capacity_msat: Option, source: &NodeId, target: &NodeId) -> u64; /// Handles updating channel penalties after failing to route through a channel. fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64); @@ -65,8 +75,8 @@ impl<'a, T: 'a + Score> LockableScore<'a> for RefCell { } impl> Score for T { - fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 { - self.deref().channel_penalty_msat(short_channel_id, source, target) + fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, channel_capacity_msat: Option, source: &NodeId, target: &NodeId) -> u64 { + self.deref().channel_penalty_msat(short_channel_id, send_amt_msat, channel_capacity_msat, source, target) } fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) { diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 974ae74e4..a98fb9912 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -892,9 +892,9 @@ where L::Target: Logger { } } - let path_penalty_msat = $next_hops_path_penalty_msat - .checked_add(scorer.channel_penalty_msat($chan_id.clone(), &$src_node_id, &$dest_node_id)) - .unwrap_or_else(|| u64::max_value()); + let path_penalty_msat = $next_hops_path_penalty_msat.checked_add( + scorer.channel_penalty_msat($chan_id.clone(), amount_to_transfer_over_msat, Some(*available_liquidity_msat), + &$src_node_id, &$dest_node_id)).unwrap_or_else(|| u64::max_value()); let new_graph_node = RouteGraphNode { node_id: $src_node_id, lowest_fee_to_peer_through_node: total_fee_msat, @@ -1121,7 +1121,7 @@ where L::Target: Logger { let src_node_id = NodeId::from_pubkey(&hop.src_node_id); let dest_node_id = NodeId::from_pubkey(&prev_hop_id); aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat - .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, &src_node_id, &dest_node_id)) + .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, final_value_msat, None, &src_node_id, &dest_node_id)) .unwrap_or_else(|| u64::max_value()); // We assume that the recipient only included route hints for routes which had @@ -4550,7 +4550,7 @@ mod tests { } impl routing::Score for BadChannelScorer { - fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 { + fn channel_penalty_msat(&self, short_channel_id: u64, _send_amt: u64, _chan_amt: Option, _source: &NodeId, _target: &NodeId) -> u64 { if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 } } @@ -4562,7 +4562,7 @@ mod tests { } impl routing::Score for BadNodeScorer { - fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 { + fn channel_penalty_msat(&self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option, _source: &NodeId, target: &NodeId) -> u64 { if *target == self.node_id { u64::max_value() } else { 0 } } diff --git a/lightning/src/routing/scorer.rs b/lightning/src/routing/scorer.rs index df744ce68..573527540 100644 --- a/lightning/src/routing/scorer.rs +++ b/lightning/src/routing/scorer.rs @@ -211,7 +211,7 @@ impl Default for ScoringParameters { impl routing::Score for ScorerUsingTime { fn channel_penalty_msat( - &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId + &self, short_channel_id: u64, _send_amt_msat: u64, _chan_capacity_msat: Option, _source: &NodeId, _target: &NodeId ) -> u64 { let failure_penalty_msat = self.channel_failures .get(&short_channel_id) @@ -417,10 +417,10 @@ mod tests { }); let source = source_node_id(); let target = target_node_id(); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); SinceEpoch::advance(Duration::from_secs(1)); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); } #[test] @@ -432,16 +432,16 @@ mod tests { }); let source = source_node_id(); let target = target_node_id(); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_064); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_064); scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_128); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_128); scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_192); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_192); } #[test] @@ -453,25 +453,25 @@ mod tests { }); let source = source_node_id(); let target = target_node_id(); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); SinceEpoch::advance(Duration::from_secs(9)); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); SinceEpoch::advance(Duration::from_secs(1)); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_256); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); SinceEpoch::advance(Duration::from_secs(10 * 8)); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_001); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_001); SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); } #[test] @@ -483,19 +483,19 @@ mod tests { }); let source = source_node_id(); let target = target_node_id(); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000); scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_256); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_768); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_768); SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_384); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_384); } #[test] @@ -509,20 +509,20 @@ mod tests { let target = target_node_id(); scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_256); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); scorer.payment_path_failed(&[], 43); - assert_eq!(scorer.channel_penalty_msat(43, &source, &target), 1_512); + assert_eq!(scorer.channel_penalty_msat(43, 1, Some(1), &source, &target), 1_512); let mut serialized_scorer = Vec::new(); scorer.write(&mut serialized_scorer).unwrap(); let deserialized_scorer = ::read(&mut io::Cursor::new(&serialized_scorer)).unwrap(); - assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target), 1_256); - assert_eq!(deserialized_scorer.channel_penalty_msat(43, &source, &target), 1_512); + assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); + assert_eq!(deserialized_scorer.channel_penalty_msat(43, 1, Some(1), &source, &target), 1_512); } #[test] @@ -536,7 +536,7 @@ mod tests { let target = target_node_id(); scorer.payment_path_failed(&[], 42); - assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512); + assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512); let mut serialized_scorer = Vec::new(); scorer.write(&mut serialized_scorer).unwrap(); @@ -544,9 +544,9 @@ mod tests { SinceEpoch::advance(Duration::from_secs(10)); let deserialized_scorer = ::read(&mut io::Cursor::new(&serialized_scorer)).unwrap(); - assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target), 1_256); + assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256); SinceEpoch::advance(Duration::from_secs(10)); - assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target), 1_128); + assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_128); } }