cf780ef06b2ab3093a93dc7167398302f170b5cd
[rust-lightning] / lightning / src / util / ser_macros.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 macro_rules! encode_tlv {
11         ($stream: expr, {$(($type: expr, $field: expr)),*}, {$(($optional_type: expr, $optional_field: expr)),*}) => { {
12                 #[allow(unused_imports)]
13                 use util::ser::BigSize;
14                 // Fields must be serialized in order, so we have to potentially switch between optional
15                 // fields and normal fields while serializing. Thus, we end up having to loop over the type
16                 // counts.
17                 // Sadly, while LLVM does appear smart enough to make `max_field` a constant, it appears to
18                 // refuse to unroll the loop. If we have enough entries that this is slow we can revisit
19                 // this design in the future.
20                 #[allow(unused_mut)]
21                 let mut max_field: u64 = 0;
22                 $(
23                         if $type >= max_field { max_field = $type + 1; }
24                 )*
25                 $(
26                         if $optional_type >= max_field { max_field = $optional_type + 1; }
27                 )*
28                 #[allow(unused_variables)]
29                 for i in 0..max_field {
30                         $(
31                                 if i == $type {
32                                         BigSize($type).write($stream)?;
33                                         BigSize($field.serialized_length() as u64).write($stream)?;
34                                         $field.write($stream)?;
35                                 }
36                         )*
37                         $(
38                                 if i == $optional_type {
39                                         if let Some(ref field) = $optional_field {
40                                                 BigSize($optional_type).write($stream)?;
41                                                 BigSize(field.serialized_length() as u64).write($stream)?;
42                                                 field.write($stream)?;
43                                         }
44                                 }
45                         )*
46                 }
47         } }
48 }
49
50 macro_rules! get_varint_length_prefixed_tlv_length {
51         ({$(($type: expr, $field: expr)),*}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}) => { {
52                 use util::ser::LengthCalculatingWriter;
53                 #[allow(unused_mut)]
54                 let mut len = LengthCalculatingWriter(0);
55                 {
56                         $(
57                                 BigSize($type).write(&mut len).expect("No in-memory data may fail to serialize");
58                                 let field_len = $field.serialized_length();
59                                 BigSize(field_len as u64).write(&mut len).expect("No in-memory data may fail to serialize");
60                                 len.0 += field_len;
61                         )*
62                         $(
63                                 if let Some(ref field) = $optional_field {
64                                         BigSize($optional_type).write(&mut len).expect("No in-memory data may fail to serialize");
65                                         let field_len = field.serialized_length();
66                                         BigSize(field_len as u64).write(&mut len).expect("No in-memory data may fail to serialize");
67                                         len.0 += field_len;
68                                 }
69                         )*
70                 }
71                 len.0
72         } }
73 }
74
75 macro_rules! encode_varint_length_prefixed_tlv {
76         ($stream: expr, {$(($type: expr, $field: expr)),*}, {$(($optional_type: expr, $optional_field: expr)),*}) => { {
77                 use util::ser::BigSize;
78                 let len = get_varint_length_prefixed_tlv_length!({ $(($type, $field)),* }, { $(($optional_type, $optional_field)),* });
79                 BigSize(len as u64).write($stream)?;
80                 encode_tlv!($stream, { $(($type, $field)),* }, { $(($optional_type, $optional_field)),* });
81         } }
82 }
83
84 macro_rules! decode_tlv {
85         ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),*}, {$(($type: expr, $field: ident)),*}) => { {
86                 use ln::msgs::DecodeError;
87                 let mut last_seen_type: Option<u64> = None;
88                 'tlv_read: loop {
89                         use util::ser;
90
91                         // First decode the type of this TLV:
92                         let typ: ser::BigSize = {
93                                 // We track whether any bytes were read during the consensus_decode call to
94                                 // determine whether we should break or return ShortRead if we get an
95                                 // UnexpectedEof. This should in every case be largely cosmetic, but its nice to
96                                 // pass the TLV test vectors exactly, which requre this distinction.
97                                 let mut tracking_reader = ser::ReadTrackingReader::new($stream);
98                                 match ser::Readable::read(&mut tracking_reader) {
99                                         Err(DecodeError::ShortRead) => {
100                                                 if !tracking_reader.have_read {
101                                                         break 'tlv_read
102                                                 } else {
103                                                         Err(DecodeError::ShortRead)?
104                                                 }
105                                         },
106                                         Err(e) => Err(e)?,
107                                         Ok(t) => t,
108                                 }
109                         };
110
111                         // Types must be unique and monotonically increasing:
112                         match last_seen_type {
113                                 Some(t) if typ.0 <= t => {
114                                         Err(DecodeError::InvalidValue)?
115                                 },
116                                 _ => {},
117                         }
118                         // As we read types, make sure we hit every required type:
119                         $({
120                                 #[allow(unused_comparisons)] // Note that $reqtype may be 0 making the second comparison always true
121                                 let invalid_order = (last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype) && typ.0 > $reqtype;
122                                 if invalid_order {
123                                         Err(DecodeError::InvalidValue)?
124                                 }
125                         })*
126                         last_seen_type = Some(typ.0);
127
128                         // Finally, read the length and value itself:
129                         let length: ser::BigSize = Readable::read($stream)?;
130                         let mut s = ser::FixedLengthReader::new($stream, length.0);
131                         match typ.0 {
132                                 $($reqtype => {
133                                         $reqfield = ser::Readable::read(&mut s)?;
134                                         if s.bytes_remain() {
135                                                 s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes
136                                                 Err(DecodeError::InvalidValue)?
137                                         }
138                                 },)*
139                                 $($type => {
140                                         $field = Some(ser::Readable::read(&mut s)?);
141                                         if s.bytes_remain() {
142                                                 s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes
143                                                 Err(DecodeError::InvalidValue)?
144                                         }
145                                 },)*
146                                 x if x % 2 == 0 => {
147                                         Err(DecodeError::UnknownRequiredFeature)?
148                                 },
149                                 _ => {},
150                         }
151                         s.eat_remaining()?;
152                 }
153                 // Make sure we got to each required type after we've read every TLV:
154                 $({
155                         #[allow(unused_comparisons)] // Note that $reqtype may be 0 making the second comparison always true
156                         let missing_req_type = last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype;
157                         if missing_req_type {
158                                 Err(DecodeError::InvalidValue)?
159                         }
160                 })*
161         } }
162 }
163
164 macro_rules! impl_writeable {
165         ($st:ident, $len: expr, {$($field:ident),*}) => {
166                 impl ::util::ser::Writeable for $st {
167                         fn write<W: ::util::ser::Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
168                                 if $len != 0 {
169                                         w.size_hint($len);
170                                 }
171                                 #[cfg(any(test, feature = "fuzztarget"))]
172                                 {
173                                         // In tests, assert that the hard-coded length matches the actual one
174                                         if $len != 0 {
175                                                 use util::ser::LengthCalculatingWriter;
176                                                 let mut len_calc = LengthCalculatingWriter(0);
177                                                 $( self.$field.write(&mut len_calc).expect("No in-memory data may fail to serialize"); )*
178                                                 assert_eq!(len_calc.0, $len);
179                                                 assert_eq!(self.serialized_length(), $len);
180                                         }
181                                 }
182                                 $( self.$field.write(w)?; )*
183                                 Ok(())
184                         }
185
186                         #[inline]
187                         fn serialized_length(&self) -> usize {
188                                 if $len == 0 || cfg!(any(test, feature = "fuzztarget")) {
189                                         let mut len_calc = 0;
190                                         $( len_calc += self.$field.serialized_length(); )*
191                                         if $len != 0 {
192                                                 // In tests, assert that the hard-coded length matches the actual one
193                                                 assert_eq!(len_calc, $len);
194                                         } else {
195                                                 return len_calc;
196                                         }
197                                 }
198                                 $len
199                         }
200                 }
201
202                 impl ::util::ser::Readable for $st {
203                         fn read<R: ::std::io::Read>(r: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
204                                 Ok(Self {
205                                         $($field: ::util::ser::Readable::read(r)?),*
206                                 })
207                         }
208                 }
209         }
210 }
211 macro_rules! impl_writeable_len_match {
212         ($struct: ident, $cmp: tt, ($calc_len: expr), {$({$match: pat, $length: expr}),*}, {$($field:ident),*}) => {
213                 impl Writeable for $struct {
214                         fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
215                                 let len = match *self {
216                                         $($match => $length,)*
217                                 };
218                                 w.size_hint(len);
219                                 #[cfg(any(test, feature = "fuzztarget"))]
220                                 {
221                                         // In tests, assert that the hard-coded length matches the actual one
222                                         use util::ser::LengthCalculatingWriter;
223                                         let mut len_calc = LengthCalculatingWriter(0);
224                                         $( self.$field.write(&mut len_calc).expect("No in-memory data may fail to serialize"); )*
225                                         assert!(len_calc.0 $cmp len);
226                                         assert_eq!(len_calc.0, self.serialized_length());
227                                 }
228                                 $( self.$field.write(w)?; )*
229                                 Ok(())
230                         }
231
232                         #[inline]
233                         fn serialized_length(&self) -> usize {
234                                 if $calc_len || cfg!(any(test, feature = "fuzztarget")) {
235                                         let mut len_calc = 0;
236                                         $( len_calc += self.$field.serialized_length(); )*
237                                         if !$calc_len {
238                                                 assert_eq!(len_calc, match *self {
239                                                         $($match => $length,)*
240                                                 });
241                                         }
242                                         return len_calc
243                                 }
244                                 match *self {
245                                         $($match => $length,)*
246                                 }
247                         }
248                 }
249
250                 impl ::util::ser::Readable for $struct {
251                         fn read<R: ::std::io::Read>(r: &mut R) -> Result<Self, DecodeError> {
252                                 Ok(Self {
253                                         $($field: Readable::read(r)?),*
254                                 })
255                         }
256                 }
257         };
258         ($struct: ident, $cmp: tt, {$({$match: pat, $length: expr}),*}, {$($field:ident),*}) => {
259                 impl_writeable_len_match!($struct, $cmp, (true), { $({ $match, $length }),* }, { $($field),* });
260         };
261         ($struct: ident, {$({$match: pat, $length: expr}),*}, {$($field:ident),*}) => {
262                 impl_writeable_len_match!($struct, ==, (false), { $({ $match, $length }),* }, { $($field),* });
263         }
264 }
265
266 /// Write out two bytes to indicate the version of an object.
267 /// $this_version represents a unique version of a type. Incremented whenever the type's
268 ///               serialization format has changed or has a new interpretation. Used by a type's
269 ///               reader to determine how to interpret fields or if it can understand a serialized
270 ///               object.
271 /// $min_version_that_can_read_this is the minimum reader version which can understand this
272 ///                                 serialized object. Previous versions will simply err with a
273 ///                                 DecodeError::UnknownVersion.
274 ///
275 /// Updates to either $this_version or $min_version_that_can_read_this should be included in
276 /// release notes.
277 ///
278 /// Both version fields can be specific to this type of object.
279 macro_rules! write_ver_prefix {
280         ($stream: expr, $this_version: expr, $min_version_that_can_read_this: expr) => {
281                 $stream.write_all(&[$this_version; 1])?;
282                 $stream.write_all(&[$min_version_that_can_read_this; 1])?;
283         }
284 }
285
286 /// Writes out a suffix to an object which contains potentially backwards-compatible, optional
287 /// fields which old nodes can happily ignore.
288 ///
289 /// It is written out in TLV format and, as with all TLV fields, unknown even fields cause a
290 /// DecodeError::UnknownRequiredFeature error, with unknown odd fields ignored.
291 ///
292 /// This is the preferred method of adding new fields that old nodes can ignore and still function
293 /// correctly.
294 macro_rules! write_tlv_fields {
295         ($stream: expr, {$(($type: expr, $field: expr)),* $(,)*}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}) => {
296                 encode_varint_length_prefixed_tlv!($stream, {$(($type, $field)),*} , {$(($optional_type, $optional_field)),*});
297         }
298 }
299
300 /// Reads a prefix added by write_ver_prefix!(), above. Takes the current version of the
301 /// serialization logic for this object. This is compared against the
302 /// $min_version_that_can_read_this added by write_ver_prefix!().
303 macro_rules! read_ver_prefix {
304         ($stream: expr, $this_version: expr) => { {
305                 let ver: u8 = Readable::read($stream)?;
306                 let min_ver: u8 = Readable::read($stream)?;
307                 if min_ver > $this_version {
308                         return Err(DecodeError::UnknownVersion);
309                 }
310                 ver
311         } }
312 }
313
314 /// Reads a suffix added by write_tlv_fields.
315 macro_rules! read_tlv_fields {
316         ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),* $(,)*}, {$(($type: expr, $field: ident)),* $(,)*}) => { {
317                 let tlv_len = ::util::ser::BigSize::read($stream)?;
318                 let mut rd = ::util::ser::FixedLengthReader::new($stream, tlv_len.0);
319                 decode_tlv!(&mut rd, {$(($reqtype, $reqfield)),*}, {$(($type, $field)),*});
320                 rd.eat_remaining().map_err(|_| ::ln::msgs::DecodeError::ShortRead)?;
321         } }
322 }
323
324 // If we naively create a struct in impl_writeable_tlv_based below, we may end up returning
325 // `Self { ,,vecfield: vecfield }` which is obviously incorrect. Instead, we have to match here to
326 // detect at least one empty field set and skip the potentially-extra comma.
327 macro_rules! _init_tlv_based_struct {
328         ($($type: ident)::*, {}, {$($field: ident),*}, {$($vecfield: ident),*}) => {
329                 Ok($($type)::* {
330                         $($field),*,
331                         $($vecfield: $vecfield.unwrap().0),*
332                 })
333         };
334         ($($type: ident)::*, {$($reqfield: ident),*}, {}, {$($vecfield: ident),*}) => {
335                 Ok($($type)::* {
336                         $($reqfield: $reqfield.0.unwrap()),*,
337                         $($vecfield: $vecfield.unwrap().0),*
338                 })
339         };
340         ($($type: ident)::*, {$($reqfield: ident),*}, {$($field: ident),*}, {}) => {
341                 Ok($($type)::* {
342                         $($reqfield: $reqfield.0.unwrap()),*,
343                         $($field),*
344                 })
345         };
346         ($($type: ident)::*, {$($reqfield: ident),*}, {$($field: ident),*}, {$($vecfield: ident),*}) => {
347                 Ok($($type)::* {
348                         $($reqfield: $reqfield.0.unwrap()),*,
349                         $($field),*,
350                         $($vecfield: $vecfield.unwrap().0),*
351                 })
352         }
353 }
354
355 // If we don't have any optional types below, but do have some vec types, we end up calling
356 // `write_tlv_field!($stream, {..}, {, (vec_ty, vec_val)})`, which is obviously broken.
357 // Instead, for write and read we match the missing values and skip the extra comma.
358 macro_rules! _write_tlv_fields {
359         ($stream: expr, {$(($type: expr, $field: expr)),* $(,)*}, {}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}) => {
360                 write_tlv_fields!($stream, {$(($type, $field)),*} , {$(($optional_type, $optional_field)),*});
361         };
362         ($stream: expr, {$(($type: expr, $field: expr)),* $(,)*}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}, {$(($optional_type_2: expr, $optional_field_2: expr)),* $(,)*}) => {
363                 write_tlv_fields!($stream, {$(($type, $field)),*} , {$(($optional_type, $optional_field)),*, $(($optional_type_2, $optional_field_2)),*});
364         }
365 }
366 macro_rules! _get_tlv_len {
367         ({$(($type: expr, $field: expr)),* $(,)*}, {}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}) => {
368                 get_varint_length_prefixed_tlv_length!({$(($type, $field)),*} , {$(($optional_type, $optional_field)),*})
369         };
370         ({$(($type: expr, $field: expr)),* $(,)*}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}, {$(($optional_type_2: expr, $optional_field_2: expr)),* $(,)*}) => {
371                 get_varint_length_prefixed_tlv_length!({$(($type, $field)),*} , {$(($optional_type, $optional_field)),*, $(($optional_type_2, $optional_field_2)),*})
372         }
373 }
374 macro_rules! _read_tlv_fields {
375         ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),* $(,)*}, {}, {$(($type: expr, $field: ident)),* $(,)*}) => {
376                 read_tlv_fields!($stream, {$(($reqtype, $reqfield)),*}, {$(($type, $field)),*});
377         };
378         ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),* $(,)*}, {$(($type: expr, $field: ident)),* $(,)*}, {$(($type_2: expr, $field_2: ident)),* $(,)*}) => {
379                 read_tlv_fields!($stream, {$(($reqtype, $reqfield)),*}, {$(($type, $field)),*, $(($type_2, $field_2)),*});
380         }
381 }
382
383 /// Implements Readable/Writeable for a struct storing it as a set of TLVs
384 /// First block includes all the required fields including a dummy value which is used during
385 /// deserialization but which will never be exposed to other code.
386 /// The second block includes optional fields.
387 /// The third block includes any Vecs which need to have their individual elements serialized.
388 macro_rules! impl_writeable_tlv_based {
389         ($st: ident, {$(($reqtype: expr, $reqfield: ident)),* $(,)*}, {$(($type: expr, $field: ident)),* $(,)*}, {$(($vectype: expr, $vecfield: ident)),* $(,)*}) => {
390                 impl ::util::ser::Writeable for $st {
391                         fn write<W: ::util::ser::Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
392                                 _write_tlv_fields!(writer, {
393                                         $(($reqtype, self.$reqfield)),*
394                                 }, {
395                                         $(($type, self.$field)),*
396                                 }, {
397                                         $(($vectype, Some(::util::ser::VecWriteWrapper(&self.$vecfield)))),*
398                                 });
399                                 Ok(())
400                         }
401
402                         #[inline]
403                         fn serialized_length(&self) -> usize {
404                                 let len = _get_tlv_len!({
405                                         $(($reqtype, self.$reqfield)),*
406                                 }, {
407                                         $(($type, self.$field)),*
408                                 }, {
409                                         $(($vectype, Some(::util::ser::VecWriteWrapper(&self.$vecfield)))),*
410                                 });
411                                 use util::ser::{BigSize, LengthCalculatingWriter};
412                                 let mut len_calc = LengthCalculatingWriter(0);
413                                 BigSize(len as u64).write(&mut len_calc).expect("No in-memory data may fail to serialize");
414                                 len + len_calc.0
415                         }
416                 }
417
418                 impl ::util::ser::Readable for $st {
419                         fn read<R: ::std::io::Read>(reader: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
420                                 $(
421                                         let mut $reqfield = ::util::ser::OptionDeserWrapper(None);
422                                 )*
423                                 $(
424                                         let mut $field = None;
425                                 )*
426                                 $(
427                                         let mut $vecfield = Some(::util::ser::VecReadWrapper(Vec::new()));
428                                 )*
429                                 _read_tlv_fields!(reader, {
430                                         $(($reqtype, $reqfield)),*
431                                 }, {
432                                         $(($type, $field)),*
433                                 }, {
434                                         $(($vectype, $vecfield)),*
435                                 });
436                                 _init_tlv_based_struct!($st, {$($reqfield),*}, {$($field),*}, {$($vecfield),*})
437                         }
438                 }
439         }
440 }
441
442 /// Implement Readable and Writeable for an enum, with struct variants stored as TLVs and tuple
443 /// variants stored directly.
444 /// The format is, for example
445 /// impl_writeable_tlv_based_enum!(EnumName,
446 ///   (0, StructVariantA) => {(0, variant_field)}, {(1, variant_optional_field)}, {},
447 ///   (1, StructVariantB) => {(0, variant_field_a), (1, variant_field_b)}, {}, {(2, variant_vec_field)};
448 ///   (2, TupleVariantA), (3, TupleVariantB),
449 /// );
450 /// The type is written as a single byte, followed by any variant data.
451 /// Attempts to read an unknown type byte result in DecodeError::UnknownRequiredFeature.
452 macro_rules! impl_writeable_tlv_based_enum {
453         ($st: ident, $(($variant_id: expr, $variant_name: ident) =>
454                 {$(($reqtype: expr, $reqfield: ident)),* $(,)*},
455                 {$(($type: expr, $field: ident)),* $(,)*},
456                 {$(($vectype: expr, $vecfield: ident)),* $(,)*}
457         ),* $(,)*;
458         $(($tuple_variant_id: expr, $tuple_variant_name: ident)),*  $(,)*) => {
459                 impl ::util::ser::Writeable for $st {
460                         fn write<W: ::util::ser::Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
461                                 match self {
462                                         $($st::$variant_name { $(ref $reqfield),* $(ref $field),*, $(ref $vecfield),* } => {
463                                                 let id: u8 = $variant_id;
464                                                 id.write(writer)?;
465                                                 _write_tlv_fields!(writer, {
466                                                         $(($reqtype, $reqfield)),*
467                                                 }, {
468                                                         $(($type, $field)),*
469                                                 }, {
470                                                         $(($vectype, Some(::util::ser::VecWriteWrapper(&$vecfield)))),*
471                                                 });
472                                         }),*
473                                         $($st::$tuple_variant_name (ref field) => {
474                                                 let id: u8 = $tuple_variant_id;
475                                                 id.write(writer)?;
476                                                 field.write(writer)?;
477                                         }),*
478                                 }
479                                 Ok(())
480                         }
481                 }
482
483                 impl ::util::ser::Readable for $st {
484                         fn read<R: ::std::io::Read>(reader: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
485                                 let id: u8 = ::util::ser::Readable::read(reader)?;
486                                 match id {
487                                         $($variant_id => {
488                                                 // Because read_tlv_fields creates a labeled loop, we cannot call it twice
489                                                 // in the same function body. Instead, we define a closure and call it.
490                                                 let f = || {
491                                                         $(
492                                                                 let mut $reqfield = ::util::ser::OptionDeserWrapper(None);
493                                                         )*
494                                                         $(
495                                                                 let mut $field = None;
496                                                         )*
497                                                         $(
498                                                                 let mut $vecfield = Some(::util::ser::VecReadWrapper(Vec::new()));
499                                                         )*
500                                                         _read_tlv_fields!(reader, {
501                                                                 $(($reqtype, $reqfield)),*
502                                                         }, {
503                                                                 $(($type, $field)),*
504                                                         }, {
505                                                                 $(($vectype, $vecfield)),*
506                                                         });
507                                                         _init_tlv_based_struct!($st::$variant_name, {$($reqfield),*}, {$($field),*}, {$($vecfield),*})
508                                                 };
509                                                 f()
510                                         }),*
511                                         $($tuple_variant_id => {
512                                                 Ok($st::$tuple_variant_name(Readable::read(reader)?))
513                                         }),*
514                                         _ => {
515                                                 Err(DecodeError::UnknownRequiredFeature)?
516                                         },
517                                 }
518                         }
519                 }
520         }
521 }
522
523 #[cfg(test)]
524 mod tests {
525         use prelude::*;
526         use std::io::Cursor;
527         use ln::msgs::DecodeError;
528         use util::ser::{Readable, Writeable, HighZeroBytesDroppedVarInt, VecWriter};
529         use bitcoin::secp256k1::PublicKey;
530
531         // The BOLT TLV test cases don't include any tests which use our "required-value" logic since
532         // the encoding layer in the BOLTs has no such concept, though it makes our macros easier to
533         // work with so they're baked into the decoder. Thus, we have a few additional tests below
534         fn tlv_reader(s: &[u8]) -> Result<(u64, u32, Option<u32>), DecodeError> {
535                 let mut s = Cursor::new(s);
536                 let mut a: u64 = 0;
537                 let mut b: u32 = 0;
538                 let mut c: Option<u32> = None;
539                 decode_tlv!(&mut s, {(2, a), (3, b)}, {(4, c)});
540                 Ok((a, b, c))
541         }
542
543         #[test]
544         fn tlv_v_short_read() {
545                 // We only expect a u32 for type 3 (which we are given), but the L says its 8 bytes.
546                 if let Err(DecodeError::ShortRead) = tlv_reader(&::hex::decode(
547                                 concat!("0100", "0208deadbeef1badbeef", "0308deadbeef")
548                                 ).unwrap()[..]) {
549                 } else { panic!(); }
550         }
551
552         #[test]
553         fn tlv_types_out_of_order() {
554                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
555                                 concat!("0100", "0304deadbeef", "0208deadbeef1badbeef")
556                                 ).unwrap()[..]) {
557                 } else { panic!(); }
558                 // ...even if its some field we don't understand
559                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
560                                 concat!("0208deadbeef1badbeef", "0100", "0304deadbeef")
561                                 ).unwrap()[..]) {
562                 } else { panic!(); }
563         }
564
565         #[test]
566         fn tlv_req_type_missing_or_extra() {
567                 // It's also bad if they included even fields we don't understand
568                 if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&::hex::decode(
569                                 concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600")
570                                 ).unwrap()[..]) {
571                 } else { panic!(); }
572                 // ... or if they're missing fields we need
573                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
574                                 concat!("0100", "0208deadbeef1badbeef")
575                                 ).unwrap()[..]) {
576                 } else { panic!(); }
577                 // ... even if that field is even
578                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
579                                 concat!("0304deadbeef", "0500")
580                                 ).unwrap()[..]) {
581                 } else { panic!(); }
582         }
583
584         #[test]
585         fn tlv_simple_good_cases() {
586                 assert_eq!(tlv_reader(&::hex::decode(
587                                 concat!("0208deadbeef1badbeef", "03041bad1dea")
588                                 ).unwrap()[..]).unwrap(),
589                         (0xdeadbeef1badbeef, 0x1bad1dea, None));
590                 assert_eq!(tlv_reader(&::hex::decode(
591                                 concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304")
592                                 ).unwrap()[..]).unwrap(),
593                         (0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
594         }
595
596         // BOLT TLV test cases
597         fn tlv_reader_n1(s: &[u8]) -> Result<(Option<HighZeroBytesDroppedVarInt<u64>>, Option<u64>, Option<(PublicKey, u64, u64)>, Option<u16>), DecodeError> {
598                 let mut s = Cursor::new(s);
599                 let mut tlv1: Option<HighZeroBytesDroppedVarInt<u64>> = None;
600                 let mut tlv2: Option<u64> = None;
601                 let mut tlv3: Option<(PublicKey, u64, u64)> = None;
602                 let mut tlv4: Option<u16> = None;
603                 decode_tlv!(&mut s, {}, {(1, tlv1), (2, tlv2), (3, tlv3), (254, tlv4)});
604                 Ok((tlv1, tlv2, tlv3, tlv4))
605         }
606
607         #[test]
608         fn bolt_tlv_bogus_stream() {
609                 macro_rules! do_test {
610                         ($stream: expr, $reason: ident) => {
611                                 if let Err(DecodeError::$reason) = tlv_reader_n1(&::hex::decode($stream).unwrap()[..]) {
612                                 } else { panic!(); }
613                         }
614                 }
615
616                 // TLVs from the BOLT test cases which should not decode as either n1 or n2
617                 do_test!(concat!("fd01"), ShortRead);
618                 do_test!(concat!("fd0001", "00"), InvalidValue);
619                 do_test!(concat!("fd0101"), ShortRead);
620                 do_test!(concat!("0f", "fd"), ShortRead);
621                 do_test!(concat!("0f", "fd26"), ShortRead);
622                 do_test!(concat!("0f", "fd2602"), ShortRead);
623                 do_test!(concat!("0f", "fd0001", "00"), InvalidValue);
624                 do_test!(concat!("0f", "fd0201", "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), ShortRead);
625
626                 do_test!(concat!("12", "00"), UnknownRequiredFeature);
627                 do_test!(concat!("fd0102", "00"), UnknownRequiredFeature);
628                 do_test!(concat!("fe01000002", "00"), UnknownRequiredFeature);
629                 do_test!(concat!("ff0100000000000002", "00"), UnknownRequiredFeature);
630         }
631
632         #[test]
633         fn bolt_tlv_bogus_n1_stream() {
634                 macro_rules! do_test {
635                         ($stream: expr, $reason: ident) => {
636                                 if let Err(DecodeError::$reason) = tlv_reader_n1(&::hex::decode($stream).unwrap()[..]) {
637                                 } else { panic!(); }
638                         }
639                 }
640
641                 // TLVs from the BOLT test cases which should not decode as n1
642                 do_test!(concat!("01", "09", "ffffffffffffffffff"), InvalidValue);
643                 do_test!(concat!("01", "01", "00"), InvalidValue);
644                 do_test!(concat!("01", "02", "0001"), InvalidValue);
645                 do_test!(concat!("01", "03", "000100"), InvalidValue);
646                 do_test!(concat!("01", "04", "00010000"), InvalidValue);
647                 do_test!(concat!("01", "05", "0001000000"), InvalidValue);
648                 do_test!(concat!("01", "06", "000100000000"), InvalidValue);
649                 do_test!(concat!("01", "07", "00010000000000"), InvalidValue);
650                 do_test!(concat!("01", "08", "0001000000000000"), InvalidValue);
651                 do_test!(concat!("02", "07", "01010101010101"), ShortRead);
652                 do_test!(concat!("02", "09", "010101010101010101"), InvalidValue);
653                 do_test!(concat!("03", "21", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb"), ShortRead);
654                 do_test!(concat!("03", "29", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb0000000000000001"), ShortRead);
655                 do_test!(concat!("03", "30", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb000000000000000100000000000001"), ShortRead);
656                 do_test!(concat!("03", "31", "043da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb00000000000000010000000000000002"), InvalidValue);
657                 do_test!(concat!("03", "32", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb0000000000000001000000000000000001"), InvalidValue);
658                 do_test!(concat!("fd00fe", "00"), ShortRead);
659                 do_test!(concat!("fd00fe", "01", "01"), ShortRead);
660                 do_test!(concat!("fd00fe", "03", "010101"), InvalidValue);
661                 do_test!(concat!("00", "00"), UnknownRequiredFeature);
662
663                 do_test!(concat!("02", "08", "0000000000000226", "01", "01", "2a"), InvalidValue);
664                 do_test!(concat!("02", "08", "0000000000000231", "02", "08", "0000000000000451"), InvalidValue);
665                 do_test!(concat!("1f", "00", "0f", "01", "2a"), InvalidValue);
666                 do_test!(concat!("1f", "00", "1f", "01", "2a"), InvalidValue);
667
668                 // The last BOLT test modified to not require creating a new decoder for one trivial test.
669                 do_test!(concat!("ffffffffffffffffff", "00", "01", "00"), InvalidValue);
670         }
671
672         #[test]
673         fn bolt_tlv_valid_n1_stream() {
674                 macro_rules! do_test {
675                         ($stream: expr, $tlv1: expr, $tlv2: expr, $tlv3: expr, $tlv4: expr) => {
676                                 if let Ok((tlv1, tlv2, tlv3, tlv4)) = tlv_reader_n1(&::hex::decode($stream).unwrap()[..]) {
677                                         assert_eq!(tlv1.map(|v| v.0), $tlv1);
678                                         assert_eq!(tlv2, $tlv2);
679                                         assert_eq!(tlv3, $tlv3);
680                                         assert_eq!(tlv4, $tlv4);
681                                 } else { panic!(); }
682                         }
683                 }
684
685                 do_test!(concat!(""), None, None, None, None);
686                 do_test!(concat!("21", "00"), None, None, None, None);
687                 do_test!(concat!("fd0201", "00"), None, None, None, None);
688                 do_test!(concat!("fd00fd", "00"), None, None, None, None);
689                 do_test!(concat!("fd00ff", "00"), None, None, None, None);
690                 do_test!(concat!("fe02000001", "00"), None, None, None, None);
691                 do_test!(concat!("ff0200000000000001", "00"), None, None, None, None);
692
693                 do_test!(concat!("01", "00"), Some(0), None, None, None);
694                 do_test!(concat!("01", "01", "01"), Some(1), None, None, None);
695                 do_test!(concat!("01", "02", "0100"), Some(256), None, None, None);
696                 do_test!(concat!("01", "03", "010000"), Some(65536), None, None, None);
697                 do_test!(concat!("01", "04", "01000000"), Some(16777216), None, None, None);
698                 do_test!(concat!("01", "05", "0100000000"), Some(4294967296), None, None, None);
699                 do_test!(concat!("01", "06", "010000000000"), Some(1099511627776), None, None, None);
700                 do_test!(concat!("01", "07", "01000000000000"), Some(281474976710656), None, None, None);
701                 do_test!(concat!("01", "08", "0100000000000000"), Some(72057594037927936), None, None, None);
702                 do_test!(concat!("02", "08", "0000000000000226"), None, Some((0 << 30) | (0 << 5) | (550 << 0)), None, None);
703                 do_test!(concat!("03", "31", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb00000000000000010000000000000002"),
704                         None, None, Some((
705                                 PublicKey::from_slice(&::hex::decode("023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb").unwrap()[..]).unwrap(), 1, 2)),
706                         None);
707                 do_test!(concat!("fd00fe", "02", "0226"), None, None, None, Some(550));
708         }
709
710         fn do_simple_test_tlv_write() -> Result<(), ::std::io::Error> {
711                 let mut stream = VecWriter(Vec::new());
712
713                 stream.0.clear();
714                 encode_varint_length_prefixed_tlv!(&mut stream, { (1, 1u8) }, { (42, None::<u64>) });
715                 assert_eq!(stream.0, ::hex::decode("03010101").unwrap());
716
717                 stream.0.clear();
718                 encode_varint_length_prefixed_tlv!(&mut stream, { }, { (1, Some(1u8)) });
719                 assert_eq!(stream.0, ::hex::decode("03010101").unwrap());
720
721                 stream.0.clear();
722                 encode_varint_length_prefixed_tlv!(&mut stream, { (4, 0xabcdu16) }, { (42, None::<u64>) });
723                 assert_eq!(stream.0, ::hex::decode("040402abcd").unwrap());
724
725                 stream.0.clear();
726                 encode_varint_length_prefixed_tlv!(&mut stream, { (0xff, 0xabcdu16) }, { (42, None::<u64>) });
727                 assert_eq!(stream.0, ::hex::decode("06fd00ff02abcd").unwrap());
728
729                 stream.0.clear();
730                 encode_varint_length_prefixed_tlv!(&mut stream, { (0, 1u64), (0xff, HighZeroBytesDroppedVarInt(0u64)) }, { (42, None::<u64>) });
731                 assert_eq!(stream.0, ::hex::decode("0e00080000000000000001fd00ff00").unwrap());
732
733                 stream.0.clear();
734                 encode_varint_length_prefixed_tlv!(&mut stream, { (0xff, HighZeroBytesDroppedVarInt(0u64)) }, { (0, Some(1u64)) });
735                 assert_eq!(stream.0, ::hex::decode("0e00080000000000000001fd00ff00").unwrap());
736
737                 Ok(())
738         }
739
740         #[test]
741         fn simple_test_tlv_write() {
742                 do_simple_test_tlv_write().unwrap();
743         }
744 }