BOLT 12 invoice: expose common helper methods and fields
authorValentine Wallace <vwallace@protonmail.com>
Thu, 9 May 2024 18:23:40 +0000 (14:23 -0400)
committerValentine Wallace <vwallace@protonmail.com>
Tue, 11 Jun 2024 18:55:06 +0000 (14:55 -0400)
Useful for static invoice support.

lightning/src/offers/invoice.rs

index ddda3b2919f16840376feb3e581c5e89671de8d4..1890337585862e67cef449e0a2471267c150c80d 100644 (file)
@@ -200,7 +200,7 @@ pub struct ExplicitSigningPubkey {}
 /// [`Bolt12Invoice::signing_pubkey`] was derived.
 ///
 /// This is not exported to bindings users as builder patterns don't map outside of move semantics.
-pub struct DerivedSigningPubkey(Keypair);
+pub struct DerivedSigningPubkey(pub(super) Keypair);
 
 impl SigningPubkeyStrategy for ExplicitSigningPubkey {}
 impl SigningPubkeyStrategy for DerivedSigningPubkey {}
@@ -958,14 +958,7 @@ impl InvoiceContents {
 
        #[cfg(feature = "std")]
        fn is_expired(&self) -> bool {
-               let absolute_expiry = self.created_at().checked_add(self.relative_expiry());
-               match absolute_expiry {
-                       Some(seconds_from_epoch) => match SystemTime::UNIX_EPOCH.elapsed() {
-                               Ok(elapsed) => elapsed > seconds_from_epoch,
-                               Err(_) => false,
-                       },
-                       None => false,
-               }
+               is_expired(self.created_at(), self.relative_expiry())
        }
 
        fn payment_hash(&self) -> PaymentHash {
@@ -977,36 +970,9 @@ impl InvoiceContents {
        }
 
        fn fallbacks(&self) -> Vec<Address> {
-               let chain = self.chain();
-               let network = if chain == ChainHash::using_genesis_block(Network::Bitcoin) {
-                       Network::Bitcoin
-               } else if chain == ChainHash::using_genesis_block(Network::Testnet) {
-                       Network::Testnet
-               } else if chain == ChainHash::using_genesis_block(Network::Signet) {
-                       Network::Signet
-               } else if chain == ChainHash::using_genesis_block(Network::Regtest) {
-                       Network::Regtest
-               } else {
-                       return Vec::new()
-               };
-
-               let to_valid_address = |address: &FallbackAddress| {
-                       let version = match WitnessVersion::try_from(address.version) {
-                               Ok(version) => version,
-                               Err(_) => return None,
-                       };
-
-                       let program = address.program.clone();
-                       let witness_program = match WitnessProgram::new(version, program) {
-                               Ok(witness_program) => witness_program,
-                               Err(_) => return None,
-                       };
-                       Some(Address::new(network, Payload::WitnessProgram(witness_program)))
-               };
-
                self.fields().fallbacks
                        .as_ref()
-                       .map(|fallbacks| fallbacks.iter().filter_map(to_valid_address).collect())
+                       .map(|fallbacks| filter_fallbacks(self.chain(), fallbacks))
                        .unwrap_or_else(Vec::new)
        }
 
@@ -1075,6 +1041,50 @@ impl InvoiceContents {
        }
 }
 
+#[cfg(feature = "std")]
+pub(super) fn is_expired(created_at: Duration, relative_expiry: Duration) -> bool {
+       let absolute_expiry = created_at.checked_add(relative_expiry);
+       match absolute_expiry {
+               Some(seconds_from_epoch) => match SystemTime::UNIX_EPOCH.elapsed() {
+                       Ok(elapsed) => elapsed > seconds_from_epoch,
+                       Err(_) => false,
+               },
+               None => false,
+       }
+}
+
+pub(super) fn filter_fallbacks(
+       chain: ChainHash, fallbacks: &Vec<FallbackAddress>
+) -> Vec<Address> {
+       let network = if chain == ChainHash::using_genesis_block(Network::Bitcoin) {
+               Network::Bitcoin
+       } else if chain == ChainHash::using_genesis_block(Network::Testnet) {
+               Network::Testnet
+       } else if chain == ChainHash::using_genesis_block(Network::Signet) {
+               Network::Signet
+       } else if chain == ChainHash::using_genesis_block(Network::Regtest) {
+               Network::Regtest
+       } else {
+               return Vec::new()
+       };
+
+       let to_valid_address = |address: &FallbackAddress| {
+               let version = match WitnessVersion::try_from(address.version) {
+                       Ok(version) => version,
+                       Err(_) => return None,
+               };
+
+               let program = address.program.clone();
+               let witness_program = match WitnessProgram::new(version, program) {
+                       Ok(witness_program) => witness_program,
+                       Err(_) => return None,
+               };
+               Some(Address::new(network, Payload::WitnessProgram(witness_program)))
+       };
+
+       fallbacks.iter().filter_map(to_valid_address).collect()
+}
+
 impl InvoiceFields {
        fn as_tlv_stream(&self) -> InvoiceTlvStreamRef {
                let features = {
@@ -1154,12 +1164,12 @@ tlv_stream!(InvoiceTlvStream, InvoiceTlvStreamRef, 160..240, {
        (176, node_id: PublicKey),
 });
 
-type BlindedPathIter<'a> = core::iter::Map<
+pub(super) type BlindedPathIter<'a> = core::iter::Map<
        core::slice::Iter<'a, (BlindedPayInfo, BlindedPath)>,
        for<'r> fn(&'r (BlindedPayInfo, BlindedPath)) -> &'r BlindedPath,
 >;
 
-type BlindedPayInfoIter<'a> = core::iter::Map<
+pub(super) type BlindedPayInfoIter<'a> = core::iter::Map<
        core::slice::Iter<'a, (BlindedPayInfo, BlindedPath)>,
        for<'r> fn(&'r (BlindedPayInfo, BlindedPath)) -> &'r BlindedPayInfo,
 >;
@@ -1205,8 +1215,8 @@ impl_writeable!(BlindedPayInfo, {
 /// Wire representation for an on-chain fallback address.
 #[derive(Clone, Debug, PartialEq)]
 pub(super) struct FallbackAddress {
-       version: u8,
-       program: Vec<u8>,
+       pub(super) version: u8,
+       pub(super) program: Vec<u8>,
 }
 
 impl_writeable!(FallbackAddress, { version, program });
@@ -1294,17 +1304,7 @@ impl TryFrom<PartialInvoiceTlvStream> for InvoiceContents {
                        },
                ) = tlv_stream;
 
-               let payment_paths = match (blindedpay, paths) {
-                       (_, None) => return Err(Bolt12SemanticError::MissingPaths),
-                       (None, _) => return Err(Bolt12SemanticError::InvalidPayInfo),
-                       (_, Some(paths)) if paths.is_empty() => return Err(Bolt12SemanticError::MissingPaths),
-                       (Some(blindedpay), Some(paths)) if paths.len() != blindedpay.len() => {
-                               return Err(Bolt12SemanticError::InvalidPayInfo);
-                       },
-                       (Some(blindedpay), Some(paths)) => {
-                               blindedpay.into_iter().zip(paths.into_iter()).collect::<Vec<_>>()
-                       },
-               };
+               let payment_paths = construct_payment_paths(blindedpay, paths)?;
 
                let created_at = match created_at {
                        None => return Err(Bolt12SemanticError::MissingCreationTime),
@@ -1372,6 +1372,22 @@ impl TryFrom<PartialInvoiceTlvStream> for InvoiceContents {
        }
 }
 
+pub(super) fn construct_payment_paths(
+       blinded_payinfos: Option<Vec<BlindedPayInfo>>, blinded_paths: Option<Vec<BlindedPath>>
+) -> Result<Vec<(BlindedPayInfo, BlindedPath)>, Bolt12SemanticError> {
+       match (blinded_payinfos, blinded_paths) {
+               (_, None) => Err(Bolt12SemanticError::MissingPaths),
+               (None, _) => Err(Bolt12SemanticError::InvalidPayInfo),
+               (_, Some(paths)) if paths.is_empty() => Err(Bolt12SemanticError::MissingPaths),
+               (Some(blindedpay), Some(paths)) if paths.len() != blindedpay.len() => {
+                       Err(Bolt12SemanticError::InvalidPayInfo)
+               },
+               (Some(blindedpay), Some(paths)) => {
+                       Ok(blindedpay.into_iter().zip(paths.into_iter()).collect::<Vec<_>>())
+               },
+       }
+}
+
 #[cfg(test)]
 mod tests {
        use super::{Bolt12Invoice, DEFAULT_RELATIVE_EXPIRY, FallbackAddress, FullInvoiceTlvStreamRef, InvoiceTlvStreamRef, SIGNATURE_TAG, UnsignedBolt12Invoice};