From d2c20ecc2d2eec46b5f1d29173bc9e29cd5179a0 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Tue, 18 Jul 2023 19:41:07 +0000 Subject: [PATCH] Pass `InFlightHltcs` to the scorer by ownership rather than ref Given we build `InFlightHtlcs` per route-fetch call, there's no reason to pass them out by reference rather than simply giving the user the full object. This also allows them to tweak the in-flight set before fetching a route. --- fuzz/src/chanmon_consistency.rs | 2 +- fuzz/src/full_stack.rs | 2 +- lightning/src/ln/outbound_payment.rs | 4 ++-- lightning/src/ln/payment_tests.rs | 2 +- lightning/src/routing/router.rs | 10 +++++----- lightning/src/util/test_utils.rs | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/fuzz/src/chanmon_consistency.rs b/fuzz/src/chanmon_consistency.rs index cd9c484d..e923ef88 100644 --- a/fuzz/src/chanmon_consistency.rs +++ b/fuzz/src/chanmon_consistency.rs @@ -89,7 +89,7 @@ struct FuzzRouter {} impl Router for FuzzRouter { fn find_route( &self, _payer: &PublicKey, _params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>, - _inflight_htlcs: &InFlightHtlcs + _inflight_htlcs: InFlightHtlcs ) -> Result { Err(msgs::LightningError { err: String::from("Not implemented"), diff --git a/fuzz/src/full_stack.rs b/fuzz/src/full_stack.rs index 9caf9104..1fbd7dbe 100644 --- a/fuzz/src/full_stack.rs +++ b/fuzz/src/full_stack.rs @@ -131,7 +131,7 @@ struct FuzzRouter {} impl Router for FuzzRouter { fn find_route( &self, _payer: &PublicKey, _params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>, - _inflight_htlcs: &InFlightHtlcs + _inflight_htlcs: InFlightHtlcs ) -> Result { Err(msgs::LightningError { err: String::from("Not implemented"), diff --git a/lightning/src/ln/outbound_payment.rs b/lightning/src/ln/outbound_payment.rs index 546dc6c5..30e718dc 100644 --- a/lightning/src/ln/outbound_payment.rs +++ b/lightning/src/ln/outbound_payment.rs @@ -669,7 +669,7 @@ impl OutboundPayments { let route = router.find_route_with_id( &node_signer.get_node_id(Recipient::Node).unwrap(), &route_params, - Some(&first_hops.iter().collect::>()), &inflight_htlcs(), + Some(&first_hops.iter().collect::>()), inflight_htlcs(), payment_hash, payment_id, ).map_err(|_| RetryableSendFailure::RouteNotFound)?; @@ -712,7 +712,7 @@ impl OutboundPayments { let route = match router.find_route_with_id( &node_signer.get_node_id(Recipient::Node).unwrap(), &route_params, - Some(&first_hops.iter().collect::>()), &inflight_htlcs(), + Some(&first_hops.iter().collect::>()), inflight_htlcs(), payment_hash, payment_id, ) { Ok(route) => route, diff --git a/lightning/src/ln/payment_tests.rs b/lightning/src/ln/payment_tests.rs index b23b6364..2ae60610 100644 --- a/lightning/src/ln/payment_tests.rs +++ b/lightning/src/ln/payment_tests.rs @@ -3218,7 +3218,7 @@ fn do_claim_from_closed_chan(fail_payment: bool) { final_value_msat: 10_000_000, }; let mut route = nodes[0].router.find_route(&nodes[0].node.get_our_node_id(), &route_params, - None, &nodes[0].node.compute_inflight_htlcs()).unwrap(); + None, nodes[0].node.compute_inflight_htlcs()).unwrap(); // Make sure the route is ordered as the B->D path before C->D route.paths.sort_by(|a, _| if a.hops[0].pubkey == nodes[1].node.get_our_node_id() { std::cmp::Ordering::Less } else { std::cmp::Ordering::Greater }); diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 9242e98b..f0822ed5 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -64,7 +64,7 @@ impl< G: Deref>, L: Deref, S: Deref, SP: Sized, Sc: Sco payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>, - inflight_htlcs: &InFlightHtlcs + inflight_htlcs: InFlightHtlcs ) -> Result { let random_seed_bytes = { let mut locked_random_seed_bytes = self.random_seed_bytes.lock().unwrap(); @@ -73,7 +73,7 @@ impl< G: Deref>, L: Deref, S: Deref, SP: Sized, Sc: Sco }; find_route( payer, params, &self.network_graph, first_hops, &*self.logger, - &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), inflight_htlcs), + &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), &inflight_htlcs), &self.score_params, &random_seed_bytes ) @@ -85,13 +85,13 @@ pub trait Router { /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values. fn find_route( &self, payer: &PublicKey, route_params: &RouteParameters, - first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: &InFlightHtlcs + first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs ) -> Result; /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values. Includes /// `PaymentHash` and `PaymentId` to be able to correlate the request with a specific payment. fn find_route_with_id( &self, payer: &PublicKey, route_params: &RouteParameters, - first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: &InFlightHtlcs, + first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs, _payment_hash: PaymentHash, _payment_id: PaymentId ) -> Result { self.find_route(payer, route_params, first_hops, inflight_htlcs) @@ -112,7 +112,7 @@ pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score, SP: impl<'a, S: Score, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> { /// Initialize a new `ScorerAccountingForInFlightHtlcs`. - pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self { + pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self { ScorerAccountingForInFlightHtlcs { scorer, inflight_htlcs diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 3cc08e04..65c0483a 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -112,13 +112,13 @@ impl<'a> TestRouter<'a> { impl<'a> Router for TestRouter<'a> { fn find_route( &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&channelmanager::ChannelDetails]>, - inflight_htlcs: &InFlightHtlcs + inflight_htlcs: InFlightHtlcs ) -> Result { if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() { assert_eq!(find_route_query, *params); if let Ok(ref route) = find_route_res { let mut binding = self.scorer.lock().unwrap(); - let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs); + let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), &inflight_htlcs); for path in &route.paths { let mut aggregate_msat = 0u64; for (idx, hop) in path.hops.iter().rev().enumerate() { -- 2.30.2