Implement `VecReadWrapper` for `MaybeReadable`
[rust-lightning] / lightning / src / util / ser_macros.rs
index b93115dcc95933bbc5a617bbf1dc5083f25cdd10..e988ad7eeab940d0e6979904c63dcd88087f023e 100644 (file)
@@ -8,6 +8,9 @@
 // licenses.
 
 macro_rules! encode_tlv {
+       ($stream: expr, $type: expr, $field: expr, (default_value, $default: expr)) => {
+               encode_tlv!($stream, $type, $field, required)
+       };
        ($stream: expr, $type: expr, $field: expr, required) => {
                BigSize($type).write($stream)?;
                BigSize($field.serialized_length() as u64).write($stream)?;
@@ -26,7 +29,7 @@ macro_rules! encode_tlv {
 }
 
 macro_rules! encode_tlv_stream {
-       ($stream: expr, {$(($type: expr, $field: expr, $fieldty: ident)),*}) => { {
+       ($stream: expr, {$(($type: expr, $field: expr, $fieldty: tt)),* $(,)*}) => { {
                #[allow(unused_imports)]
                use {
                        ln::msgs::DecodeError,
@@ -53,6 +56,9 @@ macro_rules! encode_tlv_stream {
 }
 
 macro_rules! get_varint_length_prefixed_tlv_length {
+       ($len: expr, $type: expr, $field: expr, (default_value, $default: expr)) => {
+               get_varint_length_prefixed_tlv_length!($len, $type, $field, required)
+       };
        ($len: expr, $type: expr, $field: expr, required) => {
                BigSize($type).write(&mut $len).expect("No in-memory data may fail to serialize");
                let field_len = $field.serialized_length();
@@ -73,7 +79,7 @@ macro_rules! get_varint_length_prefixed_tlv_length {
 }
 
 macro_rules! encode_varint_length_prefixed_tlv {
-       ($stream: expr, {$(($type: expr, $field: expr, $fieldty: ident)),*}) => { {
+       ($stream: expr, {$(($type: expr, $field: expr, $fieldty: tt)),*}) => { {
                use util::ser::BigSize;
                let len = {
                        #[allow(unused_mut)]
@@ -89,38 +95,55 @@ macro_rules! encode_varint_length_prefixed_tlv {
 }
 
 macro_rules! check_tlv_order {
-       ($last_seen_type: expr, $typ: expr, $type: expr, required) => {{
+       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (default_value, $default: expr)) => {{
+               #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true
+               let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type;
+               if invalid_order {
+                       $field = $default;
+               }
+       }};
+       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, required) => {{
                #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true
                let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type;
                if invalid_order {
-                       Err(DecodeError::InvalidValue)?
+                       return Err(DecodeError::InvalidValue);
                }
        }};
-       ($last_seen_type: expr, $typ: expr, $type: expr, option) => {{
+       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, option) => {{
                // no-op
        }};
-       ($last_seen_type: expr, $typ: expr, $type: expr, vec_type) => {{
+       ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, vec_type) => {{
                // no-op
        }};
 }
 
 macro_rules! check_missing_tlv {
-       ($last_seen_type: expr, $type: expr, required) => {{
+       ($last_seen_type: expr, $type: expr, $field: ident, (default_value, $default: expr)) => {{
+               #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true
+               let missing_req_type = $last_seen_type.is_none() || $last_seen_type.unwrap() < $type;
+               if missing_req_type {
+                       $field = $default;
+               }
+       }};
+       ($last_seen_type: expr, $type: expr, $field: ident, required) => {{
                #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true
                let missing_req_type = $last_seen_type.is_none() || $last_seen_type.unwrap() < $type;
                if missing_req_type {
-                       Err(DecodeError::InvalidValue)?
+                       return Err(DecodeError::InvalidValue);
                }
        }};
-       ($last_seen_type: expr, $type: expr, vec_type) => {{
+       ($last_seen_type: expr, $type: expr, $field: ident, vec_type) => {{
                // no-op
        }};
-       ($last_seen_type: expr, $type: expr, option) => {{
+       ($last_seen_type: expr, $type: expr, $field: ident, option) => {{
                // no-op
        }};
 }
 
 macro_rules! decode_tlv {
+       ($reader: expr, $field: ident, (default_value, $default: expr)) => {{
+               decode_tlv!($reader, $field, required)
+       }};
        ($reader: expr, $field: ident, required) => {{
                $field = ser::Readable::read(&mut $reader)?;
        }};
@@ -133,7 +156,7 @@ macro_rules! decode_tlv {
 }
 
 macro_rules! decode_tlv_stream {
-       ($stream: expr, {$(($type: expr, $field: ident, $fieldty: ident)),* $(,)*}) => { {
+       ($stream: expr, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => { {
                use ln::msgs::DecodeError;
                let mut last_seen_type: Option<u64> = None;
                'tlv_read: loop {
@@ -149,12 +172,12 @@ macro_rules! decode_tlv_stream {
                                match ser::Readable::read(&mut tracking_reader) {
                                        Err(DecodeError::ShortRead) => {
                                                if !tracking_reader.have_read {
-                                                       break 'tlv_read
+                                                       break 'tlv_read;
                                                } else {
-                                                       Err(DecodeError::ShortRead)?
+                                                       return Err(DecodeError::ShortRead);
                                                }
                                        },
-                                       Err(e) => Err(e)?,
+                                       Err(e) => return Err(e),
                                        Ok(t) => t,
                                }
                        };
@@ -162,29 +185,29 @@ macro_rules! decode_tlv_stream {
                        // Types must be unique and monotonically increasing:
                        match last_seen_type {
                                Some(t) if typ.0 <= t => {
-                                       Err(DecodeError::InvalidValue)?
+                                       return Err(DecodeError::InvalidValue);
                                },
                                _ => {},
                        }
                        // As we read types, make sure we hit every required type:
                        $({
-                               check_tlv_order!(last_seen_type, typ, $type, $fieldty);
+                               check_tlv_order!(last_seen_type, typ, $type, $field, $fieldty);
                        })*
                        last_seen_type = Some(typ.0);
 
                        // Finally, read the length and value itself:
-                       let length: ser::BigSize = Readable::read($stream)?;
+                       let length: ser::BigSize = ser::Readable::read($stream)?;
                        let mut s = ser::FixedLengthReader::new($stream, length.0);
                        match typ.0 {
                                $($type => {
                                        decode_tlv!(s, $field, $fieldty);
                                        if s.bytes_remain() {
                                                s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes
-                                               Err(DecodeError::InvalidValue)?
+                                               return Err(DecodeError::InvalidValue);
                                        }
                                },)*
                                x if x % 2 == 0 => {
-                                       Err(DecodeError::UnknownRequiredFeature)?
+                                       return Err(DecodeError::UnknownRequiredFeature);
                                },
                                _ => {},
                        }
@@ -192,7 +215,7 @@ macro_rules! decode_tlv_stream {
                }
                // Make sure we got to each required type after we've read every TLV:
                $({
-                       check_missing_tlv!(last_seen_type, $type, $fieldty);
+                       check_missing_tlv!(last_seen_type, $type, $field, $fieldty);
                })*
        } }
 }
@@ -200,7 +223,7 @@ macro_rules! decode_tlv_stream {
 macro_rules! impl_writeable {
        ($st:ident, $len: expr, {$($field:ident),*}) => {
                impl ::util::ser::Writeable for $st {
-                       fn write<W: ::util::ser::Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+                       fn write<W: ::util::ser::Writer>(&self, w: &mut W) -> Result<(), $crate::io::Error> {
                                if $len != 0 {
                                        w.size_hint($len);
                                }
@@ -235,7 +258,7 @@ macro_rules! impl_writeable {
                }
 
                impl ::util::ser::Readable for $st {
-                       fn read<R: ::std::io::Read>(r: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
+                       fn read<R: $crate::io::Read>(r: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
                                Ok(Self {
                                        $($field: ::util::ser::Readable::read(r)?),*
                                })
@@ -246,7 +269,7 @@ macro_rules! impl_writeable {
 macro_rules! impl_writeable_len_match {
        ($struct: ident, $cmp: tt, ($calc_len: expr), {$({$match: pat, $length: expr}),*}, {$($field:ident),*}) => {
                impl Writeable for $struct {
-                       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+                       fn write<W: Writer>(&self, w: &mut W) -> Result<(), $crate::io::Error> {
                                let len = match *self {
                                        $($match => $length,)*
                                };
@@ -282,7 +305,7 @@ macro_rules! impl_writeable_len_match {
                }
 
                impl ::util::ser::Readable for $struct {
-                       fn read<R: ::std::io::Read>(r: &mut R) -> Result<Self, DecodeError> {
+                       fn read<R: $crate::io::Read>(r: &mut R) -> Result<Self, DecodeError> {
                                Ok(Self {
                                        $($field: Readable::read(r)?),*
                                })
@@ -326,7 +349,7 @@ macro_rules! write_ver_prefix {
 /// This is the preferred method of adding new fields that old nodes can ignore and still function
 /// correctly.
 macro_rules! write_tlv_fields {
-       ($stream: expr, {$(($type: expr, $field: expr, $fieldty: ident)),* $(,)*}) => {
+       ($stream: expr, {$(($type: expr, $field: expr, $fieldty: tt)),* $(,)*}) => {
                encode_varint_length_prefixed_tlv!($stream, {$(($type, $field, $fieldty)),*});
        }
 }
@@ -347,8 +370,8 @@ macro_rules! read_ver_prefix {
 
 /// Reads a suffix added by write_tlv_fields.
 macro_rules! read_tlv_fields {
-       ($stream: expr, {$(($type: expr, $field: ident, $fieldty: ident)),* $(,)*}) => { {
-               let tlv_len = ::util::ser::BigSize::read($stream)?;
+       ($stream: expr, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => { {
+               let tlv_len: ::util::ser::BigSize = ::util::ser::Readable::read($stream)?;
                let mut rd = ::util::ser::FixedLengthReader::new($stream, tlv_len.0);
                decode_tlv_stream!(&mut rd, {$(($type, $field, $fieldty)),*});
                rd.eat_remaining().map_err(|_| ::ln::msgs::DecodeError::ShortRead)?;
@@ -356,6 +379,9 @@ macro_rules! read_tlv_fields {
 }
 
 macro_rules! init_tlv_based_struct_field {
+       ($field: ident, (default_value, $default: expr)) => {
+               $field
+       };
        ($field: ident, option) => {
                $field
        };
@@ -368,6 +394,9 @@ macro_rules! init_tlv_based_struct_field {
 }
 
 macro_rules! init_tlv_field_var {
+       ($field: ident, (default_value, $default: expr)) => {
+               let mut $field = $default;
+       };
        ($field: ident, required) => {
                let mut $field = ::util::ser::OptionDeserWrapper(None);
        };
@@ -385,9 +414,9 @@ macro_rules! init_tlv_field_var {
 /// if $fieldty is `vec_type`, then $field is a Vec, which needs to have its individual elements
 /// serialized.
 macro_rules! impl_writeable_tlv_based {
-       ($st: ident, {$(($type: expr, $field: ident, $fieldty: ident)),* $(,)*}) => {
+       ($st: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => {
                impl ::util::ser::Writeable for $st {
-                       fn write<W: ::util::ser::Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+                       fn write<W: ::util::ser::Writer>(&self, writer: &mut W) -> Result<(), $crate::io::Error> {
                                write_tlv_fields!(writer, {
                                        $(($type, self.$field, $fieldty)),*
                                });
@@ -412,7 +441,7 @@ macro_rules! impl_writeable_tlv_based {
                }
 
                impl ::util::ser::Readable for $st {
-                       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
+                       fn read<R: $crate::io::Read>(reader: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
                                $(
                                        init_tlv_field_var!($field, $fieldty);
                                )*
@@ -441,11 +470,11 @@ macro_rules! impl_writeable_tlv_based {
 /// Attempts to read an unknown type byte result in DecodeError::UnknownRequiredFeature.
 macro_rules! impl_writeable_tlv_based_enum {
        ($st: ident, $(($variant_id: expr, $variant_name: ident) =>
-               {$(($type: expr, $field: ident, $fieldty: ident)),* $(,)*}
+               {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}
        ),* $(,)*;
        $(($tuple_variant_id: expr, $tuple_variant_name: ident)),*  $(,)*) => {
                impl ::util::ser::Writeable for $st {
-                       fn write<W: ::util::ser::Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+                       fn write<W: ::util::ser::Writer>(&self, writer: &mut W) -> Result<(), $crate::io::Error> {
                                match self {
                                        $($st::$variant_name { $(ref $field),* } => {
                                                let id: u8 = $variant_id;
@@ -465,7 +494,7 @@ macro_rules! impl_writeable_tlv_based_enum {
                }
 
                impl ::util::ser::Readable for $st {
-                       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
+                       fn read<R: $crate::io::Read>(reader: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
                                let id: u8 = ::util::ser::Readable::read(reader)?;
                                match id {
                                        $($variant_id => {
@@ -490,7 +519,7 @@ macro_rules! impl_writeable_tlv_based_enum {
                                                Ok($st::$tuple_variant_name(Readable::read(reader)?))
                                        }),*
                                        _ => {
-                                               Err(DecodeError::UnknownRequiredFeature)?
+                                               Err(DecodeError::UnknownRequiredFeature)
                                        },
                                }
                        }
@@ -500,10 +529,10 @@ macro_rules! impl_writeable_tlv_based_enum {
 
 #[cfg(test)]
 mod tests {
+       use io::{self, Cursor};
        use prelude::*;
-       use std::io::Cursor;
        use ln::msgs::DecodeError;
-       use util::ser::{Readable, Writeable, HighZeroBytesDroppedVarInt, VecWriter};
+       use util::ser::{Writeable, HighZeroBytesDroppedVarInt, VecWriter};
        use bitcoin::secp256k1::PublicKey;
 
        // The BOLT TLV test cases don't include any tests which use our "required-value" logic since
@@ -685,7 +714,7 @@ mod tests {
                do_test!(concat!("fd00fe", "02", "0226"), None, None, None, Some(550));
        }
 
-       fn do_simple_test_tlv_write() -> Result<(), ::std::io::Error> {
+       fn do_simple_test_tlv_write() -> Result<(), io::Error> {
                let mut stream = VecWriter(Vec::new());
 
                stream.0.clear();