Use a variable-length integer for many serialization wrappers 2023-01-ser-cleanups
authorMatt Corallo <git@bluematt.me>
Mon, 28 Nov 2022 01:00:38 +0000 (01:00 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 17 Jan 2023 21:48:23 +0000 (21:48 +0000)
The lightning protocol uses u16s for lengths in most cases. As our
serialization framework primarily targets that, we must as well.
However, because we may serialize objects that have  more than 65K
entries, we want to be able to store larger values. Thus, we define
a variable length integer here which is backwards-compatible but
treats 0xffff as "read eight more bytes".

This doesn't address any specific known issue, but feels like good
practice just in case.

lightning/src/util/ser.rs

index 44725e722321cab2c99f0888a2a939442090e9b3..84d1a2e084feb65e1472969875a586c81ccd08e5 100644 (file)
@@ -383,6 +383,40 @@ impl Readable for BigSize {
        }
 }
 
+/// The lightning protocol uses u16s for lengths in most cases. As our serialization framework
+/// primarily targets that, we must as well. However, because we may serialize objects that have
+/// more than 65K entries, we need to be able to store larger values. Thus, we define a variable
+/// length integer here that is backwards-compatible for values < 0xffff. We treat 0xffff as
+/// "read eight more bytes".
+///
+/// To ensure we only have one valid encoding per value, we add 0xffff to values written as eight
+/// bytes. Thus, 0xfffe is serialized as 0xfffe, whereas 0xffff is serialized as
+/// 0xffff0000000000000000 (i.e. read-eight-bytes then zero).
+struct CollectionLength(pub u64);
+impl Writeable for CollectionLength {
+       #[inline]
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               if self.0 < 0xffff {
+                       (self.0 as u16).write(writer)
+               } else {
+                       0xffffu16.write(writer)?;
+                       (self.0 - 0xffff).write(writer)
+               }
+       }
+}
+
+impl Readable for CollectionLength {
+       #[inline]
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let mut val: u64 = <u16 as Readable>::read(r)? as u64;
+               if val == 0xffff {
+                       val = <u64 as Readable>::read(r)?
+                               .checked_add(0xffff).ok_or(DecodeError::InvalidValue)?;
+               }
+               Ok(CollectionLength(val))
+       }
+}
+
 /// In TLV we occasionally send fields which only consist of, or potentially end with, a
 /// variable-length integer which is simply truncated by skipping high zero bytes. This type
 /// encapsulates such integers implementing [`Readable`]/[`Writeable`] for them.
@@ -597,7 +631,7 @@ macro_rules! impl_for_map {
                {
                        #[inline]
                        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-                               (self.len() as u16).write(w)?;
+                               CollectionLength(self.len() as u64).write(w)?;
                                for (key, value) in self.iter() {
                                        key.write(w)?;
                                        value.write(w)?;
@@ -611,9 +645,9 @@ macro_rules! impl_for_map {
                {
                        #[inline]
                        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-                               let len: u16 = Readable::read(r)?;
-                               let mut ret = $constr(len as usize);
-                               for _ in 0..len {
+                               let len: CollectionLength = Readable::read(r)?;
+                               let mut ret = $constr(len.0 as usize);
+                               for _ in 0..len.0 {
                                        let k = K::read(r)?;
                                        let v_opt = V::read(r)?;
                                        if let Some(v) = v_opt {
@@ -637,7 +671,7 @@ where T: Writeable + Eq + Hash
 {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
+               CollectionLength(self.len() as u64).write(w)?;
                for item in self.iter() {
                        item.write(w)?;
                }
@@ -650,9 +684,9 @@ where T: Readable + Eq + Hash
 {
        #[inline]
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let len: u16 = Readable::read(r)?;
-               let mut ret = HashSet::with_capacity(len as usize);
-               for _ in 0..len {
+               let len: CollectionLength = Readable::read(r)?;
+               let mut ret = HashSet::with_capacity(cmp::min(len.0 as usize, MAX_BUF_SIZE / core::mem::size_of::<T>()));
+               for _ in 0..len.0 {
                        if !ret.insert(T::read(r)?) {
                                return Err(DecodeError::InvalidValue)
                        }
@@ -667,7 +701,7 @@ macro_rules! impl_for_vec {
                impl<$($name : Writeable),*> Writeable for Vec<$ty> {
                        #[inline]
                        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-                               (self.len() as u16).write(w)?;
+                               CollectionLength(self.len() as u64).write(w)?;
                                for elem in self.iter() {
                                        elem.write(w)?;
                                }
@@ -678,9 +712,9 @@ macro_rules! impl_for_vec {
                impl<$($name : Readable),*> Readable for Vec<$ty> {
                        #[inline]
                        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-                               let len: u16 = Readable::read(r)?;
-                               let mut ret = Vec::with_capacity(cmp::min(len as usize, MAX_BUF_SIZE / core::mem::size_of::<$ty>()));
-                               for _ in 0..len {
+                               let len: CollectionLength = Readable::read(r)?;
+                               let mut ret = Vec::with_capacity(cmp::min(len.0 as usize, MAX_BUF_SIZE / core::mem::size_of::<$ty>()));
+                               for _ in 0..len.0 {
                                        if let Some(val) = MaybeReadable::read(r)? {
                                                ret.push(val);
                                        }
@@ -694,7 +728,7 @@ macro_rules! impl_for_vec {
 impl Writeable for Vec<u8> {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
+               CollectionLength(self.len() as u64).write(w)?;
                w.write_all(&self)
        }
 }
@@ -702,10 +736,15 @@ impl Writeable for Vec<u8> {
 impl Readable for Vec<u8> {
        #[inline]
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-               let len: u16 = Readable::read(r)?;
-               let mut ret = Vec::with_capacity(len as usize);
-               ret.resize(len as usize, 0);
-               r.read_exact(&mut ret)?;
+               let mut len: CollectionLength = Readable::read(r)?;
+               let mut ret = Vec::new();
+               while len.0 > 0 {
+                       let readamt = cmp::min(len.0 as usize, MAX_BUF_SIZE);
+                       let readstart = ret.len();
+                       ret.resize(readstart + readamt, 0);
+                       r.read_exact(&mut ret[readstart..])?;
+                       len.0 -= readamt as u64;
+               }
                Ok(ret)
        }
 }
@@ -1040,7 +1079,7 @@ impl Readable for () {
 impl Writeable for String {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               (self.len() as u16).write(w)?;
+               CollectionLength(self.len() as u64).write(w)?;
                w.write_all(self.as_bytes())
        }
 }