Add macros for building TLV (de)serializers.
authorMatt Corallo <git@bluematt.me>
Fri, 27 Dec 2019 22:39:59 +0000 (17:39 -0500)
committerMatt Corallo <git@bluematt.me>
Sat, 25 Jan 2020 22:12:08 +0000 (17:12 -0500)
There's quite a bit of machinery included here, but it neatly
avoids any dynamic allocation during TLV deserialization, and the
calling side looks nice and simple. There's a few new state-tracking
read/write streams, but they should be pretty cheap (just a few
increments/decrements per read/write. The macro-generated code is
pretty nice, though has some redundant if statements (I haven't
checked if they get optimized out yet, but I can't imagine they
don't).

lightning/src/util/ser.rs
lightning/src/util/ser_macros.rs

index 7e4f789097940824a8ebfb3f1bb9f054cf3e145d..fd82330c85b1a7cfea29da454b80e288e1155e75 100644 (file)
@@ -6,6 +6,7 @@ use std::io::{Read, Write};
 use std::collections::HashMap;
 use std::hash::Hash;
 use std::sync::Mutex;
+use std::cmp;
 
 use secp256k1::Signature;
 use secp256k1::key::{PublicKey, SecretKey};
@@ -67,6 +68,46 @@ impl Writer for VecWriter {
        }
 }
 
+pub(crate) struct LengthCalculatingWriter(pub usize);
+impl Writer for LengthCalculatingWriter {
+       #[inline]
+       fn write_all(&mut self, buf: &[u8]) -> Result<(), ::std::io::Error> {
+               self.0 += buf.len();
+               Ok(())
+       }
+       #[inline]
+       fn size_hint(&mut self, _size: usize) {}
+}
+
+/// Essentially std::io::Take but a bit simpler and exposing the amount read at the end, cause we
+/// may need to skip ahead that much at the end.
+pub(crate) struct FixedLengthReader<R: Read> {
+       pub read: R,
+       pub read_len: u64,
+       pub max_len: u64,
+}
+impl<R: Read> FixedLengthReader<R> {
+       pub fn eat_remaining(&mut self) -> Result<(), ::std::io::Error> {
+               while self.read_len != self.max_len {
+                       debug_assert!(self.read_len < self.max_len);
+                       let mut buf = [0; 1024];
+                       let readsz = cmp::min(1024, self.max_len - self.read_len) as usize;
+                       self.read_exact(&mut buf[0..readsz])?;
+               }
+               Ok(())
+       }
+}
+impl<R: Read> Read for FixedLengthReader<R> {
+       fn read(&mut self, dest: &mut [u8]) -> Result<usize, ::std::io::Error> {
+               if dest.len() as u64 > self.max_len - self.read_len {
+                       Ok(0)
+               } else {
+                       self.read_len += dest.len() as u64;
+                       self.read.read(dest)
+               }
+       }
+}
+
 /// A trait that various rust-lightning types implement allowing them to be written out to a Writer
 pub trait Writeable {
        /// Writes self out to the given Writer
index 48e87b3bc2108a35b97282981b8738b5b87a93fb..84f997648ac2a955cf465a5fc78b32bbee7fc6c6 100644 (file)
@@ -1,3 +1,91 @@
+macro_rules! encode_tlv {
+       ($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
+               use bitcoin::consensus::Encodable;
+               use bitcoin::consensus::encode::{Error, VarInt};
+               use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
+               $(
+                       VarInt($type).consensus_encode(WriterWriteAdaptor($stream))
+                               .map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
+                       let mut len_calc = LengthCalculatingWriter(0);
+                       $field.write(&mut len_calc)?;
+                       VarInt(len_calc.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
+                               .map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
+                       $field.write($stream)?;
+               )*
+       } }
+}
+
+macro_rules! encode_varint_length_prefixed_tlv {
+       ($stream: expr, {$(($type: expr, $field: expr)),*}) => { {
+               use bitcoin::consensus::Encodable;
+               use bitcoin::consensus::encode::{Error, VarInt};
+               use util::ser::{WriterWriteAdaptor, LengthCalculatingWriter};
+               let mut len = LengthCalculatingWriter(0);
+               encode_tlv!(&mut len, {
+                       $(($type, $field)),*
+               });
+               VarInt(len.0 as u64).consensus_encode(WriterWriteAdaptor($stream))
+                       .map_err(|e| if let Error::Io(ioe) = e { ioe } else { unreachable!() })?;
+               encode_tlv!($stream, {
+                       $(($type, $field)),*
+               });
+       } }
+}
+
+macro_rules! decode_tlv {
+       ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),*}, {$(($type: expr, $field: ident)),*}) => { {
+               use ln::msgs::DecodeError;
+               let mut max_type: u64 = 0;
+               'tlv_read: loop {
+                       use bitcoin::consensus::encode;
+                       use util::ser;
+                       use std;
+
+                       let typ: encode::VarInt = match encode::Decodable::consensus_decode($stream) {
+                               Err(encode::Error::Io(ref ioe)) if ioe.kind() == std::io::ErrorKind::UnexpectedEof
+                                       => break 'tlv_read,
+                               Err(encode::Error::Io(ioe)) => Err(DecodeError::from(ioe))?,
+                               Err(_) => Err(DecodeError::InvalidValue)?,
+                               Ok(t) => t,
+                       };
+                       if typ.0 == std::u64::MAX || typ.0 + 1 <= max_type {
+                               Err(DecodeError::InvalidValue)?
+                       }
+                       $(if max_type < $reqtype + 1 && typ.0 > $reqtype {
+                               Err(DecodeError::InvalidValue)?
+                       })*
+                       max_type = typ.0 + 1;
+
+                       let length: encode::VarInt = encode::Decodable::consensus_decode($stream)
+                               .map_err(|e| match e {
+                                       encode::Error::Io(ioe) => DecodeError::from(ioe),
+                                       _ => DecodeError::InvalidValue
+                               })?;
+                       let mut s = ser::FixedLengthReader {
+                               read: $stream,
+                               read_len: 0,
+                               max_len: length.0,
+                       };
+                       match typ.0 {
+                               $($reqtype => {
+                                       $reqfield = ser::Readable::read(&mut s)?;
+                               },)*
+                               $($type => {
+                                       $field = Some(ser::Readable::read(&mut s)?);
+                               },)*
+                               x if x % 2 == 0 => {
+                                       Err(DecodeError::UnknownRequiredFeature)?
+                               },
+                               _ => {},
+                       }
+                       s.eat_remaining().map_err(|_| DecodeError::ShortRead)?;
+               }
+               $(if max_type < $reqtype + 1 {
+                       Err(DecodeError::InvalidValue)?
+               })*
+       } }
+}
+
 macro_rules! impl_writeable {
        ($st:ident, $len: expr, {$($field:ident),*}) => {
                impl ::util::ser::Writeable for $st {
@@ -40,3 +128,73 @@ macro_rules! impl_writeable_len_match {
                }
        }
 }
+
+#[cfg(test)]
+mod tests {
+       use std::io::Cursor;
+       use ln::msgs::DecodeError;
+
+       fn tlv_reader(s: &[u8]) -> Result<(u64, u32, Option<u32>), DecodeError> {
+               let mut s = Cursor::new(s);
+               let mut a: u64 = 0;
+               let mut b: u32 = 0;
+               let mut c: Option<u32> = None;
+               decode_tlv!(&mut s, {(2, a), (3, b)}, {(4, c)});
+               Ok((a, b, c))
+       }
+       #[test]
+       fn test_tlv() {
+               // Value for 3 is longer than we expect, but that's ok...
+               assert_eq!(tlv_reader(&::hex::decode(
+                               concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d")
+                               ).unwrap()[..]).unwrap(),
+                       (0xdeadbeef1badbeef, 0xdeadbeef, None));
+               // ...even if there's something afterwards
+               assert_eq!(tlv_reader(&::hex::decode(
+                               concat!("0100", "0208deadbeef1badbeef", "0308deadbeef1badf00d", "0404ffffffff")
+                               ).unwrap()[..]).unwrap(),
+                       (0xdeadbeef1badbeef, 0xdeadbeef, Some(0xffffffff)));
+               // ...but not if that extra length is missing
+               if let Err(DecodeError::ShortRead) = tlv_reader(&::hex::decode(
+                               concat!("0100", "0208deadbeef1badbeef", "0308deadbeef")
+                               ).unwrap()[..]) {
+               } else { panic!(); }
+
+               // If they're out of order that's also bad
+               if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
+                               concat!("0100", "0304deadbeef", "0208deadbeef1badbeef")
+                               ).unwrap()[..]) {
+               } else { panic!(); }
+               // ...even if its some field we don't understand
+               if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
+                               concat!("0208deadbeef1badbeef", "0100", "0304deadbeef")
+                               ).unwrap()[..]) {
+               } else { panic!(); }
+
+               // It's also bad if they included even fields we don't understand
+               if let Err(DecodeError::UnknownRequiredFeature) = tlv_reader(&::hex::decode(
+                               concat!("0100", "0208deadbeef1badbeef", "0304deadbeef", "0600")
+                               ).unwrap()[..]) {
+               } else { panic!(); }
+               // ... or if they're missing fields we need
+               if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
+                               concat!("0100", "0208deadbeef1badbeef")
+                               ).unwrap()[..]) {
+               } else { panic!(); }
+               // ... even if that field is even
+               if let Err(DecodeError::InvalidValue) = tlv_reader(&::hex::decode(
+                               concat!("0304deadbeef", "0500")
+                               ).unwrap()[..]) {
+               } else { panic!(); }
+
+               // But usually things are pretty much what we expect:
+               assert_eq!(tlv_reader(&::hex::decode(
+                               concat!("0208deadbeef1badbeef", "03041bad1dea")
+                               ).unwrap()[..]).unwrap(),
+                       (0xdeadbeef1badbeef, 0x1bad1dea, None));
+               assert_eq!(tlv_reader(&::hex::decode(
+                               concat!("0208deadbeef1badbeef", "03041bad1dea", "040401020304")
+                               ).unwrap()[..]).unwrap(),
+                       (0xdeadbeef1badbeef, 0x1bad1dea, Some(0x01020304)));
+       }
+}