]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Introduce RetryableInvoiceRequest in AwaitingInvoice
authorshaavan <shaavan.github@gmail.com>
Mon, 10 Jun 2024 11:52:09 +0000 (17:22 +0530)
committershaavan <shaavan.github@gmail.com>
Thu, 12 Sep 2024 13:17:33 +0000 (18:47 +0530)
1. To enable the retry of the Invoice Request message, it's necessary
   to store the essential data required to recreate the message.
2. A new struct is introduced to manage this data, ensuring the
   InvoiceRequest message can be reliably recreated for retries.
3. The addition of an `awaiting_invoice` flag allows tracking of
   retryable invoice requests, preventing the need to lock the
   `pending_outbound_payment` mutex.

lightning/src/ln/channelmanager.rs
lightning/src/ln/outbound_payment.rs
lightning/src/offers/invoice_request.rs

index 6414d8a28368014790dd5fd4831950557328c828..d82d35f542cfbd9a75cc3b5299d5d55247788b21 100644 (file)
@@ -61,7 +61,7 @@ use crate::ln::onion_utils::{HTLCFailReason, INVALID_ONION_BLINDING};
 use crate::ln::msgs::{ChannelMessageHandler, DecodeError, LightningError};
 #[cfg(test)]
 use crate::ln::outbound_payment;
-use crate::ln::outbound_payment::{OutboundPayments, PaymentAttempts, PendingOutboundPayment, SendAlongPathArgs, StaleExpiration};
+use crate::ln::outbound_payment::{OutboundPayments, PaymentAttempts, PendingOutboundPayment, RetryableInvoiceRequest, SendAlongPathArgs, StaleExpiration};
 use crate::ln::wire::Encode;
 use crate::offers::invoice::{Bolt12Invoice, DEFAULT_RELATIVE_EXPIRY, DerivedSigningPubkey, ExplicitSigningPubkey, InvoiceBuilder, UnsignedBolt12Invoice};
 use crate::offers::invoice_error::InvoiceError;
@@ -3105,7 +3105,7 @@ where
 
                        outbound_scid_aliases: Mutex::new(new_hash_set()),
                        pending_inbound_payments: Mutex::new(new_hash_map()),
-                       pending_outbound_payments: OutboundPayments::new(),
+                       pending_outbound_payments: OutboundPayments::new(new_hash_map()),
                        forward_htlcs: Mutex::new(new_hash_map()),
                        decode_update_add_htlcs: Mutex::new(new_hash_map()),
                        claimable_payments: Mutex::new(ClaimablePayments { claimable_payments: new_hash_map(), pending_claiming_payments: new_hash_map() }),
@@ -9005,7 +9005,7 @@ macro_rules! create_refund_builder { ($self: ident, $builder: ty) => {
                let expiration = StaleExpiration::AbsoluteTimeout(absolute_expiry);
                $self.pending_outbound_payments
                        .add_new_awaiting_invoice(
-                               payment_id, expiration, retry_strategy, max_total_routing_fee_msat,
+                               payment_id, expiration, retry_strategy, max_total_routing_fee_msat, None,
                        )
                        .map_err(|_| Bolt12SemanticError::DuplicatePaymentId)?;
 
@@ -9131,9 +9131,14 @@ where
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(self);
 
                let expiration = StaleExpiration::TimerTicks(1);
+               let retryable_invoice_request = RetryableInvoiceRequest {
+                       invoice_request: invoice_request.clone(),
+                       nonce,
+               };
                self.pending_outbound_payments
                        .add_new_awaiting_invoice(
-                               payment_id, expiration, retry_strategy, max_total_routing_fee_msat
+                               payment_id, expiration, retry_strategy, max_total_routing_fee_msat,
+                               Some(retryable_invoice_request)
                        )
                        .map_err(|_| Bolt12SemanticError::DuplicatePaymentId)?;
 
@@ -12227,10 +12232,7 @@ where
                        }
                        pending_outbound_payments = Some(outbounds);
                }
-               let pending_outbounds = OutboundPayments {
-                       pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()),
-                       retry_lock: Mutex::new(())
-               };
+               let pending_outbounds = OutboundPayments::new(pending_outbound_payments.unwrap());
 
                // We have to replay (or skip, if they were completed after we wrote the `ChannelManager`)
                // each `ChannelMonitorUpdate` in `in_flight_monitor_updates`. After doing so, we have to
index 69e97aa35730e84cd3efac831a3bd22fad48c3d3..ca0d7c17d99b722bbab622b03dec303f32ec21cd 100644 (file)
@@ -22,6 +22,8 @@ use crate::ln::features::Bolt12InvoiceFeatures;
 use crate::ln::onion_utils;
 use crate::ln::onion_utils::{DecodedOnionFailure, HTLCFailReason};
 use crate::offers::invoice::Bolt12Invoice;
+use crate::offers::invoice_request::InvoiceRequest;
+use crate::offers::nonce::Nonce;
 use crate::routing::router::{BlindedTail, InFlightHtlcs, Path, PaymentParameters, Route, RouteParameters, Router};
 use crate::sign::{EntropySource, NodeSigner, Recipient};
 use crate::util::errors::APIError;
@@ -32,6 +34,7 @@ use crate::util::ser::ReadableArgs;
 
 use core::fmt::{self, Display, Formatter};
 use core::ops::Deref;
+use core::sync::atomic::{AtomicBool, Ordering};
 use core::time::Duration;
 
 use crate::prelude::*;
@@ -53,6 +56,7 @@ pub(crate) enum PendingOutboundPayment {
                expiration: StaleExpiration,
                retry_strategy: Retry,
                max_total_routing_fee_msat: Option<u64>,
+               retryable_invoice_request: Option<RetryableInvoiceRequest>
        },
        InvoiceReceived {
                payment_hash: PaymentHash,
@@ -100,6 +104,16 @@ pub(crate) enum PendingOutboundPayment {
        },
 }
 
+pub(crate) struct RetryableInvoiceRequest {
+       pub(crate) invoice_request: InvoiceRequest,
+       pub(crate) nonce: Nonce,
+}
+
+impl_writeable_tlv_based!(RetryableInvoiceRequest, {
+       (0, invoice_request, required),
+       (2, nonce, required),
+});
+
 impl PendingOutboundPayment {
        fn increment_attempts(&mut self) {
                if let PendingOutboundPayment::Retryable { attempts, .. } = self {
@@ -666,13 +680,19 @@ pub(super) struct SendAlongPathArgs<'a> {
 
 pub(super) struct OutboundPayments {
        pub(super) pending_outbound_payments: Mutex<HashMap<PaymentId, PendingOutboundPayment>>,
-       pub(super) retry_lock: Mutex<()>,
+       awaiting_invoice: AtomicBool,
+       retry_lock: Mutex<()>,
 }
 
 impl OutboundPayments {
-       pub(super) fn new() -> Self {
+       pub(super) fn new(pending_outbound_payments: HashMap<PaymentId, PendingOutboundPayment>) -> Self {
+               let has_invoice_requests = pending_outbound_payments.values().any(|payment| {
+                       matches!(payment, PendingOutboundPayment::AwaitingInvoice { retryable_invoice_request: Some(_), .. })
+               });
+
                Self {
-                       pending_outbound_payments: Mutex::new(new_hash_map()),
+                       pending_outbound_payments: Mutex::new(pending_outbound_payments),
+                       awaiting_invoice: AtomicBool::new(has_invoice_requests),
                        retry_lock: Mutex::new(()),
                }
        }
@@ -1393,16 +1413,20 @@ impl OutboundPayments {
 
        pub(super) fn add_new_awaiting_invoice(
                &self, payment_id: PaymentId, expiration: StaleExpiration, retry_strategy: Retry,
-               max_total_routing_fee_msat: Option<u64>
+               max_total_routing_fee_msat: Option<u64>, retryable_invoice_request: Option<RetryableInvoiceRequest>
        ) -> Result<(), ()> {
                let mut pending_outbounds = self.pending_outbound_payments.lock().unwrap();
                match pending_outbounds.entry(payment_id) {
                        hash_map::Entry::Occupied(_) => Err(()),
                        hash_map::Entry::Vacant(entry) => {
+                               if retryable_invoice_request.is_some() {
+                                       self.awaiting_invoice.store(true, Ordering::Release);
+                               }
                                entry.insert(PendingOutboundPayment::AwaitingInvoice {
                                        expiration,
                                        retry_strategy,
                                        max_total_routing_fee_msat,
+                                       retryable_invoice_request,
                                });
 
                                Ok(())
@@ -1874,6 +1898,31 @@ impl OutboundPayments {
        pub fn clear_pending_payments(&self) {
                self.pending_outbound_payments.lock().unwrap().clear()
        }
+
+       pub fn release_invoice_requests_awaiting_invoice(&self) -> Vec<(PaymentId, RetryableInvoiceRequest)> {
+               if !self.awaiting_invoice.load(Ordering::Acquire) {
+                       return vec![];
+               }
+
+               let mut pending_outbound_payments = self.pending_outbound_payments.lock().unwrap();
+               let invoice_requests = pending_outbound_payments
+                       .iter_mut()
+                       .filter_map(|(payment_id, payment)| {
+                               if let PendingOutboundPayment::AwaitingInvoice {
+                                       retryable_invoice_request, ..
+                               } = payment {
+                                       retryable_invoice_request.take().map(|retryable_invoice_request| {
+                                               (*payment_id, retryable_invoice_request)
+                                       })
+                               } else {
+                                       None
+                               }
+                       })
+                       .collect();
+
+               self.awaiting_invoice.store(false, Ordering::Release);
+               invoice_requests
+       }
 }
 
 /// Returns whether a payment with the given [`PaymentHash`] and [`PaymentId`] is, in fact, a
@@ -1929,6 +1978,7 @@ impl_writeable_tlv_based_enum_upgradable!(PendingOutboundPayment,
                (0, expiration, required),
                (2, retry_strategy, required),
                (4, max_total_routing_fee_msat, option),
+               (5, retryable_invoice_request, option),
        },
        (7, InvoiceReceived) => {
                (0, payment_hash, required),
@@ -1959,6 +2009,7 @@ mod tests {
        use crate::routing::router::{InFlightHtlcs, Path, PaymentParameters, Route, RouteHop, RouteParameters};
        use crate::sync::{Arc, Mutex, RwLock};
        use crate::util::errors::APIError;
+       use crate::util::hash_tables::new_hash_map;
        use crate::util::test_utils;
 
        use alloc::collections::VecDeque;
@@ -1993,7 +2044,7 @@ mod tests {
        }
        #[cfg(feature = "std")]
        fn do_fails_paying_after_expiration(on_retry: bool) {
-               let outbound_payments = OutboundPayments::new();
+               let outbound_payments = OutboundPayments::new(new_hash_map());
                let logger = test_utils::TestLogger::new();
                let network_graph = Arc::new(NetworkGraph::new(Network::Testnet, &logger));
                let scorer = RwLock::new(test_utils::TestScorer::new());
@@ -2037,7 +2088,7 @@ mod tests {
                do_find_route_error(true);
        }
        fn do_find_route_error(on_retry: bool) {
-               let outbound_payments = OutboundPayments::new();
+               let outbound_payments = OutboundPayments::new(new_hash_map());
                let logger = test_utils::TestLogger::new();
                let network_graph = Arc::new(NetworkGraph::new(Network::Testnet, &logger));
                let scorer = RwLock::new(test_utils::TestScorer::new());
@@ -2076,7 +2127,7 @@ mod tests {
 
        #[test]
        fn initial_send_payment_path_failed_evs() {
-               let outbound_payments = OutboundPayments::new();
+               let outbound_payments = OutboundPayments::new(new_hash_map());
                let logger = test_utils::TestLogger::new();
                let network_graph = Arc::new(NetworkGraph::new(Network::Testnet, &logger));
                let scorer = RwLock::new(test_utils::TestScorer::new());
@@ -2158,7 +2209,7 @@ mod tests {
        #[test]
        fn removes_stale_awaiting_invoice_using_absolute_timeout() {
                let pending_events = Mutex::new(VecDeque::new());
-               let outbound_payments = OutboundPayments::new();
+               let outbound_payments = OutboundPayments::new(new_hash_map());
                let payment_id = PaymentId([0; 32]);
                let absolute_expiry = 100;
                let tick_interval = 10;
@@ -2167,7 +2218,7 @@ mod tests {
                assert!(!outbound_payments.has_pending_payments());
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
-                               payment_id, expiration, Retry::Attempts(0), None
+                               payment_id, expiration, Retry::Attempts(0), None, None,
                        ).is_ok()
                );
                assert!(outbound_payments.has_pending_payments());
@@ -2197,14 +2248,14 @@ mod tests {
 
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
-                               payment_id, expiration, Retry::Attempts(0), None
+                               payment_id, expiration, Retry::Attempts(0), None, None,
                        ).is_ok()
                );
                assert!(outbound_payments.has_pending_payments());
 
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
-                               payment_id, expiration, Retry::Attempts(0), None
+                               payment_id, expiration, Retry::Attempts(0), None, None,
                        ).is_err()
                );
        }
@@ -2212,7 +2263,7 @@ mod tests {
        #[test]
        fn removes_stale_awaiting_invoice_using_timer_ticks() {
                let pending_events = Mutex::new(VecDeque::new());
-               let outbound_payments = OutboundPayments::new();
+               let outbound_payments = OutboundPayments::new(new_hash_map());
                let payment_id = PaymentId([0; 32]);
                let timer_ticks = 3;
                let expiration = StaleExpiration::TimerTicks(timer_ticks);
@@ -2220,7 +2271,7 @@ mod tests {
                assert!(!outbound_payments.has_pending_payments());
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
-                               payment_id, expiration, Retry::Attempts(0), None
+                               payment_id, expiration, Retry::Attempts(0), None, None,
                        ).is_ok()
                );
                assert!(outbound_payments.has_pending_payments());
@@ -2250,14 +2301,14 @@ mod tests {
 
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
-                               payment_id, expiration, Retry::Attempts(0), None
+                               payment_id, expiration, Retry::Attempts(0), None, None,
                        ).is_ok()
                );
                assert!(outbound_payments.has_pending_payments());
 
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
-                               payment_id, expiration, Retry::Attempts(0), None
+                               payment_id, expiration, Retry::Attempts(0), None, None,
                        ).is_err()
                );
        }
@@ -2265,14 +2316,14 @@ mod tests {
        #[test]
        fn removes_abandoned_awaiting_invoice() {
                let pending_events = Mutex::new(VecDeque::new());
-               let outbound_payments = OutboundPayments::new();
+               let outbound_payments = OutboundPayments::new(new_hash_map());
                let payment_id = PaymentId([0; 32]);
                let expiration = StaleExpiration::AbsoluteTimeout(Duration::from_secs(100));
 
                assert!(!outbound_payments.has_pending_payments());
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
-                               payment_id, expiration, Retry::Attempts(0), None
+                               payment_id, expiration, Retry::Attempts(0), None, None,
                        ).is_ok()
                );
                assert!(outbound_payments.has_pending_payments());
@@ -2302,13 +2353,13 @@ mod tests {
                let keys_manager = test_utils::TestKeysInterface::new(&[0; 32], Network::Testnet);
 
                let pending_events = Mutex::new(VecDeque::new());
-               let outbound_payments = OutboundPayments::new();
+               let outbound_payments = OutboundPayments::new(new_hash_map());
                let payment_id = PaymentId([0; 32]);
                let expiration = StaleExpiration::AbsoluteTimeout(Duration::from_secs(100));
 
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
-                               payment_id, expiration, Retry::Attempts(0), None
+                               payment_id, expiration, Retry::Attempts(0), None, None,
                        ).is_ok()
                );
                assert!(outbound_payments.has_pending_payments());
@@ -2355,7 +2406,7 @@ mod tests {
                let keys_manager = test_utils::TestKeysInterface::new(&[0; 32], Network::Testnet);
 
                let pending_events = Mutex::new(VecDeque::new());
-               let outbound_payments = OutboundPayments::new();
+               let outbound_payments = OutboundPayments::new(new_hash_map());
                let payment_id = PaymentId([0; 32]);
                let expiration = StaleExpiration::AbsoluteTimeout(Duration::from_secs(100));
 
@@ -2372,7 +2423,7 @@ mod tests {
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
                                payment_id, expiration, Retry::Attempts(0),
-                               Some(invoice.amount_msats() / 100 + 50_000)
+                               Some(invoice.amount_msats() / 100 + 50_000), None,
                        ).is_ok()
                );
                assert!(outbound_payments.has_pending_payments());
@@ -2416,7 +2467,7 @@ mod tests {
                let keys_manager = test_utils::TestKeysInterface::new(&[0; 32], Network::Testnet);
 
                let pending_events = Mutex::new(VecDeque::new());
-               let outbound_payments = OutboundPayments::new();
+               let outbound_payments = OutboundPayments::new(new_hash_map());
                let payment_id = PaymentId([0; 32]);
                let expiration = StaleExpiration::AbsoluteTimeout(Duration::from_secs(100));
 
@@ -2472,7 +2523,7 @@ mod tests {
 
                assert!(
                        outbound_payments.add_new_awaiting_invoice(
-                               payment_id, expiration, Retry::Attempts(0), Some(1234)
+                               payment_id, expiration, Retry::Attempts(0), Some(1234), None,
                        ).is_ok()
                );
                assert!(outbound_payments.has_pending_payments());
index fa3d9161b8221567a694c4792632f4e3fb8df63b..32d05249cfad048b7117ff576c4e7c6eef5ab40f 100644 (file)
@@ -1039,6 +1039,13 @@ impl Writeable for InvoiceRequestContents {
        }
 }
 
+impl Readable for InvoiceRequest {
+       fn read<R: io::Read>(reader: &mut R) -> Result<Self, DecodeError> {
+               let bytes: WithoutLength<Vec<u8>> = Readable::read(reader)?;
+               Self::try_from(bytes.0).map_err(|_| DecodeError::InvalidValue)
+       }
+}
+
 /// Valid type range for invoice_request TLV records.
 pub(super) const INVOICE_REQUEST_TYPES: core::ops::Range<u64> = 80..160;