Merge pull request #2043 from valentinewallace/2023-02-initial-send-path-fails
[rust-lightning] / lightning / src / util / ser_macros.rs
index 0a5c0b643dd0c315f2782d94b7caa469dcf4e46d..5e55ab1d3b8182baf80f87e0e1a929d90380057c 100644 (file)
@@ -39,9 +39,12 @@ macro_rules! _encode_tlv {
                        field.write($stream)?;
                }
        };
-       ($stream: expr, $type: expr, $field: expr, ignorable) => {
+       ($stream: expr, $type: expr, $field: expr, upgradable_required) => {
                $crate::_encode_tlv!($stream, $type, $field, required);
        };
+       ($stream: expr, $type: expr, $field: expr, upgradable_option) => {
+               $crate::_encode_tlv!($stream, $type, $field, option);
+       };
        ($stream: expr, $type: expr, $field: expr, (option, encoding: ($fieldty: ty, $encoding: ident))) => {
                $crate::_encode_tlv!($stream, $type, $field.map(|f| $encoding(f)), option);
        };
@@ -158,9 +161,12 @@ macro_rules! _get_varint_length_prefixed_tlv_length {
                        $len.0 += field_len;
                }
        };
-       ($len: expr, $type: expr, $field: expr, ignorable) => {
+       ($len: expr, $type: expr, $field: expr, upgradable_required) => {
                $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $field, required);
        };
+       ($len: expr, $type: expr, $field: expr, upgradable_option) => {
+               $crate::_get_varint_length_prefixed_tlv_length!($len, $type, $field, option);
+       };
 }
 
 /// See the documentation of [`write_tlv_fields`].
@@ -210,7 +216,10 @@ macro_rules! _check_decoded_tlv_order {
        ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, vec_type) => {{
                // no-op
        }};
-       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, ignorable) => {{
+       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, upgradable_required) => {{
+               _check_decoded_tlv_order!($last_seen_type, $typ, $type, $field, required)
+       }};
+       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, upgradable_option) => {{
                // no-op
        }};
        ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (option: $trait: ident $(, $read_arg: expr)?)) => {{
@@ -249,7 +258,10 @@ macro_rules! _check_missing_tlv {
        ($last_seen_type: expr, $type: expr, $field: ident, option) => {{
                // no-op
        }};
-       ($last_seen_type: expr, $type: expr, $field: ident, ignorable) => {{
+       ($last_seen_type: expr, $type: expr, $field: ident, upgradable_required) => {{
+               _check_missing_tlv!($last_seen_type, $type, $field, required)
+       }};
+       ($last_seen_type: expr, $type: expr, $field: ident, upgradable_option) => {{
                // no-op
        }};
        ($last_seen_type: expr, $type: expr, $field: ident, (option: $trait: ident $(, $read_arg: expr)?)) => {{
@@ -280,7 +292,20 @@ macro_rules! _decode_tlv {
        ($reader: expr, $field: ident, option) => {{
                $field = Some($crate::util::ser::Readable::read(&mut $reader)?);
        }};
-       ($reader: expr, $field: ident, ignorable) => {{
+       // `upgradable_required` indicates we're reading a required TLV that may have been upgraded
+       // without backwards compat. We'll error if the field is missing, and return `Ok(None)` if the
+       // field is present but we can no longer understand it.
+       // Note that this variant can only be used within a `MaybeReadable` read.
+       ($reader: expr, $field: ident, upgradable_required) => {{
+               $field = match $crate::util::ser::MaybeReadable::read(&mut $reader)? {
+                       Some(res) => res,
+                       _ => return Ok(None)
+               };
+       }};
+       // `upgradable_option` indicates we're reading an Option-al TLV that may have been upgraded
+       // without backwards compat. $field will be None if the TLV is missing or if the field is present
+       // but we can no longer understand it.
+       ($reader: expr, $field: ident, upgradable_option) => {{
                $field = $crate::util::ser::MaybeReadable::read(&mut $reader)?;
        }};
        ($reader: expr, $field: ident, (option: $trait: ident $(, $read_arg: expr)?)) => {{
@@ -619,8 +644,11 @@ macro_rules! _init_tlv_based_struct_field {
        ($field: ident, option) => {
                $field
        };
-       ($field: ident, ignorable) => {
-               if $field.is_none() { return Ok(None); } else { $field.unwrap() }
+       ($field: ident, upgradable_required) => {
+               $field.0.unwrap()
+       };
+       ($field: ident, upgradable_option) => {
+               $field
        };
        ($field: ident, required) => {
                $field.0.unwrap()
@@ -637,13 +665,13 @@ macro_rules! _init_tlv_based_struct_field {
 #[macro_export]
 macro_rules! _init_tlv_field_var {
        ($field: ident, (default_value, $default: expr)) => {
-               let mut $field = $crate::util::ser::OptionDeserWrapper(None);
+               let mut $field = $crate::util::ser::RequiredWrapper(None);
        };
        ($field: ident, (static_value, $value: expr)) => {
                let $field;
        };
        ($field: ident, required) => {
-               let mut $field = $crate::util::ser::OptionDeserWrapper(None);
+               let mut $field = $crate::util::ser::RequiredWrapper(None);
        };
        ($field: ident, vec_type) => {
                let mut $field = Some(Vec::new());
@@ -651,7 +679,10 @@ macro_rules! _init_tlv_field_var {
        ($field: ident, option) => {
                let mut $field = None;
        };
-       ($field: ident, ignorable) => {
+       ($field: ident, upgradable_required) => {
+               let mut $field = $crate::util::ser::UpgradableRequired(None);
+       };
+       ($field: ident, upgradable_option) => {
                let mut $field = None;
        };
 }
@@ -948,7 +979,7 @@ macro_rules! impl_writeable_tlv_based_enum_upgradable {
                                                Ok(Some($st::$tuple_variant_name(Readable::read(reader)?)))
                                        }),*)*
                                        _ if id % 2 == 1 => Ok(None),
-                                       _ => Err(DecodeError::UnknownRequiredFeature),
+                                       _ => Err($crate::ln::msgs::DecodeError::UnknownRequiredFeature),
                                }
                        }
                }
@@ -1028,6 +1059,47 @@ mod tests {
                        (0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
        }
 
+       #[derive(Debug, PartialEq)]
+       struct TestUpgradable {
+               a: u32,
+               b: u32,
+               c: Option<u32>,
+       }
+
+       fn upgradable_tlv_reader(s: &[u8]) -> Result<Option<TestUpgradable>, DecodeError> {
+               let mut s = Cursor::new(s);
+               let mut a = 0;
+               let mut b = 0;
+               let mut c: Option<u32> = None;
+               decode_tlv_stream!(&mut s, {(2, a, upgradable_required), (3, b, upgradable_required), (4, c, upgradable_option)});
+               Ok(Some(TestUpgradable { a, b, c, }))
+       }
+
+       #[test]
+       fn upgradable_tlv_simple_good_cases() {
+               assert_eq!(upgradable_tlv_reader(&::hex::decode(
+                       concat!("0204deadbeef", "03041bad1dea", "0404deadbeef")
+               ).unwrap()[..]).unwrap(),
+               Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: Some(0xdeadbeef) }));
+
+               assert_eq!(upgradable_tlv_reader(&::hex::decode(
+                       concat!("0204deadbeef", "03041bad1dea")
+               ).unwrap()[..]).unwrap(),
+               Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: None}));
+       }
+
+       #[test]
+       fn missing_required_upgradable() {
+               if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&::hex::decode(
+                       concat!("0100", "0204deadbeef")
+                       ).unwrap()[..]) {
+               } else { panic!(); }
+               if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&::hex::decode(
+                       concat!("0100", "03041bad1dea")
+               ).unwrap()[..]) {
+               } else { panic!(); }
+       }
+
        // BOLT TLV test cases
        fn tlv_reader_n1(s: &[u8]) -> Result<(Option<HighZeroBytesDroppedBigSize<u64>>, Option<u64>, Option<(PublicKey, u64, u64)>, Option<u16>), DecodeError> {
                let mut s = Cursor::new(s);