Fix upgradable_required fields to actually be required in lower level macros
[rust-lightning] / lightning / src / util / ser_macros.rs
index 2aa3846ad99a33ee671209ece86ca1040ac9d82e..1f617de40b3ced31491cd8deedc6cf11a9b83dab 100644 (file)
@@ -217,7 +217,7 @@ macro_rules! _check_decoded_tlv_order {
                // no-op
        }};
        ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, upgradable_required) => {{
-               // no-op
+               _check_decoded_tlv_order!($last_seen_type, $typ, $type, $field, required)
        }};
        ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, upgradable_option) => {{
                // no-op
@@ -259,7 +259,7 @@ macro_rules! _check_missing_tlv {
                // no-op
        }};
        ($last_seen_type: expr, $type: expr, $field: ident, upgradable_required) => {{
-               // no-op
+               _check_missing_tlv!($last_seen_type, $type, $field, required)
        }};
        ($last_seen_type: expr, $type: expr, $field: ident, upgradable_option) => {{
                // no-op
@@ -293,7 +293,10 @@ macro_rules! _decode_tlv {
                $field = Some($crate::util::ser::Readable::read(&mut $reader)?);
        }};
        ($reader: expr, $field: ident, upgradable_required) => {{
-               $field = $crate::util::ser::MaybeReadable::read(&mut $reader)?;
+               $field = match $crate::util::ser::MaybeReadable::read(&mut $reader)? {
+                       Some(res) => res,
+                       _ => return Ok(None)
+               };
        }};
        ($reader: expr, $field: ident, upgradable_option) => {{
                $field = $crate::util::ser::MaybeReadable::read(&mut $reader)?;
@@ -636,7 +639,7 @@ macro_rules! _init_tlv_based_struct_field {
                $field
        };
        ($field: ident, upgradable_required) => {
-               if $field.is_none() { return Ok(None); } else { $field.unwrap() }
+               $field.0.unwrap()
        };
        ($field: ident, upgradable_option) => {
                $field
@@ -671,7 +674,7 @@ macro_rules! _init_tlv_field_var {
                let mut $field = None;
        };
        ($field: ident, upgradable_required) => {
-               let mut $field = None;
+               let mut $field = $crate::util::ser::UpgradableRequired(None);
        };
        ($field: ident, upgradable_option) => {
                let mut $field = None;
@@ -1050,6 +1053,47 @@ mod tests {
                        (0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
        }
 
+       #[derive(Debug, PartialEq)]
+       struct TestUpgradable {
+               a: u32,
+               b: u32,
+               c: Option<u32>,
+       }
+
+       fn upgradable_tlv_reader(s: &[u8]) -> Result<Option<TestUpgradable>, DecodeError> {
+               let mut s = Cursor::new(s);
+               let mut a = 0;
+               let mut b = 0;
+               let mut c: Option<u32> = None;
+               decode_tlv_stream!(&mut s, {(2, a, upgradable_required), (3, b, upgradable_required), (4, c, upgradable_option)});
+               Ok(Some(TestUpgradable { a, b, c, }))
+       }
+
+       #[test]
+       fn upgradable_tlv_simple_good_cases() {
+               assert_eq!(upgradable_tlv_reader(&::hex::decode(
+                       concat!("0204deadbeef", "03041bad1dea", "0404deadbeef")
+               ).unwrap()[..]).unwrap(),
+               Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: Some(0xdeadbeef) }));
+
+               assert_eq!(upgradable_tlv_reader(&::hex::decode(
+                       concat!("0204deadbeef", "03041bad1dea")
+               ).unwrap()[..]).unwrap(),
+               Some(TestUpgradable { a: 0xdeadbeef, b: 0x1bad1dea, c: None}));
+       }
+
+       #[test]
+       fn missing_required_upgradable() {
+               if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&::hex::decode(
+                       concat!("0100", "0204deadbeef")
+                       ).unwrap()[..]) {
+               } else { panic!(); }
+               if let Err(DecodeError::InvalidValue) = upgradable_tlv_reader(&::hex::decode(
+                       concat!("0100", "03041bad1dea")
+               ).unwrap()[..]) {
+               } else { panic!(); }
+       }
+
        // BOLT TLV test cases
        fn tlv_reader_n1(s: &[u8]) -> Result<(Option<HighZeroBytesDroppedBigSize<u64>>, Option<u64>, Option<(PublicKey, u64, u64)>, Option<u16>), DecodeError> {
                let mut s = Cursor::new(s);