Merge pull request #1861 from TheBlueMatt/2022-11-tx-connection-idempotency
[rust-lightning] / lightning / src / util / ser_macros.rs
index 019bea3ee935d00d898973f5a619583f7a12584f..3e1d8a9280d43da344076854739a5a81436f450b 100644 (file)
@@ -26,6 +26,12 @@ macro_rules! encode_tlv {
                        field.write($stream)?;
                }
        };
+       ($stream: expr, $type: expr, $field: expr, (option, encoding: ($fieldty: ty, $encoding: ident))) => {
+               encode_tlv!($stream, $type, $field.map(|f| $encoding(f)), option);
+       };
+       ($stream: expr, $type: expr, $field: expr, (option, encoding: $fieldty: ty)) => {
+               encode_tlv!($stream, $type, $field, option);
+       };
 }
 
 macro_rules! encode_tlv_stream {
@@ -121,6 +127,9 @@ macro_rules! check_tlv_order {
        ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (option: $trait: ident $(, $read_arg: expr)?)) => {{
                // no-op
        }};
+       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (option, encoding: $encoding: tt)) => {{
+               // no-op
+       }};
 }
 
 macro_rules! check_missing_tlv {
@@ -150,6 +159,9 @@ macro_rules! check_missing_tlv {
        ($last_seen_type: expr, $type: expr, $field: ident, (option: $trait: ident $(, $read_arg: expr)?)) => {{
                // no-op
        }};
+       ($last_seen_type: expr, $type: expr, $field: ident, (option, encoding: $encoding: tt)) => {{
+               // no-op
+       }};
 }
 
 macro_rules! decode_tlv {
@@ -172,6 +184,15 @@ macro_rules! decode_tlv {
        ($reader: expr, $field: ident, (option: $trait: ident $(, $read_arg: expr)?)) => {{
                $field = Some($trait::read(&mut $reader $(, $read_arg)*)?);
        }};
+       ($reader: expr, $field: ident, (option, encoding: ($fieldty: ty, $encoding: ident))) => {{
+               $field = {
+                       let field: $encoding<$fieldty> = ser::Readable::read(&mut $reader)?;
+                       Some(field.0)
+               };
+       }};
+       ($reader: expr, $field: ident, (option, encoding: $fieldty: ty)) => {{
+               decode_tlv!($reader, $field, option);
+       }};
 }
 
 // `$decode_custom_tlv` is a closure that may be optionally provided to handle custom message types.
@@ -180,6 +201,17 @@ macro_rules! decode_tlv {
 // `Ok(false)` if the message type is unknown, and `Err(DecodeError)` if parsing fails.
 macro_rules! decode_tlv_stream {
        ($stream: expr, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}
+        $(, $decode_custom_tlv: expr)?) => { {
+               let rewind = |_, _| { unreachable!() };
+               use core::ops::RangeBounds;
+               decode_tlv_stream_range!(
+                       $stream, .., rewind, {$(($type, $field, $fieldty)),*} $(, $decode_custom_tlv)?
+               );
+       } }
+}
+
+macro_rules! decode_tlv_stream_range {
+       ($stream: expr, $range: expr, $rewind: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}
         $(, $decode_custom_tlv: expr)?) => { {
                use $crate::ln::msgs::DecodeError;
                let mut last_seen_type: Option<u64> = None;
@@ -194,7 +226,7 @@ macro_rules! decode_tlv_stream {
                                // UnexpectedEof. This should in every case be largely cosmetic, but its nice to
                                // pass the TLV test vectors exactly, which requre this distinction.
                                let mut tracking_reader = ser::ReadTrackingReader::new(&mut stream_ref);
-                               match $crate::util::ser::Readable::read(&mut tracking_reader) {
+                               match <$crate::util::ser::BigSize as $crate::util::ser::Readable>::read(&mut tracking_reader) {
                                        Err(DecodeError::ShortRead) => {
                                                if !tracking_reader.have_read {
                                                        break 'tlv_read;
@@ -203,7 +235,15 @@ macro_rules! decode_tlv_stream {
                                                }
                                        },
                                        Err(e) => return Err(e),
-                                       Ok(t) => t,
+                                       Ok(t) => if $range.contains(&t.0) { t } else {
+                                               drop(tracking_reader);
+
+                                               // Assumes the type id is minimally encoded, which is enforced on read.
+                                               use $crate::util::ser::Writeable;
+                                               let bytes_read = t.serialized_length();
+                                               $rewind(stream_ref, bytes_read);
+                                               break 'tlv_read;
+                                       },
                                }
                        };
 
@@ -391,6 +431,18 @@ macro_rules! init_tlv_field_var {
        };
 }
 
+macro_rules! init_and_read_tlv_fields {
+       ($reader: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => {
+               $(
+                       init_tlv_field_var!($field, $fieldty);
+               )*
+
+               read_tlv_fields!($reader, {
+                       $(($type, $field, $fieldty)),*
+               });
+       }
+}
+
 /// Implements Readable/Writeable for a struct storing it as a set of TLVs
 /// If $fieldty is `required`, then $field is a required field that is not an Option nor a Vec.
 /// If $fieldty is `option`, then $field is optional field.
@@ -425,10 +477,7 @@ macro_rules! impl_writeable_tlv_based {
 
                impl $crate::util::ser::Readable for $st {
                        fn read<R: $crate::io::Read>(reader: &mut R) -> Result<Self, $crate::ln::msgs::DecodeError> {
-                               $(
-                                       init_tlv_field_var!($field, $fieldty);
-                               )*
-                               read_tlv_fields!(reader, {
+                               init_and_read_tlv_fields!(reader, {
                                        $(($type, $field, $fieldty)),*
                                });
                                Ok(Self {
@@ -441,6 +490,78 @@ macro_rules! impl_writeable_tlv_based {
        }
 }
 
+/// Defines a struct for a TLV stream and a similar struct using references for non-primitive types,
+/// implementing [`Readable`] for the former and [`Writeable`] for the latter. Useful as an
+/// intermediary format when reading or writing a type encoded as a TLV stream. Note that each field
+/// representing a TLV record has its type wrapped with an [`Option`]. A tuple consisting of a type
+/// and a serialization wrapper may be given in place of a type when custom serialization is
+/// required.
+///
+/// [`Readable`]: crate::util::ser::Readable
+/// [`Writeable`]: crate::util::ser::Writeable
+macro_rules! tlv_stream {
+       ($name:ident, $nameref:ident, $range:expr, {
+               $(($type:expr, $field:ident : $fieldty:tt)),* $(,)*
+       }) => {
+               #[derive(Debug)]
+               pub(crate) struct $name {
+                       $(
+                               $field: Option<tlv_record_type!($fieldty)>,
+                       )*
+               }
+
+               pub(crate) struct $nameref<'a> {
+                       $(
+                               pub(crate) $field: Option<tlv_record_ref_type!($fieldty)>,
+                       )*
+               }
+
+               impl<'a> $crate::util::ser::Writeable for $nameref<'a> {
+                       fn write<W: $crate::util::ser::Writer>(&self, writer: &mut W) -> Result<(), $crate::io::Error> {
+                               encode_tlv_stream!(writer, {
+                                       $(($type, self.$field, (option, encoding: $fieldty))),*
+                               });
+                               Ok(())
+                       }
+               }
+
+               impl $crate::util::ser::SeekReadable for $name {
+                       fn read<R: $crate::io::Read + $crate::io::Seek>(reader: &mut R) -> Result<Self, $crate::ln::msgs::DecodeError> {
+                               $(
+                                       init_tlv_field_var!($field, option);
+                               )*
+                               let rewind = |cursor: &mut R, offset: usize| {
+                                       cursor.seek($crate::io::SeekFrom::Current(-(offset as i64))).expect("");
+                               };
+                               decode_tlv_stream_range!(reader, $range, rewind, {
+                                       $(($type, $field, (option, encoding: $fieldty))),*
+                               });
+
+                               Ok(Self {
+                                       $(
+                                               $field: $field
+                                       ),*
+                               })
+                       }
+               }
+       }
+}
+
+macro_rules! tlv_record_type {
+       (($type:ty, $wrapper:ident)) => { $type };
+       ($type:ty) => { $type };
+}
+
+macro_rules! tlv_record_ref_type {
+       (char) => { char };
+       (u8) => { u8 };
+       ((u16, $wrapper: ident)) => { u16 };
+       ((u32, $wrapper: ident)) => { u32 };
+       ((u64, $wrapper: ident)) => { u64 };
+       (($type:ty, $wrapper:ident)) => { &'a $type };
+       ($type:ty) => { &'a $type };
+}
+
 macro_rules! _impl_writeable_tlv_based_enum_common {
        ($st: ident, $(($variant_id: expr, $variant_name: ident) =>
                {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}
@@ -493,10 +614,7 @@ macro_rules! impl_writeable_tlv_based_enum_upgradable {
                                                // Because read_tlv_fields creates a labeled loop, we cannot call it twice
                                                // in the same function body. Instead, we define a closure and call it.
                                                let f = || {
-                                                       $(
-                                                               init_tlv_field_var!($field, $fieldty);
-                                                       )*
-                                                       read_tlv_fields!(reader, {
+                                                       init_and_read_tlv_fields!(reader, {
                                                                $(($type, $field, $fieldty)),*
                                                        });
                                                        Ok(Some($st::$variant_name {
@@ -546,10 +664,7 @@ macro_rules! impl_writeable_tlv_based_enum {
                                                // Because read_tlv_fields creates a labeled loop, we cannot call it twice
                                                // in the same function body. Instead, we define a closure and call it.
                                                let f = || {
-                                                       $(
-                                                               init_tlv_field_var!($field, $fieldty);
-                                                       )*
-                                                       read_tlv_fields!(reader, {
+                                                       init_and_read_tlv_fields!(reader, {
                                                                $(($type, $field, $fieldty)),*
                                                        });
                                                        Ok($st::$variant_name {