X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Fser.rs;h=928cc61946e46f358a70e82672b06bda1dec4867;hb=4155f54716f6bc8632d2a501d22c51a2545670b1;hp=a656604d458c492c17e8fc9a206c20afe4faa5ef;hpb=3bf2f7189c3afbfb269eef9c923c307ff64c6c79;p=rust-lightning diff --git a/lightning/src/util/ser.rs b/lightning/src/util/ser.rs index a656604d..928cc619 100644 --- a/lightning/src/util/ser.rs +++ b/lightning/src/util/ser.rs @@ -22,6 +22,8 @@ use core::cmp; use core::convert::TryFrom; use core::ops::Deref; +use alloc::collections::BTreeMap; + use bitcoin::secp256k1::{PublicKey, SecretKey}; use bitcoin::secp256k1::constants::{PUBLIC_KEY_SIZE, SECRET_KEY_SIZE, COMPACT_SIGNATURE_SIZE, SCHNORR_SIGNATURE_SIZE}; use bitcoin::secp256k1::ecdsa; @@ -381,6 +383,40 @@ impl Readable for BigSize { } } +/// The lightning protocol uses u16s for lengths in most cases. As our serialization framework +/// primarily targets that, we must as well. However, because we may serialize objects that have +/// more than 65K entries, we need to be able to store larger values. Thus, we define a variable +/// length integer here that is backwards-compatible for values < 0xffff. We treat 0xffff as +/// "read eight more bytes". +/// +/// To ensure we only have one valid encoding per value, we add 0xffff to values written as eight +/// bytes. Thus, 0xfffe is serialized as 0xfffe, whereas 0xffff is serialized as +/// 0xffff0000000000000000 (i.e. read-eight-bytes then zero). +struct CollectionLength(pub u64); +impl Writeable for CollectionLength { + #[inline] + fn write(&self, writer: &mut W) -> Result<(), io::Error> { + if self.0 < 0xffff { + (self.0 as u16).write(writer) + } else { + 0xffffu16.write(writer)?; + (self.0 - 0xffff).write(writer) + } + } +} + +impl Readable for CollectionLength { + #[inline] + fn read(r: &mut R) -> Result { + let mut val: u64 = ::read(r)? as u64; + if val == 0xffff { + val = ::read(r)? + .checked_add(0xffff).ok_or(DecodeError::InvalidValue)?; + } + Ok(CollectionLength(val)) + } +} + /// In TLV we occasionally send fields which only consist of, or potentially end with, a /// variable-length integer which is simply truncated by skipping high zero bytes. This type /// encapsulates such integers implementing [`Readable`]/[`Writeable`] for them. @@ -588,50 +624,74 @@ impl<'a, T> From<&'a Vec> for WithoutLength<&'a Vec> { fn from(v: &'a Vec) -> Self { Self(v) } } -// HashMap -impl Writeable for HashMap - where K: Writeable + Eq + Hash, - V: Writeable -{ +#[derive(Debug)] +pub(crate) struct Iterable<'a, I: Iterator + Clone, T: 'a>(pub I); + +impl<'a, I: Iterator + Clone, T: 'a + Writeable> Writeable for Iterable<'a, I, T> { #[inline] - fn write(&self, w: &mut W) -> Result<(), io::Error> { - (self.len() as u16).write(w)?; - for (key, value) in self.iter() { - key.write(w)?; - value.write(w)?; + fn write(&self, writer: &mut W) -> Result<(), io::Error> { + for ref v in self.0.clone() { + v.write(writer)?; } Ok(()) } } -impl Readable for HashMap - where K: Readable + Eq + Hash, - V: MaybeReadable -{ - #[inline] - fn read(r: &mut R) -> Result { - let len: u16 = Readable::read(r)?; - let mut ret = HashMap::with_capacity(len as usize); - for _ in 0..len { - let k = K::read(r)?; - let v_opt = V::read(r)?; - if let Some(v) = v_opt { - if ret.insert(k, v).is_some() { - return Err(DecodeError::InvalidValue); +#[cfg(test)] +impl<'a, I: Iterator + Clone, T: 'a + PartialEq> PartialEq for Iterable<'a, I, T> { + fn eq(&self, other: &Self) -> bool { + self.0.clone().collect::>() == other.0.clone().collect::>() + } +} + +macro_rules! impl_for_map { + ($ty: ident, $keybound: ident, $constr: expr) => { + impl Writeable for $ty + where K: Writeable + Eq + $keybound, V: Writeable + { + #[inline] + fn write(&self, w: &mut W) -> Result<(), io::Error> { + CollectionLength(self.len() as u64).write(w)?; + for (key, value) in self.iter() { + key.write(w)?; + value.write(w)?; } + Ok(()) + } + } + + impl Readable for $ty + where K: Readable + Eq + $keybound, V: MaybeReadable + { + #[inline] + fn read(r: &mut R) -> Result { + let len: CollectionLength = Readable::read(r)?; + let mut ret = $constr(len.0 as usize); + for _ in 0..len.0 { + let k = K::read(r)?; + let v_opt = V::read(r)?; + if let Some(v) = v_opt { + if ret.insert(k, v).is_some() { + return Err(DecodeError::InvalidValue); + } + } + } + Ok(ret) } } - Ok(ret) } } +impl_for_map!(BTreeMap, Ord, |_| BTreeMap::new()); +impl_for_map!(HashMap, Hash, |len| HashMap::with_capacity(len)); + // HashSet impl Writeable for HashSet where T: Writeable + Eq + Hash { #[inline] fn write(&self, w: &mut W) -> Result<(), io::Error> { - (self.len() as u16).write(w)?; + CollectionLength(self.len() as u64).write(w)?; for item in self.iter() { item.write(w)?; } @@ -644,9 +704,9 @@ where T: Readable + Eq + Hash { #[inline] fn read(r: &mut R) -> Result { - let len: u16 = Readable::read(r)?; - let mut ret = HashSet::with_capacity(len as usize); - for _ in 0..len { + let len: CollectionLength = Readable::read(r)?; + let mut ret = HashSet::with_capacity(cmp::min(len.0 as usize, MAX_BUF_SIZE / core::mem::size_of::())); + for _ in 0..len.0 { if !ret.insert(T::read(r)?) { return Err(DecodeError::InvalidValue) } @@ -656,51 +716,62 @@ where T: Readable + Eq + Hash } // Vectors -impl Writeable for Vec { - #[inline] - fn write(&self, w: &mut W) -> Result<(), io::Error> { - (self.len() as u16).write(w)?; - w.write_all(&self) - } -} +macro_rules! impl_for_vec { + ($ty: ty $(, $name: ident)*) => { + impl<$($name : Writeable),*> Writeable for Vec<$ty> { + #[inline] + fn write(&self, w: &mut W) -> Result<(), io::Error> { + CollectionLength(self.len() as u64).write(w)?; + for elem in self.iter() { + elem.write(w)?; + } + Ok(()) + } + } -impl Readable for Vec { - #[inline] - fn read(r: &mut R) -> Result { - let len: u16 = Readable::read(r)?; - let mut ret = Vec::with_capacity(len as usize); - ret.resize(len as usize, 0); - r.read_exact(&mut ret)?; - Ok(ret) + impl<$($name : Readable),*> Readable for Vec<$ty> { + #[inline] + fn read(r: &mut R) -> Result { + let len: CollectionLength = Readable::read(r)?; + let mut ret = Vec::with_capacity(cmp::min(len.0 as usize, MAX_BUF_SIZE / core::mem::size_of::<$ty>())); + for _ in 0..len.0 { + if let Some(val) = MaybeReadable::read(r)? { + ret.push(val); + } + } + Ok(ret) + } + } } } -impl Writeable for Vec { + +impl Writeable for Vec { #[inline] fn write(&self, w: &mut W) -> Result<(), io::Error> { - (self.len() as u16).write(w)?; - for e in self.iter() { - e.write(w)?; - } - Ok(()) + CollectionLength(self.len() as u64).write(w)?; + w.write_all(&self) } } -impl Readable for Vec { +impl Readable for Vec { #[inline] fn read(r: &mut R) -> Result { - let len: u16 = Readable::read(r)?; - let byte_size = (len as usize) - .checked_mul(COMPACT_SIGNATURE_SIZE) - .ok_or(DecodeError::BadLengthDescriptor)?; - if byte_size > MAX_BUF_SIZE { - return Err(DecodeError::BadLengthDescriptor); + let mut len: CollectionLength = Readable::read(r)?; + let mut ret = Vec::new(); + while len.0 > 0 { + let readamt = cmp::min(len.0 as usize, MAX_BUF_SIZE); + let readstart = ret.len(); + ret.resize(readstart + readamt, 0); + r.read_exact(&mut ret[readstart..])?; + len.0 -= readamt as u64; } - let mut ret = Vec::with_capacity(len as usize); - for _ in 0..len { ret.push(Readable::read(r)?); } Ok(ret) } } +impl_for_vec!(ecdsa::Signature); +impl_for_vec!((A, B), A, B); + impl Writeable for Script { fn write(&self, w: &mut W) -> Result<(), io::Error> { (self.len() as u16).write(w)?; @@ -1014,6 +1085,24 @@ impl Writeable for (A, B, C) { } } +impl Readable for (A, B, C, D) { + fn read(r: &mut R) -> Result { + let a: A = Readable::read(r)?; + let b: B = Readable::read(r)?; + let c: C = Readable::read(r)?; + let d: D = Readable::read(r)?; + Ok((a, b, c, d)) + } +} +impl Writeable for (A, B, C, D) { + fn write(&self, w: &mut W) -> Result<(), io::Error> { + self.0.write(w)?; + self.1.write(w)?; + self.2.write(w)?; + self.3.write(w) + } +} + impl Writeable for () { fn write(&self, _: &mut W) -> Result<(), io::Error> { Ok(()) @@ -1028,7 +1117,7 @@ impl Readable for () { impl Writeable for String { #[inline] fn write(&self, w: &mut W) -> Result<(), io::Error> { - (self.len() as u16).write(w)?; + CollectionLength(self.len() as u64).write(w)?; w.write_all(self.as_bytes()) } }