Merge pull request #167 from TheBlueMatt/2018-09-dup-htlc
[rust-lightning] / src / util / ser.rs
1 use std::result::Result;
2 use std::io::{Read, Write};
3 use std::collections::HashMap;
4 use std::hash::Hash;
5 use std::mem;
6
7 use secp256k1::{Secp256k1, Signature};
8 use secp256k1::key::PublicKey;
9 use bitcoin::util::hash::Sha256dHash;
10 use bitcoin::blockdata::script::Script;
11 use std::marker::Sized;
12 use ln::msgs::DecodeError;
13
14 use util::byte_utils::{be64_to_array, be32_to_array, be16_to_array, slice_to_be16, slice_to_be32, slice_to_be64};
15
16 const MAX_BUF_SIZE: usize = 16 * 1024;
17
18 pub struct Writer<W> { writer: W }
19 pub struct Reader<R> { reader: R }
20
21 pub trait Writeable<W: Write> {
22         fn write(&self, writer: &mut Writer<W>) -> Result<(), DecodeError>;
23 }
24
25 pub trait Readable<R>
26         where Self: Sized,
27               R: Read
28 {
29         fn read(reader: &mut Reader<R>) -> Result<Self, DecodeError>;
30 }
31
32 impl<W: Write> Writer<W> {
33         pub fn new(writer: W) -> Writer<W> {
34                 return Writer { writer }
35         }
36         pub fn into_inner(self) -> W { self.writer }
37         pub fn get_ref(&self) -> &W { &self.writer }
38         fn write_u64(&mut self, v: u64) -> Result<(), DecodeError> {
39                 Ok(self.writer.write_all(&be64_to_array(v))?)
40         }
41         fn write_u32(&mut self, v: u32) -> Result<(), DecodeError> {
42                 Ok(self.writer.write_all(&be32_to_array(v))?)
43         }
44         fn write_u16(&mut self, v: u16) -> Result<(), DecodeError> {
45                 Ok(self.writer.write_all(&be16_to_array(v))?)
46         }
47         fn write_u8(&mut self, v: u8) -> Result<(), DecodeError> {
48                 Ok(self.writer.write_all(&[v])?)
49         }
50         fn write_bool(&mut self, v: bool) -> Result<(), DecodeError> {
51                 Ok(self.writer.write_all(&[if v {1} else {0}])?)
52         }
53         pub fn write_all(&mut self, v: &[u8]) -> Result<(), DecodeError> {
54                 Ok(self.writer.write_all(v)?)
55         }
56 }
57
58 impl<R: Read> Reader<R> {
59         pub fn new(reader: R) -> Reader<R> {
60                 return Reader { reader }
61         }
62         pub fn into_inner(self) -> R { self.reader }
63         pub fn get_ref(&self) -> &R { &self.reader }
64
65         fn read_u64(&mut self) -> Result<u64, DecodeError> {
66                 let mut buf = [0; 8];
67                 self.reader.read_exact(&mut buf)?;
68                 Ok(slice_to_be64(&buf))
69         }
70
71         fn read_u32(&mut self) -> Result<u32, DecodeError> {
72                 let mut buf = [0; 4];
73                 self.reader.read_exact(&mut buf)?;
74                 Ok(slice_to_be32(&buf))
75         }
76
77         fn read_u16(&mut self) -> Result<u16, DecodeError> {
78                 let mut buf = [0; 2];
79                 self.reader.read_exact(&mut buf)?;
80                 Ok(slice_to_be16(&buf))
81         }
82
83         fn read_u8(&mut self) -> Result<u8, DecodeError> {
84                 let mut buf = [0; 1];
85                 self.reader.read_exact(&mut buf)?;
86                 Ok(buf[0])
87         }
88         fn read_bool(&mut self) -> Result<bool, DecodeError> {
89                 let mut buf = [0; 1];
90                 self.reader.read_exact(&mut buf)?;
91                 if buf[0] != 0 && buf[0] != 1 {
92                         return Err(DecodeError::InvalidValue);
93                 }
94                 Ok(buf[0] == 1)
95         }
96         pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), DecodeError> {
97                 Ok(self.reader.read_exact(buf)?)
98         }
99         pub fn read_to_end(&mut self, buf: &mut Vec<u8>) -> Result<usize, DecodeError> {
100                 Ok(self.reader.read_to_end(buf)?)
101         }
102 }
103
104 macro_rules! impl_writeable_primitive {
105         ($val_type:ty, $meth_write:ident, $meth_read:ident) => {
106                 impl<W:Write> Writeable<W> for $val_type {
107                         #[inline]
108                         fn write(&self, writer: &mut Writer<W>) -> Result<(), DecodeError> {
109                                 writer.$meth_write(*self)
110                         }
111                 }
112                 impl<R:Read> Readable<R> for $val_type {
113                         #[inline]
114                         fn read(reader: &mut Reader<R>) -> Result<$val_type, DecodeError> {
115                                 reader.$meth_read()
116                         }
117                 }
118         }
119 }
120
121 impl_writeable_primitive!(u64, write_u64, read_u64);
122 impl_writeable_primitive!(u32, write_u32, read_u32);
123 impl_writeable_primitive!(u16, write_u16, read_u16);
124 impl_writeable_primitive!(u8, write_u8, read_u8);
125 impl_writeable_primitive!(bool, write_bool, read_bool);
126
127 // u8 arrays
128 macro_rules! impl_array {
129         ( $size:expr ) => (
130                 impl<W> Writeable<W> for [u8; $size]
131                         where W: Write
132                 {
133                         #[inline]
134                         fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
135                                 w.write_all(self)?;
136                                 Ok(())
137                         }
138                 }
139
140                 impl<R> Readable<R> for [u8; $size]
141                         where R: Read
142                 {
143                         #[inline]
144                         fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
145                                 let mut buf = [0u8; $size];
146                                 r.read_exact(&mut buf)?;
147                                 Ok(buf)
148                         }
149                 }
150         );
151 }
152
153 //TODO: performance issue with [u8; size] with impl_array!()
154 impl_array!(32); // for channel id & hmac
155 impl_array!(33); // for PublicKey
156 impl_array!(64); // for Signature
157 impl_array!(1300); // for OnionPacket.hop_data
158
159 // HashMap
160 impl<W, K, V> Writeable<W> for HashMap<K, V>
161         where W: Write,
162               K: Writeable<W> + Eq + Hash,
163               V: Writeable<W>
164 {
165         #[inline]
166         fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
167         (self.len() as u16).write(w)?;
168                 for (key, value) in self.iter() {
169                         key.write(w)?;
170                         value.write(w)?;
171                 }
172                 Ok(())
173         }
174 }
175
176 impl<R, K, V> Readable<R> for HashMap<K, V>
177         where R: Read,
178               K: Readable<R> + Eq + Hash,
179               V: Readable<R>
180 {
181         #[inline]
182         fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
183                 let len: u16 = Readable::read(r)?;
184                 let mut ret = HashMap::with_capacity(len as usize);
185                 for _ in 0..len {
186                                 ret.insert(K::read(r)?, V::read(r)?);
187                 }
188                 Ok(ret)
189         }
190 }
191
192 // Vectors
193 impl<W: Write, T: Writeable<W>> Writeable<W> for Vec<T> {
194         #[inline]
195         fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
196                 let byte_size = (self.len() as usize)
197                                 .checked_mul(mem::size_of::<T>())
198                                 .ok_or(DecodeError::BadLengthDescriptor)?;
199                 if byte_size > MAX_BUF_SIZE {
200                                 return Err(DecodeError::BadLengthDescriptor);
201                 }
202                 (self.len() as u16).write(w)?;
203                 // performance with Vec<u8>
204                 for e in self.iter() {
205                         e.write(w)?;
206                 }
207                 Ok(())
208         }
209 }
210
211 impl<R: Read, T: Readable<R>> Readable<R> for Vec<T> {
212         #[inline]
213         fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
214                         let len: u16 = Readable::read(r)?;
215                         let byte_size = (len as usize)
216                                         .checked_mul(mem::size_of::<T>())
217                                         .ok_or(DecodeError::BadLengthDescriptor)?;
218                         if byte_size > MAX_BUF_SIZE {
219                                         return Err(DecodeError::BadLengthDescriptor);
220                         }
221                         let mut ret = Vec::with_capacity(len as usize);
222                         for _ in 0..len { ret.push(T::read(r)?); }
223                         Ok(ret)
224         }
225 }
226
227 impl<W: Write> Writeable<W> for Script {
228         fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
229                 self.to_bytes().to_vec().write(w)
230         }
231 }
232
233 impl<R: Read> Readable<R> for Script {
234         fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
235                 let len = <u16 as Readable<R>>::read(r)? as usize;
236                 let mut buf = vec![0; len];
237                 r.read_exact(&mut buf)?;
238                 Ok(Script::from(buf))
239         }
240 }
241
242 impl<W: Write> Writeable<W> for Option<Script> {
243         fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
244                 if let &Some(ref script) = self {
245                         script.write(w)?;
246                 }
247                 Ok(())
248         }
249 }
250
251 impl<R: Read> Readable<R> for Option<Script> {
252         fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
253                 match <u16 as Readable<R>>::read(r) {
254                         Ok(len) => {
255                                 let mut buf = vec![0; len as usize];
256                                 r.read_exact(&mut buf)?;
257                                 Ok(Some(Script::from(buf)))
258                         },
259                         Err(DecodeError::ShortRead) => Ok(None),
260                         Err(e) => Err(e)
261                 }
262         }
263 }
264
265 impl<W: Write> Writeable<W> for PublicKey {
266         fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
267                 self.serialize().write(w)
268         }
269 }
270
271 impl<R: Read> Readable<R> for PublicKey {
272         fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
273                 let buf: [u8; 33] = Readable::read(r)?;
274                 match PublicKey::from_slice(&Secp256k1::without_caps(), &buf) {
275                         Ok(key) => Ok(key),
276                         Err(_) => return Err(DecodeError::BadPublicKey),
277                 }
278         }
279 }
280
281 impl<W: Write> Writeable<W> for Sha256dHash {
282         fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
283                 self.as_bytes().write(w)
284         }
285 }
286
287 impl<R: Read> Readable<R> for Sha256dHash {
288         fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
289                 let buf: [u8; 32] = Readable::read(r)?;
290                 Ok(From::from(&buf[..]))
291         }
292 }
293
294 impl<W: Write> Writeable<W> for Signature {
295         fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
296                 self.serialize_compact(&Secp256k1::without_caps()).write(w)
297         }
298 }
299
300 impl<R: Read> Readable<R> for Signature {
301         fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
302                 let buf: [u8; 64] = Readable::read(r)?;
303                 match Signature::from_compact(&Secp256k1::without_caps(), &buf) {
304                         Ok(sig) => Ok(sig),
305                         Err(_) => return Err(DecodeError::BadSignature),
306                 }
307         }
308 }
309
310 macro_rules! impl_writeable {
311         ($st:ident, {$($field:ident),*}) => {
312                 impl<W: ::std::io::Write> Writeable<W> for $st {
313                         fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
314                                 $( self.$field.write(w)?; )*
315                                 Ok(())
316                         }
317                 }
318
319                 impl<R: ::std::io::Read> Readable<R> for $st {
320                         fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
321                                 Ok(Self {
322                                         $($field: Readable::read(r)?),*
323                                 })
324                         }
325                 }
326         }
327 }