Merge pull request #298 from TheBlueMatt/2019-01-271-cleanup
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Wed, 23 Jan 2019 19:48:13 +0000 (14:48 -0500)
committerGitHub <noreply@github.com>
Wed, 23 Jan 2019 19:48:13 +0000 (14:48 -0500)
Implement serialize/deserialize for Router

src/ln/channel.rs
src/ln/channelmanager.rs
src/ln/channelmonitor.rs
src/ln/msgs.rs
src/ln/peer_handler.rs
src/ln/router.rs
src/util/ser.rs
src/util/test_utils.rs

index 77ea03994a810a06b7d00ecb8f23d496ef623a6a..10b0725cdf2bc75ee2907ee548f3bd123a2a3006 100644 (file)
@@ -15,7 +15,7 @@ use secp256k1::{Secp256k1,Signature};
 use secp256k1;
 
 use ln::msgs;
-use ln::msgs::DecodeError;
+use ln::msgs::{DecodeError, OptionalField};
 use ln::channelmonitor::ChannelMonitor;
 use ln::channelmanager::{PendingHTLCStatus, HTLCSource, HTLCFailReason, HTLCFailureMsg, PendingForwardHTLCInfo, RAACommitmentOrder, PaymentPreimage, PaymentHash};
 use ln::chan_utils::{TxCreationKeys,HTLCOutputInCommitment,HTLC_SUCCESS_TX_WEIGHT,HTLC_TIMEOUT_TX_WEIGHT};
@@ -2943,7 +2943,7 @@ impl Channel {
                        htlc_basepoint: PublicKey::from_secret_key(&self.secp_ctx, &self.local_keys.htlc_base_key),
                        first_per_commitment_point: PublicKey::from_secret_key(&self.secp_ctx, &local_commitment_secret),
                        channel_flags: if self.config.announced_channel {1} else {0},
-                       shutdown_scriptpubkey: None,
+                       shutdown_scriptpubkey: OptionalField::Absent
                }
        }
 
@@ -2975,7 +2975,7 @@ impl Channel {
                        delayed_payment_basepoint: PublicKey::from_secret_key(&self.secp_ctx, &self.local_keys.delayed_payment_base_key),
                        htlc_basepoint: PublicKey::from_secret_key(&self.secp_ctx, &self.local_keys.htlc_base_key),
                        first_per_commitment_point: PublicKey::from_secret_key(&self.secp_ctx, &local_commitment_secret),
-                       shutdown_scriptpubkey: None,
+                       shutdown_scriptpubkey: OptionalField::Absent
                }
        }
 
@@ -3103,7 +3103,7 @@ impl Channel {
                        // dropped this channel on disconnect as it hasn't yet reached FundingSent so we can't
                        // overflow here.
                        next_remote_commitment_number: INITIAL_COMMITMENT_NUMBER - self.cur_remote_commitment_transaction_number - 1,
-                       data_loss_protect: None,
+                       data_loss_protect: OptionalField::Absent,
                }
        }
 
@@ -3688,14 +3688,6 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                        });
                }
 
-               macro_rules! read_option { () => {
-                       match <u8 as Readable<R>>::read(reader)? {
-                               0 => None,
-                               1 => Some(Readable::read(reader)?),
-                               _ => return Err(DecodeError::InvalidValue),
-                       }
-               } }
-
                let pending_outbound_htlc_count: u64 = Readable::read(reader)?;
                let mut pending_outbound_htlcs = Vec::with_capacity(cmp::min(pending_outbound_htlc_count as usize, OUR_MAX_HTLCS as usize));
                for _ in 0..pending_outbound_htlc_count {
@@ -3705,7 +3697,7 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                                cltv_expiry: Readable::read(reader)?,
                                payment_hash: Readable::read(reader)?,
                                source: Readable::read(reader)?,
-                               fail_reason: read_option!(),
+                               fail_reason: Readable::read(reader)?,
                                state: match <u8 as Readable<R>>::read(reader)? {
                                        0 => OutboundHTLCState::LocalAnnounced(Box::new(Readable::read(reader)?)),
                                        1 => OutboundHTLCState::Committed,
@@ -3763,8 +3755,8 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                        monitor_pending_failures.push((Readable::read(reader)?, Readable::read(reader)?, Readable::read(reader)?));
                }
 
-               let pending_update_fee = read_option!();
-               let holding_cell_update_fee = read_option!();
+               let pending_update_fee = Readable::read(reader)?;
+               let holding_cell_update_fee = Readable::read(reader)?;
 
                let next_local_htlc_id = Readable::read(reader)?;
                let next_remote_htlc_id = Readable::read(reader)?;
@@ -3786,8 +3778,8 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                        _ => return Err(DecodeError::InvalidValue),
                };
 
-               let funding_tx_confirmed_in = read_option!();
-               let short_channel_id = read_option!();
+               let funding_tx_confirmed_in = Readable::read(reader)?;
+               let short_channel_id = Readable::read(reader)?;
 
                let last_block_connected = Readable::read(reader)?;
                let funding_tx_confirmations = Readable::read(reader)?;
@@ -3802,17 +3794,17 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                let their_max_accepted_htlcs = Readable::read(reader)?;
                let minimum_depth = Readable::read(reader)?;
 
-               let their_funding_pubkey = read_option!();
-               let their_revocation_basepoint = read_option!();
-               let their_payment_basepoint = read_option!();
-               let their_delayed_payment_basepoint = read_option!();
-               let their_htlc_basepoint = read_option!();
-               let their_cur_commitment_point = read_option!();
+               let their_funding_pubkey = Readable::read(reader)?;
+               let their_revocation_basepoint = Readable::read(reader)?;
+               let their_payment_basepoint = Readable::read(reader)?;
+               let their_delayed_payment_basepoint = Readable::read(reader)?;
+               let their_htlc_basepoint = Readable::read(reader)?;
+               let their_cur_commitment_point = Readable::read(reader)?;
 
-               let their_prev_commitment_point = read_option!();
+               let their_prev_commitment_point = Readable::read(reader)?;
                let their_node_id = Readable::read(reader)?;
 
-               let their_shutdown_scriptpubkey = read_option!();
+               let their_shutdown_scriptpubkey = Readable::read(reader)?;
                let (monitor_last_block, channel_monitor) = ReadableArgs::read(reader, logger.clone())?;
                // We drop the ChannelMonitor's last block connected hash cause we don't actually bother
                // doing full block connection operations on the internal CHannelMonitor copies
index 06d179f40c15e711b78536b46b9018066c41585a..fd638754756a881a8055bc70e19570c8df953537 100644 (file)
@@ -2621,12 +2621,7 @@ const MIN_SERIALIZATION_VERSION: u8 = 1;
 
 impl Writeable for PendingForwardHTLCInfo {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
-               if let &Some(ref onion) = &self.onion_packet {
-                       1u8.write(writer)?;
-                       onion.write(writer)?;
-               } else {
-                       0u8.write(writer)?;
-               }
+               self.onion_packet.write(writer)?;
                self.incoming_shared_secret.write(writer)?;
                self.payment_hash.write(writer)?;
                self.short_channel_id.write(writer)?;
@@ -2638,13 +2633,8 @@ impl Writeable for PendingForwardHTLCInfo {
 
 impl<R: ::std::io::Read> Readable<R> for PendingForwardHTLCInfo {
        fn read(reader: &mut R) -> Result<PendingForwardHTLCInfo, DecodeError> {
-               let onion_packet = match <u8 as Readable<R>>::read(reader)? {
-                       0 => None,
-                       1 => Some(msgs::OnionPacket::read(reader)?),
-                       _ => return Err(DecodeError::InvalidValue),
-               };
                Ok(PendingForwardHTLCInfo {
-                       onion_packet,
+                       onion_packet: Readable::read(reader)?,
                        incoming_shared_secret: Readable::read(reader)?,
                        payment_hash: Readable::read(reader)?,
                        short_channel_id: Readable::read(reader)?,
index 8e70ce7adba6b04b77fa369f549b15ac350733d5..8e1ba12f2ab8445f5bd6400bcd4fd03cee8832be 100644 (file)
@@ -1960,13 +1960,6 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                }
                        }
                }
-               macro_rules! read_option { () => {
-                       match <u8 as Readable<R>>::read(reader)? {
-                               0 => None,
-                               1 => Some(Readable::read(reader)?),
-                               _ => return Err(DecodeError::InvalidValue),
-                       }
-               } }
 
                let _ver: u8 = Readable::read(reader)?;
                let min_ver: u8 = Readable::read(reader)?;
@@ -1983,16 +1976,8 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                let delayed_payment_base_key = Readable::read(reader)?;
                                let payment_base_key = Readable::read(reader)?;
                                let shutdown_pubkey = Readable::read(reader)?;
-                               let prev_latest_per_commitment_point = match <u8 as Readable<R>>::read(reader)? {
-                                       0 => None,
-                                       1 => Some(Readable::read(reader)?),
-                                       _ => return Err(DecodeError::InvalidValue),
-                               };
-                               let latest_per_commitment_point = match <u8 as Readable<R>>::read(reader)? {
-                                       0 => None,
-                                       1 => Some(Readable::read(reader)?),
-                                       _ => return Err(DecodeError::InvalidValue),
-                               };
+                               let prev_latest_per_commitment_point = Readable::read(reader)?;
+                               let latest_per_commitment_point = Readable::read(reader)?;
                                // Technically this can fail and serialize fail a round-trip, but only for serialization of
                                // barely-init'd ChannelMonitors that we can't do anything with.
                                let outpoint = OutPoint {
@@ -2000,8 +1985,8 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                        index: Readable::read(reader)?,
                                };
                                let funding_info = Some((outpoint, Readable::read(reader)?));
-                               let current_remote_commitment_txid = read_option!();
-                               let prev_remote_commitment_txid = read_option!();
+                               let current_remote_commitment_txid = Readable::read(reader)?;
+                               let prev_remote_commitment_txid = Readable::read(reader)?;
                                Storage::Local {
                                        revocation_base_key,
                                        htlc_base_key,
@@ -2052,7 +2037,7 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                        let amount_msat: u64 = Readable::read(reader)?;
                                        let cltv_expiry: u32 = Readable::read(reader)?;
                                        let payment_hash: PaymentHash = Readable::read(reader)?;
-                                       let transaction_output_index: Option<u32> = read_option!();
+                                       let transaction_output_index: Option<u32> = Readable::read(reader)?;
 
                                        HTLCOutputInCommitment {
                                                offered, amount_msat, cltv_expiry, payment_hash, transaction_output_index
@@ -2068,7 +2053,7 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                        let htlcs_count: u64 = Readable::read(reader)?;
                        let mut htlcs = Vec::with_capacity(cmp::min(htlcs_count as usize, MAX_ALLOC_SIZE / 32));
                        for _ in 0..htlcs_count {
-                               htlcs.push((read_htlc_in_commitment!(), read_option!().map(|o: HTLCSource| Box::new(o))));
+                               htlcs.push((read_htlc_in_commitment!(), <Option<HTLCSource> as Readable<R>>::read(reader)?.map(|o: HTLCSource| Box::new(o))));
                        }
                        if let Some(_) = remote_claimable_outpoints.insert(txid, htlcs) {
                                return Err(DecodeError::InvalidValue);
@@ -2131,7 +2116,7 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                                        1 => Some((Readable::read(reader)?, Readable::read(reader)?)),
                                                        _ => return Err(DecodeError::InvalidValue),
                                                };
-                                               htlcs.push((htlc, sigs, read_option!()));
+                                               htlcs.push((htlc, sigs, Readable::read(reader)?));
                                        }
 
                                        LocalSignedTx {
index 57e755019532bcb899b01fb94982954f5f9c280b..62c4bb0d288512605a7a9487658fe52c4cfa9927 100644 (file)
@@ -26,7 +26,7 @@ use std::{cmp, fmt};
 use std::io::Read;
 use std::result::Result;
 
-use util::{byte_utils, events};
+use util::events;
 use util::ser::{Readable, Writeable, Writer};
 
 use ln::channelmanager::{PaymentPreimage, PaymentHash};
@@ -47,7 +47,6 @@ pub enum DecodeError {
        /// node_announcement included more than one address of a given type!
        ExtraAddressesPerType,
        /// A length descriptor in the packet didn't describe the later data correctly
-       /// (currently only generated in node_announcement)
        BadLengthDescriptor,
        /// Error from std::io
        Io(::std::io::Error),
@@ -191,7 +190,7 @@ pub struct OpenChannel {
        pub(crate) htlc_basepoint: PublicKey,
        pub(crate) first_per_commitment_point: PublicKey,
        pub(crate) channel_flags: u8,
-       pub(crate) shutdown_scriptpubkey: Option<Script>,
+       pub(crate) shutdown_scriptpubkey: OptionalField<Script>,
 }
 
 /// An accept_channel message to be sent or received from a peer
@@ -211,7 +210,7 @@ pub struct AcceptChannel {
        pub(crate) delayed_payment_basepoint: PublicKey,
        pub(crate) htlc_basepoint: PublicKey,
        pub(crate) first_per_commitment_point: PublicKey,
-       pub(crate) shutdown_scriptpubkey: Option<Script>,
+       pub(crate) shutdown_scriptpubkey: OptionalField<Script>
 }
 
 /// A funding_created message to be sent or received from a peer
@@ -323,7 +322,7 @@ pub struct ChannelReestablish {
        pub(crate) channel_id: [u8; 32],
        pub(crate) next_local_commitment_number: u64,
        pub(crate) next_remote_commitment_number: u64,
-       pub(crate) data_loss_protect: Option<DataLossProtect>,
+       pub(crate) data_loss_protect: OptionalField<DataLossProtect>,
 }
 
 /// An announcement_signatures message to be sent or received from a peer
@@ -336,7 +335,7 @@ pub struct AnnouncementSignatures {
 }
 
 /// An address which can be used to connect to a remote peer
-#[derive(Clone)]
+#[derive(PartialEq, Clone)]
 pub enum NetAddress {
        /// An IPv4 address/port on which the peer is listenting.
        IPv4 {
@@ -382,9 +381,84 @@ impl NetAddress {
                        &NetAddress::OnionV3 {..} => { 4 },
                }
        }
+
+       /// Strict byte-length of address descriptor, 1-byte type not recorded
+       fn len(&self) -> u16 {
+               match self {
+                       &NetAddress::IPv4 { .. } => { 6 },
+                       &NetAddress::IPv6 { .. } => { 18 },
+                       &NetAddress::OnionV2 { .. } => { 12 },
+                       &NetAddress::OnionV3 { .. } => { 37 },
+               }
+       }
 }
 
-#[derive(Clone)]
+impl Writeable for NetAddress {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+               match self {
+                       &NetAddress::IPv4 { ref addr, ref port } => {
+                               1u8.write(writer)?;
+                               addr.write(writer)?;
+                               port.write(writer)?;
+                       },
+                       &NetAddress::IPv6 { ref addr, ref port } => {
+                               2u8.write(writer)?;
+                               addr.write(writer)?;
+                               port.write(writer)?;
+                       },
+                       &NetAddress::OnionV2 { ref addr, ref port } => {
+                               3u8.write(writer)?;
+                               addr.write(writer)?;
+                               port.write(writer)?;
+                       },
+                       &NetAddress::OnionV3 { ref ed25519_pubkey, ref checksum, ref version, ref port } => {
+                               4u8.write(writer)?;
+                               ed25519_pubkey.write(writer)?;
+                               checksum.write(writer)?;
+                               version.write(writer)?;
+                               port.write(writer)?;
+                       }
+               }
+               Ok(())
+       }
+}
+
+impl<R: ::std::io::Read>  Readable<R> for Result<NetAddress, u8> {
+       fn read(reader: &mut R) -> Result<Result<NetAddress, u8>, DecodeError> {
+               let byte = <u8 as Readable<R>>::read(reader)?;
+               match byte {
+                       1 => {
+                               Ok(Ok(NetAddress::IPv4 {
+                                       addr: Readable::read(reader)?,
+                                       port: Readable::read(reader)?,
+                               }))
+                       },
+                       2 => {
+                               Ok(Ok(NetAddress::IPv6 {
+                                       addr: Readable::read(reader)?,
+                                       port: Readable::read(reader)?,
+                               }))
+                       },
+                       3 => {
+                               Ok(Ok(NetAddress::OnionV2 {
+                                       addr: Readable::read(reader)?,
+                                       port: Readable::read(reader)?,
+                               }))
+                       },
+                       4 => {
+                               Ok(Ok(NetAddress::OnionV3 {
+                                       ed25519_pubkey: Readable::read(reader)?,
+                                       checksum: Readable::read(reader)?,
+                                       version: Readable::read(reader)?,
+                                       port: Readable::read(reader)?,
+                               }))
+                       },
+                       _ => return Ok(Err(byte)),
+               }
+       }
+}
+
+#[derive(PartialEq, Clone)]
 // Only exposed as broadcast of node_announcement should be filtered by node_id
 /// The unsigned part of a node_announcement
 pub struct UnsignedNodeAnnouncement {
@@ -401,7 +475,7 @@ pub struct UnsignedNodeAnnouncement {
        pub(crate) excess_address_data: Vec<u8>,
        pub(crate) excess_data: Vec<u8>,
 }
-#[derive(Clone)]
+#[derive(PartialEq, Clone)]
 /// A node_announcement message to be sent or received from a peer
 pub struct NodeAnnouncement {
        pub(crate) signature: Signature,
@@ -518,6 +592,18 @@ pub enum HTLCFailChannelUpdate {
        }
 }
 
+/// Messages could have optional fields to use with extended features
+/// As we wish to serialize these differently from Option<T>s (Options get a tag byte, but
+/// OptionalFeild simply gets Present if there are enough bytes to read into it), we have a
+/// separate enum type for them.
+#[derive(Clone, PartialEq)]
+pub enum OptionalField<T> {
+       /// Optional field is included in message
+       Present(T),
+       /// Optional field is absent in message
+       Absent
+}
+
 /// A trait to describe an object which can receive channel messages.
 ///
 /// Messages MAY be called in parallel when they originate from different their_node_ids, however
@@ -696,8 +782,35 @@ impl From<::std::io::Error> for DecodeError {
        }
 }
 
+impl Writeable for OptionalField<Script> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               match *self {
+                       OptionalField::Present(ref script) => {
+                               // Note that Writeable for script includes the 16-bit length tag for us
+                               script.write(w)?;
+                       },
+                       OptionalField::Absent => {}
+               }
+               Ok(())
+       }
+}
+
+impl<R: Read> Readable<R> for OptionalField<Script> {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
+               match <u16 as Readable<R>>::read(r) {
+                       Ok(len) => {
+                               let mut buf = vec![0; len as usize];
+                               r.read_exact(&mut buf)?;
+                               Ok(OptionalField::Present(Script::from(buf)))
+                       },
+                       Err(DecodeError::ShortRead) => Ok(OptionalField::Absent),
+                       Err(e) => Err(e)
+               }
+       }
+}
+
 impl_writeable_len_match!(AcceptChannel, {
-               {AcceptChannel{ shutdown_scriptpubkey: Some(ref script), ..}, 270 + 2 + script.len()},
+               {AcceptChannel{ shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 270 + 2 + script.len()},
                {_, 270}
        }, {
        temporary_channel_id,
@@ -726,13 +839,16 @@ impl_writeable!(AnnouncementSignatures, 32+8+64*2, {
 
 impl Writeable for ChannelReestablish {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               w.size_hint(if self.data_loss_protect.is_some() { 32+2*8+33+32 } else { 32+2*8 });
+               w.size_hint(if let OptionalField::Present(..) = self.data_loss_protect { 32+2*8+33+32 } else { 32+2*8 });
                self.channel_id.write(w)?;
                self.next_local_commitment_number.write(w)?;
                self.next_remote_commitment_number.write(w)?;
-               if let Some(ref data_loss_protect) = self.data_loss_protect {
-                       data_loss_protect.your_last_per_commitment_secret.write(w)?;
-                       data_loss_protect.my_current_per_commitment_point.write(w)?;
+               match self.data_loss_protect {
+                       OptionalField::Present(ref data_loss_protect) => {
+                               (*data_loss_protect).your_last_per_commitment_secret.write(w)?;
+                               (*data_loss_protect).my_current_per_commitment_point.write(w)?;
+                       },
+                       OptionalField::Absent => {}
                }
                Ok(())
        }
@@ -747,11 +863,11 @@ impl<R: Read> Readable<R> for ChannelReestablish{
                        data_loss_protect: {
                                match <[u8; 32] as Readable<R>>::read(r) {
                                        Ok(your_last_per_commitment_secret) =>
-                                               Some(DataLossProtect {
+                                               OptionalField::Present(DataLossProtect {
                                                        your_last_per_commitment_secret,
                                                        my_current_per_commitment_point: Readable::read(r)?,
                                                }),
-                                       Err(DecodeError::ShortRead) => None,
+                                       Err(DecodeError::ShortRead) => OptionalField::Absent,
                                        Err(e) => return Err(e)
                                }
                        }
@@ -818,8 +934,8 @@ impl_writeable_len_match!(Init, {
 });
 
 impl_writeable_len_match!(OpenChannel, {
-               { OpenChannel { shutdown_scriptpubkey: Some(ref script), .. }, 319 + 2 + script.len() },
-               { OpenChannel { shutdown_scriptpubkey: None, .. }, 319 }
+               { OpenChannel { shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 319 + 2 + script.len() },
+               { _, 319 }
        }, {
        chain_hash,
        temporary_channel_id,
@@ -1150,38 +1266,17 @@ impl Writeable for UnsignedNodeAnnouncement {
                w.write_all(&self.rgb)?;
                self.alias.write(w)?;
 
-               let mut addr_slice = Vec::with_capacity(self.addresses.len() * 18);
                let mut addrs_to_encode = self.addresses.clone();
                addrs_to_encode.sort_unstable_by(|a, b| { a.get_id().cmp(&b.get_id()) });
                addrs_to_encode.dedup_by(|a, b| { a.get_id() == b.get_id() });
-               for addr in addrs_to_encode.iter() {
-                       match addr {
-                               &NetAddress::IPv4{addr, port} => {
-                                       addr_slice.push(1);
-                                       addr_slice.extend_from_slice(&addr);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(port));
-                               },
-                               &NetAddress::IPv6{addr, port} => {
-                                       addr_slice.push(2);
-                                       addr_slice.extend_from_slice(&addr);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(port));
-                               },
-                               &NetAddress::OnionV2{addr, port} => {
-                                       addr_slice.push(3);
-                                       addr_slice.extend_from_slice(&addr);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(port));
-                               },
-                               &NetAddress::OnionV3{ed25519_pubkey, checksum, version, port} => {
-                                       addr_slice.push(4);
-                                       addr_slice.extend_from_slice(&ed25519_pubkey);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(checksum));
-                                       addr_slice.push(version);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(port));
-                               },
-                       }
+               let mut addr_len = 0;
+               for addr in &addrs_to_encode {
+                       addr_len += 1 + addr.len();
+               }
+               (addr_len + self.excess_address_data.len() as u16).write(w)?;
+               for addr in addrs_to_encode {
+                       addr.write(w)?;
                }
-               ((addr_slice.len() + self.excess_address_data.len()) as u16).write(w)?;
-               w.write_all(&addr_slice[..])?;
                w.write_all(&self.excess_address_data[..])?;
                w.write_all(&self.excess_data[..])?;
                Ok(())
@@ -1200,112 +1295,77 @@ impl<R: Read> Readable<R> for UnsignedNodeAnnouncement {
                r.read_exact(&mut rgb)?;
                let alias: [u8; 32] = Readable::read(r)?;
 
-               let addrlen: u16 = Readable::read(r)?;
+               let addr_len: u16 = Readable::read(r)?;
+               let mut addresses: Vec<NetAddress> = Vec::with_capacity(4);
                let mut addr_readpos = 0;
-               let mut addresses = Vec::with_capacity(4);
-               let mut f: u8 = 0;
-               let mut excess = 0;
+               let mut excess = false;
+               let mut excess_byte = 0;
                loop {
-                       if addrlen <= addr_readpos { break; }
-                       f = Readable::read(r)?;
-                       match f {
-                               1 => {
-                                       if addresses.len() > 0 {
-                                               return Err(DecodeError::ExtraAddressesPerType);
-                                       }
-                                       if addrlen < addr_readpos + 1 + 6 {
-                                               return Err(DecodeError::BadLengthDescriptor);
-                                       }
-                                       addresses.push(NetAddress::IPv4 {
-                                               addr: {
-                                                       let mut addr = [0; 4];
-                                                       r.read_exact(&mut addr)?;
-                                                       addr
+                       if addr_len <= addr_readpos { break; }
+                       match Readable::read(r) {
+                               Ok(Ok(addr)) => {
+                                       match addr {
+                                               NetAddress::IPv4 { .. } => {
+                                                       if addresses.len() > 0 {
+                                                               return Err(DecodeError::ExtraAddressesPerType);
+                                                       }
                                                },
-                                               port: Readable::read(r)?,
-                                       });
-                                       addr_readpos += 1 + 6
-                               },
-                               2 => {
-                                       if addresses.len() > 1 || (addresses.len() == 1 && addresses[0].get_id() != 1) {
-                                               return Err(DecodeError::ExtraAddressesPerType);
-                                       }
-                                       if addrlen < addr_readpos + 1 + 18 {
-                                               return Err(DecodeError::BadLengthDescriptor);
-                                       }
-                                       addresses.push(NetAddress::IPv6 {
-                                               addr: {
-                                                       let mut addr = [0; 16];
-                                                       r.read_exact(&mut addr)?;
-                                                       addr
+                                               NetAddress::IPv6 { .. } => {
+                                                       if addresses.len() > 1 || (addresses.len() == 1 && addresses[0].get_id() != 1) {
+                                                               return Err(DecodeError::ExtraAddressesPerType);
+                                                       }
                                                },
-                                               port: Readable::read(r)?,
-                                       });
-                                       addr_readpos += 1 + 18
-                               },
-                               3 => {
-                                       if addresses.len() > 2 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 2) {
-                                               return Err(DecodeError::ExtraAddressesPerType);
-                                       }
-                                       if addrlen < addr_readpos + 1 + 12 {
-                                               return Err(DecodeError::BadLengthDescriptor);
-                                       }
-                                       addresses.push(NetAddress::OnionV2 {
-                                               addr: {
-                                                       let mut addr = [0; 10];
-                                                       r.read_exact(&mut addr)?;
-                                                       addr
+                                               NetAddress::OnionV2 { .. } => {
+                                                       if addresses.len() > 2 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 2) {
+                                                               return Err(DecodeError::ExtraAddressesPerType);
+                                                       }
+                                               },
+                                               NetAddress::OnionV3 { .. } => {
+                                                       if addresses.len() > 3 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 3) {
+                                                               return Err(DecodeError::ExtraAddressesPerType);
+                                                       }
                                                },
-                                               port: Readable::read(r)?,
-                                       });
-                                       addr_readpos += 1 + 12
-                               },
-                               4 => {
-                                       if addresses.len() > 3 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 3) {
-                                               return Err(DecodeError::ExtraAddressesPerType);
                                        }
-                                       if addrlen < addr_readpos + 1 + 37 {
+                                       if addr_len < addr_readpos + 1 + addr.len() {
                                                return Err(DecodeError::BadLengthDescriptor);
                                        }
-                                       addresses.push(NetAddress::OnionV3 {
-                                               ed25519_pubkey: Readable::read(r)?,
-                                               checksum: Readable::read(r)?,
-                                               version: Readable::read(r)?,
-                                               port: Readable::read(r)?,
-                                       });
-                                       addr_readpos += 1 + 37
+                                       addr_readpos += (1 + addr.len()) as u16;
+                                       addresses.push(addr);
                                },
-                               _ => { excess = 1; break; }
+                               Ok(Err(unknown_descriptor)) => {
+                                       excess = true;
+                                       excess_byte = unknown_descriptor;
+                                       break;
+                               },
+                               Err(DecodeError::ShortRead) => return Err(DecodeError::BadLengthDescriptor),
+                               Err(e) => return Err(e),
                        }
                }
 
                let mut excess_data = vec![];
-               let excess_address_data = if addr_readpos < addrlen {
-                       let mut excess_address_data = vec![0; (addrlen - addr_readpos) as usize];
-                       r.read_exact(&mut excess_address_data[excess..])?;
-                       if excess == 1 {
-                               excess_address_data[0] = f;
+               let excess_address_data = if addr_readpos < addr_len {
+                       let mut excess_address_data = vec![0; (addr_len - addr_readpos) as usize];
+                       r.read_exact(&mut excess_address_data[if excess { 1 } else { 0 }..])?;
+                       if excess {
+                               excess_address_data[0] = excess_byte;
                        }
                        excess_address_data
                } else {
-                       if excess == 1 {
-                               excess_data.push(f);
+                       if excess {
+                               excess_data.push(excess_byte);
                        }
                        Vec::new()
                };
-
+               r.read_to_end(&mut excess_data)?;
                Ok(UnsignedNodeAnnouncement {
-                       features: features,
-                       timestamp: timestamp,
-                       node_id: node_id,
-                       rgb: rgb,
-                       alias: alias,
-                       addresses: addresses,
-                       excess_address_data: excess_address_data,
-                       excess_data: {
-                               r.read_to_end(&mut excess_data)?;
-                               excess_data
-                       },
+                       features,
+                       timestamp,
+                       node_id,
+                       rgb,
+                       alias,
+                       addresses,
+                       excess_address_data,
+                       excess_data,
                })
        }
 }
@@ -1322,6 +1382,7 @@ impl_writeable_len_match!(NodeAnnouncement, {
 mod tests {
        use hex;
        use ln::msgs;
+       use ln::msgs::OptionalField;
        use util::ser::Writeable;
        use secp256k1::key::{PublicKey,SecretKey};
        use secp256k1::Secp256k1;
@@ -1332,7 +1393,7 @@ mod tests {
                        channel_id: [4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0],
                        next_local_commitment_number: 3,
                        next_remote_commitment_number: 4,
-                       data_loss_protect: None,
+                       data_loss_protect: OptionalField::Absent,
                };
 
                let encoded_value = cr.encode();
@@ -1353,7 +1414,7 @@ mod tests {
                        channel_id: [4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0],
                        next_local_commitment_number: 3,
                        next_remote_commitment_number: 4,
-                       data_loss_protect: Some(msgs::DataLossProtect { your_last_per_commitment_secret: [9;32], my_current_per_commitment_point: public_key}),
+                       data_loss_protect: OptionalField::Present(msgs::DataLossProtect { your_last_per_commitment_secret: [9;32], my_current_per_commitment_point: public_key}),
                };
 
                let encoded_value = cr.encode();
index 78a5896f2ce993899d3a0fd1b2b1c3450d4c26a8..e0a1f6f2d76d5c266a131317ed14c7de831cad96 100644 (file)
@@ -478,8 +478,14 @@ impl<Descriptor: SocketDescriptor> PeerManager<Descriptor> {
                                                                                                        log_debug!(self, "Got a channel/node announcement with an known required feature flag, you may want to udpate!");
                                                                                                        continue;
                                                                                                },
-                                                                                               msgs::DecodeError::InvalidValue => return Err(PeerHandleError{ no_connection_possible: false }),
-                                                                                               msgs::DecodeError::ShortRead => return Err(PeerHandleError{ no_connection_possible: false }),
+                                                                                               msgs::DecodeError::InvalidValue => {
+                                                                                                       log_debug!(self, "Got an invalid value while deserializing message");
+                                                                                                       return Err(PeerHandleError{ no_connection_possible: false });
+                                                                                               },
+                                                                                               msgs::DecodeError::ShortRead => {
+                                                                                                       log_debug!(self, "Deserialization failed due to shortness of message");
+                                                                                                       return Err(PeerHandleError{ no_connection_possible: false });
+                                                                                               },
                                                                                                msgs::DecodeError::ExtraAddressesPerType => {
                                                                                                        log_debug!(self, "Error decoding message, ignoring due to lnd spec incompatibility. See https://github.com/lightningnetwork/lnd/issues/1407");
                                                                                                        continue;
index e25f9b223c05b0b9fc7216c7852191b8493c2033..f4a702fb709387f859b5fe54c82457dc5229ded1 100644 (file)
@@ -15,7 +15,7 @@ use chain::chaininterface::{ChainError, ChainWatchInterface};
 use ln::channelmanager;
 use ln::msgs::{DecodeError,ErrorAction,HandleError,RoutingMessageHandler,NetAddress,GlobalFeatures};
 use ln::msgs;
-use util::ser::{Writeable, Readable};
+use util::ser::{Writeable, Readable, Writer, ReadableArgs};
 use util::logger::Logger;
 
 use std::cmp;
@@ -78,6 +78,7 @@ impl<R: ::std::io::Read> Readable<R> for Route {
        }
 }
 
+#[derive(PartialEq)]
 struct DirectionalChannelInfo {
        src_node_id: PublicKey,
        last_update: u32,
@@ -96,6 +97,18 @@ impl std::fmt::Display for DirectionalChannelInfo {
        }
 }
 
+impl_writeable!(DirectionalChannelInfo, 0, {
+       src_node_id,
+       last_update,
+       enabled,
+       cltv_expiry_delta,
+       htlc_minimum_msat,
+       fee_base_msat,
+       fee_proportional_millionths,
+       last_update_message
+});
+
+#[derive(PartialEq)]
 struct ChannelInfo {
        features: GlobalFeatures,
        one_to_two: DirectionalChannelInfo,
@@ -112,6 +125,14 @@ impl std::fmt::Display for ChannelInfo {
        }
 }
 
+impl_writeable!(ChannelInfo, 0, {
+       features,
+       one_to_two,
+       two_to_one,
+       announcement_message
+});
+
+#[derive(PartialEq)]
 struct NodeInfo {
        #[cfg(feature = "non_bitcoin_chain_hash_routing")]
        channels: Vec<(u64, Sha256dHash)>,
@@ -138,6 +159,68 @@ impl std::fmt::Display for NodeInfo {
        }
 }
 
+impl Writeable for NodeInfo {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+               (self.channels.len() as u64).write(writer)?;
+               for ref chan in self.channels.iter() {
+                       chan.write(writer)?;
+               }
+               self.lowest_inbound_channel_fee_base_msat.write(writer)?;
+               self.lowest_inbound_channel_fee_proportional_millionths.write(writer)?;
+               self.features.write(writer)?;
+               self.last_update.write(writer)?;
+               self.rgb.write(writer)?;
+               self.alias.write(writer)?;
+               (self.addresses.len() as u64).write(writer)?;
+               for ref addr in &self.addresses {
+                       addr.write(writer)?;
+               }
+               self.announcement_message.write(writer)?;
+               Ok(())
+       }
+}
+
+const MAX_ALLOC_SIZE: u64 = 64*1024;
+
+impl<R: ::std::io::Read> Readable<R> for NodeInfo {
+       fn read(reader: &mut R) -> Result<NodeInfo, DecodeError> {
+               let channels_count: u64 = Readable::read(reader)?;
+               let mut channels = Vec::with_capacity(cmp::min(channels_count, MAX_ALLOC_SIZE / 8) as usize);
+               for _ in 0..channels_count {
+                       channels.push(Readable::read(reader)?);
+               }
+               let lowest_inbound_channel_fee_base_msat = Readable::read(reader)?;
+               let lowest_inbound_channel_fee_proportional_millionths = Readable::read(reader)?;
+               let features = Readable::read(reader)?;
+               let last_update = Readable::read(reader)?;
+               let rgb = Readable::read(reader)?;
+               let alias = Readable::read(reader)?;
+               let addresses_count: u64 = Readable::read(reader)?;
+               let mut addresses = Vec::with_capacity(cmp::min(addresses_count, MAX_ALLOC_SIZE / 40) as usize);
+               for _ in 0..addresses_count {
+                       match Readable::read(reader) {
+                               Ok(Ok(addr)) => { addresses.push(addr); },
+                               Ok(Err(_)) => return Err(DecodeError::InvalidValue),
+                               Err(DecodeError::ShortRead) => return Err(DecodeError::BadLengthDescriptor),
+                               _ => unreachable!(),
+                       }
+               }
+               let announcement_message = Readable::read(reader)?;
+               Ok(NodeInfo {
+                       channels,
+                       lowest_inbound_channel_fee_base_msat,
+                       lowest_inbound_channel_fee_proportional_millionths,
+                       features,
+                       last_update,
+                       rgb,
+                       alias,
+                       addresses,
+                       announcement_message
+               })
+       }
+}
+
+#[derive(PartialEq)]
 struct NetworkMap {
        #[cfg(feature = "non_bitcoin_chain_hash_routing")]
        channels: BTreeMap<(u64, Sha256dHash), ChannelInfo>,
@@ -147,6 +230,49 @@ struct NetworkMap {
        our_node_id: PublicKey,
        nodes: BTreeMap<PublicKey, NodeInfo>,
 }
+
+impl Writeable for NetworkMap {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+               (self.channels.len() as u64).write(writer)?;
+               for (ref chan_id, ref chan_info) in self.channels.iter() {
+                       (*chan_id).write(writer)?;
+                       chan_info.write(writer)?;
+               }
+               self.our_node_id.write(writer)?;
+               (self.nodes.len() as u64).write(writer)?;
+               for (ref node_id, ref node_info) in self.nodes.iter() {
+                       node_id.write(writer)?;
+                       node_info.write(writer)?;
+               }
+               Ok(())
+       }
+}
+
+impl<R: ::std::io::Read> Readable<R> for NetworkMap {
+       fn read(reader: &mut R) -> Result<NetworkMap, DecodeError> {
+               let channels_count: u64 = Readable::read(reader)?;
+               let mut channels = BTreeMap::new();
+               for _ in 0..channels_count {
+                       let chan_id: u64 = Readable::read(reader)?;
+                       let chan_info = Readable::read(reader)?;
+                       channels.insert(chan_id, chan_info);
+               }
+               let our_node_id = Readable::read(reader)?;
+               let nodes_count: u64 = Readable::read(reader)?;
+               let mut nodes = BTreeMap::new();
+               for _ in 0..nodes_count {
+                       let node_id = Readable::read(reader)?;
+                       let node_info = Readable::read(reader)?;
+                       nodes.insert(node_id, node_info);
+               }
+               Ok(NetworkMap {
+                       channels,
+                       our_node_id,
+                       nodes,
+               })
+       }
+}
+
 struct MutNetworkMap<'a> {
        #[cfg(feature = "non_bitcoin_chain_hash_routing")]
        channels: &'a mut BTreeMap<(u64, Sha256dHash), ChannelInfo>,
@@ -228,6 +354,51 @@ pub struct Router {
        logger: Arc<Logger>,
 }
 
+const SERIALIZATION_VERSION: u8 = 1;
+const MIN_SERIALIZATION_VERSION: u8 = 1;
+
+impl Writeable for Router {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+               writer.write_all(&[SERIALIZATION_VERSION; 1])?;
+               writer.write_all(&[MIN_SERIALIZATION_VERSION; 1])?;
+
+               let network = self.network_map.read().unwrap();
+               network.write(writer)?;
+               Ok(())
+       }
+}
+
+/// Arguments for the creation of a Router that are not deserialized.
+/// At a high-level, the process for deserializing a Router and resuming normal operation is:
+/// 1) Deserialize the Router by filling in this struct and calling <Router>::read(reaser, args).
+/// 2) Register the new Router with your ChainWatchInterface
+pub struct RouterReadArgs {
+       /// The ChainWatchInterface for use in the Router in the future.
+       ///
+       /// No calls to the ChainWatchInterface will be made during deserialization.
+       pub chain_monitor: Arc<ChainWatchInterface>,
+       /// The Logger for use in the ChannelManager and which may be used to log information during
+       /// deserialization.
+       pub logger: Arc<Logger>,
+}
+
+impl<R: ::std::io::Read> ReadableArgs<R, RouterReadArgs> for Router {
+       fn read(reader: &mut R, args: RouterReadArgs) -> Result<Router, DecodeError> {
+               let _ver: u8 = Readable::read(reader)?;
+               let min_ver: u8 = Readable::read(reader)?;
+               if min_ver > SERIALIZATION_VERSION {
+                       return Err(DecodeError::UnknownVersion);
+               }
+               let network_map = Readable::read(reader)?;
+               Ok(Router {
+                       secp_ctx: Secp256k1::verification_only(),
+                       network_map: RwLock::new(network_map),
+                       chain_monitor: args.chain_monitor,
+                       logger: args.logger,
+               })
+       }
+}
+
 macro_rules! secp_verify_sig {
        ( $secp_ctx: expr, $msg: expr, $sig: expr, $pubkey: expr ) => {
                match $secp_ctx.verify($msg, $sig, $pubkey) {
@@ -845,7 +1016,9 @@ mod tests {
        use ln::router::{Router,NodeInfo,NetworkMap,ChannelInfo,DirectionalChannelInfo,RouteHint};
        use ln::msgs::GlobalFeatures;
        use util::test_utils;
+       use util::test_utils::TestVecWriter;
        use util::logger::Logger;
+       use util::ser::{Writeable, Readable};
 
        use bitcoin::util::hash::Sha256dHash;
        use bitcoin::network::constants::Network;
@@ -1439,5 +1612,14 @@ mod tests {
                        assert_eq!(route.hops[4].fee_msat, 2000);
                        assert_eq!(route.hops[4].cltv_expiry_delta, 42);
                }
+
+               { // Test Router serialization/deserialization
+                       let mut w = TestVecWriter(Vec::new());
+                       let network = router.network_map.read().unwrap();
+                       assert!(!network.channels.is_empty());
+                       assert!(!network.nodes.is_empty());
+                       network.write(&mut w).unwrap();
+                       assert!(<NetworkMap>::read(&mut ::std::io::Cursor::new(&w.0)).unwrap() == *network);
+               }
        }
 }
index 0b2a626fb0beff87d7df6b7a44adf1b75f35cc89..d832c7018825e75df028bd2e2b8ea538962cc0e3 100644 (file)
@@ -203,6 +203,10 @@ macro_rules! impl_array {
 }
 
 //TODO: performance issue with [u8; size] with impl_array!()
+impl_array!(3); // for rgb
+impl_array!(4); // for IPv4
+impl_array!(10); // for OnionV2
+impl_array!(16); // for IPv6
 impl_array!(32); // for channel id & hmac
 impl_array!(33); // for PublicKey
 impl_array!(64); // for Signature
@@ -302,29 +306,6 @@ impl<R: Read> Readable<R> for Script {
        }
 }
 
-impl Writeable for Option<Script> {
-       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               if let &Some(ref script) = self {
-                       script.write(w)?;
-               }
-               Ok(())
-       }
-}
-
-impl<R: Read> Readable<R> for Option<Script> {
-       fn read(r: &mut R) -> Result<Self, DecodeError> {
-               match <u16 as Readable<R>>::read(r) {
-                       Ok(len) => {
-                               let mut buf = vec![0; len as usize];
-                               r.read_exact(&mut buf)?;
-                               Ok(Some(Script::from(buf)))
-                       },
-                       Err(DecodeError::ShortRead) => Ok(None),
-                       Err(e) => Err(e)
-               }
-       }
-}
-
 impl Writeable for PublicKey {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
                self.serialize().write(w)
@@ -413,3 +394,29 @@ impl<R: Read> Readable<R> for PaymentHash {
                Ok(PaymentHash(buf))
        }
 }
+
+impl<T: Writeable> Writeable for Option<T> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               match *self {
+                       None => 0u8.write(w)?,
+                       Some(ref data) => {
+                               1u8.write(w)?;
+                               data.write(w)?;
+                       }
+               }
+               Ok(())
+       }
+}
+
+impl<R, T> Readable<R> for Option<T>
+       where R: Read,
+             T: Readable<R>
+{
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
+               match <u8 as Readable<R>>::read(r)? {
+                       0 => Ok(None),
+                       1 => Ok(Some(Readable::read(r)?)),
+                       _ => return Err(DecodeError::InvalidValue),
+               }
+       }
+}
index b04f728b0f71edde16afa738fa2caa248c242bdb..b889e254eb47557df5d1055f9fb2b86811d30654 100644 (file)
@@ -20,8 +20,8 @@ use secp256k1::{SecretKey, PublicKey};
 use std::sync::{Arc,Mutex};
 use std::{mem};
 
-struct VecWriter(Vec<u8>);
-impl Writer for VecWriter {
+pub struct TestVecWriter(pub Vec<u8>);
+impl Writer for TestVecWriter {
        fn write_all(&mut self, buf: &[u8]) -> Result<(), ::std::io::Error> {
                self.0.extend_from_slice(buf);
                Ok(())
@@ -58,7 +58,7 @@ impl channelmonitor::ManyChannelMonitor for TestChannelMonitor {
        fn add_update_monitor(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
                // At every point where we get a monitor update, we should be able to send a useful monitor
                // to a watchtower and disk...
-               let mut w = VecWriter(Vec::new());
+               let mut w = TestVecWriter(Vec::new());
                monitor.write_for_disk(&mut w).unwrap();
                assert!(<(Sha256dHash, channelmonitor::ChannelMonitor)>::read(
                                &mut ::std::io::Cursor::new(&w.0), Arc::new(TestLogger::new())).unwrap().1 == monitor);