X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Foffers%2Fmerkle.rs;h=f7c33902c51441cd3e2a358637c9ddb4d2cc31e6;hb=4fb5708eec5a0683039c7877a0b3d452e21735c9;hp=94a1eac0ca416bac20ad7a456c224d84706a73d0;hpb=f30dc859e72763618249b1a1686229013178d057;p=rust-lightning diff --git a/lightning/src/offers/merkle.rs b/lightning/src/offers/merkle.rs index 94a1eac0..f7c33902 100644 --- a/lightning/src/offers/merkle.rs +++ b/lightning/src/offers/merkle.rs @@ -143,28 +143,38 @@ fn tagged_branch_hash_from_engine( /// [`Iterator`] over a sequence of bytes yielding [`TlvRecord`]s. The input is assumed to be a /// well-formed TLV stream. -struct TlvStream<'a> { +#[derive(Clone)] +pub(super) struct TlvStream<'a> { data: io::Cursor<&'a [u8]>, } impl<'a> TlvStream<'a> { - fn new(data: &'a [u8]) -> Self { + pub fn new(data: &'a [u8]) -> Self { Self { data: io::Cursor::new(data), } } + pub fn range(self, types: T) -> impl core::iter::Iterator> + where + T: core::ops::RangeBounds + Clone, + { + let take_range = types.clone(); + self.skip_while(move |record| !types.contains(&record.r#type)) + .take_while(move |record| take_range.contains(&record.r#type)) + } + fn skip_signatures(self) -> core::iter::Filter, fn(&TlvRecord) -> bool> { self.filter(|record| !SIGNATURE_TYPES.contains(&record.r#type)) } } /// A slice into a [`TlvStream`] for a record. -struct TlvRecord<'a> { - r#type: u64, +pub(super) struct TlvRecord<'a> { + pub(super) r#type: u64, type_bytes: &'a [u8], // The entire TLV record. - record_bytes: &'a [u8], + pub(super) record_bytes: &'a [u8], } impl<'a> Iterator for TlvStream<'a> { @@ -212,10 +222,11 @@ impl<'a> Writeable for WithoutSignatures<'a> { #[cfg(test)] mod tests { - use super::{TlvStream, WithoutSignatures}; + use super::{SIGNATURE_TYPES, TlvStream, WithoutSignatures}; use bitcoin::hashes::{Hash, sha256}; use bitcoin::secp256k1::{KeyPair, Secp256k1, SecretKey}; + use bitcoin::secp256k1::schnorr::Signature; use core::convert::Infallible; use crate::offers::offer::{Amount, OfferBuilder}; use crate::offers::invoice_request::InvoiceRequest; @@ -270,6 +281,10 @@ mod tests { super::root_hash(&invoice_request.bytes[..]), sha256::Hash::from_slice(&hex::decode("608407c18ad9a94d9ea2bcdbe170b6c20c462a7833a197621c916f78cf18e624").unwrap()).unwrap(), ); + assert_eq!( + invoice_request.signature(), + Signature::from_slice(&hex::decode("b8f83ea3288cfd6ea510cdb481472575141e8d8744157f98562d162cc1c472526fdb24befefbdebab4dbb726bbd1b7d8aec057f8fa805187e5950d2bbe0e5642").unwrap()).unwrap(), + ); } #[test] @@ -302,6 +317,38 @@ mod tests { ); } + #[test] + fn iterates_over_tlv_stream_range() { + let secp_ctx = Secp256k1::new(); + let recipient_pubkey = { + let secret_key = SecretKey::from_slice(&[41; 32]).unwrap(); + KeyPair::from_secret_key(&secp_ctx, &secret_key).public_key() + }; + let payer_keys = { + let secret_key = SecretKey::from_slice(&[42; 32]).unwrap(); + KeyPair::from_secret_key(&secp_ctx, &secret_key) + }; + + let invoice_request = OfferBuilder::new("foo".into(), recipient_pubkey) + .amount_msats(100) + .build_unchecked() + .request_invoice(vec![0; 8], payer_keys.public_key()).unwrap() + .build_unchecked() + .sign::<_, Infallible>(|digest| Ok(secp_ctx.sign_schnorr_no_aux_rand(digest, &payer_keys))) + .unwrap(); + + let tlv_stream = TlvStream::new(&invoice_request.bytes).range(0..1) + .chain(TlvStream::new(&invoice_request.bytes).range(1..80)) + .chain(TlvStream::new(&invoice_request.bytes).range(80..160)) + .chain(TlvStream::new(&invoice_request.bytes).range(160..240)) + .chain(TlvStream::new(&invoice_request.bytes).range(SIGNATURE_TYPES)) + .map(|r| r.record_bytes.to_vec()) + .flatten() + .collect::>(); + + assert_eq!(tlv_stream, invoice_request.bytes); + } + impl AsRef<[u8]> for InvoiceRequest { fn as_ref(&self) -> &[u8] { &self.bytes