Use a variable-length integer for many serialization wrappers
[rust-lightning] / lightning / src / util / ser.rs
index 44725e722321cab2c99f0888a2a939442090e9b3..84d1a2e084feb65e1472969875a586c81ccd08e5 100644 (file)
@@ -383,6 +383,40 @@ impl Readable for BigSize {
        }
 }
 
+/// The lightning protocol uses u16s for lengths in most cases. As our serialization framework
+/// primarily targets that, we must as well. However, because we may serialize objects that have
+/// more than 65K entries, we need to be able to store larger values. Thus, we define a variable
+/// length integer here that is backwards-compatible for values < 0xffff. We treat 0xffff as
+/// "read eight more bytes".
+///
+/// To ensure we only have one valid encoding per value, we add 0xffff to values written as eight
+/// bytes. Thus, 0xfffe is serialized as 0xfffe, whereas 0xffff is serialized as
+/// 0xffff0000000000000000 (i.e. read-eight-bytes then zero).
+struct CollectionLength(pub u64);
+impl Writeable for CollectionLength {
+       #[inline]
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               if self.0 < 0xffff {
+                       (self.0 as u16).write(writer)
+               } else {
+                       0xffffu16.write(writer)?;
+                       (self.0 - 0xffff).write(writer)
+               }
+       }
+}
+
+impl Readable for CollectionLength {
+       #[inline]
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let mut val: u64 = <u16 as Readable>::read(r)? as u64;
+               if val == 0xffff {
+                       val = <u64 as Readable>::read(r)?
+                               .checked_add(0xffff).ok_or(DecodeError::InvalidValue)?;
+               }
+               Ok(CollectionLength(val))
+       }
+}
+
 /// In TLV we occasionally send fields which only consist of, or potentially end with, a
 /// variable-length integer which is simply truncated by skipping high zero bytes. This type
 /// encapsulates such integers implementing [`Readable`]/[`Writeable`] for them.
@@ -597,7 +631,7 @@ macro_rules! impl_for_map {
                {
                        #[inline]
                        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-                               (self.len() as u16).write(w)?;
+                               CollectionLength(self.len() as u64).write(w)?;
                                for (key, value) in self.iter() {
                                        key.write(w)?;
                                        value.write(w)?;
@@ -611,9 +645,9 @@ macro_rules! impl_for_map {
                {
                        #[inline]
                        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-                               let len: u16 = Readable::read(r)?;
-                               let mut ret = $constr(len as usize);
-                               for _ in 0..len {
+                               let len: CollectionLength = Readable::read(r)?;
+                               let mut ret = $constr(len.0 as usize);
+                               for _ in 0..len.0 {
                                        let k = K::read(r)?;
                                        let v_opt = V::read(r)?;
                                        if let Some(v) = v_opt {
@@ -637,7 +671,7 @@ where T: Writeable + Eq + Hash
 {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
+               CollectionLength(self.len() as u64).write(w)?;
                for item in self.iter() {
                        item.write(w)?;
                }
@@ -650,9 +684,9 @@ where T: Readable + Eq + Hash
 {
        #[inline]
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let len: u16 = Readable::read(r)?;
-               let mut ret = HashSet::with_capacity(len as usize);
-               for _ in 0..len {
+               let len: CollectionLength = Readable::read(r)?;
+               let mut ret = HashSet::with_capacity(cmp::min(len.0 as usize, MAX_BUF_SIZE / core::mem::size_of::<T>()));
+               for _ in 0..len.0 {
                        if !ret.insert(T::read(r)?) {
                                return Err(DecodeError::InvalidValue)
                        }
@@ -667,7 +701,7 @@ macro_rules! impl_for_vec {
                impl<$($name : Writeable),*> Writeable for Vec<$ty> {
                        #[inline]
                        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-                               (self.len() as u16).write(w)?;
+                               CollectionLength(self.len() as u64).write(w)?;
                                for elem in self.iter() {
                                        elem.write(w)?;
                                }
@@ -678,9 +712,9 @@ macro_rules! impl_for_vec {
                impl<$($name : Readable),*> Readable for Vec<$ty> {
                        #[inline]
                        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-                               let len: u16 = Readable::read(r)?;
-                               let mut ret = Vec::with_capacity(cmp::min(len as usize, MAX_BUF_SIZE / core::mem::size_of::<$ty>()));
-                               for _ in 0..len {
+                               let len: CollectionLength = Readable::read(r)?;
+                               let mut ret = Vec::with_capacity(cmp::min(len.0 as usize, MAX_BUF_SIZE / core::mem::size_of::<$ty>()));
+                               for _ in 0..len.0 {
                                        if let Some(val) = MaybeReadable::read(r)? {
                                                ret.push(val);
                                        }
@@ -694,7 +728,7 @@ macro_rules! impl_for_vec {
 impl Writeable for Vec<u8> {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
+               CollectionLength(self.len() as u64).write(w)?;
                w.write_all(&self)
        }
 }
@@ -702,10 +736,15 @@ impl Writeable for Vec<u8> {
 impl Readable for Vec<u8> {
        #[inline]
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let len: u16 = Readable::read(r)?;
-               let mut ret = Vec::with_capacity(len as usize);
-               ret.resize(len as usize, 0);
-               r.read_exact(&mut ret)?;
+               let mut len: CollectionLength = Readable::read(r)?;
+               let mut ret = Vec::new();
+               while len.0 > 0 {
+                       let readamt = cmp::min(len.0 as usize, MAX_BUF_SIZE);
+                       let readstart = ret.len();
+                       ret.resize(readstart + readamt, 0);
+                       r.read_exact(&mut ret[readstart..])?;
+                       len.0 -= readamt as u64;
+               }
                Ok(ret)
        }
 }
@@ -1040,7 +1079,7 @@ impl Readable for () {
 impl Writeable for String {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
+               CollectionLength(self.len() as u64).write(w)?;
                w.write_all(self.as_bytes())
        }
 }