X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Fser_macros.rs;h=d37a841eaff31945ecc2af27d624ca8570552006;hb=22501c3c5ec42efb74bf469a8081e89e1d006f83;hp=3ec0848680ffe62a2532e03d47d1fc5b62d7f4b8;hpb=384c4dc7753e4b7ac53ea380e52809babd8f0f9b;p=rust-lightning diff --git a/lightning/src/util/ser_macros.rs b/lightning/src/util/ser_macros.rs index 3ec08486..d37a841e 100644 --- a/lightning/src/util/ser_macros.rs +++ b/lightning/src/util/ser_macros.rs @@ -11,6 +11,9 @@ macro_rules! encode_tlv { ($stream: expr, $type: expr, $field: expr, (default_value, $default: expr)) => { encode_tlv!($stream, $type, $field, required) }; + ($stream: expr, $type: expr, $field: expr, (static_value, $value: expr)) => { + let _ = &$field; // Ensure we "use" the $field + }; ($stream: expr, $type: expr, $field: expr, required) => { BigSize($type).write($stream)?; BigSize($field.serialized_length() as u64).write($stream)?; @@ -34,6 +37,17 @@ macro_rules! encode_tlv { }; } + +macro_rules! check_encoded_tlv_order { + ($last_type: expr, $type: expr, (static_value, $value: expr)) => { }; + ($last_type: expr, $type: expr, $fieldty: tt) => { + if let Some(t) = $last_type { + debug_assert!(t <= $type); + } + $last_type = Some($type); + }; +} + macro_rules! encode_tlv_stream { ($stream: expr, {$(($type: expr, $field: expr, $fieldty: tt)),* $(,)*}) => { { #[allow(unused_imports)] @@ -52,10 +66,7 @@ macro_rules! encode_tlv_stream { { let mut last_seen: Option = None; $( - if let Some(t) = last_seen { - debug_assert!(t <= $type); - } - last_seen = Some($type); + check_encoded_tlv_order!(last_seen, $type, $fieldty); )* } } } @@ -65,6 +76,8 @@ macro_rules! get_varint_length_prefixed_tlv_length { ($len: expr, $type: expr, $field: expr, (default_value, $default: expr)) => { get_varint_length_prefixed_tlv_length!($len, $type, $field, required) }; + ($len: expr, $type: expr, $field: expr, (static_value, $value: expr)) => { + }; ($len: expr, $type: expr, $field: expr, required) => { BigSize($type).write(&mut $len).expect("No in-memory data may fail to serialize"); let field_len = $field.serialized_length(); @@ -100,7 +113,7 @@ macro_rules! encode_varint_length_prefixed_tlv { } } } -macro_rules! check_tlv_order { +macro_rules! check_decoded_tlv_order { ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (default_value, $default: expr)) => {{ #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type; @@ -108,6 +121,8 @@ macro_rules! check_tlv_order { $field = $default.into(); } }}; + ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (static_value, $value: expr)) => { + }; ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, required) => {{ #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type; @@ -140,6 +155,9 @@ macro_rules! check_missing_tlv { $field = $default.into(); } }}; + ($last_seen_type: expr, $type: expr, $field: expr, (static_value, $value: expr)) => { + $field = $value; + }; ($last_seen_type: expr, $type: expr, $field: ident, required) => {{ #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true let missing_req_type = $last_seen_type.is_none() || $last_seen_type.unwrap() < $type; @@ -168,6 +186,8 @@ macro_rules! decode_tlv { ($reader: expr, $field: ident, (default_value, $default: expr)) => {{ decode_tlv!($reader, $field, required) }}; + ($reader: expr, $field: ident, (static_value, $value: expr)) => {{ + }}; ($reader: expr, $field: ident, required) => {{ $field = $crate::util::ser::Readable::read(&mut $reader)?; }}; @@ -195,12 +215,28 @@ macro_rules! decode_tlv { }}; } +macro_rules! _decode_tlv_stream_match_check { + ($val: ident, $type: expr, (static_value, $value: expr)) => { false }; + ($val: ident, $type: expr, $fieldty: tt) => { $val == $type } +} + // `$decode_custom_tlv` is a closure that may be optionally provided to handle custom message types. // If it is provided, it will be called with the custom type and the `FixedLengthReader` containing // the message contents. It should return `Ok(true)` if the custom message is successfully parsed, // `Ok(false)` if the message type is unknown, and `Err(DecodeError)` if parsing fails. macro_rules! decode_tlv_stream { ($stream: expr, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*} + $(, $decode_custom_tlv: expr)?) => { { + let rewind = |_, _| { unreachable!() }; + use core::ops::RangeBounds; + decode_tlv_stream_range!( + $stream, .., rewind, {$(($type, $field, $fieldty)),*} $(, $decode_custom_tlv)? + ); + } } +} + +macro_rules! decode_tlv_stream_range { + ($stream: expr, $range: expr, $rewind: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*} $(, $decode_custom_tlv: expr)?) => { { use $crate::ln::msgs::DecodeError; let mut last_seen_type: Option = None; @@ -215,7 +251,7 @@ macro_rules! decode_tlv_stream { // UnexpectedEof. This should in every case be largely cosmetic, but its nice to // pass the TLV test vectors exactly, which requre this distinction. let mut tracking_reader = ser::ReadTrackingReader::new(&mut stream_ref); - match $crate::util::ser::Readable::read(&mut tracking_reader) { + match <$crate::util::ser::BigSize as $crate::util::ser::Readable>::read(&mut tracking_reader) { Err(DecodeError::ShortRead) => { if !tracking_reader.have_read { break 'tlv_read; @@ -224,7 +260,15 @@ macro_rules! decode_tlv_stream { } }, Err(e) => return Err(e), - Ok(t) => t, + Ok(t) => if $range.contains(&t.0) { t } else { + drop(tracking_reader); + + // Assumes the type id is minimally encoded, which is enforced on read. + use $crate::util::ser::Writeable; + let bytes_read = t.serialized_length(); + $rewind(stream_ref, bytes_read); + break 'tlv_read; + }, } }; @@ -237,7 +281,7 @@ macro_rules! decode_tlv_stream { } // As we read types, make sure we hit every required type: $({ - check_tlv_order!(last_seen_type, typ, $type, $field, $fieldty); + check_decoded_tlv_order!(last_seen_type, typ, $type, $field, $fieldty); })* last_seen_type = Some(typ.0); @@ -245,7 +289,7 @@ macro_rules! decode_tlv_stream { let length: ser::BigSize = $crate::util::ser::Readable::read(&mut stream_ref)?; let mut s = ser::FixedLengthReader::new(&mut stream_ref, length.0); match typ.0 { - $($type => { + $(_t if _decode_tlv_stream_match_check!(_t, $type, $fieldty) => { decode_tlv!(s, $field, $fieldty); if s.bytes_remain() { s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes @@ -386,6 +430,9 @@ macro_rules! init_tlv_based_struct_field { ($field: ident, (default_value, $default: expr)) => { $field.0.unwrap() }; + ($field: ident, (static_value, $value: expr)) => { + $field + }; ($field: ident, option) => { $field }; @@ -401,6 +448,9 @@ macro_rules! init_tlv_field_var { ($field: ident, (default_value, $default: expr)) => { let mut $field = $crate::util::ser::OptionDeserWrapper(None); }; + ($field: ident, (static_value, $value: expr)) => { + let $field; + }; ($field: ident, required) => { let mut $field = $crate::util::ser::OptionDeserWrapper(None); }; @@ -412,6 +462,18 @@ macro_rules! init_tlv_field_var { }; } +macro_rules! init_and_read_tlv_fields { + ($reader: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => { + $( + init_tlv_field_var!($field, $fieldty); + )* + + read_tlv_fields!($reader, { + $(($type, $field, $fieldty)),* + }); + } +} + /// Implements Readable/Writeable for a struct storing it as a set of TLVs /// If $fieldty is `required`, then $field is a required field that is not an Option nor a Vec. /// If $fieldty is `option`, then $field is optional field. @@ -446,10 +508,7 @@ macro_rules! impl_writeable_tlv_based { impl $crate::util::ser::Readable for $st { fn read(reader: &mut R) -> Result { - $( - init_tlv_field_var!($field, $fieldty); - )* - read_tlv_fields!(reader, { + init_and_read_tlv_fields!(reader, { $(($type, $field, $fieldty)),* }); Ok(Self { @@ -472,19 +531,20 @@ macro_rules! impl_writeable_tlv_based { /// [`Readable`]: crate::util::ser::Readable /// [`Writeable`]: crate::util::ser::Writeable macro_rules! tlv_stream { - ($name:ident, $nameref:ident, { + ($name:ident, $nameref:ident, $range:expr, { $(($type:expr, $field:ident : $fieldty:tt)),* $(,)* }) => { #[derive(Debug)] - struct $name { + pub(super) struct $name { $( - $field: Option, + pub(super) $field: Option, )* } - pub(crate) struct $nameref<'a> { + #[derive(Debug, PartialEq)] + pub(super) struct $nameref<'a> { $( - pub(crate) $field: Option, + pub(super) $field: Option, )* } @@ -497,12 +557,15 @@ macro_rules! tlv_stream { } } - impl $crate::util::ser::Readable for $name { - fn read(reader: &mut R) -> Result { + impl $crate::util::ser::SeekReadable for $name { + fn read(reader: &mut R) -> Result { $( init_tlv_field_var!($field, option); )* - decode_tlv_stream!(reader, { + let rewind = |cursor: &mut R, offset: usize| { + cursor.seek($crate::io::SeekFrom::Current(-(offset as i64))).expect(""); + }; + decode_tlv_stream_range!(reader, $range, rewind, { $(($type, $field, (option, encoding: $fieldty))),* }); @@ -583,10 +646,7 @@ macro_rules! impl_writeable_tlv_based_enum_upgradable { // Because read_tlv_fields creates a labeled loop, we cannot call it twice // in the same function body. Instead, we define a closure and call it. let f = || { - $( - init_tlv_field_var!($field, $fieldty); - )* - read_tlv_fields!(reader, { + init_and_read_tlv_fields!(reader, { $(($type, $field, $fieldty)),* }); Ok(Some($st::$variant_name { @@ -636,10 +696,7 @@ macro_rules! impl_writeable_tlv_based_enum { // Because read_tlv_fields creates a labeled loop, we cannot call it twice // in the same function body. Instead, we define a closure and call it. let f = || { - $( - init_tlv_field_var!($field, $fieldty); - )* - read_tlv_fields!(reader, { + init_and_read_tlv_fields!(reader, { $(($type, $field, $fieldty)),* }); Ok($st::$variant_name { @@ -654,7 +711,7 @@ macro_rules! impl_writeable_tlv_based_enum { Ok($st::$tuple_variant_name(Readable::read(reader)?)) }),* _ => { - Err(DecodeError::UnknownRequiredFeature) + Err($crate::ln::msgs::DecodeError::UnknownRequiredFeature) }, } }