Include Offer context in blinded payment paths
[rust-lightning] / lightning / src / blinded_path / payment.rs
index 6467af568886bd1891a7e8c52ae6b5956736f6bd..99979ecf4a37140ef5f7375fdc2c0cc67e0eaa7b 100644 (file)
@@ -12,10 +12,11 @@ use crate::ln::channelmanager::CounterpartyForwardingInfo;
 use crate::ln::features::BlindedHopFeatures;
 use crate::ln::msgs::DecodeError;
 use crate::offers::invoice::BlindedPayInfo;
-use crate::prelude::*;
+use crate::offers::offer::OfferId;
 use crate::util::ser::{HighZeroBytesDroppedBigSize, Readable, Writeable, Writer};
 
-use core::convert::TryFrom;
+#[allow(unused_imports)]
+use crate::prelude::*;
 
 /// An intermediate node, its outbound channel, and relay parameters.
 #[derive(Clone, Debug)]
@@ -53,6 +54,8 @@ pub struct ReceiveTlvs {
        pub payment_secret: PaymentSecret,
        /// Constraints for the receiver of this payment.
        pub payment_constraints: PaymentConstraints,
+       /// Context for the receiver of this payment.
+       pub payment_context: PaymentContext,
 }
 
 /// Data to construct a [`BlindedHop`] for sending a payment over.
@@ -97,6 +100,43 @@ pub struct PaymentConstraints {
        pub htlc_minimum_msat: u64,
 }
 
+/// The context of an inbound payment, which is included in a [`BlindedPath`] via [`ReceiveTlvs`]
+/// and surfaced in [`PaymentPurpose`].
+///
+/// [`BlindedPath`]: crate::blinded_path::BlindedPath
+/// [`PaymentPurpose`]: crate::events::PaymentPurpose
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum PaymentContext {
+       /// The payment context was unknown.
+       Unknown(UnknownPaymentContext),
+
+       /// The payment was made for an invoice requested from a BOLT 12 [`Offer`].
+       ///
+       /// [`Offer`]: crate::offers::offer::Offer
+       Bolt12Offer(Bolt12OfferContext),
+}
+
+/// An unknown payment context.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct UnknownPaymentContext(());
+
+/// The context of a payment made for an invoice requested from a BOLT 12 [`Offer`].
+///
+/// [`Offer`]: crate::offers::offer::Offer
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Bolt12OfferContext {
+       /// The identifier of the [`Offer`].
+       ///
+       /// [`Offer`]: crate::offers::offer::Offer
+       pub offer_id: OfferId,
+}
+
+impl PaymentContext {
+       pub(crate) fn unknown() -> Self {
+               PaymentContext::Unknown(UnknownPaymentContext(()))
+       }
+}
+
 impl TryFrom<CounterpartyForwardingInfo> for PaymentRelay {
        type Error = ();
 
@@ -137,7 +177,8 @@ impl Writeable for ReceiveTlvs {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
                encode_tlv_stream!(w, {
                        (12, self.payment_constraints, required),
-                       (65536, self.payment_secret, required)
+                       (65536, self.payment_secret, required),
+                       (65537, self.payment_context, required)
                });
                Ok(())
        }
@@ -163,11 +204,14 @@ impl Readable for BlindedPaymentTlvs {
                        (12, payment_constraints, required),
                        (14, features, option),
                        (65536, payment_secret, option),
+                       (65537, payment_context, (default_value, PaymentContext::unknown())),
                });
                let _padding: Option<utils::Padding> = _padding;
 
                if let Some(short_channel_id) = scid {
-                       if payment_secret.is_some() { return Err(DecodeError::InvalidValue) }
+                       if payment_secret.is_some() {
+                               return Err(DecodeError::InvalidValue)
+                       }
                        Ok(BlindedPaymentTlvs::Forward(ForwardTlvs {
                                short_channel_id,
                                payment_relay: payment_relay.ok_or(DecodeError::InvalidValue)?,
@@ -179,6 +223,7 @@ impl Readable for BlindedPaymentTlvs {
                        Ok(BlindedPaymentTlvs::Receive(ReceiveTlvs {
                                payment_secret: payment_secret.ok_or(DecodeError::InvalidValue)?,
                                payment_constraints: payment_constraints.0.unwrap(),
+                               payment_context: payment_context.0.unwrap(),
                        }))
                }
        }
@@ -309,10 +354,32 @@ impl Readable for PaymentConstraints {
        }
 }
 
+impl_writeable_tlv_based_enum!(PaymentContext,
+       ;
+       (0, Unknown),
+       (1, Bolt12Offer),
+);
+
+impl Writeable for UnknownPaymentContext {
+       fn write<W: Writer>(&self, _w: &mut W) -> Result<(), io::Error> {
+               Ok(())
+       }
+}
+
+impl Readable for UnknownPaymentContext {
+       fn read<R: io::Read>(_r: &mut R) -> Result<Self, DecodeError> {
+               Ok(UnknownPaymentContext(()))
+       }
+}
+
+impl_writeable_tlv_based!(Bolt12OfferContext, {
+       (0, offer_id, required),
+});
+
 #[cfg(test)]
 mod tests {
        use bitcoin::secp256k1::PublicKey;
-       use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, ReceiveTlvs, PaymentConstraints, PaymentRelay};
+       use crate::blinded_path::payment::{ForwardNode, ForwardTlvs, ReceiveTlvs, PaymentConstraints, PaymentContext, PaymentRelay};
        use crate::ln::PaymentSecret;
        use crate::ln::features::BlindedHopFeatures;
        use crate::ln::functional_test_utils::TEST_FINAL_CLTV;
@@ -361,6 +428,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 1,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
                let htlc_maximum_msat = 100_000;
                let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, 12).unwrap();
@@ -379,6 +447,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 1,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
                let blinded_payinfo = super::compute_payinfo(&[], &recv_tlvs, 4242, TEST_FINAL_CLTV as u16).unwrap();
                assert_eq!(blinded_payinfo.fee_base_msat, 0);
@@ -432,6 +501,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 3,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
                let htlc_maximum_msat = 100_000;
                let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_maximum_msat, TEST_FINAL_CLTV as u16).unwrap();
@@ -482,6 +552,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 1,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
                let htlc_minimum_msat = 3798;
                assert!(super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, htlc_minimum_msat - 1, TEST_FINAL_CLTV as u16).is_err());
@@ -536,6 +607,7 @@ mod tests {
                                max_cltv_expiry: 0,
                                htlc_minimum_msat: 1,
                        },
+                       payment_context: PaymentContext::unknown(),
                };
 
                let blinded_payinfo = super::compute_payinfo(&intermediate_nodes[..], &recv_tlvs, 10_000, TEST_FINAL_CLTV as u16).unwrap();