]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Add utilities for getting a path's final value and cltv delta
authorValentine Wallace <vwallace@protonmail.com>
Wed, 19 Apr 2023 01:18:44 +0000 (21:18 -0400)
committerValentine Wallace <vwallace@protonmail.com>
Fri, 21 Apr 2023 15:48:27 +0000 (11:48 -0400)
lightning/src/ln/channelmanager.rs
lightning/src/ln/outbound_payment.rs
lightning/src/routing/router.rs
lightning/src/routing/scoring.rs

index 3fa8fa5ba947080d52cb888f505f40b568adb2a1..a3a2f80e3980679c0cbecbebffd3e34920f76589 100644 (file)
@@ -7025,7 +7025,7 @@ impl Readable for HTLCSource {
                                let path = path.unwrap();
                                if let Some(params) = payment_params.as_mut() {
                                        if params.final_cltv_expiry_delta == 0 {
-                                               params.final_cltv_expiry_delta = path.last().unwrap().cltv_expiry_delta;
+                                               params.final_cltv_expiry_delta = path.final_cltv_expiry_delta();
                                        }
                                }
                                Ok(HTLCSource::OutboundRoute {
@@ -7725,7 +7725,7 @@ where
                                                                return Err(DecodeError::InvalidValue);
                                                        }
 
-                                                       let path_amt = path.last().unwrap().fee_msat;
+                                                       let path_amt = path.final_value_msat();
                                                        let mut session_priv_bytes = [0; 32];
                                                        session_priv_bytes[..].copy_from_slice(&session_priv[..]);
                                                        match pending_outbounds.pending_outbound_payments.lock().unwrap().entry(payment_id) {
@@ -7735,7 +7735,7 @@ where
                                                                                if newly_added { "Added" } else { "Had" }, path_amt, log_bytes!(session_priv_bytes), log_bytes!(htlc.payment_hash.0));
                                                                },
                                                                hash_map::Entry::Vacant(entry) => {
-                                                                       let path_fee = path.get_path_fees();
+                                                                       let path_fee = path.fee_msat();
                                                                        entry.insert(PendingOutboundPayment::Retryable {
                                                                                retry_strategy: None,
                                                                                attempts: PaymentAttempts::new(),
index 4758b7ea4f1bb52286990af3c584af024a0b9df5..8b17bdbcba6770a3c639dffef53a76652bc4e587 100644 (file)
@@ -172,10 +172,9 @@ impl PendingOutboundPayment {
                if remove_res {
                        if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, ref mut pending_fee_msat, .. } = self {
                                let path = path.expect("Fulfilling a payment should always come with a path");
-                               let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
-                               *pending_amt_msat -= path_last_hop.fee_msat;
+                               *pending_amt_msat -= path.final_value_msat();
                                if let Some(fee_msat) = pending_fee_msat.as_mut() {
-                                       *fee_msat -= path.get_path_fees();
+                                       *fee_msat -= path.fee_msat();
                                }
                        }
                }
@@ -193,10 +192,9 @@ impl PendingOutboundPayment {
                };
                if insert_res {
                        if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, ref mut pending_fee_msat, .. } = self {
-                               let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
-                               *pending_amt_msat += path_last_hop.fee_msat;
+                               *pending_amt_msat += path.final_value_msat();
                                if let Some(fee_msat) = pending_fee_msat.as_mut() {
-                                       *fee_msat += path.get_path_fees();
+                                       *fee_msat += path.fee_msat();
                                }
                        }
                }
@@ -756,7 +754,7 @@ impl OutboundPayments {
                                                PendingOutboundPayment::Retryable {
                                                        total_msat, keysend_preimage, payment_secret, payment_metadata, pending_amt_msat, ..
                                                } => {
-                                                       let retry_amt_msat: u64 = route.paths.iter().map(|path| path.last().unwrap().fee_msat).sum();
+                                                       let retry_amt_msat = route.get_total_amount();
                                                        if retry_amt_msat + *pending_amt_msat > *total_msat * (100 + RETRY_OVERFLOW_PERCENTAGE) / 100 {
                                                                log_error!(logger, "retry_amt_msat of {} will put pending_amt_msat (currently: {}) more than 10% over total_payment_amt_msat of {}", retry_amt_msat, pending_amt_msat, total_msat);
                                                                abandon_with_entry!(payment, PaymentFailureReason::UnexpectedError);
@@ -1008,7 +1006,7 @@ impl OutboundPayments {
                                        continue 'path_check;
                                }
                        }
-                       total_value += path.last().unwrap().fee_msat;
+                       total_value += path.final_value_msat();
                        path_errs.push(Ok(()));
                }
                if path_errs.iter().any(|e| e.is_err()) {
@@ -1056,7 +1054,7 @@ impl OutboundPayments {
                                has_err = true;
                                has_ok = true;
                        } else if res.is_err() {
-                               pending_amt_unsent += path.last().unwrap().fee_msat;
+                               pending_amt_unsent += path.final_value_msat();
                        }
                }
                if has_err && has_ok {
index 0910985418ed8b23a958efa5f9056bef8414c3e7..0410c5c554b52bed6797e7b7cc0de8568e1efc51 100644 (file)
@@ -261,17 +261,43 @@ pub struct Route {
        pub payment_params: Option<PaymentParameters>,
 }
 
+// This trait is deleted in the next commit
 pub(crate) trait RoutePath {
        /// Gets the fees for a given path, excluding any excess paid to the recipient.
-       fn get_path_fees(&self) -> u64;
+       fn fee_msat(&self) -> u64;
+
+       /// Gets the total amount paid on this path, excluding the fees.
+       fn final_value_msat(&self) -> u64;
+
+       /// Gets the final hop's CLTV expiry delta.
+       fn final_cltv_expiry_delta(&self) -> u32;
 }
 impl RoutePath for Vec<RouteHop> {
-       fn get_path_fees(&self) -> u64 {
+       fn fee_msat(&self) -> u64 {
                // Do not count last hop of each path since that's the full value of the payment
                self.split_last().map(|(_, path_prefix)| path_prefix).unwrap_or(&[])
                        .iter().map(|hop| &hop.fee_msat)
                        .sum()
        }
+       fn final_value_msat(&self) -> u64 {
+               self.last().map_or(0, |hop| hop.fee_msat)
+       }
+       fn final_cltv_expiry_delta(&self) -> u32 {
+               self.last().map_or(0, |hop| hop.cltv_expiry_delta)
+       }
+}
+impl RoutePath for &[&RouteHop] {
+       fn fee_msat(&self) -> u64 {
+               self.split_last().map(|(_, path_prefix)| path_prefix).unwrap_or(&[])
+                       .iter().map(|hop| &hop.fee_msat)
+                       .sum()
+       }
+       fn final_value_msat(&self) -> u64 {
+               self.last().map_or(0, |hop| hop.fee_msat)
+       }
+       fn final_cltv_expiry_delta(&self) -> u32 {
+               self.last().map_or(0, |hop| hop.cltv_expiry_delta)
+       }
 }
 
 impl Route {
@@ -280,15 +306,13 @@ impl Route {
        /// This doesn't include any extra payment made to the recipient, which can happen in excess of
        /// the amount passed to [`find_route`]'s `params.final_value_msat`.
        pub fn get_total_fees(&self) -> u64 {
-               self.paths.iter().map(|path| path.get_path_fees()).sum()
+               self.paths.iter().map(|path| path.fee_msat()).sum()
        }
 
        /// Returns the total amount paid on this [`Route`], excluding the fees. Might be more than
        /// requested if we had to reach htlc_minimum_msat.
        pub fn get_total_amount(&self) -> u64 {
-               return self.paths.iter()
-                       .map(|path| path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0))
-                       .sum();
+               self.paths.iter().map(|path| path.final_value_msat()).sum()
        }
 }
 
@@ -2183,7 +2207,7 @@ mod tests {
        use crate::routing::gossip::{NetworkGraph, P2PGossipSync, NodeId, EffectiveCapacity};
        use crate::routing::utxo::UtxoResult;
        use crate::routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
-               PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
+               PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees, RoutePath,
                DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE};
        use crate::routing::scoring::{ChannelUsage, FixedPenaltyScorer, Score, ProbabilisticScorer, ProbabilisticScoringParameters};
        use crate::routing::test_utils::{add_channel, add_or_update_node, build_graph, build_line_graph, id_to_feature_flags, get_nodes, update_channel};
@@ -3487,7 +3511,7 @@ mod tests {
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
                        assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                       assert_eq!(path.last().unwrap().fee_msat, 250_000_000);
+                       assert_eq!(path.final_value_msat(), 250_000_000);
                }
 
                // Check that setting next_outbound_htlc_limit_msat in first_hops limits the channels.
@@ -3523,7 +3547,7 @@ mod tests {
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
                        assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                       assert_eq!(path.last().unwrap().fee_msat, 200_000_000);
+                       assert_eq!(path.final_value_msat(), 200_000_000);
                }
 
                // Enable channel #1 back.
@@ -3570,7 +3594,7 @@ mod tests {
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
                        assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                       assert_eq!(path.last().unwrap().fee_msat, 15_000);
+                       assert_eq!(path.final_value_msat(), 15_000);
                }
 
                // Now let's see if routing works if we know only capacity from the UTXO.
@@ -3641,7 +3665,7 @@ mod tests {
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
                        assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                       assert_eq!(path.last().unwrap().fee_msat, 15_000);
+                       assert_eq!(path.final_value_msat(), 15_000);
                }
 
                // Now let's see if routing chooses htlc_maximum_msat over UTXO capacity.
@@ -3673,7 +3697,7 @@ mod tests {
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
                        assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                       assert_eq!(path.last().unwrap().fee_msat, 10_000);
+                       assert_eq!(path.final_value_msat(), 10_000);
                }
        }
 
@@ -3786,7 +3810,7 @@ mod tests {
                        for path in &route.paths {
                                assert_eq!(path.len(), 4);
                                assert_eq!(path.last().unwrap().pubkey, nodes[3]);
-                               total_amount_paid_msat += path.last().unwrap().fee_msat;
+                               total_amount_paid_msat += path.final_value_msat();
                        }
                        assert_eq!(total_amount_paid_msat, 49_000);
                }
@@ -3799,7 +3823,7 @@ mod tests {
                        for path in &route.paths {
                                assert_eq!(path.len(), 4);
                                assert_eq!(path.last().unwrap().pubkey, nodes[3]);
-                               total_amount_paid_msat += path.last().unwrap().fee_msat;
+                               total_amount_paid_msat += path.final_value_msat();
                        }
                        assert_eq!(total_amount_paid_msat, 50_000);
                }
@@ -3847,7 +3871,7 @@ mod tests {
                        for path in &route.paths {
                                assert_eq!(path.len(), 2);
                                assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                               total_amount_paid_msat += path.last().unwrap().fee_msat;
+                               total_amount_paid_msat += path.final_value_msat();
                        }
                        assert_eq!(total_amount_paid_msat, 50_000);
                }
@@ -3993,7 +4017,7 @@ mod tests {
                        for path in &route.paths {
                                assert_eq!(path.len(), 2);
                                assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                               total_amount_paid_msat += path.last().unwrap().fee_msat;
+                               total_amount_paid_msat += path.final_value_msat();
                        }
                        assert_eq!(total_amount_paid_msat, 250_000);
                }
@@ -4007,7 +4031,7 @@ mod tests {
                        for path in &route.paths {
                                assert_eq!(path.len(), 2);
                                assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                               total_amount_paid_msat += path.last().unwrap().fee_msat;
+                               total_amount_paid_msat += path.final_value_msat();
                        }
                        assert_eq!(total_amount_paid_msat, 290_000);
                }
@@ -4171,7 +4195,7 @@ mod tests {
                        let mut total_amount_paid_msat = 0;
                        for path in &route.paths {
                                assert_eq!(path.last().unwrap().pubkey, nodes[3]);
-                               total_amount_paid_msat += path.last().unwrap().fee_msat;
+                               total_amount_paid_msat += path.final_value_msat();
                        }
                        assert_eq!(total_amount_paid_msat, 300_000);
                }
@@ -4333,7 +4357,7 @@ mod tests {
                        let mut total_paid_msat = 0;
                        for path in &route.paths {
                                assert_eq!(path.last().unwrap().pubkey, nodes[3]);
-                               total_value_transferred_msat += path.last().unwrap().fee_msat;
+                               total_value_transferred_msat += path.final_value_msat();
                                for hop in path {
                                        total_paid_msat += hop.fee_msat;
                                }
@@ -4510,7 +4534,7 @@ mod tests {
                        let mut total_amount_paid_msat = 0;
                        for path in &route.paths {
                                assert_eq!(path.last().unwrap().pubkey, nodes[3]);
-                               total_amount_paid_msat += path.last().unwrap().fee_msat;
+                               total_amount_paid_msat += path.final_value_msat();
                        }
                        assert_eq!(total_amount_paid_msat, 200_000);
                        assert_eq!(route.get_total_fees(), 150_000);
@@ -4737,7 +4761,7 @@ mod tests {
                        for path in &route.paths {
                                assert_eq!(path.len(), 2);
                                assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                               total_amount_paid_msat += path.last().unwrap().fee_msat;
+                               total_amount_paid_msat += path.final_value_msat();
                        }
                        assert_eq!(total_amount_paid_msat, 125_000);
                }
@@ -4750,7 +4774,7 @@ mod tests {
                        for path in &route.paths {
                                assert_eq!(path.len(), 2);
                                assert_eq!(path.last().unwrap().pubkey, nodes[2]);
-                               total_amount_paid_msat += path.last().unwrap().fee_msat;
+                               total_amount_paid_msat += path.final_value_msat();
                        }
                        assert_eq!(total_amount_paid_msat, 90_000);
                }
index 4d342562bea75afbe091af9244584dcbb2124ffb..2ef2a8557da4ab6edde968d723463b1431cdb4a7 100644 (file)
@@ -56,7 +56,7 @@
 
 use crate::ln::msgs::DecodeError;
 use crate::routing::gossip::{EffectiveCapacity, NetworkGraph, NodeId};
-use crate::routing::router::RouteHop;
+use crate::routing::router::{RouteHop, RoutePath};
 use crate::util::ser::{Readable, ReadableArgs, Writeable, Writer};
 use crate::util::logger::Logger;
 use crate::util::time::Time;
@@ -1234,7 +1234,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Score for Probabilis
        }
 
        fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
-               let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0);
+               let amount_msat = path.final_value_msat();
                log_trace!(self.logger, "Scoring path through to SCID {} as having failed at {} msat", short_channel_id, amount_msat);
                let network_graph = self.network_graph.read_only();
                for (hop_idx, hop) in path.iter().enumerate() {
@@ -1273,7 +1273,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, T: Time> Score for Probabilis
        }
 
        fn payment_path_successful(&mut self, path: &[&RouteHop]) {
-               let amount_msat = path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0);
+               let amount_msat = path.final_value_msat();
                log_trace!(self.logger, "Scoring path through SCID {} as having succeeded at {} msat.",
                        path.split_last().map(|(hop, _)| hop.short_channel_id).unwrap_or(0), amount_msat);
                let network_graph = self.network_graph.read_only();