Provide `Score` the HTLC amount and channel capacity
authorMatt Corallo <git@bluematt.me>
Fri, 12 Nov 2021 03:52:58 +0000 (03:52 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 16 Nov 2021 20:58:04 +0000 (20:58 +0000)
This should allow `Score` implementations to make substantially
better decisions, including of the form "willing to pay X to avoid
routing over this channel which may have a high failure rate".

lightning-invoice/src/payment.rs
lightning/src/routing/mod.rs
lightning/src/routing/router.rs
lightning/src/routing/scorer.rs

index 85786a90a3450bec756328889e6de002a8108658..4099afbaa4c74b60e0427755f5dd0f42dbf9ab77 100644 (file)
@@ -73,7 +73,7 @@
 //! # struct FakeScorer {};
 //! # impl routing::Score for FakeScorer {
 //! #     fn channel_penalty_msat(
-//! #         &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+//! #         &self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option<u64>, _source: &NodeId, _target: &NodeId
 //! #     ) -> u64 { 0 }
 //! #     fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}
 //! # }
@@ -1227,7 +1227,7 @@ mod tests {
 
        impl routing::Score for TestScorer {
                fn channel_penalty_msat(
-                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+                       &self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option<u64>, _source: &NodeId, _target: &NodeId
                ) -> u64 { 0 }
 
                fn payment_path_failed(&mut self, _path: &[&RouteHop], short_channel_id: u64) {
index 3a48ffe93ddf42c77fb3b3c9b17d727dba14c8f0..91478bafc0468254a4fa13e8882258b549923bed 100644 (file)
@@ -24,9 +24,19 @@ use sync::{Mutex, MutexGuard};
 ///
 ///    Scoring is in terms of fees willing to be paid in order to avoid routing through a channel.
 pub trait Score {
-       /// Returns the fee in msats willing to be paid to avoid routing through the given channel
-       /// in the direction from `source` to `target`.
-       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64;
+       /// Returns the fee in msats willing to be paid to avoid routing `send_amt_msat` through the
+       /// given channel in the direction from `source` to `target`.
+       ///
+       /// The channel's capacity (less any other MPP parts which are also being considered for use in
+       /// the same payment) is given by `channel_capacity_msat`. It may be guessed from various
+       /// sources or assumed from no data at all.
+       ///
+       /// For hints provided in the invoice, we assume the channel has sufficient capacity to accept
+       /// the invoice's full amount, and provide a `channel_capacity_msat` of `None`. In all other
+       /// cases it is set to `Some`, even if we're guessing at the channel value.
+       ///
+       /// Your code should be overflow-safe through a `channel_capacity_msat` of 21 million BTC.
+       fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, channel_capacity_msat: Option<u64>, source: &NodeId, target: &NodeId) -> u64;
 
        /// Handles updating channel penalties after failing to route through a channel.
        fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64);
@@ -65,8 +75,8 @@ impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
 }
 
 impl<S: Score, T: DerefMut<Target=S>> Score for T {
-       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 {
-               self.deref().channel_penalty_msat(short_channel_id, source, target)
+       fn channel_penalty_msat(&self, short_channel_id: u64, send_amt_msat: u64, channel_capacity_msat: Option<u64>, source: &NodeId, target: &NodeId) -> u64 {
+               self.deref().channel_penalty_msat(short_channel_id, send_amt_msat, channel_capacity_msat, source, target)
        }
 
        fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
index 974ae74e4960f0c636670f721106aa4cb8d87d53..a98fb99124690566040561dee77893ba0c99cda1 100644 (file)
@@ -892,9 +892,9 @@ where L::Target: Logger {
                                                                }
                                                        }
 
-                                                       let path_penalty_msat = $next_hops_path_penalty_msat
-                                                               .checked_add(scorer.channel_penalty_msat($chan_id.clone(), &$src_node_id, &$dest_node_id))
-                                                               .unwrap_or_else(|| u64::max_value());
+                                                       let path_penalty_msat = $next_hops_path_penalty_msat.checked_add(
+                                                               scorer.channel_penalty_msat($chan_id.clone(), amount_to_transfer_over_msat, Some(*available_liquidity_msat),
+                                                                       &$src_node_id, &$dest_node_id)).unwrap_or_else(|| u64::max_value());
                                                        let new_graph_node = RouteGraphNode {
                                                                node_id: $src_node_id,
                                                                lowest_fee_to_peer_through_node: total_fee_msat,
@@ -1121,7 +1121,7 @@ where L::Target: Logger {
                                        let src_node_id = NodeId::from_pubkey(&hop.src_node_id);
                                        let dest_node_id = NodeId::from_pubkey(&prev_hop_id);
                                        aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
-                                               .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, &src_node_id, &dest_node_id))
+                                               .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, final_value_msat, None, &src_node_id, &dest_node_id))
                                                .unwrap_or_else(|| u64::max_value());
 
                                        // We assume that the recipient only included route hints for routes which had
@@ -4550,7 +4550,7 @@ mod tests {
        }
 
        impl routing::Score for BadChannelScorer {
-               fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 {
+               fn channel_penalty_msat(&self, short_channel_id: u64, _send_amt: u64, _chan_amt: Option<u64>, _source: &NodeId, _target: &NodeId) -> u64 {
                        if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
                }
 
@@ -4562,7 +4562,7 @@ mod tests {
        }
 
        impl routing::Score for BadNodeScorer {
-               fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 {
+               fn channel_penalty_msat(&self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option<u64>, _source: &NodeId, target: &NodeId) -> u64 {
                        if *target == self.node_id { u64::max_value() } else { 0 }
                }
 
index df744ce686b55bc4777e7ddceb824625653af0ee..573527540b2e569209335536470c5d2f1da55c5f 100644 (file)
@@ -211,7 +211,7 @@ impl Default for ScoringParameters {
 
 impl<T: Time> routing::Score for ScorerUsingTime<T> {
        fn channel_penalty_msat(
-               &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId
+               &self, short_channel_id: u64, _send_amt_msat: u64, _chan_capacity_msat: Option<u64>, _source: &NodeId, _target: &NodeId
        ) -> u64 {
                let failure_penalty_msat = self.channel_failures
                        .get(&short_channel_id)
@@ -417,10 +417,10 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000);
 
                SinceEpoch::advance(Duration::from_secs(1));
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000);
        }
 
        #[test]
@@ -432,16 +432,16 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_064);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_064);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_128);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_128);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_192);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_192);
        }
 
        #[test]
@@ -453,25 +453,25 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512);
 
                SinceEpoch::advance(Duration::from_secs(9));
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512);
 
                SinceEpoch::advance(Duration::from_secs(1));
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_256);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256);
 
                SinceEpoch::advance(Duration::from_secs(10 * 8));
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_001);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_001);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000);
        }
 
        #[test]
@@ -483,19 +483,19 @@ mod tests {
                });
                let source = source_node_id();
                let target = target_node_id();
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_000);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_000);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_256);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256);
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_768);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_768);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_384);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_384);
        }
 
        #[test]
@@ -509,20 +509,20 @@ mod tests {
                let target = target_node_id();
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_256);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256);
 
                scorer.payment_path_failed(&[], 43);
-               assert_eq!(scorer.channel_penalty_msat(43, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(43, 1, Some(1), &source, &target), 1_512);
 
                let mut serialized_scorer = Vec::new();
                scorer.write(&mut serialized_scorer).unwrap();
 
                let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
-               assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target), 1_256);
-               assert_eq!(deserialized_scorer.channel_penalty_msat(43, &source, &target), 1_512);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(43, 1, Some(1), &source, &target), 1_512);
        }
 
        #[test]
@@ -536,7 +536,7 @@ mod tests {
                let target = target_node_id();
 
                scorer.payment_path_failed(&[], 42);
-               assert_eq!(scorer.channel_penalty_msat(42, &source, &target), 1_512);
+               assert_eq!(scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_512);
 
                let mut serialized_scorer = Vec::new();
                scorer.write(&mut serialized_scorer).unwrap();
@@ -544,9 +544,9 @@ mod tests {
                SinceEpoch::advance(Duration::from_secs(10));
 
                let deserialized_scorer = <Scorer>::read(&mut io::Cursor::new(&serialized_scorer)).unwrap();
-               assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target), 1_256);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_256);
 
                SinceEpoch::advance(Duration::from_secs(10));
-               assert_eq!(deserialized_scorer.channel_penalty_msat(42, &source, &target), 1_128);
+               assert_eq!(deserialized_scorer.channel_penalty_msat(42, 1, Some(1), &source, &target), 1_128);
        }
 }