Add size_hint in ser and call size_hint in all message serializers
[rust-lightning] / src / ln / msgs.rs
index b4432c7084d9ebf8340b547ec2ca515700650ca0..2c89acf2387686ac8c74a1197109ef07d1daad2f 100644 (file)
@@ -7,10 +7,11 @@ use bitcoin::blockdata::script::Script;
 
 use std::error::Error;
 use std::{cmp, fmt};
+use std::io::Read;
 use std::result::Result;
 
 use util::{byte_utils, internal_traits, events};
-use util::ser::{Readable, Reader, Writeable, Writer};
+use util::ser::{Readable, Writeable, Writer};
 
 pub trait MsgEncodable {
        fn encode(&self) -> Vec<u8>;
@@ -1703,7 +1704,10 @@ impl MsgDecodable for ErrorMessage {
        }
 }
 
-impl_writeable!(AcceptChannel, {
+impl_writeable_len_match!(AcceptChannel, {
+               {AcceptChannel{ shutdown_scriptpubkey: Some(ref script), ..}, 270 + 2 + script.len()},
+               {_, 270}
+       }, {
        temporary_channel_id,
        dust_limit_satoshis,
        max_htlc_value_in_flight_msat,
@@ -1721,15 +1725,16 @@ impl_writeable!(AcceptChannel, {
        shutdown_scriptpubkey
 });
 
-impl_writeable!(AnnouncementSignatures, {
+impl_writeable!(AnnouncementSignatures, 32+8+64*2, {
        channel_id,
        short_channel_id,
        node_signature,
        bitcoin_signature
 });
 
-impl<W: ::std::io::Write> Writeable<W> for ChannelReestablish {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for ChannelReestablish {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(if self.data_loss_protect.is_some() { 32+2*8+33+32 } else { 32+2*8 });
                self.channel_id.write(w)?;
                self.next_local_commitment_number.write(w)?;
                self.next_remote_commitment_number.write(w)?;
@@ -1741,8 +1746,8 @@ impl<W: ::std::io::Write> Writeable<W> for ChannelReestablish {
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for ChannelReestablish{
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for ChannelReestablish{
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(Self {
                        channel_id: Readable::read(r)?,
                        next_local_commitment_number: Readable::read(r)?,
@@ -1762,55 +1767,68 @@ impl<R: ::std::io::Read> Readable<R> for ChannelReestablish{
        }
 }
 
-impl_writeable!(ClosingSigned, {
+impl_writeable!(ClosingSigned, 32+8+64, {
        channel_id,
        fee_satoshis,
        signature
 });
 
-impl_writeable!(CommitmentSigned, {
+impl_writeable_len_match!(CommitmentSigned, {
+               { CommitmentSigned { ref htlc_signatures, .. }, 32+64+2+htlc_signatures.len()*64 }
+       }, {
        channel_id,
        signature,
        htlc_signatures
 });
 
-impl_writeable!(DecodedOnionErrorPacket, {
+impl_writeable_len_match!(DecodedOnionErrorPacket, {
+               { DecodedOnionErrorPacket { ref failuremsg, ref pad, .. }, 32 + 4 + failuremsg.len() + pad.len() }
+       }, {
        hmac,
        failuremsg,
        pad
 });
 
-impl_writeable!(FundingCreated, {
+impl_writeable!(FundingCreated, 32+32+2+64, {
        temporary_channel_id,
        funding_txid,
        funding_output_index,
        signature
 });
 
-impl_writeable!(FundingSigned, {
+impl_writeable!(FundingSigned, 32+64, {
        channel_id,
        signature
 });
 
-impl_writeable!(FundingLocked, {
+impl_writeable!(FundingLocked, 32+33, {
        channel_id,
        next_per_commitment_point
 });
 
-impl_writeable!(GlobalFeatures, {
+impl_writeable_len_match!(GlobalFeatures, {
+               { GlobalFeatures { ref flags }, flags.len() + 2 }
+       }, {
        flags
 });
 
-impl_writeable!(LocalFeatures, {
+impl_writeable_len_match!(LocalFeatures, {
+               { LocalFeatures { ref flags }, flags.len() + 2 }
+       }, {
        flags
 });
 
-impl_writeable!(Init, {
+impl_writeable_len_match!(Init, {
+               { Init { ref global_features, ref local_features }, global_features.flags.len() + local_features.flags.len() + 4 }
+       }, {
        global_features,
        local_features
 });
 
-impl_writeable!(OpenChannel, {
+impl_writeable_len_match!(OpenChannel, {
+               { OpenChannel { shutdown_scriptpubkey: Some(ref script), .. }, 319 + 2 + script.len() },
+               { OpenChannel { shutdown_scriptpubkey: None, .. }, 319 }
+       }, {
        chain_hash,
        temporary_channel_id,
        funding_satoshis,
@@ -1832,47 +1850,54 @@ impl_writeable!(OpenChannel, {
        shutdown_scriptpubkey
 });
 
-impl_writeable!(RevokeAndACK, {
+impl_writeable!(RevokeAndACK, 32+32+33, {
        channel_id,
        per_commitment_secret,
        next_per_commitment_point
 });
 
-impl_writeable!(Shutdown, {
+impl_writeable_len_match!(Shutdown, {
+               { Shutdown { ref scriptpubkey, .. }, 32 + 2 + scriptpubkey.len() }
+       }, {
        channel_id,
        scriptpubkey
 });
 
-impl_writeable!(UpdateFailHTLC, {
+impl_writeable_len_match!(UpdateFailHTLC, {
+               { UpdateFailHTLC { ref reason, .. }, 32 + 10 + reason.data.len() }
+       }, {
        channel_id,
        htlc_id,
        reason
 });
 
-impl_writeable!(UpdateFailMalformedHTLC, {
+impl_writeable!(UpdateFailMalformedHTLC, 32+8+32+2, {
        channel_id,
        htlc_id,
        sha256_of_onion,
        failure_code
 });
 
-impl_writeable!(UpdateFee, {
+impl_writeable!(UpdateFee, 32+4, {
        channel_id,
        feerate_per_kw
 });
 
-impl_writeable!(UpdateFulfillHTLC, {
+impl_writeable!(UpdateFulfillHTLC, 32+8+32, {
        channel_id,
        htlc_id,
        payment_preimage
 });
 
-impl_writeable!(OnionErrorPacket, {
+impl_writeable_len_match!(OnionErrorPacket, {
+               { OnionErrorPacket { ref data, .. }, 2 + data.len() }
+       }, {
        data
 });
 
-impl<W: ::std::io::Write> Writeable<W> for OnionPacket {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for OnionPacket {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(1 + 33 + 20*65 + 32);
                self.version.write(w)?;
                match self.public_key {
                        Ok(pubkey) => pubkey.write(w)?,
@@ -1884,8 +1909,8 @@ impl<W: ::std::io::Write> Writeable<W> for OnionPacket {
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for OnionPacket {
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for OnionPacket {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(OnionPacket {
                        version: Readable::read(r)?,
                        public_key: {
@@ -1899,7 +1924,7 @@ impl<R: ::std::io::Read> Readable<R> for OnionPacket {
        }
 }
 
-impl_writeable!(UpdateAddHTLC, {
+impl_writeable!(UpdateAddHTLC, 32+8+8+32+4+1366, {
        channel_id,
        htlc_id,
        amount_msat,
@@ -1908,8 +1933,9 @@ impl_writeable!(UpdateAddHTLC, {
        onion_routing_packet
 });
 
-impl<W: ::std::io::Write> Writeable<W> for OnionRealm0HopData {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for OnionRealm0HopData {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(32);
                self.short_channel_id.write(w)?;
                self.amt_to_forward.write(w)?;
                self.outgoing_cltv_value.write(w)?;
@@ -1918,8 +1944,8 @@ impl<W: ::std::io::Write> Writeable<W> for OnionRealm0HopData {
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for OnionRealm0HopData {
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for OnionRealm0HopData {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(OnionRealm0HopData {
                        short_channel_id: Readable::read(r)?,
                        amt_to_forward: Readable::read(r)?,
@@ -1932,8 +1958,9 @@ impl<R: ::std::io::Read> Readable<R> for OnionRealm0HopData {
        }
 }
 
-impl<W: ::std::io::Write> Writeable<W> for OnionHopData {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for OnionHopData {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(65);
                self.realm.write(w)?;
                self.data.write(w)?;
                self.hmac.write(w)?;
@@ -1941,8 +1968,8 @@ impl<W: ::std::io::Write> Writeable<W> for OnionHopData {
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for OnionHopData {
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for OnionHopData {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(OnionHopData {
                        realm: {
                                let r: u8 = Readable::read(r)?;
@@ -1957,16 +1984,17 @@ impl<R: ::std::io::Read> Readable<R> for OnionHopData {
        }
 }
 
-impl<W: ::std::io::Write> Writeable<W> for Ping {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for Ping {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(self.byteslen as usize + 4);
                self.ponglen.write(w)?;
                vec![0u8; self.byteslen as usize].write(w)?; // size-unchecked write
                Ok(())
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for Ping {
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for Ping {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(Ping {
                        ponglen: Readable::read(r)?,
                        byteslen: {
@@ -1978,15 +2006,16 @@ impl<R: ::std::io::Read> Readable<R> for Ping {
        }
 }
 
-impl<W: ::std::io::Write> Writeable<W> for Pong {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for Pong {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(self.byteslen as usize + 2);
                vec![0u8; self.byteslen as usize].write(w)?; // size-unchecked write
                Ok(())
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for Pong {
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for Pong {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(Pong {
                        byteslen: {
                                let byteslen = Readable::read(r)?;
@@ -1997,8 +2026,9 @@ impl<R: ::std::io::Read> Readable<R> for Pong {
        }
 }
 
-impl<W: ::std::io::Write> Writeable<W> for UnsignedChannelAnnouncement {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for UnsignedChannelAnnouncement {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(2 + 2*32 + 4*33 + self.features.flags.len() + self.excess_data.len());
                self.features.write(w)?;
                self.chain_hash.write(w)?;
                self.short_channel_id.write(w)?;
@@ -2011,8 +2041,8 @@ impl<W: ::std::io::Write> Writeable<W> for UnsignedChannelAnnouncement {
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for UnsignedChannelAnnouncement {
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for UnsignedChannelAnnouncement {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(Self {
                        features: {
                                let f: GlobalFeatures = Readable::read(r)?;
@@ -2036,7 +2066,10 @@ impl<R: ::std::io::Read> Readable<R> for UnsignedChannelAnnouncement {
        }
 }
 
-impl_writeable!(ChannelAnnouncement,{
+impl_writeable_len_match!(ChannelAnnouncement, {
+               { ChannelAnnouncement { contents: UnsignedChannelAnnouncement {ref features, ref excess_data, ..}, .. },
+                       2 + 2*32 + 4*33 + features.flags.len() + excess_data.len() + 4*64 }
+       }, {
        node_signature_1,
        node_signature_2,
        bitcoin_signature_1,
@@ -2044,8 +2077,9 @@ impl_writeable!(ChannelAnnouncement,{
        contents
 });
 
-impl<W: ::std::io::Write> Writeable<W> for UnsignedChannelUpdate {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for UnsignedChannelUpdate {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(64 + self.excess_data.len());
                self.chain_hash.write(w)?;
                self.short_channel_id.write(w)?;
                self.timestamp.write(w)?;
@@ -2059,8 +2093,8 @@ impl<W: ::std::io::Write> Writeable<W> for UnsignedChannelUpdate {
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for UnsignedChannelUpdate {
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for UnsignedChannelUpdate {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(Self {
                        chain_hash: Readable::read(r)?,
                        short_channel_id: Readable::read(r)?,
@@ -2079,21 +2113,26 @@ impl<R: ::std::io::Read> Readable<R> for UnsignedChannelUpdate {
        }
 }
 
-impl_writeable!(ChannelUpdate, {
+impl_writeable_len_match!(ChannelUpdate, {
+               { ChannelUpdate { contents: UnsignedChannelUpdate {ref excess_data, ..}, .. },
+                       64 + excess_data.len() + 64 }
+       }, {
        signature,
        contents
 });
 
-impl<W: ::std::io::Write> Writeable<W> for ErrorMessage {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for ErrorMessage {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(32 + 2 + self.data.len());
                self.channel_id.write(w)?;
-               self.data.as_bytes().to_vec().write(w)?; // write with size prefix
+               (self.data.len() as u16).write(w)?;
+               w.write_all(self.data.as_bytes())?;
                Ok(())
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for ErrorMessage {
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for ErrorMessage {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                Ok(Self {
                        channel_id: Readable::read(r)?,
                        data: {
@@ -2110,8 +2149,9 @@ impl<R: ::std::io::Read> Readable<R> for ErrorMessage {
        }
 }
 
-impl<W: ::std::io::Write> Writeable<W> for UnsignedNodeAnnouncement {
-       fn write(&self, w: &mut Writer<W>) -> Result<(), DecodeError> {
+impl<W: Writer> Writeable<W> for UnsignedNodeAnnouncement {
+       fn write(&self, w: &mut W) -> Result<(), DecodeError> {
+               w.size_hint(64 + 76 + self.features.flags.len() + self.addresses.len()*38 + self.excess_address_data.len() + self.excess_data.len());
                self.features.write(w)?;
                self.timestamp.write(w)?;
                self.node_id.write(w)?;
@@ -2156,8 +2196,8 @@ impl<W: ::std::io::Write> Writeable<W> for UnsignedNodeAnnouncement {
        }
 }
 
-impl<R: ::std::io::Read> Readable<R> for UnsignedNodeAnnouncement {
-       fn read(r: &mut Reader<R>) -> Result<Self, DecodeError> {
+impl<R: Read> Readable<R> for UnsignedNodeAnnouncement {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
                let features: GlobalFeatures = Readable::read(r)?;
                if features.requires_unknown_bits() {
                        return Err(DecodeError::UnknownRequiredFeature);
@@ -2278,7 +2318,10 @@ impl<R: ::std::io::Read> Readable<R> for UnsignedNodeAnnouncement {
        }
 }
 
-impl_writeable!(NodeAnnouncement, {
+impl_writeable_len_match!(NodeAnnouncement, {
+               { NodeAnnouncement { contents: UnsignedNodeAnnouncement { ref features, ref addresses, ref excess_address_data, ref excess_data, ..}, .. },
+                       64 + 76 + features.flags.len() + addresses.len()*38 + excess_address_data.len() + excess_data.len() }
+       }, {
        signature,
        contents
 });