Test utils: allow queueing >2 persistence update results
[rust-lightning] / lightning / src / util / ser_macros.rs
index 3ec0848680ffe62a2532e03d47d1fc5b62d7f4b8..d37a841eaff31945ecc2af27d624ca8570552006 100644 (file)
@@ -11,6 +11,9 @@ macro_rules! encode_tlv {
        ($stream: expr, $type: expr, $field: expr, (default_value, $default: expr)) => {
                encode_tlv!($stream, $type, $field, required)
        };
+       ($stream: expr, $type: expr, $field: expr, (static_value, $value: expr)) => {
+               let _ = &$field; // Ensure we "use" the $field
+       };
        ($stream: expr, $type: expr, $field: expr, required) => {
                BigSize($type).write($stream)?;
                BigSize($field.serialized_length() as u64).write($stream)?;
@@ -34,6 +37,17 @@ macro_rules! encode_tlv {
        };
 }
 
+
+macro_rules! check_encoded_tlv_order {
+       ($last_type: expr, $type: expr, (static_value, $value: expr)) => { };
+       ($last_type: expr, $type: expr, $fieldty: tt) => {
+               if let Some(t) = $last_type {
+                       debug_assert!(t <= $type);
+               }
+               $last_type = Some($type);
+       };
+}
+
 macro_rules! encode_tlv_stream {
        ($stream: expr, {$(($type: expr, $field: expr, $fieldty: tt)),* $(,)*}) => { {
                #[allow(unused_imports)]
@@ -52,10 +66,7 @@ macro_rules! encode_tlv_stream {
                {
                        let mut last_seen: Option<u64> = None;
                        $(
-                               if let Some(t) = last_seen {
-                                       debug_assert!(t <= $type);
-                               }
-                               last_seen = Some($type);
+                               check_encoded_tlv_order!(last_seen, $type, $fieldty);
                        )*
                }
        } }
@@ -65,6 +76,8 @@ macro_rules! get_varint_length_prefixed_tlv_length {
        ($len: expr, $type: expr, $field: expr, (default_value, $default: expr)) => {
                get_varint_length_prefixed_tlv_length!($len, $type, $field, required)
        };
+       ($len: expr, $type: expr, $field: expr, (static_value, $value: expr)) => {
+       };
        ($len: expr, $type: expr, $field: expr, required) => {
                BigSize($type).write(&mut $len).expect("No in-memory data may fail to serialize");
                let field_len = $field.serialized_length();
@@ -100,7 +113,7 @@ macro_rules! encode_varint_length_prefixed_tlv {
        } }
 }
 
-macro_rules! check_tlv_order {
+macro_rules! check_decoded_tlv_order {
        ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (default_value, $default: expr)) => {{
                #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true
                let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type;
@@ -108,6 +121,8 @@ macro_rules! check_tlv_order {
                        $field = $default.into();
                }
        }};
+       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (static_value, $value: expr)) => {
+       };
        ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, required) => {{
                #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true
                let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type;
@@ -140,6 +155,9 @@ macro_rules! check_missing_tlv {
                        $field = $default.into();
                }
        }};
+       ($last_seen_type: expr, $type: expr, $field: expr, (static_value, $value: expr)) => {
+               $field = $value;
+       };
        ($last_seen_type: expr, $type: expr, $field: ident, required) => {{
                #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true
                let missing_req_type = $last_seen_type.is_none() || $last_seen_type.unwrap() < $type;
@@ -168,6 +186,8 @@ macro_rules! decode_tlv {
        ($reader: expr, $field: ident, (default_value, $default: expr)) => {{
                decode_tlv!($reader, $field, required)
        }};
+       ($reader: expr, $field: ident, (static_value, $value: expr)) => {{
+       }};
        ($reader: expr, $field: ident, required) => {{
                $field = $crate::util::ser::Readable::read(&mut $reader)?;
        }};
@@ -195,12 +215,28 @@ macro_rules! decode_tlv {
        }};
 }
 
+macro_rules! _decode_tlv_stream_match_check {
+       ($val: ident, $type: expr, (static_value, $value: expr)) => { false };
+       ($val: ident, $type: expr, $fieldty: tt) => { $val == $type }
+}
+
 // `$decode_custom_tlv` is a closure that may be optionally provided to handle custom message types.
 // If it is provided, it will be called with the custom type and the `FixedLengthReader` containing
 // the message contents. It should return `Ok(true)` if the custom message is successfully parsed,
 // `Ok(false)` if the message type is unknown, and `Err(DecodeError)` if parsing fails.
 macro_rules! decode_tlv_stream {
        ($stream: expr, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}
+        $(, $decode_custom_tlv: expr)?) => { {
+               let rewind = |_, _| { unreachable!() };
+               use core::ops::RangeBounds;
+               decode_tlv_stream_range!(
+                       $stream, .., rewind, {$(($type, $field, $fieldty)),*} $(, $decode_custom_tlv)?
+               );
+       } }
+}
+
+macro_rules! decode_tlv_stream_range {
+       ($stream: expr, $range: expr, $rewind: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}
         $(, $decode_custom_tlv: expr)?) => { {
                use $crate::ln::msgs::DecodeError;
                let mut last_seen_type: Option<u64> = None;
@@ -215,7 +251,7 @@ macro_rules! decode_tlv_stream {
                                // UnexpectedEof. This should in every case be largely cosmetic, but its nice to
                                // pass the TLV test vectors exactly, which requre this distinction.
                                let mut tracking_reader = ser::ReadTrackingReader::new(&mut stream_ref);
-                               match $crate::util::ser::Readable::read(&mut tracking_reader) {
+                               match <$crate::util::ser::BigSize as $crate::util::ser::Readable>::read(&mut tracking_reader) {
                                        Err(DecodeError::ShortRead) => {
                                                if !tracking_reader.have_read {
                                                        break 'tlv_read;
@@ -224,7 +260,15 @@ macro_rules! decode_tlv_stream {
                                                }
                                        },
                                        Err(e) => return Err(e),
-                                       Ok(t) => t,
+                                       Ok(t) => if $range.contains(&t.0) { t } else {
+                                               drop(tracking_reader);
+
+                                               // Assumes the type id is minimally encoded, which is enforced on read.
+                                               use $crate::util::ser::Writeable;
+                                               let bytes_read = t.serialized_length();
+                                               $rewind(stream_ref, bytes_read);
+                                               break 'tlv_read;
+                                       },
                                }
                        };
 
@@ -237,7 +281,7 @@ macro_rules! decode_tlv_stream {
                        }
                        // As we read types, make sure we hit every required type:
                        $({
-                               check_tlv_order!(last_seen_type, typ, $type, $field, $fieldty);
+                               check_decoded_tlv_order!(last_seen_type, typ, $type, $field, $fieldty);
                        })*
                        last_seen_type = Some(typ.0);
 
@@ -245,7 +289,7 @@ macro_rules! decode_tlv_stream {
                        let length: ser::BigSize = $crate::util::ser::Readable::read(&mut stream_ref)?;
                        let mut s = ser::FixedLengthReader::new(&mut stream_ref, length.0);
                        match typ.0 {
-                               $($type => {
+                               $(_t if _decode_tlv_stream_match_check!(_t, $type, $fieldty) => {
                                        decode_tlv!(s, $field, $fieldty);
                                        if s.bytes_remain() {
                                                s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes
@@ -386,6 +430,9 @@ macro_rules! init_tlv_based_struct_field {
        ($field: ident, (default_value, $default: expr)) => {
                $field.0.unwrap()
        };
+       ($field: ident, (static_value, $value: expr)) => {
+               $field
+       };
        ($field: ident, option) => {
                $field
        };
@@ -401,6 +448,9 @@ macro_rules! init_tlv_field_var {
        ($field: ident, (default_value, $default: expr)) => {
                let mut $field = $crate::util::ser::OptionDeserWrapper(None);
        };
+       ($field: ident, (static_value, $value: expr)) => {
+               let $field;
+       };
        ($field: ident, required) => {
                let mut $field = $crate::util::ser::OptionDeserWrapper(None);
        };
@@ -412,6 +462,18 @@ macro_rules! init_tlv_field_var {
        };
 }
 
+macro_rules! init_and_read_tlv_fields {
+       ($reader: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => {
+               $(
+                       init_tlv_field_var!($field, $fieldty);
+               )*
+
+               read_tlv_fields!($reader, {
+                       $(($type, $field, $fieldty)),*
+               });
+       }
+}
+
 /// Implements Readable/Writeable for a struct storing it as a set of TLVs
 /// If $fieldty is `required`, then $field is a required field that is not an Option nor a Vec.
 /// If $fieldty is `option`, then $field is optional field.
@@ -446,10 +508,7 @@ macro_rules! impl_writeable_tlv_based {
 
                impl $crate::util::ser::Readable for $st {
                        fn read<R: $crate::io::Read>(reader: &mut R) -> Result<Self, $crate::ln::msgs::DecodeError> {
-                               $(
-                                       init_tlv_field_var!($field, $fieldty);
-                               )*
-                               read_tlv_fields!(reader, {
+                               init_and_read_tlv_fields!(reader, {
                                        $(($type, $field, $fieldty)),*
                                });
                                Ok(Self {
@@ -472,19 +531,20 @@ macro_rules! impl_writeable_tlv_based {
 /// [`Readable`]: crate::util::ser::Readable
 /// [`Writeable`]: crate::util::ser::Writeable
 macro_rules! tlv_stream {
-       ($name:ident, $nameref:ident, {
+       ($name:ident, $nameref:ident, $range:expr, {
                $(($type:expr, $field:ident : $fieldty:tt)),* $(,)*
        }) => {
                #[derive(Debug)]
-               struct $name {
+               pub(super) struct $name {
                        $(
-                               $field: Option<tlv_record_type!($fieldty)>,
+                               pub(super) $field: Option<tlv_record_type!($fieldty)>,
                        )*
                }
 
-               pub(crate) struct $nameref<'a> {
+               #[derive(Debug, PartialEq)]
+               pub(super) struct $nameref<'a> {
                        $(
-                               pub(crate) $field: Option<tlv_record_ref_type!($fieldty)>,
+                               pub(super) $field: Option<tlv_record_ref_type!($fieldty)>,
                        )*
                }
 
@@ -497,12 +557,15 @@ macro_rules! tlv_stream {
                        }
                }
 
-               impl $crate::util::ser::Readable for $name {
-                       fn read<R: $crate::io::Read>(reader: &mut R) -> Result<Self, $crate::ln::msgs::DecodeError> {
+               impl $crate::util::ser::SeekReadable for $name {
+                       fn read<R: $crate::io::Read + $crate::io::Seek>(reader: &mut R) -> Result<Self, $crate::ln::msgs::DecodeError> {
                                $(
                                        init_tlv_field_var!($field, option);
                                )*
-                               decode_tlv_stream!(reader, {
+                               let rewind = |cursor: &mut R, offset: usize| {
+                                       cursor.seek($crate::io::SeekFrom::Current(-(offset as i64))).expect("");
+                               };
+                               decode_tlv_stream_range!(reader, $range, rewind, {
                                        $(($type, $field, (option, encoding: $fieldty))),*
                                });
 
@@ -583,10 +646,7 @@ macro_rules! impl_writeable_tlv_based_enum_upgradable {
                                                // Because read_tlv_fields creates a labeled loop, we cannot call it twice
                                                // in the same function body. Instead, we define a closure and call it.
                                                let f = || {
-                                                       $(
-                                                               init_tlv_field_var!($field, $fieldty);
-                                                       )*
-                                                       read_tlv_fields!(reader, {
+                                                       init_and_read_tlv_fields!(reader, {
                                                                $(($type, $field, $fieldty)),*
                                                        });
                                                        Ok(Some($st::$variant_name {
@@ -636,10 +696,7 @@ macro_rules! impl_writeable_tlv_based_enum {
                                                // Because read_tlv_fields creates a labeled loop, we cannot call it twice
                                                // in the same function body. Instead, we define a closure and call it.
                                                let f = || {
-                                                       $(
-                                                               init_tlv_field_var!($field, $fieldty);
-                                                       )*
-                                                       read_tlv_fields!(reader, {
+                                                       init_and_read_tlv_fields!(reader, {
                                                                $(($type, $field, $fieldty)),*
                                                        });
                                                        Ok($st::$variant_name {
@@ -654,7 +711,7 @@ macro_rules! impl_writeable_tlv_based_enum {
                                                Ok($st::$tuple_variant_name(Readable::read(reader)?))
                                        }),*
                                        _ => {
-                                               Err(DecodeError::UnknownRequiredFeature)
+                                               Err($crate::ln::msgs::DecodeError::UnknownRequiredFeature)
                                        },
                                }
                        }