Merge pull request #1726 from jkczyz/2022-09-offer-parsing
[rust-lightning] / lightning / src / util / ser_macros.rs
index 9a11dac7a4922bfb5e6ddf2911d9974f2a1e9470..3e1d8a9280d43da344076854739a5a81436f450b 100644 (file)
@@ -201,6 +201,17 @@ macro_rules! decode_tlv {
 // `Ok(false)` if the message type is unknown, and `Err(DecodeError)` if parsing fails.
 macro_rules! decode_tlv_stream {
        ($stream: expr, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}
+        $(, $decode_custom_tlv: expr)?) => { {
+               let rewind = |_, _| { unreachable!() };
+               use core::ops::RangeBounds;
+               decode_tlv_stream_range!(
+                       $stream, .., rewind, {$(($type, $field, $fieldty)),*} $(, $decode_custom_tlv)?
+               );
+       } }
+}
+
+macro_rules! decode_tlv_stream_range {
+       ($stream: expr, $range: expr, $rewind: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}
         $(, $decode_custom_tlv: expr)?) => { {
                use $crate::ln::msgs::DecodeError;
                let mut last_seen_type: Option<u64> = None;
@@ -215,7 +226,7 @@ macro_rules! decode_tlv_stream {
                                // UnexpectedEof. This should in every case be largely cosmetic, but its nice to
                                // pass the TLV test vectors exactly, which requre this distinction.
                                let mut tracking_reader = ser::ReadTrackingReader::new(&mut stream_ref);
-                               match $crate::util::ser::Readable::read(&mut tracking_reader) {
+                               match <$crate::util::ser::BigSize as $crate::util::ser::Readable>::read(&mut tracking_reader) {
                                        Err(DecodeError::ShortRead) => {
                                                if !tracking_reader.have_read {
                                                        break 'tlv_read;
@@ -224,7 +235,15 @@ macro_rules! decode_tlv_stream {
                                                }
                                        },
                                        Err(e) => return Err(e),
-                                       Ok(t) => t,
+                                       Ok(t) => if $range.contains(&t.0) { t } else {
+                                               drop(tracking_reader);
+
+                                               // Assumes the type id is minimally encoded, which is enforced on read.
+                                               use $crate::util::ser::Writeable;
+                                               let bytes_read = t.serialized_length();
+                                               $rewind(stream_ref, bytes_read);
+                                               break 'tlv_read;
+                                       },
                                }
                        };
 
@@ -481,11 +500,11 @@ macro_rules! impl_writeable_tlv_based {
 /// [`Readable`]: crate::util::ser::Readable
 /// [`Writeable`]: crate::util::ser::Writeable
 macro_rules! tlv_stream {
-       ($name:ident, $nameref:ident, {
+       ($name:ident, $nameref:ident, $range:expr, {
                $(($type:expr, $field:ident : $fieldty:tt)),* $(,)*
        }) => {
                #[derive(Debug)]
-               struct $name {
+               pub(crate) struct $name {
                        $(
                                $field: Option<tlv_record_type!($fieldty)>,
                        )*
@@ -506,12 +525,15 @@ macro_rules! tlv_stream {
                        }
                }
 
-               impl $crate::util::ser::Readable for $name {
-                       fn read<R: $crate::io::Read>(reader: &mut R) -> Result<Self, $crate::ln::msgs::DecodeError> {
+               impl $crate::util::ser::SeekReadable for $name {
+                       fn read<R: $crate::io::Read + $crate::io::Seek>(reader: &mut R) -> Result<Self, $crate::ln::msgs::DecodeError> {
                                $(
                                        init_tlv_field_var!($field, option);
                                )*
-                               decode_tlv_stream!(reader, {
+                               let rewind = |cursor: &mut R, offset: usize| {
+                                       cursor.seek($crate::io::SeekFrom::Current(-(offset as i64))).expect("");
+                               };
+                               decode_tlv_stream_range!(reader, $range, rewind, {
                                        $(($type, $field, (option, encoding: $fieldty))),*
                                });