Merge pull request #1273 from jkczyz/2022-01-invoice-expiry
[rust-lightning] / lightning / src / util / ser.rs
index b27cf4b348c3d98fc805dd7478dbe640910d80f4..6a52e5e1c898230c0d31a4cb9ce20c6acc236076 100644 (file)
@@ -27,6 +27,7 @@ use bitcoin::consensus::Encodable;
 use bitcoin::hashes::sha256d::Hash as Sha256dHash;
 use bitcoin::hash_types::{Txid, BlockHash};
 use core::marker::Sized;
+use core::time::Duration;
 use ln::msgs::DecodeError;
 use ln::{PaymentPreimage, PaymentHash, PaymentSecret};
 
@@ -35,18 +36,13 @@ use util::byte_utils::{be48_to_array, slice_to_be48};
 /// serialization buffer size
 pub const MAX_BUF_SIZE: usize = 64 * 1024;
 
-/// A trait that is similar to std::io::Write but has one extra function which can be used to size
-/// buffers being written into.
-/// An impl is provided for any type that also impls std::io::Write which simply ignores size
-/// hints.
+/// A simplified version of std::io::Write that exists largely for backwards compatibility.
+/// An impl is provided for any type that also impls std::io::Write.
 ///
 /// (C-not exported) as we only export serialization to/from byte arrays instead
 pub trait Writer {
        /// Writes the given buf out. See std::io::Write::write_all for more
        fn write_all(&mut self, buf: &[u8]) -> Result<(), io::Error>;
-       /// Hints that data of the given size is about the be written. This may not always be called
-       /// prior to data being written and may be safely ignored.
-       fn size_hint(&mut self, size: usize);
 }
 
 impl<W: Write> Writer for W {
@@ -54,8 +50,6 @@ impl<W: Write> Writer for W {
        fn write_all(&mut self, buf: &[u8]) -> Result<(), io::Error> {
                <Self as io::Write>::write_all(self, buf)
        }
-       #[inline]
-       fn size_hint(&mut self, _size: usize) { }
 }
 
 pub(crate) struct WriterWriteAdaptor<'a, W: Writer + 'a>(pub &'a mut W);
@@ -82,10 +76,6 @@ impl Writer for VecWriter {
                self.0.extend_from_slice(buf);
                Ok(())
        }
-       #[inline]
-       fn size_hint(&mut self, size: usize) {
-               self.0.reserve_exact(size);
-       }
 }
 
 /// Writer that only tracks the amount of data written - useful if you need to calculate the length
@@ -97,8 +87,6 @@ impl Writer for LengthCalculatingWriter {
                self.0 += buf.len();
                Ok(())
        }
-       #[inline]
-       fn size_hint(&mut self, _size: usize) {}
 }
 
 /// Essentially std::io::Take but a bit simpler and with a method to walk the underlying stream
@@ -241,6 +229,13 @@ pub trait MaybeReadable
        fn read<R: Read>(reader: &mut R) -> Result<Option<Self>, DecodeError>;
 }
 
+impl<T: Readable> MaybeReadable for T {
+       #[inline]
+       fn read<R: Read>(reader: &mut R) -> Result<Option<T>, DecodeError> {
+               Ok(Some(Readable::read(reader)?))
+       }
+}
+
 pub(crate) struct OptionDeserWrapper<T: Readable>(pub Option<T>);
 impl<T: Readable> Readable for OptionDeserWrapper<T> {
        #[inline]
@@ -262,15 +257,16 @@ impl<'a, T: Writeable> Writeable for VecWriteWrapper<'a, T> {
 }
 
 /// Wrapper to read elements from a given stream until it reaches the end of the stream.
-pub(crate) struct VecReadWrapper<T: Readable>(pub Vec<T>);
-impl<T: Readable> Readable for VecReadWrapper<T> {
+pub(crate) struct VecReadWrapper<T>(pub Vec<T>);
+impl<T: MaybeReadable> Readable for VecReadWrapper<T> {
        #[inline]
        fn read<R: Read>(mut reader: &mut R) -> Result<Self, DecodeError> {
                let mut values = Vec::new();
                loop {
                        let mut track_read = ReadTrackingReader::new(&mut reader);
-                       match Readable::read(&mut track_read) {
-                               Ok(v) => { values.push(v); },
+                       match MaybeReadable::read(&mut track_read) {
+                               Ok(Some(v)) => { values.push(v); },
+                               Ok(None) => { },
                                // If we failed to read any bytes at all, we reached the end of our TLV
                                // stream and have simply exhausted all entries.
                                Err(ref e) if e == &DecodeError::ShortRead && !track_read.have_read => break,
@@ -479,10 +475,9 @@ macro_rules! impl_array {
        );
 }
 
-//TODO: performance issue with [u8; size] with impl_array!()
 impl_array!(3); // for rgb
 impl_array!(4); // for IPv4
-impl_array!(10); // for OnionV2
+impl_array!(12); // for OnionV2
 impl_array!(16); // for IPv6
 impl_array!(32); // for channel id & hmac
 impl_array!(PUBLIC_KEY_SIZE); // for PublicKey
@@ -507,14 +502,50 @@ impl<K, V> Writeable for HashMap<K, V>
 
 impl<K, V> Readable for HashMap<K, V>
        where K: Readable + Eq + Hash,
-             V: Readable
+             V: MaybeReadable
 {
        #[inline]
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
                let len: u16 = Readable::read(r)?;
                let mut ret = HashMap::with_capacity(len as usize);
                for _ in 0..len {
-                       ret.insert(K::read(r)?, V::read(r)?);
+                       let k = K::read(r)?;
+                       let v_opt = V::read(r)?;
+                       if let Some(v) = v_opt {
+                               if ret.insert(k, v).is_some() {
+                                       return Err(DecodeError::InvalidValue);
+                               }
+                       }
+               }
+               Ok(ret)
+       }
+}
+
+// HashSet
+impl<T> Writeable for HashSet<T>
+where T: Writeable + Eq + Hash
+{
+       #[inline]
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               (self.len() as u16).write(w)?;
+               for item in self.iter() {
+                       item.write(w)?;
+               }
+               Ok(())
+       }
+}
+
+impl<T> Readable for HashSet<T>
+where T: Readable + Eq + Hash
+{
+       #[inline]
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let len: u16 = Readable::read(r)?;
+               let mut ret = HashSet::with_capacity(len as usize);
+               for _ in 0..len {
+                       if !ret.insert(T::read(r)?) {
+                               return Err(DecodeError::InvalidValue)
+                       }
                }
                Ok(ret)
        }
@@ -561,7 +592,7 @@ impl Readable for Vec<Signature> {
                        return Err(DecodeError::BadLengthDescriptor);
                }
                let mut ret = Vec::with_capacity(len as usize);
-               for _ in 0..len { ret.push(Signature::read(r)?); }
+               for _ in 0..len { ret.push(Readable::read(r)?); }
                Ok(ret)
        }
 }
@@ -726,7 +757,8 @@ impl<T: Writeable> Writeable for Option<T> {
 impl<T: Readable> Readable for Option<T>
 {
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
-               match BigSize::read(r)?.0 {
+               let len: BigSize = Readable::read(r)?;
+               match len.0 {
                        0 => Ok(None),
                        len => {
                                let mut reader = FixedLengthReader::new(r, len - 1);
@@ -852,3 +884,46 @@ impl<A: Writeable, B: Writeable, C: Writeable> Writeable for (A, B, C) {
                self.2.write(w)
        }
 }
+
+impl Writeable for () {
+       fn write<W: Writer>(&self, _: &mut W) -> Result<(), io::Error> {
+               Ok(())
+       }
+}
+impl Readable for () {
+       fn read<R: Read>(_r: &mut R) -> Result<Self, DecodeError> {
+               Ok(())
+       }
+}
+
+impl Writeable for String {
+       #[inline]
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               (self.len() as u16).write(w)?;
+               w.write_all(self.as_bytes())
+       }
+}
+impl Readable for String {
+       #[inline]
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let v: Vec<u8> = Readable::read(r)?;
+               let ret = String::from_utf8(v).map_err(|_| DecodeError::InvalidValue)?;
+               Ok(ret)
+       }
+}
+
+impl Writeable for Duration {
+       #[inline]
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.as_secs().write(w)?;
+               self.subsec_nanos().write(w)
+       }
+}
+impl Readable for Duration {
+       #[inline]
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let secs = Readable::read(r)?;
+               let nanos = Readable::read(r)?;
+               Ok(Duration::new(secs, nanos))
+       }
+}