Fix unknown handling in `impl_writeable_tlv_based_enum_upgradable`
[rust-lightning] / lightning / src / util / ser_macros.rs
index 4f1fada952a49af110bd0abc67e38edb5b8ba293..5d6988ba1f193585c0d325c49fa679e52b4553c5 100644 (file)
@@ -791,6 +791,9 @@ macro_rules! _init_tlv_field_var {
 
 /// Equivalent to running [`_init_tlv_field_var`] then [`read_tlv_fields`].
 ///
+/// If any unused values are read, their type MUST be specified or else `rustc` will read them as an
+/// `i64`.
+///
 /// This is exported for use by other exported macros, do not use directly.
 #[doc(hidden)]
 #[macro_export]
@@ -807,6 +810,9 @@ macro_rules! _init_and_read_len_prefixed_tlv_fields {
 }
 
 /// Equivalent to running [`_init_tlv_field_var`] then [`decode_tlv_stream`].
+///
+/// If any unused values are read, their type MUST be specified or else `rustc` will read them as an
+/// `i64`.
 macro_rules! _init_and_read_tlv_stream {
        ($reader: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => {
                $(
@@ -911,7 +917,7 @@ macro_rules! tlv_stream {
 
                #[cfg_attr(test, derive(PartialEq))]
                #[derive(Debug)]
-               pub(super) struct $nameref<'a> {
+               pub(crate) struct $nameref<'a> {
                        $(
                                pub(super) $field: Option<tlv_record_ref_type!($fieldty)>,
                        )*
@@ -1059,6 +1065,10 @@ macro_rules! impl_writeable_tlv_based_enum {
 /// when [`MaybeReadable`] is practical instead of just [`Readable`] as it provides an upgrade path for
 /// new variants to be added which are simply ignored by existing clients.
 ///
+/// Note that only struct and unit variants (not tuple variants) will support downgrading, thus any
+/// new odd variants MUST be non-tuple (i.e. described using `$variant_id` and `$variant_name` not
+/// `$tuple_variant_id` and `$tuple_variant_name`).
+///
 /// [`MaybeReadable`]: crate::util::ser::MaybeReadable
 /// [`Writeable`]: crate::util::ser::Writeable
 /// [`DecodeError::UnknownRequiredFeature`]: crate::ln::msgs::DecodeError::UnknownRequiredFeature
@@ -1096,7 +1106,14 @@ macro_rules! impl_writeable_tlv_based_enum_upgradable {
                                        $($($tuple_variant_id => {
                                                Ok(Some($st::$tuple_variant_name(Readable::read(reader)?)))
                                        }),*)*
-                                       _ if id % 2 == 1 => Ok(None),
+                                       _ if id % 2 == 1 => {
+                                               // Assume that a $variant_id was written, not a $tuple_variant_id, and read
+                                               // the length prefix and discard the correct number of bytes.
+                                               let tlv_len: $crate::util::ser::BigSize = $crate::util::ser::Readable::read(reader)?;
+                                               let mut rd = $crate::util::ser::FixedLengthReader::new(reader, tlv_len.0);
+                                               rd.eat_remaining().map_err(|_| $crate::ln::msgs::DecodeError::ShortRead)?;
+                                               Ok(None)
+                                       },
                                        _ => Err($crate::ln::msgs::DecodeError::UnknownRequiredFeature),
                                }
                        }
@@ -1110,6 +1127,7 @@ mod tests {
        use crate::prelude::*;
        use crate::ln::msgs::DecodeError;
        use crate::util::ser::{Writeable, HighZeroBytesDroppedBigSize, VecWriter};
+       use bitcoin::hashes::hex::FromHex;
        use bitcoin::secp256k1::PublicKey;
 
        // The BOLT TLV test cases don't include any tests which use our "required-value" logic since
@@ -1127,7 +1145,7 @@ mod tests {
        #[test]
        fn tlv_v_short_read() {
                // We only expect a u32 for type 3 (which we are given), but the L says its 8 bytes.
-               if let Err(DecodeError::ShortRead) = tlv_reader(&::hex::decode(
+               if let Err(DecodeError::ShortRead) = tlv_reader(&<Vec<u8>>::from_hex(
                                concat!("0100", "0208deadbeef1badbeef", "0308deadbeef")
                                ).unwrap()[..]) {
                } else { panic!(); }
@@ -1135,12 +1153,12 @@ mod tests {
 
        #[test]
        fn tlv_types_out_of_order() {
-               if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
+               if let Err(DecodeError::InvalidValue) = tlv_reader(&<Vec<u8>>::from_hex(
                                concat!("0100", "0304deadbeef", "0208deadbeef1badbeef")
                                ).unwrap()[..]) {
                } else { panic!(); }
                // ...even if its some field we don't understand
-               if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
+               if let Err(DecodeError::InvalidValue) = tlv_reader(&<Vec<u8>>::from_hex(
                                concat!("0208deadbeef1badbeef", "0100", "0304deadbeef")
                                ).unwrap()[..]) {
                } else { panic!(); }
@@ -1149,17 +1167,17 @@ mod tests {
        #[test]
        fn tlv_req_type_missing_or_extra() {
                // It's also bad if they included even fields we don't understand
-               if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&::hex::decode(
+               if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&<Vec<u8>>::from_hex(
                                concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600")
                                ).unwrap()[..]) {
                } else { panic!(); }
                // ... or if they're missing fields we need
-               if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
+               if let Err(DecodeError::InvalidValue) = tlv_reader(&<Vec<u8>>::from_hex(
                                concat!("0100", "0208deadbeef1badbeef")
                                ).unwrap()[..]) {
                } else { panic!(); }
                // ... even if that field is even
-               if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
+               if let Err(DecodeError::InvalidValue) = tlv_reader(&<Vec<u8>>::from_hex(
                                concat!("0304deadbeef", "0500")
                                ).unwrap()[..]) {
                } else { panic!(); }
@@ -1167,11 +1185,11 @@ mod tests {
 
        #[test]
        fn tlv_simple_good_cases() {
-               assert_eq!(tlv_reader(&::hex::decode(
+               assert_eq!(tlv_reader(&<Vec<u8>>::from_hex(
                                concat!("0208deadbeef1badbeef", "03041bad1dea")
                                ).unwrap()[..]).unwrap(),
                        (0xdeadbeef1badbeef, 0x1bad1dea, None));
-               assert_eq!(tlv_reader(&::hex::decode(
+               assert_eq!(tlv_reader(&<Vec<u8>>::from_hex(
                                concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304")
                                ).unwrap()[..]).unwrap(),
                        (0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
@@ -1195,12 +1213,12 @@ mod tests {
 
        #[test]
        fn upgradable_tlv_simple_good_cases() {
-               assert_eq!(upgradable_tlv_reader(&::hex::decode(
+               assert_eq!(upgradable_tlv_reader(&<Vec<u8>>::from_hex(
                        concat!("0204deadbeef", "03041bad1dea", "0404deadbeef")
                ).unwrap()[..]).unwrap(),
                Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: Some(0xdeadbeef) }));
 
-               assert_eq!(upgradable_tlv_reader(&::hex::decode(
+               assert_eq!(upgradable_tlv_reader(&<Vec<u8>>::from_hex(
                        concat!("0204deadbeef", "03041bad1dea")
                ).unwrap()[..]).unwrap(),
                Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: None}));
@@ -1208,11 +1226,11 @@ mod tests {
 
        #[test]
        fn missing_required_upgradable() {
-               if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&::hex::decode(
+               if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&<Vec<u8>>::from_hex(
                        concat!("0100", "0204deadbeef")
                        ).unwrap()[..]) {
                } else { panic!(); }
-               if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&::hex::decode(
+               if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&<Vec<u8>>::from_hex(
                        concat!("0100", "03041bad1dea")
                ).unwrap()[..]) {
                } else { panic!(); }
@@ -1233,7 +1251,7 @@ mod tests {
        fn bolt_tlv_bogus_stream() {
                macro_rules! do_test {
                        ($stream: expr, $reason: ident) => {
-                               if let Err(DecodeError::$reason) = tlv_reader_n1(&::hex::decode($stream).unwrap()[..]) {
+                               if let Err(DecodeError::$reason) = tlv_reader_n1(&<Vec<u8>>::from_hex($stream).unwrap()[..]) {
                                } else { panic!(); }
                        }
                }
@@ -1258,7 +1276,7 @@ mod tests {
        fn bolt_tlv_bogus_n1_stream() {
                macro_rules! do_test {
                        ($stream: expr, $reason: ident) => {
-                               if let Err(DecodeError::$reason) = tlv_reader_n1(&::hex::decode($stream).unwrap()[..]) {
+                               if let Err(DecodeError::$reason) = tlv_reader_n1(&<Vec<u8>>::from_hex($stream).unwrap()[..]) {
                                } else { panic!(); }
                        }
                }
@@ -1298,7 +1316,7 @@ mod tests {
        fn bolt_tlv_valid_n1_stream() {
                macro_rules! do_test {
                        ($stream: expr, $tlv1: expr, $tlv2: expr, $tlv3: expr, $tlv4: expr) => {
-                               if let Ok((tlv1, tlv2, tlv3, tlv4)) = tlv_reader_n1(&::hex::decode($stream).unwrap()[..]) {
+                               if let Ok((tlv1, tlv2, tlv3, tlv4)) = tlv_reader_n1(&<Vec<u8>>::from_hex($stream).unwrap()[..]) {
                                        assert_eq!(tlv1.map(|v| v.0), $tlv1);
                                        assert_eq!(tlv2, $tlv2);
                                        assert_eq!(tlv3, $tlv3);
@@ -1327,7 +1345,7 @@ mod tests {
                do_test!(concat!("02", "08", "0000000000000226"), None, Some((0 << 30) | (0 << 5) | (550 << 0)), None, None);
                do_test!(concat!("03", "31", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb00000000000000010000000000000002"),
                        None, None, Some((
-                               PublicKey::from_slice(&::hex::decode("023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb").unwrap()[..]).unwrap(), 1, 2)),
+                               PublicKey::from_slice(&<Vec<u8>>::from_hex("023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb").unwrap()[..]).unwrap(), 1, 2)),
                        None);
                do_test!(concat!("fd00fe", "02", "0226"), None, None, None, Some(550));
        }
@@ -1337,27 +1355,27 @@ mod tests {
 
                stream.0.clear();
                _encode_varint_length_prefixed_tlv!(&mut stream, {(1, 1u8, required), (42, None::<u64>, option)});
-               assert_eq!(stream.0, ::hex::decode("03010101").unwrap());
+               assert_eq!(stream.0, <Vec<u8>>::from_hex("03010101").unwrap());
 
                stream.0.clear();
                _encode_varint_length_prefixed_tlv!(&mut stream, {(1, Some(1u8), option)});
-               assert_eq!(stream.0, ::hex::decode("03010101").unwrap());
+               assert_eq!(stream.0, <Vec<u8>>::from_hex("03010101").unwrap());
 
                stream.0.clear();
                _encode_varint_length_prefixed_tlv!(&mut stream, {(4, 0xabcdu16, required), (42, None::<u64>, option)});
-               assert_eq!(stream.0, ::hex::decode("040402abcd").unwrap());
+               assert_eq!(stream.0, <Vec<u8>>::from_hex("040402abcd").unwrap());
 
                stream.0.clear();
                _encode_varint_length_prefixed_tlv!(&mut stream, {(42, None::<u64>, option), (0xff, 0xabcdu16, required)});
-               assert_eq!(stream.0, ::hex::decode("06fd00ff02abcd").unwrap());
+               assert_eq!(stream.0, <Vec<u8>>::from_hex("06fd00ff02abcd").unwrap());
 
                stream.0.clear();
                _encode_varint_length_prefixed_tlv!(&mut stream, {(0, 1u64, required), (42, None::<u64>, option), (0xff, HighZeroBytesDroppedBigSize(0u64), required)});
-               assert_eq!(stream.0, ::hex::decode("0e00080000000000000001fd00ff00").unwrap());
+               assert_eq!(stream.0, <Vec<u8>>::from_hex("0e00080000000000000001fd00ff00").unwrap());
 
                stream.0.clear();
                _encode_varint_length_prefixed_tlv!(&mut stream, {(0, Some(1u64), option), (0xff, HighZeroBytesDroppedBigSize(0u64), required)});
-               assert_eq!(stream.0, ::hex::decode("0e00080000000000000001fd00ff00").unwrap());
+               assert_eq!(stream.0, <Vec<u8>>::from_hex("0e00080000000000000001fd00ff00").unwrap());
 
                Ok(())
        }