Introduce `RouteParameters::max_total_routing_fee_msat`
[rust-lightning] / lightning / src / ln / outbound_payment.rs
index 023412e1afb56cdcdd0a2d9195064e5744ccab6c..2f1e3b0dd0e3d4358f7b43b5e01d1706f3ec005f 100644 (file)
@@ -54,10 +54,14 @@ pub(crate) enum PendingOutboundPayment {
        AwaitingInvoice {
                timer_ticks_without_response: u8,
                retry_strategy: Retry,
+               max_total_routing_fee_msat: Option<u64>,
        },
        InvoiceReceived {
                payment_hash: PaymentHash,
                retry_strategy: Retry,
+               // Note this field is currently just replicated from AwaitingInvoice but not actually
+               // used anywhere.
+               max_total_routing_fee_msat: Option<u64>,
        },
        Retryable {
                retry_strategy: Option<Retry>,
@@ -76,6 +80,7 @@ pub(crate) enum PendingOutboundPayment {
                total_msat: u64,
                /// Our best known block height at the time this payment was initiated.
                starting_block_height: u32,
+               remaining_max_total_routing_fee_msat: Option<u64>,
        },
        /// When a pending payment is fulfilled, we continue tracking it until all pending HTLCs have
        /// been resolved. This ensures we don't look up pending payments in ChannelMonitors on restart
@@ -731,12 +736,15 @@ impl OutboundPayments {
                SP: Fn(SendAlongPathArgs) -> Result<(), APIError>,
        {
                let payment_hash = invoice.payment_hash();
+               let mut max_total_routing_fee_msat = None;
                match self.pending_outbound_payments.lock().unwrap().entry(payment_id) {
                        hash_map::Entry::Occupied(entry) => match entry.get() {
-                               PendingOutboundPayment::AwaitingInvoice { retry_strategy, .. } => {
+                               PendingOutboundPayment::AwaitingInvoice { retry_strategy, max_total_routing_fee_msat: max_total_fee, .. } => {
+                                       max_total_routing_fee_msat = *max_total_fee;
                                        *entry.into_mut() = PendingOutboundPayment::InvoiceReceived {
                                                payment_hash,
                                                retry_strategy: *retry_strategy,
+                                               max_total_routing_fee_msat,
                                        };
                                },
                                _ => return Err(Bolt12PaymentError::DuplicateInvoice),
@@ -747,6 +755,7 @@ impl OutboundPayments {
                let route_params = RouteParameters {
                        payment_params: PaymentParameters::from_bolt12_invoice(&invoice),
                        final_value_msat: invoice.amount_msats(),
+                       max_total_routing_fee_msat,
                };
 
                self.find_route_and_send_payment(
@@ -779,11 +788,12 @@ impl OutboundPayments {
                        let mut retry_id_route_params = None;
                        for (pmt_id, pmt) in outbounds.iter_mut() {
                                if pmt.is_auto_retryable_now() {
-                                       if let PendingOutboundPayment::Retryable { pending_amt_msat, total_msat, payment_params: Some(params), payment_hash, .. } = pmt {
+                                       if let PendingOutboundPayment::Retryable { pending_amt_msat, total_msat, payment_params: Some(params), payment_hash, remaining_max_total_routing_fee_msat, .. } = pmt {
                                                if pending_amt_msat < total_msat {
                                                        retry_id_route_params = Some((*payment_hash, *pmt_id, RouteParameters {
                                                                final_value_msat: *total_msat - *pending_amt_msat,
                                                                payment_params: params.clone(),
+                                                               max_total_routing_fee_msat: *remaining_max_total_routing_fee_msat,
                                                        }));
                                                        break
                                                }
@@ -987,7 +997,7 @@ impl OutboundPayments {
                                                        log_error!(logger, "Payment not yet sent");
                                                        return
                                                },
-                                               PendingOutboundPayment::InvoiceReceived { payment_hash, retry_strategy } => {
+                                               PendingOutboundPayment::InvoiceReceived { payment_hash, retry_strategy, .. } => {
                                                        let total_amount = route_params.final_value_msat;
                                                        let recipient_onion = RecipientOnionFields {
                                                                payment_secret: None,
@@ -1207,6 +1217,8 @@ impl OutboundPayments {
                        custom_tlvs: recipient_onion.custom_tlvs,
                        starting_block_height: best_block_height,
                        total_msat: route.get_total_amount(),
+                       remaining_max_total_routing_fee_msat:
+                               route.route_params.as_ref().and_then(|p| p.max_total_routing_fee_msat),
                };
 
                for (path, session_priv_bytes) in route.paths.iter().zip(onion_session_privs.iter()) {
@@ -1218,7 +1230,7 @@ impl OutboundPayments {
 
        #[allow(unused)]
        pub(super) fn add_new_awaiting_invoice(
-               &self, payment_id: PaymentId, retry_strategy: Retry
+               &self, payment_id: PaymentId, retry_strategy: Retry, max_total_routing_fee_msat: Option<u64>
        ) -> Result<(), ()> {
                let mut pending_outbounds = self.pending_outbound_payments.lock().unwrap();
                match pending_outbounds.entry(payment_id) {
@@ -1227,6 +1239,7 @@ impl OutboundPayments {
                                entry.insert(PendingOutboundPayment::AwaitingInvoice {
                                        timer_ticks_without_response: 0,
                                        retry_strategy,
+                                       max_total_routing_fee_msat,
                                });
 
                                Ok(())
@@ -1328,8 +1341,9 @@ impl OutboundPayments {
                                failed_paths_retry: if pending_amt_unsent != 0 {
                                        if let Some(payment_params) = route.route_params.as_ref().map(|p| p.payment_params.clone()) {
                                                Some(RouteParameters {
-                                                       payment_params: payment_params,
+                                                       payment_params,
                                                        final_value_msat: pending_amt_unsent,
+                                                       max_total_routing_fee_msat: None,
                                                })
                                        } else { None }
                                } else { None },
@@ -1689,6 +1703,7 @@ impl_writeable_tlv_based_enum_upgradable!(PendingOutboundPayment,
                (8, pending_amt_msat, required),
                (9, custom_tlvs, optional_vec),
                (10, starting_block_height, required),
+               (11, remaining_max_total_routing_fee_msat, option),
                (not_written, retry_strategy, (static_value, None)),
                (not_written, attempts, (static_value, PaymentAttempts::new())),
        },
@@ -1700,10 +1715,12 @@ impl_writeable_tlv_based_enum_upgradable!(PendingOutboundPayment,
        (5, AwaitingInvoice) => {
                (0, timer_ticks_without_response, required),
                (2, retry_strategy, required),
+               (4, max_total_routing_fee_msat, option),
        },
        (7, InvoiceReceived) => {
                (0, payment_hash, required),
                (2, retry_strategy, required),
+               (4, max_total_routing_fee_msat, option),
        },
 );
 
@@ -1926,7 +1943,9 @@ mod tests {
                let payment_id = PaymentId([0; 32]);
 
                assert!(!outbound_payments.has_pending_payments());
-               assert!(outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0)).is_ok());
+               assert!(
+                       outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0), None).is_ok()
+               );
                assert!(outbound_payments.has_pending_payments());
 
                for _ in 0..INVOICE_REQUEST_TIMEOUT_TICKS {
@@ -1944,10 +1963,15 @@ mod tests {
                );
                assert!(pending_events.lock().unwrap().is_empty());
 
-               assert!(outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0)).is_ok());
+               assert!(
+                       outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0), None).is_ok()
+               );
                assert!(outbound_payments.has_pending_payments());
 
-               assert!(outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0)).is_err());
+               assert!(
+                       outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0), None)
+                               .is_err()
+               );
        }
 
        #[test]
@@ -1957,7 +1981,9 @@ mod tests {
                let payment_id = PaymentId([0; 32]);
 
                assert!(!outbound_payments.has_pending_payments());
-               assert!(outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0)).is_ok());
+               assert!(
+                       outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0), None).is_ok()
+               );
                assert!(outbound_payments.has_pending_payments());
 
                outbound_payments.abandon_payment(
@@ -1985,7 +2011,9 @@ mod tests {
                let outbound_payments = OutboundPayments::new();
                let payment_id = PaymentId([0; 32]);
 
-               assert!(outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0)).is_ok());
+               assert!(
+                       outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0), None).is_ok()
+               );
                assert!(outbound_payments.has_pending_payments());
 
                let created_at = now() - DEFAULT_RELATIVE_EXPIRY;
@@ -2031,7 +2059,9 @@ mod tests {
                let outbound_payments = OutboundPayments::new();
                let payment_id = PaymentId([0; 32]);
 
-               assert!(outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0)).is_ok());
+               assert!(
+                       outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0), None).is_ok()
+               );
                assert!(outbound_payments.has_pending_payments());
 
                let invoice = OfferBuilder::new("foo".into(), recipient_pubkey())
@@ -2045,10 +2075,10 @@ mod tests {
                        .sign(recipient_sign).unwrap();
 
                router.expect_find_route(
-                       RouteParameters {
-                               payment_params: PaymentParameters::from_bolt12_invoice(&invoice),
-                               final_value_msat: invoice.amount_msats(),
-                       },
+                       RouteParameters::from_payment_params_and_value(
+                               PaymentParameters::from_bolt12_invoice(&invoice),
+                               invoice.amount_msats(),
+                       ),
                        Err(LightningError { err: String::new(), action: ErrorAction::IgnoreError }),
                );
 
@@ -2084,7 +2114,9 @@ mod tests {
                let outbound_payments = OutboundPayments::new();
                let payment_id = PaymentId([0; 32]);
 
-               assert!(outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0)).is_ok());
+               assert!(
+                       outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0), None).is_ok()
+               );
                assert!(outbound_payments.has_pending_payments());
 
                let invoice = OfferBuilder::new("foo".into(), recipient_pubkey())
@@ -2097,10 +2129,10 @@ mod tests {
                        .build().unwrap()
                        .sign(recipient_sign).unwrap();
 
-               let route_params = RouteParameters {
-                       payment_params: PaymentParameters::from_bolt12_invoice(&invoice),
-                       final_value_msat: invoice.amount_msats(),
-               };
+               let route_params = RouteParameters::from_payment_params_and_value(
+                       PaymentParameters::from_bolt12_invoice(&invoice),
+                       invoice.amount_msats(),
+               );
                router.expect_find_route(
                        route_params.clone(), Ok(Route { paths: vec![], route_params: Some(route_params) })
                );
@@ -2150,6 +2182,7 @@ mod tests {
                let route_params = RouteParameters {
                        payment_params: PaymentParameters::from_bolt12_invoice(&invoice),
                        final_value_msat: invoice.amount_msats(),
+                       max_total_routing_fee_msat: Some(1234),
                };
                router.expect_find_route(
                        route_params.clone(),
@@ -2185,7 +2218,9 @@ mod tests {
                assert!(!outbound_payments.has_pending_payments());
                assert!(pending_events.lock().unwrap().is_empty());
 
-               assert!(outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0)).is_ok());
+               assert!(
+                       outbound_payments.add_new_awaiting_invoice(payment_id, Retry::Attempts(0), Some(1234)).is_ok()
+               );
                assert!(outbound_payments.has_pending_payments());
 
                assert_eq!(