Append backwards-compat TLVs to serialization of larger structs
[rust-lightning] / lightning / src / util / ser_macros.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 macro_rules! encode_tlv {
11         ($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
12                 #[allow(unused_imports)]
13                 use util::ser::{BigSize, LengthCalculatingWriter};
14                 $(
15                         BigSize($type).write($stream)?;
16                         let mut len_calc = LengthCalculatingWriter(0);
17                         $field.write(&mut len_calc)?;
18                         BigSize(len_calc.0 as u64).write($stream)?;
19                         $field.write($stream)?;
20                 )*
21         } }
22 }
23
24 macro_rules! encode_varint_length_prefixed_tlv {
25         ($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
26                 use util::ser::{BigSize, LengthCalculatingWriter};
27                 #[allow(unused_mut)]
28                 let mut len = LengthCalculatingWriter(0);
29                 {
30                         $(
31                                 BigSize($type).write(&mut len)?;
32                                 let mut field_len = LengthCalculatingWriter(0);
33                                 $field.write(&mut field_len)?;
34                                 BigSize(field_len.0 as u64).write(&mut len)?;
35                                 len.0 += field_len.0;
36                         )*
37                 }
38
39                 BigSize(len.0 as u64).write($stream)?;
40                 encode_tlv!($stream, {
41                         $(($type, $field)),*
42                 });
43         } }
44 }
45
46 macro_rules! decode_tlv {
47         ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),*}, {$(($type: expr, $field: ident)),*}) => { {
48                 use ln::msgs::DecodeError;
49                 let mut last_seen_type: Option<u64> = None;
50                 'tlv_read: loop {
51                         use util::ser;
52
53                         // First decode the type of this TLV:
54                         let typ: ser::BigSize = {
55                                 // We track whether any bytes were read during the consensus_decode call to
56                                 // determine whether we should break or return ShortRead if we get an
57                                 // UnexpectedEof. This should in every case be largely cosmetic, but its nice to
58                                 // pass the TLV test vectors exactly, which requre this distinction.
59                                 let mut tracking_reader = ser::ReadTrackingReader::new($stream);
60                                 match ser::Readable::read(&mut tracking_reader) {
61                                         Err(DecodeError::ShortRead) => {
62                                                 if !tracking_reader.have_read {
63                                                         break 'tlv_read
64                                                 } else {
65                                                         Err(DecodeError::ShortRead)?
66                                                 }
67                                         },
68                                         Err(e) => Err(e)?,
69                                         Ok(t) => t,
70                                 }
71                         };
72
73                         // Types must be unique and monotonically increasing:
74                         match last_seen_type {
75                                 Some(t) if typ.0 <= t => {
76                                         Err(DecodeError::InvalidValue)?
77                                 },
78                                 _ => {},
79                         }
80                         // As we read types, make sure we hit every required type:
81                         $(if (last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype) && typ.0 > $reqtype {
82                                 Err(DecodeError::InvalidValue)?
83                         })*
84                         last_seen_type = Some(typ.0);
85
86                         // Finally, read the length and value itself:
87                         let length: ser::BigSize = Readable::read($stream)?;
88                         let mut s = ser::FixedLengthReader::new($stream, length.0);
89                         match typ.0 {
90                                 $($reqtype => {
91                                         $reqfield = ser::Readable::read(&mut s)?;
92                                         if s.bytes_remain() {
93                                                 s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes
94                                                 Err(DecodeError::InvalidValue)?
95                                         }
96                                 },)*
97                                 $($type => {
98                                         $field = Some(ser::Readable::read(&mut s)?);
99                                         if s.bytes_remain() {
100                                                 s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes
101                                                 Err(DecodeError::InvalidValue)?
102                                         }
103                                 },)*
104                                 x if x % 2 == 0 => {
105                                         Err(DecodeError::UnknownRequiredFeature)?
106                                 },
107                                 _ => {},
108                         }
109                         s.eat_remaining()?;
110                 }
111                 // Make sure we got to each required type after we've read every TLV:
112                 $(if last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype {
113                         Err(DecodeError::InvalidValue)?
114                 })*
115         } }
116 }
117
118 macro_rules! impl_writeable {
119         ($st:ident, $len: expr, {$($field:ident),*}) => {
120                 impl ::util::ser::Writeable for $st {
121                         fn write<W: ::util::ser::Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
122                                 if $len != 0 {
123                                         w.size_hint($len);
124                                 }
125                                 #[cfg(any(test, feature = "fuzztarget"))]
126                                 {
127                                         // In tests, assert that the hard-coded length matches the actual one
128                                         if $len != 0 {
129                                                 use util::ser::LengthCalculatingWriter;
130                                                 let mut len_calc = LengthCalculatingWriter(0);
131                                                 $( self.$field.write(&mut len_calc)?; )*
132                                                 assert_eq!(len_calc.0, $len);
133                                         }
134                                 }
135                                 $( self.$field.write(w)?; )*
136                                 Ok(())
137                         }
138                 }
139
140                 impl ::util::ser::Readable for $st {
141                         fn read<R: ::std::io::Read>(r: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
142                                 Ok(Self {
143                                         $($field: ::util::ser::Readable::read(r)?),*
144                                 })
145                         }
146                 }
147         }
148 }
149 macro_rules! impl_writeable_len_match {
150         ($struct: ident, $cmp: tt, {$({$match: pat, $length: expr}),*}, {$($field:ident),*}) => {
151                 impl Writeable for $struct {
152                         fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
153                                 let len = match *self {
154                                         $($match => $length,)*
155                                 };
156                                 w.size_hint(len);
157                                 #[cfg(any(test, feature = "fuzztarget"))]
158                                 {
159                                         // In tests, assert that the hard-coded length matches the actual one
160                                         use util::ser::LengthCalculatingWriter;
161                                         let mut len_calc = LengthCalculatingWriter(0);
162                                         $( self.$field.write(&mut len_calc)?; )*
163                                         assert!(len_calc.0 $cmp len);
164                                 }
165                                 $( self.$field.write(w)?; )*
166                                 Ok(())
167                         }
168                 }
169
170                 impl ::util::ser::Readable for $struct {
171                         fn read<R: ::std::io::Read>(r: &mut R) -> Result<Self, DecodeError> {
172                                 Ok(Self {
173                                         $($field: Readable::read(r)?),*
174                                 })
175                         }
176                 }
177         };
178         ($struct: ident, {$({$match: pat, $length: expr}),*}, {$($field:ident),*}) => {
179                 impl_writeable_len_match!($struct, ==, { $({ $match, $length }),* }, { $($field),* });
180         }
181 }
182
183 /// Write out two bytes to indicate the version of an object.
184 /// $this_version represents a unique version of a type. Incremented whenever the type's
185 ///               serialization format has changed or has a new interpretation. Used by a type's
186 ///               reader to determine how to interpret fields or if it can understand a serialized
187 ///               object.
188 /// $min_version_that_can_read_this is the minimum reader version which can understand this
189 ///                                 serialized object. Previous versions will simply err with a
190 ///                                 DecodeError::UnknownVersion.
191 ///
192 /// Updates to either $this_version or $min_version_that_can_read_this should be included in
193 /// release notes.
194 ///
195 /// Both version fields can be specific to this type of object.
196 macro_rules! write_ver_prefix {
197         ($stream: expr, $this_version: expr, $min_version_that_can_read_this: expr) => {
198                 $stream.write_all(&[$this_version; 1])?;
199                 $stream.write_all(&[$min_version_that_can_read_this; 1])?;
200         }
201 }
202
203 /// Writes out a suffix to an object which contains potentially backwards-compatible, optional
204 /// fields which old nodes can happily ignore.
205 ///
206 /// It is written out in TLV format and, as with all TLV fields, unknown even fields cause a
207 /// DecodeError::UnknownRequiredFeature error, with unknown odd fields ignored.
208 ///
209 /// This is the preferred method of adding new fields that old nodes can ignore and still function
210 /// correctly.
211 macro_rules! write_tlv_fields {
212         ($stream: expr, {$(($type: expr, $field: expr)),*}) => {
213                 encode_varint_length_prefixed_tlv!($stream, {$(($type, $field)),*});
214         }
215 }
216
217 /// Reads a prefix added by write_ver_prefix!(), above. Takes the current version of the
218 /// serialization logic for this object. This is compared against the
219 /// $min_version_that_can_read_this added by write_ver_prefix!().
220 macro_rules! read_ver_prefix {
221         ($stream: expr, $this_version: expr) => { {
222                 let ver: u8 = Readable::read($stream)?;
223                 let min_ver: u8 = Readable::read($stream)?;
224                 if min_ver > $this_version {
225                         return Err(DecodeError::UnknownVersion);
226                 }
227                 ver
228         } }
229 }
230
231 /// Reads a suffix added by write_tlv_fields.
232 macro_rules! read_tlv_fields {
233         ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),*}, {$(($type: expr, $field: ident)),*}) => { {
234                 let tlv_len = ::util::ser::BigSize::read($stream)?;
235                 let mut rd = ::util::ser::FixedLengthReader::new($stream, tlv_len.0);
236                 decode_tlv!(&mut rd, {$(($reqtype, $reqfield)),*}, {$(($type, $field)),*});
237                 rd.eat_remaining().map_err(|_| DecodeError::ShortRead)?;
238         } }
239 }
240
241 #[cfg(test)]
242 mod tests {
243         use std::io::{Cursor, Read};
244         use ln::msgs::DecodeError;
245         use util::ser::{Readable, Writeable, HighZeroBytesDroppedVarInt, VecWriter};
246         use bitcoin::secp256k1::PublicKey;
247
248         // The BOLT TLV test cases don't include any tests which use our "required-value" logic since
249         // the encoding layer in the BOLTs has no such concept, though it makes our macros easier to
250         // work with so they're baked into the decoder. Thus, we have a few additional tests below
251         fn tlv_reader(s: &[u8]) -> Result<(u64, u32, Option<u32>), DecodeError> {
252                 let mut s = Cursor::new(s);
253                 let mut a: u64 = 0;
254                 let mut b: u32 = 0;
255                 let mut c: Option<u32> = None;
256                 decode_tlv!(&mut s, {(2, a), (3, b)}, {(4, c)});
257                 Ok((a, b, c))
258         }
259
260         #[test]
261         fn tlv_v_short_read() {
262                 // We only expect a u32 for type 3 (which we are given), but the L says its 8 bytes.
263                 if let Err(DecodeError::ShortRead) = tlv_reader(&::hex::decode(
264                                 concat!("0100", "0208deadbeef1badbeef", "0308deadbeef")
265                                 ).unwrap()[..]) {
266                 } else { panic!(); }
267         }
268
269         #[test]
270         fn tlv_types_out_of_order() {
271                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
272                                 concat!("0100", "0304deadbeef", "0208deadbeef1badbeef")
273                                 ).unwrap()[..]) {
274                 } else { panic!(); }
275                 // ...even if its some field we don't understand
276                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
277                                 concat!("0208deadbeef1badbeef", "0100", "0304deadbeef")
278                                 ).unwrap()[..]) {
279                 } else { panic!(); }
280         }
281
282         #[test]
283         fn tlv_req_type_missing_or_extra() {
284                 // It's also bad if they included even fields we don't understand
285                 if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&::hex::decode(
286                                 concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600")
287                                 ).unwrap()[..]) {
288                 } else { panic!(); }
289                 // ... or if they're missing fields we need
290                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
291                                 concat!("0100", "0208deadbeef1badbeef")
292                                 ).unwrap()[..]) {
293                 } else { panic!(); }
294                 // ... even if that field is even
295                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
296                                 concat!("0304deadbeef", "0500")
297                                 ).unwrap()[..]) {
298                 } else { panic!(); }
299         }
300
301         #[test]
302         fn tlv_simple_good_cases() {
303                 assert_eq!(tlv_reader(&::hex::decode(
304                                 concat!("0208deadbeef1badbeef", "03041bad1dea")
305                                 ).unwrap()[..]).unwrap(),
306                         (0xdeadbeef1badbeef, 0x1bad1dea, None));
307                 assert_eq!(tlv_reader(&::hex::decode(
308                                 concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304")
309                                 ).unwrap()[..]).unwrap(),
310                         (0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
311         }
312
313         impl Readable for (PublicKey, u64, u64) {
314                 #[inline]
315                 fn read<R: Read>(reader: &mut R) -> Result<(PublicKey, u64, u64), DecodeError> {
316                         Ok((Readable::read(reader)?, Readable::read(reader)?, Readable::read(reader)?))
317                 }
318         }
319
320         // BOLT TLV test cases
321         fn tlv_reader_n1(s: &[u8]) -> Result<(Option<HighZeroBytesDroppedVarInt<u64>>, Option<u64>, Option<(PublicKey, u64, u64)>, Option<u16>), DecodeError> {
322                 let mut s = Cursor::new(s);
323                 let mut tlv1: Option<HighZeroBytesDroppedVarInt<u64>> = None;
324                 let mut tlv2: Option<u64> = None;
325                 let mut tlv3: Option<(PublicKey, u64, u64)> = None;
326                 let mut tlv4: Option<u16> = None;
327                 decode_tlv!(&mut s, {}, {(1, tlv1), (2, tlv2), (3, tlv3), (254, tlv4)});
328                 Ok((tlv1, tlv2, tlv3, tlv4))
329         }
330
331         #[test]
332         fn bolt_tlv_bogus_stream() {
333                 macro_rules! do_test {
334                         ($stream: expr, $reason: ident) => {
335                                 if let Err(DecodeError::$reason) = tlv_reader_n1(&::hex::decode($stream).unwrap()[..]) {
336                                 } else { panic!(); }
337                         }
338                 }
339
340                 // TLVs from the BOLT test cases which should not decode as either n1 or n2
341                 do_test!(concat!("fd01"), ShortRead);
342                 do_test!(concat!("fd0001", "00"), InvalidValue);
343                 do_test!(concat!("fd0101"), ShortRead);
344                 do_test!(concat!("0f", "fd"), ShortRead);
345                 do_test!(concat!("0f", "fd26"), ShortRead);
346                 do_test!(concat!("0f", "fd2602"), ShortRead);
347                 do_test!(concat!("0f", "fd0001", "00"), InvalidValue);
348                 do_test!(concat!("0f", "fd0201", "000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), ShortRead);
349
350                 do_test!(concat!("12", "00"), UnknownRequiredFeature);
351                 do_test!(concat!("fd0102", "00"), UnknownRequiredFeature);
352                 do_test!(concat!("fe01000002", "00"), UnknownRequiredFeature);
353                 do_test!(concat!("ff0100000000000002", "00"), UnknownRequiredFeature);
354         }
355
356         #[test]
357         fn bolt_tlv_bogus_n1_stream() {
358                 macro_rules! do_test {
359                         ($stream: expr, $reason: ident) => {
360                                 if let Err(DecodeError::$reason) = tlv_reader_n1(&::hex::decode($stream).unwrap()[..]) {
361                                 } else { panic!(); }
362                         }
363                 }
364
365                 // TLVs from the BOLT test cases which should not decode as n1
366                 do_test!(concat!("01", "09", "ffffffffffffffffff"), InvalidValue);
367                 do_test!(concat!("01", "01", "00"), InvalidValue);
368                 do_test!(concat!("01", "02", "0001"), InvalidValue);
369                 do_test!(concat!("01", "03", "000100"), InvalidValue);
370                 do_test!(concat!("01", "04", "00010000"), InvalidValue);
371                 do_test!(concat!("01", "05", "0001000000"), InvalidValue);
372                 do_test!(concat!("01", "06", "000100000000"), InvalidValue);
373                 do_test!(concat!("01", "07", "00010000000000"), InvalidValue);
374                 do_test!(concat!("01", "08", "0001000000000000"), InvalidValue);
375                 do_test!(concat!("02", "07", "01010101010101"), ShortRead);
376                 do_test!(concat!("02", "09", "010101010101010101"), InvalidValue);
377                 do_test!(concat!("03", "21", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb"), ShortRead);
378                 do_test!(concat!("03", "29", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb0000000000000001"), ShortRead);
379                 do_test!(concat!("03", "30", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb000000000000000100000000000001"), ShortRead);
380                 do_test!(concat!("03", "31", "043da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb00000000000000010000000000000002"), InvalidValue);
381                 do_test!(concat!("03", "32", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb0000000000000001000000000000000001"), InvalidValue);
382                 do_test!(concat!("fd00fe", "00"), ShortRead);
383                 do_test!(concat!("fd00fe", "01", "01"), ShortRead);
384                 do_test!(concat!("fd00fe", "03", "010101"), InvalidValue);
385                 do_test!(concat!("00", "00"), UnknownRequiredFeature);
386
387                 do_test!(concat!("02", "08", "0000000000000226", "01", "01", "2a"), InvalidValue);
388                 do_test!(concat!("02", "08", "0000000000000231", "02", "08", "0000000000000451"), InvalidValue);
389                 do_test!(concat!("1f", "00", "0f", "01", "2a"), InvalidValue);
390                 do_test!(concat!("1f", "00", "1f", "01", "2a"), InvalidValue);
391
392                 // The last BOLT test modified to not require creating a new decoder for one trivial test.
393                 do_test!(concat!("ffffffffffffffffff", "00", "01", "00"), InvalidValue);
394         }
395
396         #[test]
397         fn bolt_tlv_valid_n1_stream() {
398                 macro_rules! do_test {
399                         ($stream: expr, $tlv1: expr, $tlv2: expr, $tlv3: expr, $tlv4: expr) => {
400                                 if let Ok((tlv1, tlv2, tlv3, tlv4)) = tlv_reader_n1(&::hex::decode($stream).unwrap()[..]) {
401                                         assert_eq!(tlv1.map(|v| v.0), $tlv1);
402                                         assert_eq!(tlv2, $tlv2);
403                                         assert_eq!(tlv3, $tlv3);
404                                         assert_eq!(tlv4, $tlv4);
405                                 } else { panic!(); }
406                         }
407                 }
408
409                 do_test!(concat!(""), None, None, None, None);
410                 do_test!(concat!("21", "00"), None, None, None, None);
411                 do_test!(concat!("fd0201", "00"), None, None, None, None);
412                 do_test!(concat!("fd00fd", "00"), None, None, None, None);
413                 do_test!(concat!("fd00ff", "00"), None, None, None, None);
414                 do_test!(concat!("fe02000001", "00"), None, None, None, None);
415                 do_test!(concat!("ff0200000000000001", "00"), None, None, None, None);
416
417                 do_test!(concat!("01", "00"), Some(0), None, None, None);
418                 do_test!(concat!("01", "01", "01"), Some(1), None, None, None);
419                 do_test!(concat!("01", "02", "0100"), Some(256), None, None, None);
420                 do_test!(concat!("01", "03", "010000"), Some(65536), None, None, None);
421                 do_test!(concat!("01", "04", "01000000"), Some(16777216), None, None, None);
422                 do_test!(concat!("01", "05", "0100000000"), Some(4294967296), None, None, None);
423                 do_test!(concat!("01", "06", "010000000000"), Some(1099511627776), None, None, None);
424                 do_test!(concat!("01", "07", "01000000000000"), Some(281474976710656), None, None, None);
425                 do_test!(concat!("01", "08", "0100000000000000"), Some(72057594037927936), None, None, None);
426                 do_test!(concat!("02", "08", "0000000000000226"), None, Some((0 << 30) | (0 << 5) | (550 << 0)), None, None);
427                 do_test!(concat!("03", "31", "023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb00000000000000010000000000000002"),
428                         None, None, Some((
429                                 PublicKey::from_slice(&::hex::decode("023da092f6980e58d2c037173180e9a465476026ee50f96695963e8efe436f54eb").unwrap()[..]).unwrap(), 1, 2)),
430                         None);
431                 do_test!(concat!("fd00fe", "02", "0226"), None, None, None, Some(550));
432         }
433
434         fn do_simple_test_tlv_write() -> Result<(), ::std::io::Error> {
435                 let mut stream = VecWriter(Vec::new());
436
437                 stream.0.clear();
438                 encode_varint_length_prefixed_tlv!(&mut stream, { (1, 1u8) });
439                 assert_eq!(stream.0, ::hex::decode("03010101").unwrap());
440
441                 stream.0.clear();
442                 encode_varint_length_prefixed_tlv!(&mut stream, { (4, 0xabcdu16) });
443                 assert_eq!(stream.0, ::hex::decode("040402abcd").unwrap());
444
445                 stream.0.clear();
446                 encode_varint_length_prefixed_tlv!(&mut stream, { (0xff, 0xabcdu16) });
447                 assert_eq!(stream.0, ::hex::decode("06fd00ff02abcd").unwrap());
448
449                 stream.0.clear();
450                 encode_varint_length_prefixed_tlv!(&mut stream, { (0, 1u64), (0xff, HighZeroBytesDroppedVarInt(0u64)) });
451                 assert_eq!(stream.0, ::hex::decode("0e00080000000000000001fd00ff00").unwrap());
452
453                 Ok(())
454         }
455
456         #[test]
457         fn simple_test_tlv_write() {
458                 do_simple_test_tlv_write().unwrap();
459         }
460 }