Merge pull request #1940 from TheBlueMatt/2023-01-nostd-try-lock
[rust-lightning] / lightning-invoice / src / payment.rs
index d6c876d0c9fc90b535a1f7e4fe495470c9f2244a..9fb472d3b311fbc872bdab4d5c74455a4f4e6857 100644 (file)
@@ -44,7 +44,7 @@
 //! # use lightning::util::logger::{Logger, Record};
 //! # use lightning::util::ser::{Writeable, Writer};
 //! # use lightning_invoice::Invoice;
-//! # use lightning_invoice::payment::{InvoicePayer, Payer, Retry, ScoringRouter};
+//! # use lightning_invoice::payment::{InvoicePayer, Payer, Retry};
 //! # use secp256k1::PublicKey;
 //! # use std::cell::RefCell;
 //! # use std::ops::Deref;
 //! # impl Router for FakeRouter {
 //! #     fn find_route(
 //! #         &self, payer: &PublicKey, params: &RouteParameters,
-//! #         first_hops: Option<&[&ChannelDetails]>, _inflight_htlcs: InFlightHtlcs
+//! #         first_hops: Option<&[&ChannelDetails]>, _inflight_htlcs: &InFlightHtlcs
 //! #     ) -> Result<Route, LightningError> { unimplemented!() }
-//! # }
-//! # impl ScoringRouter for FakeRouter {
 //! #     fn notify_payment_path_failed(&self, path: &[&RouteHop], short_channel_id: u64) {  unimplemented!() }
 //! #     fn notify_payment_path_successful(&self, path: &[&RouteHop]) {  unimplemented!() }
 //! #     fn notify_payment_probe_successful(&self, path: &[&RouteHop]) {  unimplemented!() }
@@ -146,8 +144,7 @@ use crate::prelude::*;
 use lightning::ln::{PaymentHash, PaymentPreimage, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 use lightning::ln::msgs::LightningError;
-use lightning::routing::router::{InFlightHtlcs, PaymentParameters, Route, RouteHop, RouteParameters, Router};
-use lightning::util::errors::APIError;
+use lightning::routing::router::{InFlightHtlcs, PaymentParameters, Route, RouteParameters, Router};
 use lightning::util::events::{Event, EventHandler};
 use lightning::util::logger::Logger;
 use crate::time_utils::Time;
@@ -187,7 +184,7 @@ mod sealed {
 /// (C-not exported) generally all users should use the [`InvoicePayer`] type alias.
 pub struct InvoicePayerUsingTime<
        P: Deref,
-       R: ScoringRouter,
+       R: Router,
        L: Deref,
        E: sealed::BaseEventHandler,
        T: Time
@@ -200,26 +197,10 @@ pub struct InvoicePayerUsingTime<
        logger: L,
        event_handler: E,
        /// Caches the overall attempts at making a payment, which is updated prior to retrying.
-       payment_cache: Mutex<HashMap<PaymentHash, PaymentInfo<T>>>,
+       payment_cache: Mutex<HashMap<PaymentHash, PaymentAttempts<T>>>,
        retry: Retry,
 }
 
-/// Used by [`InvoicePayerUsingTime::payment_cache`] to track the payments that are either
-/// currently being made, or have outstanding paths that need retrying.
-struct PaymentInfo<T: Time> {
-       attempts: PaymentAttempts<T>,
-       paths: Vec<Vec<RouteHop>>,
-}
-
-impl<T: Time> PaymentInfo<T> {
-       fn new() -> Self {
-               PaymentInfo {
-                       attempts: PaymentAttempts::new(),
-                       paths: vec![],
-               }
-       }
-}
-
 /// Storing minimal payment attempts information required for determining if a outbound payment can
 /// be retried.
 #[derive(Clone, Copy)]
@@ -296,30 +277,6 @@ pub trait Payer {
        fn inflight_htlcs(&self) -> InFlightHtlcs;
 }
 
-/// A trait defining behavior for a [`Router`] implementation that also supports scoring channels
-/// based on payment and probe success/failure.
-///
-/// [`Router`]: lightning::routing::router::Router
-pub trait ScoringRouter: Router {
-       /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values. Includes
-       /// `PaymentHash` and `PaymentId` to be able to correlate the request with a specific payment.
-       fn find_route_with_id(
-               &self, payer: &PublicKey, route_params: &RouteParameters,
-               first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs,
-               _payment_hash: PaymentHash, _payment_id: PaymentId
-       ) -> Result<Route, LightningError> {
-               self.find_route(payer, route_params, first_hops, inflight_htlcs)
-       }
-       /// Lets the router know that payment through a specific path has failed.
-       fn notify_payment_path_failed(&self, path: &[&RouteHop], short_channel_id: u64);
-       /// Lets the router know that payment through a specific path was successful.
-       fn notify_payment_path_successful(&self, path: &[&RouteHop]);
-       /// Lets the router know that a payment probe was successful.
-       fn notify_payment_probe_successful(&self, path: &[&RouteHop]);
-       /// Lets the router know that a payment probe failed.
-       fn notify_payment_probe_failed(&self, path: &[&RouteHop], short_channel_id: u64);
-}
-
 /// Strategies available to retry payment path failures for an [`Invoice`].
 ///
 #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
@@ -359,7 +316,7 @@ pub enum PaymentError {
        Sending(PaymentSendFailure),
 }
 
-impl<P: Deref, R: ScoringRouter, L: Deref, E: sealed::BaseEventHandler, T: Time>
+impl<P: Deref, R: Router, L: Deref, E: sealed::BaseEventHandler, T: Time>
        InvoicePayerUsingTime<P, R, L, E, T>
 where
        P::Target: Payer,
@@ -459,7 +416,7 @@ where
                let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
                match self.payment_cache.lock().unwrap().entry(payment_hash) {
                        hash_map::Entry::Occupied(_) => return Err(PaymentError::Invoice("payment pending")),
-                       hash_map::Entry::Vacant(entry) => entry.insert(PaymentInfo::new()),
+                       hash_map::Entry::Vacant(entry) => entry.insert(PaymentAttempts::new()),
                };
 
                let payment_secret = Some(invoice.payment_secret().clone());
@@ -523,7 +480,7 @@ where
        ) -> Result<(), PaymentError> {
                match self.payment_cache.lock().unwrap().entry(payment_hash) {
                        hash_map::Entry::Occupied(_) => return Err(PaymentError::Invoice("payment pending")),
-                       hash_map::Entry::Vacant(entry) => entry.insert(PaymentInfo::new()),
+                       hash_map::Entry::Vacant(entry) => entry.insert(PaymentAttempts::new()),
                };
 
                let route_params = RouteParameters {
@@ -553,44 +510,27 @@ where
                let first_hops = self.payer.first_hops();
                let inflight_htlcs = self.payer.inflight_htlcs();
                let route = self.router.find_route(
-                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()), inflight_htlcs
+                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()), &inflight_htlcs
                ).map_err(|e| PaymentError::Routing(e))?;
 
                match send_payment(&route) {
-                       Ok(()) => {
-                               for path in route.paths {
-                                       self.process_path_inflight_htlcs(payment_hash, path);
-                               }
-                               Ok(())
-                       },
+                       Ok(()) => Ok(()),
                        Err(e) => match e {
                                PaymentSendFailure::ParameterError(_) => Err(e),
                                PaymentSendFailure::PathParameterError(_) => Err(e),
                                PaymentSendFailure::DuplicatePayment => Err(e),
                                PaymentSendFailure::AllFailedResendSafe(_) => {
                                        let mut payment_cache = self.payment_cache.lock().unwrap();
-                                       let payment_info = payment_cache.get_mut(&payment_hash).unwrap();
-                                       payment_info.attempts.count += 1;
-                                       if self.retry.is_retryable_now(&payment_info.attempts) {
+                                       let payment_attempts = payment_cache.get_mut(&payment_hash).unwrap();
+                                       payment_attempts.count += 1;
+                                       if self.retry.is_retryable_now(payment_attempts) {
                                                core::mem::drop(payment_cache);
                                                Ok(self.pay_internal(params, payment_hash, send_payment)?)
                                        } else {
                                                Err(e)
                                        }
                                },
-                               PaymentSendFailure::PartialFailure { failed_paths_retry, payment_id, results } => {
-                                       // If a `PartialFailure` event returns a result that is an `Ok()`, it means that
-                                       // part of our payment is retried. When we receive `MonitorUpdateInProgress`, it
-                                       // means that we are still waiting for our channel monitor update to be completed.
-                                       for (result, path) in results.iter().zip(route.paths.into_iter()) {
-                                               match result {
-                                                       Ok(_) | Err(APIError::MonitorUpdateInProgress) => {
-                                                               self.process_path_inflight_htlcs(payment_hash, path);
-                                                       },
-                                                       _ => {},
-                                               }
-                                       }
-
+                               PaymentSendFailure::PartialFailure { failed_paths_retry, payment_id, .. } => {
                                        if let Some(retry_data) = failed_paths_retry {
                                                // Some paths were sent, even if we failed to send the full MPP value our
                                                // recipient may misbehave and claim the funds, at which point we have to
@@ -610,36 +550,16 @@ where
                }.map_err(|e| PaymentError::Sending(e))
        }
 
-       // Takes in a path to have its information stored in `payment_cache`. This is done for paths
-       // that are pending retry.
-       fn process_path_inflight_htlcs(&self, payment_hash: PaymentHash, path: Vec<RouteHop>) {
-               self.payment_cache.lock().unwrap().entry(payment_hash)
-                       .or_insert_with(|| PaymentInfo::new())
-                       .paths.push(path);
-       }
-
-       // Find the path we want to remove in `payment_cache`. If it doesn't exist, do nothing.
-       fn remove_path_inflight_htlcs(&self, payment_hash: PaymentHash, path: &Vec<RouteHop>) {
-               self.payment_cache.lock().unwrap().entry(payment_hash)
-                       .and_modify(|payment_info| {
-                               if let Some(idx) = payment_info.paths.iter().position(|p| p == path) {
-                                       payment_info.paths.swap_remove(idx);
-                               }
-                       });
-       }
-
        fn retry_payment(
                &self, payment_id: PaymentId, payment_hash: PaymentHash, params: &RouteParameters
        ) -> Result<(), ()> {
-               let attempts = self.payment_cache.lock().unwrap().entry(payment_hash)
-                       .and_modify(|info| info.attempts.count += 1 )
-                       .or_insert_with(|| PaymentInfo {
-                               attempts: PaymentAttempts {
+               let attempts =
+                       *self.payment_cache.lock().unwrap().entry(payment_hash)
+                               .and_modify(|attempts| attempts.count += 1)
+                               .or_insert(PaymentAttempts {
                                        count: 1,
-                                       first_attempted_at: T::now(),
-                               },
-                               paths: vec![],
-                       }).attempts;
+                                       first_attempted_at: T::now()
+                               });
 
                if !self.retry.is_retryable_now(&attempts) {
                        log_trace!(self.logger, "Payment {} exceeded maximum attempts; not retrying ({})", log_bytes!(payment_hash.0), attempts);
@@ -658,7 +578,7 @@ where
                let inflight_htlcs = self.payer.inflight_htlcs();
 
                let route = self.router.find_route(
-                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()), inflight_htlcs
+                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()), &inflight_htlcs
                );
 
                if route.is_err() {
@@ -667,12 +587,7 @@ where
                }
 
                match self.payer.retry_payment(&route.as_ref().unwrap(), payment_id) {
-                       Ok(()) => {
-                               for path in route.unwrap().paths.into_iter() {
-                                       self.process_path_inflight_htlcs(payment_hash, path);
-                               }
-                               Ok(())
-                       },
+                       Ok(()) => Ok(()),
                        Err(PaymentSendFailure::ParameterError(_)) |
                        Err(PaymentSendFailure::PathParameterError(_)) => {
                                log_trace!(self.logger, "Failed to retry for payment {} due to bogus route/payment data, not retrying.", log_bytes!(payment_hash.0));
@@ -685,19 +600,7 @@ where
                                log_error!(self.logger, "Got a DuplicatePayment error when attempting to retry a payment, this shouldn't happen.");
                                Err(())
                        }
-                       Err(PaymentSendFailure::PartialFailure { failed_paths_retry, results, .. }) => {
-                               // If a `PartialFailure` error contains a result that is an `Ok()`, it means that
-                               // part of our payment is retried. When we receive `MonitorUpdateInProgress`, it
-                               // means that we are still waiting for our channel monitor update to complete.
-                               for (result, path) in results.iter().zip(route.unwrap().paths.into_iter()) {
-                                       match result {
-                                               Ok(_) | Err(APIError::MonitorUpdateInProgress) => {
-                                                       self.process_path_inflight_htlcs(payment_hash, path);
-                                               },
-                                               _ => {},
-                                       }
-                               }
-
+                       Err(PaymentSendFailure::PartialFailure { failed_paths_retry, .. }) => {
                                if let Some(retry) = failed_paths_retry {
                                        // Always return Ok for the same reason as noted in pay_internal.
                                        let _ = self.retry_payment(payment_id, payment_hash, &retry);
@@ -727,7 +630,7 @@ fn has_expired(route_params: &RouteParameters) -> bool {
        } else { false }
 }
 
-impl<P: Deref, R: ScoringRouter, L: Deref, E: sealed::BaseEventHandler, T: Time>
+impl<P: Deref, R: Router, L: Deref, E: sealed::BaseEventHandler, T: Time>
        InvoicePayerUsingTime<P, R, L, E, T>
 where
        P::Target: Payer,
@@ -736,16 +639,6 @@ where
        /// Returns a bool indicating whether the processed event should be forwarded to a user-provided
        /// event handler.
        fn handle_event_internal(&self, event: &Event) -> bool {
-               match event {
-                       Event::PaymentPathFailed { payment_hash, path, ..  }
-                       | Event::PaymentPathSuccessful { path, payment_hash: Some(payment_hash), .. }
-                       | Event::ProbeSuccessful { payment_hash, path, .. }
-                       | Event::ProbeFailed { payment_hash, path, .. } => {
-                               self.remove_path_inflight_htlcs(*payment_hash, path);
-                       },
-                       _ => {},
-               }
-
                match event {
                        Event::PaymentPathFailed {
                                payment_id, payment_hash, payment_failed_permanently, path, short_channel_id, retry, ..
@@ -781,7 +674,7 @@ where
                                let mut payment_cache = self.payment_cache.lock().unwrap();
                                let attempts = payment_cache
                                        .remove(payment_hash)
-                                       .map_or(1, |payment_info| payment_info.attempts.count + 1);
+                                       .map_or(1, |attempts| attempts.count + 1);
                                log_trace!(self.logger, "Payment {} succeeded (attempts: {})", log_bytes!(payment_hash.0), attempts);
                        },
                        Event::ProbeSuccessful { payment_hash, path, .. } => {
@@ -804,7 +697,7 @@ where
        }
 }
 
-impl<P: Deref, R: ScoringRouter, L: Deref, E: EventHandler, T: Time>
+impl<P: Deref, R: Router, L: Deref, E: EventHandler, T: Time>
        EventHandler for InvoicePayerUsingTime<P, R, L, E, T>
 where
        P::Target: Payer,
@@ -818,7 +711,7 @@ where
        }
 }
 
-impl<P: Deref, R: ScoringRouter, L: Deref, T: Time, F: Future, H: Fn(Event) -> F>
+impl<P: Deref, R: Router, L: Deref, T: Time, F: Future, H: Fn(Event) -> F>
        InvoicePayerUsingTime<P, R, L, H, T>
 where
        P::Target: Payer,
@@ -838,7 +731,7 @@ where
 mod tests {
        use super::*;
        use crate::{InvoiceBuilder, Currency};
-       use crate::utils::{ScorerAccountingForInFlightHtlcs, create_invoice_from_channelmanager_and_duration_since_epoch};
+       use crate::utils::create_invoice_from_channelmanager_and_duration_since_epoch;
        use bitcoin_hashes::sha256::Hash as Sha256;
        use lightning::ln::PaymentPreimage;
        use lightning::ln::channelmanager;
@@ -846,7 +739,7 @@ mod tests {
        use lightning::ln::functional_test_utils::*;
        use lightning::ln::msgs::{ChannelMessageHandler, ErrorAction, LightningError};
        use lightning::routing::gossip::{EffectiveCapacity, NodeId};
-       use lightning::routing::router::{InFlightHtlcs, PaymentParameters, Route, RouteHop, Router};
+       use lightning::routing::router::{InFlightHtlcs, PaymentParameters, Route, RouteHop, Router, ScorerAccountingForInFlightHtlcs};
        use lightning::routing::scoring::{ChannelUsage, LockableScore, Score};
        use lightning::util::test_utils::TestLogger;
        use lightning::util::errors::APIError;
@@ -1777,12 +1670,12 @@ mod tests {
        impl Router for TestRouter {
                fn find_route(
                        &self, payer: &PublicKey, route_params: &RouteParameters,
-                       _first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs
+                       _first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: &InFlightHtlcs
                ) -> Result<Route, LightningError> {
                        // Simulate calling the Scorer just as you would in find_route
                        let route = Self::route_for_value(route_params.final_value_msat);
-                       let mut locked_scorer = self.scorer.lock();
-                       let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer.deref_mut(), inflight_htlcs);
+                       let locked_scorer = self.scorer.lock();
+                       let scorer = ScorerAccountingForInFlightHtlcs::new(locked_scorer, inflight_htlcs);
                        for path in route.paths {
                                let mut aggregate_msat = 0u64;
                                for (idx, hop) in path.iter().rev().enumerate() {
@@ -1807,9 +1700,7 @@ mod tests {
                                payment_params: Some(route_params.payment_params.clone()), ..Self::route_for_value(route_params.final_value_msat)
                        })
                }
-       }
 
-       impl ScoringRouter for TestRouter {
                fn notify_payment_path_failed(&self, path: &[&RouteHop], short_channel_id: u64) {
                        self.scorer.lock().payment_path_failed(path, short_channel_id);
                }
@@ -1832,13 +1723,11 @@ mod tests {
        impl Router for FailingRouter {
                fn find_route(
                        &self, _payer: &PublicKey, _params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>,
-                       _inflight_htlcs: InFlightHtlcs,
+                       _inflight_htlcs: &InFlightHtlcs,
                ) -> Result<Route, LightningError> {
                        Err(LightningError { err: String::new(), action: ErrorAction::IgnoreError })
                }
-       }
 
-       impl ScoringRouter for FailingRouter {
                fn notify_payment_path_failed(&self, _path: &[&RouteHop], _short_channel_id: u64) {}
 
                fn notify_payment_path_successful(&self, _path: &[&RouteHop]) {}
@@ -2122,12 +2011,11 @@ mod tests {
        impl Router for ManualRouter {
                fn find_route(
                        &self, _payer: &PublicKey, _params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>,
-                       _inflight_htlcs: InFlightHtlcs
+                       _inflight_htlcs: &InFlightHtlcs
                ) -> Result<Route, LightningError> {
                        self.0.borrow_mut().pop_front().unwrap()
                }
-       }
-       impl ScoringRouter for ManualRouter {
+
                fn notify_payment_path_failed(&self, _path: &[&RouteHop], _short_channel_id: u64) {}
 
                fn notify_payment_path_successful(&self, _path: &[&RouteHop]) {}