Add/announce features for payment_secret and basic_mpp
[rust-lightning] / lightning / src / util / ser_macros.rs
1 macro_rules! encode_tlv {
2         ($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
3                 use bitcoin::consensus::Encodable;
4                 use bitcoin::consensus::encode::{Error, VarInt};
5                 use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
6                 $(
7                         VarInt($type).consensus_encode(WriterWriteAdaptor($stream))
8                                 .map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
9                         let mut len_calc = LengthCalculatingWriter(0);
10                         $field.write(&mut len_calc)?;
11                         VarInt(len_calc.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
12                                 .map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
13                         $field.write($stream)?;
14                 )*
15         } }
16 }
17
18 macro_rules! encode_varint_length_prefixed_tlv {
19         ($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
20                 use bitcoin::consensus::Encodable;
21                 use bitcoin::consensus::encode::{Error, VarInt};
22                 use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
23                 let mut len = LengthCalculatingWriter(0);
24                 encode_tlv!(&mut len, {
25                         $(($type, $field)),*
26                 });
27                 VarInt(len.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
28                         .map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
29                 encode_tlv!($stream, {
30                         $(($type, $field)),*
31                 });
32         } }
33 }
34
35 macro_rules! decode_tlv {
36         ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),*}, {$(($type: expr, $field: ident)),*}) => { {
37                 use ln::msgs::DecodeError;
38                 let mut max_type: u64 = 0;
39                 'tlv_read: loop {
40                         use bitcoin::consensus::encode;
41                         use util::ser;
42                         use std;
43
44                         let typ: encode::VarInt = match encode::Decodable::consensus_decode($stream) {
45                                 Err(encode::Error::Io(ref ioe)) if ioe.kind() == std::io::ErrorKind::UnexpectedEof
46                                         => break 'tlv_read,
47                                 Err(encode::Error::Io(ioe)) => Err(DecodeError::from(ioe))?,
48                                 Err(_) => Err(DecodeError::InvalidValue)?,
49                                 Ok(t) => t,
50                         };
51                         if typ.0 == std::u64::MAX || typ.0 + 1 <= max_type {
52                                 Err(DecodeError::InvalidValue)?
53                         }
54                         $(if max_type < $reqtype + 1 && typ.0 > $reqtype {
55                                 Err(DecodeError::InvalidValue)?
56                         })*
57                         max_type = typ.0 + 1;
58
59                         let length: encode::VarInt = encode::Decodable::consensus_decode($stream)
60                                 .map_err(|e| match e {
61                                         encode::Error::Io(ioe) => DecodeError::from(ioe),
62                                         _ => DecodeError::InvalidValue
63                                 })?;
64                         let mut s = ser::FixedLengthReader {
65                                 read: $stream,
66                                 read_len: 0,
67                                 max_len: length.0,
68                         };
69                         match typ.0 {
70                                 $($reqtype => {
71                                         $reqfield = ser::Readable::read(&mut s)?;
72                                 },)*
73                                 $($type => {
74                                         $field = Some(ser::Readable::read(&mut s)?);
75                                 },)*
76                                 x if x % 2 == 0 => {
77                                         Err(DecodeError::UnknownRequiredFeature)?
78                                 },
79                                 _ => {},
80                         }
81                         s.eat_remaining().map_err(|_| DecodeError::ShortRead)?;
82                 }
83                 $(if max_type < $reqtype + 1 {
84                         Err(DecodeError::InvalidValue)?
85                 })*
86         } }
87 }
88
89 macro_rules! impl_writeable {
90         ($st:ident, $len: expr, {$($field:ident),*}) => {
91                 impl ::util::ser::Writeable for $st {
92                         fn write<W: ::util::ser::Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
93                                 if $len != 0 {
94                                         w.size_hint($len);
95                                 }
96                                 $( self.$field.write(w)?; )*
97                                 Ok(())
98                         }
99                 }
100
101                 impl<R: ::std::io::Read> ::util::ser::Readable<R> for $st {
102                         fn read(r: &mut R) -> Result<Self, ::ln::msgs::DecodeError> {
103                                 Ok(Self {
104                                         $($field: ::util::ser::Readable::read(r)?),*
105                                 })
106                         }
107                 }
108         }
109 }
110 macro_rules! impl_writeable_len_match {
111         ($st:ident, {$({$m: pat, $l: expr}),*}, {$($field:ident),*}) => {
112                 impl Writeable for $st {
113                         fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
114                                 w.size_hint(match *self {
115                                         $($m => $l,)*
116                                 });
117                                 $( self.$field.write(w)?; )*
118                                 Ok(())
119                         }
120                 }
121
122                 impl<R: ::std::io::Read> Readable<R> for $st {
123                         fn read(r: &mut R) -> Result<Self, DecodeError> {
124                                 Ok(Self {
125                                         $($field: Readable::read(r)?),*
126                                 })
127                         }
128                 }
129         }
130 }
131
132 #[cfg(test)]
133 mod tests {
134         use std::io::Cursor;
135         use ln::msgs::DecodeError;
136
137         fn tlv_reader(s: &[u8]) -> Result<(u64, u32, Option<u32>), DecodeError> {
138                 let mut s = Cursor::new(s);
139                 let mut a: u64 = 0;
140                 let mut b: u32 = 0;
141                 let mut c: Option<u32> = None;
142                 decode_tlv!(&mut s, {(2, a), (3, b)}, {(4, c)});
143                 Ok((a, b, c))
144         }
145         #[test]
146         fn test_tlv() {
147                 // Value for 3 is longer than we expect, but that's ok...
148                 assert_eq!(tlv_reader(&::hex::decode(
149                                 concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d")
150                                 ).unwrap()[..]).unwrap(),
151                         (0xdeadbeef1badbeef, 0xdeadbeef, None));
152                 // ...even if there's something afterwards
153                 assert_eq!(tlv_reader(&::hex::decode(
154                                 concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d", "0404ffffffff")
155                                 ).unwrap()[..]).unwrap(),
156                         (0xdeadbeef1badbeef, 0xdeadbeef, Some(0xffffffff)));
157                 // ...but not if that extra length is missing
158                 if let Err(DecodeError::ShortRead) = tlv_reader(&::hex::decode(
159                                 concat!("0100", "0208deadbeef1badbeef", "0308deadbeef")
160                                 ).unwrap()[..]) {
161                 } else { panic!(); }
162
163                 // If they're out of order that's also bad
164                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
165                                 concat!("0100", "0304deadbeef", "0208deadbeef1badbeef")
166                                 ).unwrap()[..]) {
167                 } else { panic!(); }
168                 // ...even if its some field we don't understand
169                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
170                                 concat!("0208deadbeef1badbeef", "0100", "0304deadbeef")
171                                 ).unwrap()[..]) {
172                 } else { panic!(); }
173
174                 // It's also bad if they included even fields we don't understand
175                 if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&::hex::decode(
176                                 concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600")
177                                 ).unwrap()[..]) {
178                 } else { panic!(); }
179                 // ... or if they're missing fields we need
180                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
181                                 concat!("0100", "0208deadbeef1badbeef")
182                                 ).unwrap()[..]) {
183                 } else { panic!(); }
184                 // ... even if that field is even
185                 if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
186                                 concat!("0304deadbeef", "0500")
187                                 ).unwrap()[..]) {
188                 } else { panic!(); }
189
190                 // But usually things are pretty much what we expect:
191                 assert_eq!(tlv_reader(&::hex::decode(
192                                 concat!("0208deadbeef1badbeef", "03041bad1dea")
193                                 ).unwrap()[..]).unwrap(),
194                         (0xdeadbeef1badbeef, 0x1bad1dea, None));
195                 assert_eq!(tlv_reader(&::hex::decode(
196                                 concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304")
197                                 ).unwrap()[..]).unwrap(),
198                         (0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
199         }
200 }