Macro-ize InvoiceBuilder
[rust-lightning] / lightning / src / offers / invoice.rs
index 9e5922692ba5ba0c832966e3cee4d06297484a0d..457f4ac3038862165c34a8cf15fbf802638566ac 100644 (file)
@@ -169,7 +169,7 @@ pub struct DerivedSigningPubkey(KeyPair);
 impl SigningPubkeyStrategy for ExplicitSigningPubkey {}
 impl SigningPubkeyStrategy for DerivedSigningPubkey {}
 
-impl<'a> InvoiceBuilder<'a, ExplicitSigningPubkey> {
+macro_rules! invoice_explicit_signing_pubkey_builder_methods { ($self: ident, $self_type: ty) => {
        pub(super) fn for_offer(
                invoice_request: &'a InvoiceRequest, payment_paths: Vec<(BlindedPayInfo, BlindedPath)>,
                created_at: Duration, payment_hash: PaymentHash
@@ -203,25 +203,25 @@ impl<'a> InvoiceBuilder<'a, ExplicitSigningPubkey> {
 
        /// Builds an unsigned [`Bolt12Invoice`] after checking for valid semantics. It can be signed by
        /// [`UnsignedBolt12Invoice::sign`].
-       pub fn build(self) -> Result<UnsignedBolt12Invoice, Bolt12SemanticError> {
+       pub fn build($self: $self_type) -> Result<UnsignedBolt12Invoice, Bolt12SemanticError> {
                #[cfg(feature = "std")] {
-                       if self.invoice.is_offer_or_refund_expired() {
+                       if $self.invoice.is_offer_or_refund_expired() {
                                return Err(Bolt12SemanticError::AlreadyExpired);
                        }
                }
 
                #[cfg(not(feature = "std"))] {
-                       if self.invoice.is_offer_or_refund_expired_no_std(self.invoice.created_at()) {
+                       if $self.invoice.is_offer_or_refund_expired_no_std($self.invoice.created_at()) {
                                return Err(Bolt12SemanticError::AlreadyExpired);
                        }
                }
 
-               let InvoiceBuilder { invreq_bytes, invoice, .. } = self;
+               let InvoiceBuilder { invreq_bytes, invoice, .. } = $self;
                Ok(UnsignedBolt12Invoice::new(invreq_bytes, invoice))
        }
-}
+} }
 
-impl<'a> InvoiceBuilder<'a, DerivedSigningPubkey> {
+macro_rules! invoice_derived_signing_pubkey_builder_methods { ($self: ident, $self_type: ty) => {
        pub(super) fn for_offer_using_keys(
                invoice_request: &'a InvoiceRequest, payment_paths: Vec<(BlindedPayInfo, BlindedPath)>,
                created_at: Duration, payment_hash: PaymentHash, keys: KeyPair
@@ -256,23 +256,23 @@ impl<'a> InvoiceBuilder<'a, DerivedSigningPubkey> {
 
        /// Builds a signed [`Bolt12Invoice`] after checking for valid semantics.
        pub fn build_and_sign<T: secp256k1::Signing>(
-               self, secp_ctx: &Secp256k1<T>
+               $self: $self_type, secp_ctx: &Secp256k1<T>
        ) -> Result<Bolt12Invoice, Bolt12SemanticError> {
                #[cfg(feature = "std")] {
-                       if self.invoice.is_offer_or_refund_expired() {
+                       if $self.invoice.is_offer_or_refund_expired() {
                                return Err(Bolt12SemanticError::AlreadyExpired);
                        }
                }
 
                #[cfg(not(feature = "std"))] {
-                       if self.invoice.is_offer_or_refund_expired_no_std(self.invoice.created_at()) {
+                       if $self.invoice.is_offer_or_refund_expired_no_std($self.invoice.created_at()) {
                                return Err(Bolt12SemanticError::AlreadyExpired);
                        }
                }
 
                let InvoiceBuilder {
                        invreq_bytes, invoice, signing_pubkey_strategy: DerivedSigningPubkey(keys)
-               } = self;
+               } = $self;
                let unsigned_invoice = UnsignedBolt12Invoice::new(invreq_bytes, invoice);
 
                let invoice = unsigned_invoice
@@ -282,9 +282,11 @@ impl<'a> InvoiceBuilder<'a, DerivedSigningPubkey> {
                        .unwrap();
                Ok(invoice)
        }
-}
+} }
 
-impl<'a, S: SigningPubkeyStrategy> InvoiceBuilder<'a, S> {
+macro_rules! invoice_builder_methods { (
+       $self: ident, $self_type: ty, $return_type: ty, $return_value: expr
+) => {
        pub(crate) fn amount_msats(
                invoice_request: &InvoiceRequest
        ) -> Result<u64, Bolt12SemanticError> {
@@ -326,57 +328,69 @@ impl<'a, S: SigningPubkeyStrategy> InvoiceBuilder<'a, S> {
        /// [`Bolt12Invoice::is_expired`].
        ///
        /// Successive calls to this method will override the previous setting.
-       pub fn relative_expiry(mut self, relative_expiry_secs: u32) -> Self {
+       pub fn relative_expiry(mut $self: $self_type, relative_expiry_secs: u32) -> $return_type {
                let relative_expiry = Duration::from_secs(relative_expiry_secs as u64);
-               self.invoice.fields_mut().relative_expiry = Some(relative_expiry);
-               self
+               $self.invoice.fields_mut().relative_expiry = Some(relative_expiry);
+               $return_value
        }
 
        /// Adds a P2WSH address to [`Bolt12Invoice::fallbacks`].
        ///
        /// Successive calls to this method will add another address. Caller is responsible for not
        /// adding duplicate addresses and only calling if capable of receiving to P2WSH addresses.
-       pub fn fallback_v0_p2wsh(mut self, script_hash: &WScriptHash) -> Self {
+       pub fn fallback_v0_p2wsh(mut $self: $self_type, script_hash: &WScriptHash) -> $return_type {
                let address = FallbackAddress {
                        version: WitnessVersion::V0.to_num(),
                        program: Vec::from(script_hash.to_byte_array()),
                };
-               self.invoice.fields_mut().fallbacks.get_or_insert_with(Vec::new).push(address);
-               self
+               $self.invoice.fields_mut().fallbacks.get_or_insert_with(Vec::new).push(address);
+               $return_value
        }
 
        /// Adds a P2WPKH address to [`Bolt12Invoice::fallbacks`].
        ///
        /// Successive calls to this method will add another address. Caller is responsible for not
        /// adding duplicate addresses and only calling if capable of receiving to P2WPKH addresses.
-       pub fn fallback_v0_p2wpkh(mut self, pubkey_hash: &WPubkeyHash) -> Self {
+       pub fn fallback_v0_p2wpkh(mut $self: $self_type, pubkey_hash: &WPubkeyHash) -> $return_type {
                let address = FallbackAddress {
                        version: WitnessVersion::V0.to_num(),
                        program: Vec::from(pubkey_hash.to_byte_array()),
                };
-               self.invoice.fields_mut().fallbacks.get_or_insert_with(Vec::new).push(address);
-               self
+               $self.invoice.fields_mut().fallbacks.get_or_insert_with(Vec::new).push(address);
+               $return_value
        }
 
        /// Adds a P2TR address to [`Bolt12Invoice::fallbacks`].
        ///
        /// Successive calls to this method will add another address. Caller is responsible for not
        /// adding duplicate addresses and only calling if capable of receiving to P2TR addresses.
-       pub fn fallback_v1_p2tr_tweaked(mut self, output_key: &TweakedPublicKey) -> Self {
+       pub fn fallback_v1_p2tr_tweaked(mut $self: $self_type, output_key: &TweakedPublicKey) -> $return_type {
                let address = FallbackAddress {
                        version: WitnessVersion::V1.to_num(),
                        program: Vec::from(&output_key.serialize()[..]),
                };
-               self.invoice.fields_mut().fallbacks.get_or_insert_with(Vec::new).push(address);
-               self
+               $self.invoice.fields_mut().fallbacks.get_or_insert_with(Vec::new).push(address);
+               $return_value
        }
 
        /// Sets [`Bolt12Invoice::invoice_features`] to indicate MPP may be used. Otherwise, MPP is
        /// disallowed.
-       pub fn allow_mpp(mut self) -> Self {
-               self.invoice.fields_mut().features.set_basic_mpp_optional();
-               self
+       pub fn allow_mpp(mut $self: $self_type) -> $return_type {
+               $self.invoice.fields_mut().features.set_basic_mpp_optional();
+               $return_value
        }
+} }
+
+impl<'a> InvoiceBuilder<'a, ExplicitSigningPubkey> {
+       invoice_explicit_signing_pubkey_builder_methods!(self, Self);
+}
+
+impl<'a> InvoiceBuilder<'a, DerivedSigningPubkey> {
+       invoice_derived_signing_pubkey_builder_methods!(self, Self);
+}
+
+impl<'a, S: SigningPubkeyStrategy> InvoiceBuilder<'a, S> {
+       invoice_builder_methods!(self, Self, Self, self);
 }
 
 /// A semantically valid [`Bolt12Invoice`] that hasn't been signed.
@@ -412,32 +426,38 @@ impl UnsignedBolt12Invoice {
        pub fn tagged_hash(&self) -> &TaggedHash {
                &self.tagged_hash
        }
+}
 
+macro_rules! unsigned_invoice_sign_method { ($self: ident, $self_type: ty) => {
        /// Signs the [`TaggedHash`] of the invoice using the given function.
        ///
        /// Note: The hash computation may have included unknown, odd TLV records.
        ///
        /// This is not exported to bindings users as functions aren't currently mapped.
-       pub fn sign<F, E>(mut self, sign: F) -> Result<Bolt12Invoice, SignError<E>>
+       pub fn sign<F, E>(mut $self: $self_type, sign: F) -> Result<Bolt12Invoice, SignError<E>>
        where
                F: FnOnce(&Self) -> Result<Signature, E>
        {
-               let pubkey = self.contents.fields().signing_pubkey;
-               let signature = merkle::sign_message(sign, &self, pubkey)?;
+               let pubkey = $self.contents.fields().signing_pubkey;
+               let signature = merkle::sign_message(sign, &$self, pubkey)?;
 
                // Append the signature TLV record to the bytes.
                let signature_tlv_stream = SignatureTlvStreamRef {
                        signature: Some(&signature),
                };
-               signature_tlv_stream.write(&mut self.bytes).unwrap();
+               signature_tlv_stream.write(&mut $self.bytes).unwrap();
 
                Ok(Bolt12Invoice {
-                       bytes: self.bytes,
-                       contents: self.contents,
+                       bytes: $self.bytes,
+                       contents: $self.contents,
                        signature,
-                       tagged_hash: self.tagged_hash,
+                       tagged_hash: $self.tagged_hash,
                })
        }
+} }
+
+impl UnsignedBolt12Invoice {
+       unsigned_invoice_sign_method!(self, Self);
 }
 
 impl AsRef<TaggedHash> for UnsignedBolt12Invoice {