Compute InflightHtlcs from available information in ChannelManager
[rust-lightning] / lightning / src / ln / channel.rs
index 9b1bf2e45241fbca6659c3de48c3bbbe7cda29d6..9b73d30f23f1d95c47c7df4e6c3c075008003db1 100644 (file)
@@ -509,7 +509,7 @@ pub(super) struct Channel<Signer: Sign> {
 
        inbound_handshake_limits_override: Option<ChannelHandshakeLimits>,
 
-       user_id: u64,
+       user_id: u128,
 
        channel_id: [u8; 32],
        channel_state: u32,
@@ -902,7 +902,7 @@ impl<Signer: Sign> Channel<Signer> {
        // Constructors:
        pub fn new_outbound<K: Deref, F: Deref>(
                fee_estimator: &LowerBoundedFeeEstimator<F>, keys_provider: &K, counterparty_node_id: PublicKey, their_features: &InitFeatures,
-               channel_value_satoshis: u64, push_msat: u64, user_id: u64, config: &UserConfig, current_chain_height: u32,
+               channel_value_satoshis: u64, push_msat: u64, user_id: u128, config: &UserConfig, current_chain_height: u32,
                outbound_scid_alias: u64
        ) -> Result<Channel<Signer>, APIError>
        where K::Target: KeysInterface<Signer = Signer>,
@@ -1102,7 +1102,7 @@ impl<Signer: Sign> Channel<Signer> {
        /// Assumes chain_hash has already been checked and corresponds with what we expect!
        pub fn new_from_req<K: Deref, F: Deref, L: Deref>(
                fee_estimator: &LowerBoundedFeeEstimator<F>, keys_provider: &K, counterparty_node_id: PublicKey, their_features: &InitFeatures,
-               msg: &msgs::OpenChannel, user_id: u64, config: &UserConfig, current_chain_height: u32, logger: &L,
+               msg: &msgs::OpenChannel, user_id: u128, config: &UserConfig, current_chain_height: u32, logger: &L,
                outbound_scid_alias: u64
        ) -> Result<Channel<Signer>, ChannelError>
                where K::Target: KeysInterface<Signer = Signer>,
@@ -4205,7 +4205,7 @@ impl<Signer: Sign> Channel<Signer> {
        pub fn shutdown<K: Deref>(
                &mut self, keys_provider: &K, their_features: &InitFeatures, msg: &msgs::Shutdown
        ) -> Result<(Option<msgs::Shutdown>, Option<ChannelMonitorUpdate>, Vec<(HTLCSource, PaymentHash)>), ChannelError>
-       where K::Target: KeysInterface<Signer = Signer>
+       where K::Target: KeysInterface
        {
                if self.channel_state & (ChannelState::PeerDisconnected as u32) == ChannelState::PeerDisconnected as u32 {
                        return Err(ChannelError::Close("Peer sent shutdown when we needed a channel_reestablish".to_owned()));
@@ -4482,7 +4482,7 @@ impl<Signer: Sign> Channel<Signer> {
 
        /// Gets the "user_id" value passed into the construction of this channel. It has no special
        /// meaning and exists only to allow users to have a persistent identifier of a channel.
-       pub fn get_user_id(&self) -> u64 {
+       pub fn get_user_id(&self) -> u128 {
                self.user_id
        }
 
@@ -5178,7 +5178,7 @@ impl<Signer: Sign> Channel<Signer> {
        /// should be sent back to the counterparty node.
        ///
        /// [`msgs::AcceptChannel`]: crate::ln::msgs::AcceptChannel
-       pub fn accept_inbound_channel(&mut self, user_id: u64) -> msgs::AcceptChannel {
+       pub fn accept_inbound_channel(&mut self, user_id: u128) -> msgs::AcceptChannel {
                if self.is_outbound() {
                        panic!("Tried to send accept_channel for an outbound channel?");
                }
@@ -5825,7 +5825,7 @@ impl<Signer: Sign> Channel<Signer> {
        /// holding cell HTLCs for payment failure.
        pub fn get_shutdown<K: Deref>(&mut self, keys_provider: &K, their_features: &InitFeatures, target_feerate_sats_per_kw: Option<u32>)
        -> Result<(msgs::Shutdown, Option<ChannelMonitorUpdate>, Vec<(HTLCSource, PaymentHash)>), APIError>
-       where K::Target: KeysInterface<Signer = Signer> {
+       where K::Target: KeysInterface {
                for htlc in self.pending_outbound_htlcs.iter() {
                        if let OutboundHTLCState::LocalAnnounced(_) = htlc.state {
                                return Err(APIError::APIMisuseError{err: "Cannot begin shutdown with pending HTLCs. Process pending events first".to_owned()});
@@ -5941,6 +5941,17 @@ impl<Signer: Sign> Channel<Signer> {
                self.update_time_counter += 1;
                (monitor_update, dropped_outbound_htlcs)
        }
+
+       pub fn inflight_htlc_sources(&self) -> impl Iterator<Item=&HTLCSource> {
+               self.holding_cell_htlc_updates.iter()
+                       .flat_map(|htlc_update| {
+                               match htlc_update {
+                                       HTLCUpdateAwaitingACK::AddHTLC { source, .. } => { Some(source) }
+                                       _ => None
+                               }
+                       })
+                       .chain(self.pending_outbound_htlcs.iter().map(|htlc| &htlc.source))
+       }
 }
 
 const SERIALIZATION_VERSION: u8 = 2;
@@ -6007,7 +6018,11 @@ impl<Signer: Sign> Writeable for Channel<Signer> {
 
                write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
 
-               self.user_id.write(writer)?;
+               // `user_id` used to be a single u64 value. In order to remain backwards compatible with
+               // versions prior to 0.0.113, the u128 is serialized as two separate u64 values. We write
+               // the low bytes now and the optional high bytes later.
+               let user_id_low = self.user_id as u64;
+               user_id_low.write(writer)?;
 
                // Version 1 deserializers expected to read parts of the config object here. Version 2
                // deserializers (0.0.99) now read config through TLVs, and as we now require them for
@@ -6254,6 +6269,11 @@ impl<Signer: Sign> Writeable for Channel<Signer> {
 
                let channel_ready_event_emitted = Some(self.channel_ready_event_emitted);
 
+               // `user_id` used to be a single u64 value. In order to remain backwards compatible with
+               // versions prior to 0.0.113, the u128 is serialized as two separate u64 values. Therefore,
+               // we write the high bytes as an option here.
+               let user_id_high_opt = Some((self.user_id >> 64) as u64);
+
                write_tlv_fields!(writer, {
                        (0, self.announcement_sigs, option),
                        // minimum_depth and counterparty_selected_channel_reserve_satoshis used to have a
@@ -6277,6 +6297,7 @@ impl<Signer: Sign> Writeable for Channel<Signer> {
                        (19, self.latest_inbound_scid_alias, option),
                        (21, self.outbound_scid_alias, required),
                        (23, channel_ready_event_emitted, option),
+                       (25, user_id_high_opt, option),
                });
 
                Ok(())
@@ -6284,13 +6305,16 @@ impl<Signer: Sign> Writeable for Channel<Signer> {
 }
 
 const MAX_ALLOC_SIZE: usize = 64*1024;
-impl<'a, Signer: Sign, K: Deref> ReadableArgs<(&'a K, u32)> for Channel<Signer>
-               where K::Target: KeysInterface<Signer = Signer> {
+impl<'a, K: Deref> ReadableArgs<(&'a K, u32)> for Channel<<K::Target as KeysInterface>::Signer>
+               where K::Target: KeysInterface {
        fn read<R : io::Read>(reader: &mut R, args: (&'a K, u32)) -> Result<Self, DecodeError> {
                let (keys_source, serialized_height) = args;
                let ver = read_ver_prefix!(reader, SERIALIZATION_VERSION);
 
-               let user_id = Readable::read(reader)?;
+               // `user_id` used to be a single u64 value. In order to remain backwards compatible with
+               // versions prior to 0.0.113, the u128 is serialized as two separate u64 values. We read
+               // the low bytes now and the high bytes later.
+               let user_id_low: u64 = Readable::read(reader)?;
 
                let mut config = Some(LegacyChannelConfig::default());
                if ver == 1 {
@@ -6536,6 +6560,8 @@ impl<'a, Signer: Sign, K: Deref> ReadableArgs<(&'a K, u32)> for Channel<Signer>
                let mut outbound_scid_alias = None;
                let mut channel_ready_event_emitted = None;
 
+               let mut user_id_high_opt: Option<u64> = None;
+
                read_tlv_fields!(reader, {
                        (0, announcement_sigs, option),
                        (1, minimum_depth, option),
@@ -6553,6 +6579,7 @@ impl<'a, Signer: Sign, K: Deref> ReadableArgs<(&'a K, u32)> for Channel<Signer>
                        (19, latest_inbound_scid_alias, option),
                        (21, outbound_scid_alias, option),
                        (23, channel_ready_event_emitted, option),
+                       (25, user_id_high_opt, option),
                });
 
                if let Some(preimages) = preimages_opt {
@@ -6589,6 +6616,11 @@ impl<'a, Signer: Sign, K: Deref> ReadableArgs<(&'a K, u32)> for Channel<Signer>
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&keys_source.get_secure_random_bytes());
 
+               // `user_id` used to be a single u64 value. In order to remain backwards
+               // compatible with versions prior to 0.0.113, the u128 is serialized as two
+               // separate u64 values.
+               let user_id = user_id_low as u128 + ((user_id_high_opt.unwrap_or(0) as u128) << 64);
+
                Ok(Channel {
                        user_id,