]> git.bitcoin.ninja Git - rust-lightning/commitdiff
TlvStream range iterator
authorJeffrey Czyz <jkczyz@gmail.com>
Wed, 25 Jan 2023 17:34:43 +0000 (11:34 -0600)
committerJeffrey Czyz <jkczyz@gmail.com>
Thu, 20 Apr 2023 02:08:05 +0000 (21:08 -0500)
Add an iterator that yields TlvRecords over a range of a TlvStream.
Useful for verifying that, e.g., an InvoiceRequest was sent in response
to an Offer constructed by the intended recipient.

lightning/src/offers/merkle.rs

index 94a1eac0ca416bac20ad7a456c224d84706a73d0..f682746742050749463c181340fdad72f3d2d554 100644 (file)
@@ -143,28 +143,37 @@ fn tagged_branch_hash_from_engine(
 
 /// [`Iterator`] over a sequence of bytes yielding [`TlvRecord`]s. The input is assumed to be a
 /// well-formed TLV stream.
-struct TlvStream<'a> {
+pub(super) struct TlvStream<'a> {
        data: io::Cursor<&'a [u8]>,
 }
 
 impl<'a> TlvStream<'a> {
-       fn new(data: &'a [u8]) -> Self {
+       pub fn new(data: &'a [u8]) -> Self {
                Self {
                        data: io::Cursor::new(data),
                }
        }
 
+       pub fn range<T>(self, types: T) -> impl core::iter::Iterator<Item = TlvRecord<'a>>
+       where
+               T: core::ops::RangeBounds<u64> + Clone,
+       {
+               let take_range = types.clone();
+               self.skip_while(move |record| !types.contains(&record.r#type))
+                       .take_while(move |record| take_range.contains(&record.r#type))
+       }
+
        fn skip_signatures(self) -> core::iter::Filter<TlvStream<'a>, fn(&TlvRecord) -> bool> {
                self.filter(|record| !SIGNATURE_TYPES.contains(&record.r#type))
        }
 }
 
 /// A slice into a [`TlvStream`] for a record.
-struct TlvRecord<'a> {
-       r#type: u64,
+pub(super) struct TlvRecord<'a> {
+       pub(super) r#type: u64,
        type_bytes: &'a [u8],
        // The entire TLV record.
-       record_bytes: &'a [u8],
+       pub(super) record_bytes: &'a [u8],
 }
 
 impl<'a> Iterator for TlvStream<'a> {
@@ -212,7 +221,7 @@ impl<'a> Writeable for WithoutSignatures<'a> {
 
 #[cfg(test)]
 mod tests {
-       use super::{TlvStream, WithoutSignatures};
+       use super::{SIGNATURE_TYPES, TlvStream, WithoutSignatures};
 
        use bitcoin::hashes::{Hash, sha256};
        use bitcoin::secp256k1::{KeyPair, Secp256k1, SecretKey};
@@ -302,6 +311,38 @@ mod tests {
                );
        }
 
+       #[test]
+       fn iterates_over_tlv_stream_range() {
+               let secp_ctx = Secp256k1::new();
+               let recipient_pubkey = {
+                       let secret_key = SecretKey::from_slice(&[41; 32]).unwrap();
+                       KeyPair::from_secret_key(&secp_ctx, &secret_key).public_key()
+               };
+               let payer_keys = {
+                       let secret_key = SecretKey::from_slice(&[42; 32]).unwrap();
+                       KeyPair::from_secret_key(&secp_ctx, &secret_key)
+               };
+
+               let invoice_request = OfferBuilder::new("foo".into(), recipient_pubkey)
+                       .amount_msats(100)
+                       .build_unchecked()
+                       .request_invoice(vec![0; 8], payer_keys.public_key()).unwrap()
+                       .build_unchecked()
+                       .sign::<_, Infallible>(|digest| Ok(secp_ctx.sign_schnorr_no_aux_rand(digest, &payer_keys)))
+                       .unwrap();
+
+               let tlv_stream = TlvStream::new(&invoice_request.bytes).range(0..1)
+                       .chain(TlvStream::new(&invoice_request.bytes).range(1..80))
+                       .chain(TlvStream::new(&invoice_request.bytes).range(80..160))
+                       .chain(TlvStream::new(&invoice_request.bytes).range(160..240))
+                       .chain(TlvStream::new(&invoice_request.bytes).range(SIGNATURE_TYPES))
+                       .map(|r| r.record_bytes.to_vec())
+                       .flatten()
+                       .collect::<Vec<u8>>();
+
+               assert_eq!(tlv_stream, invoice_request.bytes);
+       }
+
        impl AsRef<[u8]> for InvoiceRequest {
                fn as_ref(&self) -> &[u8] {
                        &self.bytes