Pass `InFlightHltcs` to the scorer by ownership rather than ref
authorMatt Corallo <git@bluematt.me>
Tue, 18 Jul 2023 19:41:07 +0000 (19:41 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 20 Jul 2023 19:49:43 +0000 (19:49 +0000)
Given we build `InFlightHtlcs` per route-fetch call, there's no
reason to pass them out by reference rather than simply giving the
user the full object. This also allows them to tweak the in-flight
set before fetching a route.

fuzz/src/chanmon_consistency.rs
fuzz/src/full_stack.rs
lightning/src/ln/outbound_payment.rs
lightning/src/ln/payment_tests.rs
lightning/src/routing/router.rs
lightning/src/util/test_utils.rs

index cd9c484da2a82f75a924b72150c68d602d5b47ce..e923ef882f26ce643aa7c9478f15ecc648c931d8 100644 (file)
@@ -89,7 +89,7 @@ struct FuzzRouter {}
 impl Router for FuzzRouter {
        fn find_route(
                &self, _payer: &PublicKey, _params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>,
-               _inflight_htlcs: &InFlightHtlcs
+               _inflight_htlcs: InFlightHtlcs
        ) -> Result<Route, msgs::LightningError> {
                Err(msgs::LightningError {
                        err: String::from("Not implemented"),
index 9caf91040346c45436671cb0f4c3e74ad1cbb373..1fbd7dbec8834b9e9dee38f1c19ff02e770f79d1 100644 (file)
@@ -131,7 +131,7 @@ struct FuzzRouter {}
 impl Router for FuzzRouter {
        fn find_route(
                &self, _payer: &PublicKey, _params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>,
-               _inflight_htlcs: &InFlightHtlcs
+               _inflight_htlcs: InFlightHtlcs
        ) -> Result<Route, msgs::LightningError> {
                Err(msgs::LightningError {
                        err: String::from("Not implemented"),
index 546dc6c5bcd93171411263ff1c77229a08ea2ae1..30e718dccd65321cd5fee3f3d2fd632a0fce6d61 100644 (file)
@@ -669,7 +669,7 @@ impl OutboundPayments {
 
                let route = router.find_route_with_id(
                        &node_signer.get_node_id(Recipient::Node).unwrap(), &route_params,
-                       Some(&first_hops.iter().collect::<Vec<_>>()), &inflight_htlcs(),
+                       Some(&first_hops.iter().collect::<Vec<_>>()), inflight_htlcs(),
                        payment_hash, payment_id,
                ).map_err(|_| RetryableSendFailure::RouteNotFound)?;
 
@@ -712,7 +712,7 @@ impl OutboundPayments {
 
                let route = match router.find_route_with_id(
                        &node_signer.get_node_id(Recipient::Node).unwrap(), &route_params,
-                       Some(&first_hops.iter().collect::<Vec<_>>()), &inflight_htlcs(),
+                       Some(&first_hops.iter().collect::<Vec<_>>()), inflight_htlcs(),
                        payment_hash, payment_id,
                ) {
                        Ok(route) => route,
index b23b6364182da0f8445f1da0f108dddfbf8ad86c..2ae606106c92a7591be57770bf6627a3bac9dc95 100644 (file)
@@ -3218,7 +3218,7 @@ fn do_claim_from_closed_chan(fail_payment: bool) {
                final_value_msat: 10_000_000,
        };
        let mut route = nodes[0].router.find_route(&nodes[0].node.get_our_node_id(), &route_params,
-               None, &nodes[0].node.compute_inflight_htlcs()).unwrap();
+               None, nodes[0].node.compute_inflight_htlcs()).unwrap();
        // Make sure the route is ordered as the B->D path before C->D
        route.paths.sort_by(|a, _| if a.hops[0].pubkey == nodes[1].node.get_our_node_id() {
                std::cmp::Ordering::Less } else { std::cmp::Ordering::Greater });
index 9242e98b1ab5061eca462e31b479f3e34602a8e3..f0822ed5b38a6d69c1f281270ec3502dc2860d4e 100644 (file)
@@ -64,7 +64,7 @@ impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Sco
                payer: &PublicKey,
                params: &RouteParameters,
                first_hops: Option<&[&ChannelDetails]>,
-               inflight_htlcs: &InFlightHtlcs
+               inflight_htlcs: InFlightHtlcs
        ) -> Result<Route, LightningError> {
                let random_seed_bytes = {
                        let mut locked_random_seed_bytes = self.random_seed_bytes.lock().unwrap();
@@ -73,7 +73,7 @@ impl< G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref, SP: Sized, Sc: Sco
                };
                find_route(
                        payer, params, &self.network_graph, first_hops, &*self.logger,
-                       &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), inflight_htlcs),
+                       &ScorerAccountingForInFlightHtlcs::new(self.scorer.lock().deref_mut(), &inflight_htlcs),
                        &self.score_params,
                        &random_seed_bytes
                )
@@ -85,13 +85,13 @@ pub trait Router {
        /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values.
        fn find_route(
                &self, payer: &PublicKey, route_params: &RouteParameters,
-               first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: &InFlightHtlcs
+               first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs
        ) -> Result<Route, LightningError>;
        /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values. Includes
        /// `PaymentHash` and `PaymentId` to be able to correlate the request with a specific payment.
        fn find_route_with_id(
                &self, payer: &PublicKey, route_params: &RouteParameters,
-               first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: &InFlightHtlcs,
+               first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs,
                _payment_hash: PaymentHash, _payment_id: PaymentId
        ) -> Result<Route, LightningError> {
                self.find_route(payer, route_params, first_hops, inflight_htlcs)
@@ -112,7 +112,7 @@ pub struct ScorerAccountingForInFlightHtlcs<'a, S: Score<ScoreParams = SP>, SP:
 
 impl<'a, S: Score<ScoreParams = SP>, SP: Sized> ScorerAccountingForInFlightHtlcs<'a, S, SP> {
        /// Initialize a new `ScorerAccountingForInFlightHtlcs`.
-       pub fn new(scorer:  &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
+       pub fn new(scorer: &'a mut S, inflight_htlcs: &'a InFlightHtlcs) -> Self {
                ScorerAccountingForInFlightHtlcs {
                        scorer,
                        inflight_htlcs
index 3cc08e046283c0480b81952ec98464565b550605..65c0483a59c9906e36b4a42587698362bd77702c 100644 (file)
@@ -112,13 +112,13 @@ impl<'a> TestRouter<'a> {
 impl<'a> Router for TestRouter<'a> {
        fn find_route(
                &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&channelmanager::ChannelDetails]>,
-               inflight_htlcs: &InFlightHtlcs
+               inflight_htlcs: InFlightHtlcs
        ) -> Result<Route, msgs::LightningError> {
                if let Some((find_route_query, find_route_res)) = self.next_routes.lock().unwrap().pop_front() {
                        assert_eq!(find_route_query, *params);
                        if let Ok(ref route) = find_route_res {
                                let mut binding = self.scorer.lock().unwrap();
-                               let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), inflight_htlcs);
+                               let scorer = ScorerAccountingForInFlightHtlcs::new(binding.deref_mut(), &inflight_htlcs);
                                for path in &route.paths {
                                        let mut aggregate_msat = 0u64;
                                        for (idx, hop) in path.hops.iter().rev().enumerate() {