Fix TLV serialization to work with large types.
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 9c1f8ecfd97d54b474efc61eab62015238cf5e11..e7b1ff48783d3a37d0dfbd0fa2f2dd9985b2c53a 100644 (file)
@@ -64,7 +64,6 @@ use util::errors::APIError;
 use prelude::*;
 use core::{cmp, mem};
 use core::cell::RefCell;
-use std::collections::{HashMap, hash_map, HashSet};
 use std::io::{Cursor, Read};
 use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard};
 use core::sync::atomic::{AtomicUsize, Ordering};
@@ -498,6 +497,7 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
 /// Typically, the block-specific parameters are derived from the best block hash for the network,
 /// as a newly constructed `ChannelManager` will not have created any channels yet. These parameters
 /// are not needed when deserializing a previously constructed `ChannelManager`.
+#[derive(Clone, Copy, PartialEq)]
 pub struct ChainParameters {
        /// The network for determining the `chain_hash` in Lightning messages.
        pub network: Network,
@@ -509,7 +509,7 @@ pub struct ChainParameters {
 }
 
 /// The best known block as identified by its hash and height.
-#[derive(Clone, Copy)]
+#[derive(Clone, Copy, PartialEq)]
 pub struct BestBlock {
        block_hash: BlockHash,
        height: u32,
@@ -2321,7 +2321,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        }
 
        #[cfg(any(test, feature = "_test_utils"))]
-       pub(crate) fn test_process_background_events(&self) {
+       /// Process background events, for functional testing
+       pub fn test_process_background_events(&self) {
                self.process_background_events();
        }
 
@@ -4317,261 +4318,87 @@ impl PersistenceNotifier {
 const SERIALIZATION_VERSION: u8 = 1;
 const MIN_SERIALIZATION_VERSION: u8 = 1;
 
-impl Writeable for PendingHTLCRouting {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
-               match &self {
-                       &PendingHTLCRouting::Forward { ref onion_packet, ref short_channel_id } => {
-                               0u8.write(writer)?;
-                               onion_packet.write(writer)?;
-                               short_channel_id.write(writer)?;
-                       },
-                       &PendingHTLCRouting::Receive { ref payment_data, ref incoming_cltv_expiry } => {
-                               1u8.write(writer)?;
-                               payment_data.payment_secret.write(writer)?;
-                               payment_data.total_msat.write(writer)?;
-                               incoming_cltv_expiry.write(writer)?;
-                       },
-               }
-               Ok(())
-       }
-}
-
-impl Readable for PendingHTLCRouting {
-       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<PendingHTLCRouting, DecodeError> {
-               match Readable::read(reader)? {
-                       0u8 => Ok(PendingHTLCRouting::Forward {
-                               onion_packet: Readable::read(reader)?,
-                               short_channel_id: Readable::read(reader)?,
-                       }),
-                       1u8 => Ok(PendingHTLCRouting::Receive {
-                               payment_data: msgs::FinalOnionHopData {
-                                       payment_secret: Readable::read(reader)?,
-                                       total_msat: Readable::read(reader)?,
-                               },
-                               incoming_cltv_expiry: Readable::read(reader)?,
-                       }),
-                       _ => Err(DecodeError::InvalidValue),
-               }
+impl_writeable_tlv_based_enum!(PendingHTLCRouting,
+       (0, Forward) => {
+               (0, onion_packet, required),
+               (2, short_channel_id, required),
+       },
+       (1, Receive) => {
+               (0, payment_data, required),
+               (2, incoming_cltv_expiry, required),
        }
-}
+;);
 
 impl_writeable_tlv_based!(PendingHTLCInfo, {
-       (0, routing),
-       (2, incoming_shared_secret),
-       (4, payment_hash),
-       (6, amt_to_forward),
-       (8, outgoing_cltv_value)
-}, {}, {});
-
-impl Writeable for HTLCFailureMsg {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
-               match self {
-                       &HTLCFailureMsg::Relay(ref fail_msg) => {
-                               0u8.write(writer)?;
-                               fail_msg.write(writer)?;
-                       },
-                       &HTLCFailureMsg::Malformed(ref fail_msg) => {
-                               1u8.write(writer)?;
-                               fail_msg.write(writer)?;
-                       }
-               }
-               Ok(())
-       }
-}
-
-impl Readable for HTLCFailureMsg {
-       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<HTLCFailureMsg, DecodeError> {
-               match <u8 as Readable>::read(reader)? {
-                       0 => Ok(HTLCFailureMsg::Relay(Readable::read(reader)?)),
-                       1 => Ok(HTLCFailureMsg::Malformed(Readable::read(reader)?)),
-                       _ => Err(DecodeError::InvalidValue),
-               }
-       }
-}
-
-impl Writeable for PendingHTLCStatus {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
-               match self {
-                       &PendingHTLCStatus::Forward(ref forward_info) => {
-                               0u8.write(writer)?;
-                               forward_info.write(writer)?;
-                       },
-                       &PendingHTLCStatus::Fail(ref fail_msg) => {
-                               1u8.write(writer)?;
-                               fail_msg.write(writer)?;
-                       }
-               }
-               Ok(())
-       }
-}
-
-impl Readable for PendingHTLCStatus {
-       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<PendingHTLCStatus, DecodeError> {
-               match <u8 as Readable>::read(reader)? {
-                       0 => Ok(PendingHTLCStatus::Forward(Readable::read(reader)?)),
-                       1 => Ok(PendingHTLCStatus::Fail(Readable::read(reader)?)),
-                       _ => Err(DecodeError::InvalidValue),
-               }
-       }
-}
+       (0, routing, required),
+       (2, incoming_shared_secret, required),
+       (4, payment_hash, required),
+       (6, amt_to_forward, required),
+       (8, outgoing_cltv_value, required)
+});
+
+impl_writeable_tlv_based_enum!(HTLCFailureMsg, ;
+       (0, Relay),
+       (1, Malformed),
+);
+impl_writeable_tlv_based_enum!(PendingHTLCStatus, ;
+       (0, Forward),
+       (1, Fail),
+);
 
 impl_writeable_tlv_based!(HTLCPreviousHopData, {
-       (0, short_channel_id),
-       (2, outpoint),
-       (4, htlc_id),
-       (6, incoming_packet_shared_secret)
-}, {}, {});
-
-impl Writeable for ClaimableHTLC {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
-               write_tlv_fields!(writer, {
-                       (0, self.prev_hop),
-                       (2, self.value),
-                       (4, self.payment_data.payment_secret),
-                       (6, self.payment_data.total_msat),
-                       (8, self.cltv_expiry)
-               }, {});
-               Ok(())
-       }
-}
-
-impl Readable for ClaimableHTLC {
-       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<Self, DecodeError> {
-               let mut prev_hop = HTLCPreviousHopData {
-                       short_channel_id: 0, htlc_id: 0,
-                       incoming_packet_shared_secret: [0; 32],
-                       outpoint: OutPoint::null(),
-               };
-               let mut value = 0;
-               let mut payment_secret = PaymentSecret([0; 32]);
-               let mut total_msat = 0;
-               let mut cltv_expiry = 0;
-               read_tlv_fields!(reader, {
-                       (0, prev_hop),
-                       (2, value),
-                       (4, payment_secret),
-                       (6, total_msat),
-                       (8, cltv_expiry)
-               }, {});
-               Ok(ClaimableHTLC {
-                       prev_hop,
-                       value,
-                       payment_data: msgs::FinalOnionHopData {
-                               payment_secret,
-                               total_msat,
-                       },
-                       cltv_expiry,
-               })
-       }
-}
-
-impl Writeable for HTLCSource {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
-               match self {
-                       &HTLCSource::PreviousHopData(ref hop_data) => {
-                               0u8.write(writer)?;
-                               hop_data.write(writer)?;
-                       },
-                       &HTLCSource::OutboundRoute { ref path, ref session_priv, ref first_hop_htlc_msat } => {
-                               1u8.write(writer)?;
-                               path.write(writer)?;
-                               session_priv.write(writer)?;
-                               first_hop_htlc_msat.write(writer)?;
-                       }
-               }
-               Ok(())
-       }
-}
-
-impl Readable for HTLCSource {
-       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<HTLCSource, DecodeError> {
-               match <u8 as Readable>::read(reader)? {
-                       0 => Ok(HTLCSource::PreviousHopData(Readable::read(reader)?)),
-                       1 => Ok(HTLCSource::OutboundRoute {
-                               path: Readable::read(reader)?,
-                               session_priv: Readable::read(reader)?,
-                               first_hop_htlc_msat: Readable::read(reader)?,
-                       }),
-                       _ => Err(DecodeError::InvalidValue),
-               }
-       }
-}
-
-impl Writeable for HTLCFailReason {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
-               match self {
-                       &HTLCFailReason::LightningError { ref err } => {
-                               0u8.write(writer)?;
-                               err.write(writer)?;
-                       },
-                       &HTLCFailReason::Reason { ref failure_code, ref data } => {
-                               1u8.write(writer)?;
-                               failure_code.write(writer)?;
-                               data.write(writer)?;
-                       }
-               }
-               Ok(())
-       }
-}
-
-impl Readable for HTLCFailReason {
-       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<HTLCFailReason, DecodeError> {
-               match <u8 as Readable>::read(reader)? {
-                       0 => Ok(HTLCFailReason::LightningError { err: Readable::read(reader)? }),
-                       1 => Ok(HTLCFailReason::Reason {
-                               failure_code: Readable::read(reader)?,
-                               data: Readable::read(reader)?,
-                       }),
-                       _ => Err(DecodeError::InvalidValue),
-               }
-       }
-}
-
-impl Writeable for HTLCForwardInfo {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
-               match self {
-                       &HTLCForwardInfo::AddHTLC { ref prev_short_channel_id, ref prev_funding_outpoint, ref prev_htlc_id, ref forward_info } => {
-                               0u8.write(writer)?;
-                               prev_short_channel_id.write(writer)?;
-                               prev_funding_outpoint.write(writer)?;
-                               prev_htlc_id.write(writer)?;
-                               forward_info.write(writer)?;
-                       },
-                       &HTLCForwardInfo::FailHTLC { ref htlc_id, ref err_packet } => {
-                               1u8.write(writer)?;
-                               htlc_id.write(writer)?;
-                               err_packet.write(writer)?;
-                       },
-               }
-               Ok(())
-       }
-}
-
-impl Readable for HTLCForwardInfo {
-       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<HTLCForwardInfo, DecodeError> {
-               match <u8 as Readable>::read(reader)? {
-                       0 => Ok(HTLCForwardInfo::AddHTLC {
-                               prev_short_channel_id: Readable::read(reader)?,
-                               prev_funding_outpoint: Readable::read(reader)?,
-                               prev_htlc_id: Readable::read(reader)?,
-                               forward_info: Readable::read(reader)?,
-                       }),
-                       1 => Ok(HTLCForwardInfo::FailHTLC {
-                               htlc_id: Readable::read(reader)?,
-                               err_packet: Readable::read(reader)?,
-                       }),
-                       _ => Err(DecodeError::InvalidValue),
-               }
-       }
-}
+       (0, short_channel_id, required),
+       (2, outpoint, required),
+       (4, htlc_id, required),
+       (6, incoming_packet_shared_secret, required)
+});
+
+impl_writeable_tlv_based!(ClaimableHTLC, {
+       (0, prev_hop, required),
+       (2, value, required),
+       (4, payment_data, required),
+       (6, cltv_expiry, required),
+});
+
+impl_writeable_tlv_based_enum!(HTLCSource,
+       (0, OutboundRoute) => {
+               (0, session_priv, required),
+               (2, first_hop_htlc_msat, required),
+               (4, path, vec_type),
+       }, ;
+       (1, PreviousHopData)
+);
+
+impl_writeable_tlv_based_enum!(HTLCFailReason,
+       (0, LightningError) => {
+               (0, err, required),
+       },
+       (1, Reason) => {
+               (0, failure_code, required),
+               (2, data, vec_type),
+       },
+;);
+
+impl_writeable_tlv_based_enum!(HTLCForwardInfo,
+       (0, AddHTLC) => {
+               (0, forward_info, required),
+               (2, prev_short_channel_id, required),
+               (4, prev_htlc_id, required),
+               (6, prev_funding_outpoint, required),
+       },
+       (1, FailHTLC) => {
+               (0, htlc_id, required),
+               (2, err_packet, required),
+       },
+;);
 
 impl_writeable_tlv_based!(PendingInboundPayment, {
-       (0, payment_secret),
-       (2, expiry_time),
-       (4, user_payment_id),
-       (6, payment_preimage),
-       (8, min_value_msat),
-}, {}, {});
+       (0, payment_secret, required),
+       (2, expiry_time, required),
+       (4, user_payment_id, required),
+       (6, payment_preimage, required),
+       (8, min_value_msat, required),
+});
 
 impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable for ChannelManager<Signer, M, T, K, F, L>
        where M::Target: chain::Watch<Signer>,
@@ -4666,7 +4493,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                        session_priv.write(writer)?;
                }
 
-               write_tlv_fields!(writer, {}, {});
+               write_tlv_fields!(writer, {});
 
                Ok(())
        }
@@ -4911,7 +4738,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                        }
                }
 
-               read_tlv_fields!(reader, {}, {});
+               read_tlv_fields!(reader, {});
 
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes());