Append backwards-compat TLVs to serialization of larger structs
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 922912932e7c6df71a21f78ef459c80e06a34ccf..4919a225aacc2133c1623d245c5bc83c4be6c05f 100644 (file)
@@ -61,15 +61,15 @@ use util::chacha20::{ChaCha20, ChaChaReader};
 use util::logger::Logger;
 use util::errors::APIError;
 
-use std::{cmp, mem};
+use core::{cmp, mem};
 use std::collections::{HashMap, hash_map, HashSet};
 use std::io::{Cursor, Read};
 use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard};
-use std::sync::atomic::{AtomicUsize, Ordering};
-use std::time::Duration;
+use core::sync::atomic::{AtomicUsize, Ordering};
+use core::time::Duration;
 #[cfg(any(test, feature = "allow_wallclock_use"))]
 use std::time::Instant;
-use std::ops::Deref;
+use core::ops::Deref;
 use bitcoin::hashes::hex::ToHex;
 
 // We hold various information about HTLC relay in the HTLC objects in Channel itself:
@@ -442,6 +442,18 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// Locked *after* channel_state.
        pending_inbound_payments: Mutex<HashMap<PaymentHash, PendingInboundPayment>>,
 
+       /// The session_priv bytes of outbound payments which are pending resolution.
+       /// The authoritative state of these HTLCs resides either within Channels or ChannelMonitors
+       /// (if the channel has been force-closed), however we track them here to prevent duplicative
+       /// PaymentSent/PaymentFailed events. Specifically, in the case of a duplicative
+       /// update_fulfill_htlc message after a reconnect, we may "claim" a payment twice.
+       /// Additionally, because ChannelMonitors are often not re-serialized after connecting block(s)
+       /// which may generate a claim event, we may receive similar duplicate claim/fail MonitorEvents
+       /// after reloading from disk while replaying blocks against ChannelMonitors.
+       ///
+       /// Locked *after* channel_state.
+       pending_outbound_payments: Mutex<HashSet<[u8; 32]>>,
+
        our_network_key: SecretKey,
        our_network_pubkey: PublicKey,
 
@@ -913,6 +925,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                pending_msg_events: Vec::new(),
                        }),
                        pending_inbound_payments: Mutex::new(HashMap::new()),
+                       pending_outbound_payments: Mutex::new(HashSet::new()),
 
                        our_network_key: keys_manager.get_node_secret(),
                        our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &keys_manager.get_node_secret()),
@@ -1467,7 +1480,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        pub(crate) fn send_payment_along_path(&self, path: &Vec<RouteHop>, payment_hash: &PaymentHash, payment_secret: &Option<PaymentSecret>, total_value: u64, cur_height: u32) -> Result<(), APIError> {
                log_trace!(self.logger, "Attempting to send payment for path with next hop {}", path.first().unwrap().short_channel_id);
                let prng_seed = self.keys_manager.get_secure_random_bytes();
-               let session_priv = SecretKey::from_slice(&self.keys_manager.get_secure_random_bytes()[..]).expect("RNG is busted");
+               let session_priv_bytes = self.keys_manager.get_secure_random_bytes();
+               let session_priv = SecretKey::from_slice(&session_priv_bytes[..]).expect("RNG is busted");
 
                let onion_keys = onion_utils::construct_onion_keys(&self.secp_ctx, &path, &session_priv)
                        .map_err(|_| APIError::RouteError{err: "Pubkey along hop was maliciously selected"})?;
@@ -1478,6 +1492,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash);
 
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
+               assert!(self.pending_outbound_payments.lock().unwrap().insert(session_priv_bytes));
 
                let err: Result<(), _> = loop {
                        let mut channel_lock = self.channel_state.lock().unwrap();
@@ -1766,7 +1781,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        // be absurd. We ensure this by checking that at least 500 (our stated public contract on when
        // broadcast_node_announcement panics) of the maximum-length addresses would fit in a 64KB
        // message...
-       const HALF_MESSAGE_IS_ADDRS: u32 = ::std::u16::MAX as u32 / (NetAddress::MAX_LEN as u32 + 1) / 2;
+       const HALF_MESSAGE_IS_ADDRS: u32 = ::core::u16::MAX as u32 / (NetAddress::MAX_LEN as u32 + 1) / 2;
        #[deny(const_err)]
        #[allow(dead_code)]
        // ...by failing to compile if the number of addresses that would be half of a message is
@@ -2228,17 +2243,25 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        self.fail_htlc_backwards_internal(channel_state,
                                                htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data});
                                },
-                               HTLCSource::OutboundRoute { .. } => {
-                                       self.pending_events.lock().unwrap().push(
-                                               events::Event::PaymentFailed {
-                                                       payment_hash,
-                                                       rejected_by_dest: false,
+                               HTLCSource::OutboundRoute { session_priv, .. } => {
+                                       if {
+                                               let mut session_priv_bytes = [0; 32];
+                                               session_priv_bytes.copy_from_slice(&session_priv[..]);
+                                               self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
+                                       } {
+                                               self.pending_events.lock().unwrap().push(
+                                                       events::Event::PaymentFailed {
+                                                               payment_hash,
+                                                               rejected_by_dest: false,
 #[cfg(test)]
-                                                       error_code: None,
+                                                               error_code: None,
 #[cfg(test)]
-                                                       error_data: None,
-                                               }
-                                       )
+                                                               error_data: None,
+                                                       }
+                                               )
+                                       } else {
+                                               log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
+                                       }
                                },
                        };
                }
@@ -2260,7 +2283,15 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                // from block_connected which may run during initialization prior to the chain_monitor
                // being fully configured. See the docs for `ChannelManagerReadArgs` for more.
                match source {
-                       HTLCSource::OutboundRoute { ref path, .. } => {
+                       HTLCSource::OutboundRoute { ref path, session_priv, .. } => {
+                               if {
+                                       let mut session_priv_bytes = [0; 32];
+                                       session_priv_bytes.copy_from_slice(&session_priv[..]);
+                                       !self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
+                               } {
+                                       log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
+                                       return;
+                               }
                                log_trace!(self.logger, "Failing outbound payment HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                mem::drop(channel_state_lock);
                                match &onion_error {
@@ -2489,12 +2520,20 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
        fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_preimage: PaymentPreimage) {
                match source {
-                       HTLCSource::OutboundRoute { .. } => {
+                       HTLCSource::OutboundRoute { session_priv, .. } => {
                                mem::drop(channel_state_lock);
-                               let mut pending_events = self.pending_events.lock().unwrap();
-                               pending_events.push(events::Event::PaymentSent {
-                                       payment_preimage
-                               });
+                               if {
+                                       let mut session_priv_bytes = [0; 32];
+                                       session_priv_bytes.copy_from_slice(&session_priv[..]);
+                                       self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
+                               } {
+                                       let mut pending_events = self.pending_events.lock().unwrap();
+                                       pending_events.push(events::Event::PaymentSent {
+                                               payment_preimage
+                                       });
+                               } else {
+                                       log_trace!(self.logger, "Received duplicative fulfill for HTLC with payment_preimage {}", log_bytes!(payment_preimage.0));
+                               }
                        },
                        HTLCSource::PreviousHopData(hop_data) => {
                                let prev_outpoint = hop_data.outpoint;
@@ -3517,7 +3556,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        /// `invoice_expiry_delta_secs` describes the number of seconds that the invoice is valid for
        /// in excess of the current time. This should roughly match the expiry time set in the invoice.
        /// After this many seconds, we will remove the inbound payment, resulting in any attempts to
-       /// pay the invoice failing. The BOLT spec suggests 7,200 secs as a default validity time for
+       /// pay the invoice failing. The BOLT spec suggests 3,600 secs as a default validity time for
        /// invoices when no timeout is set.
        ///
        /// Note that we use block header time to time-out pending inbound payments (with some margin
@@ -4392,8 +4431,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
                let _consistency_lock = self.total_consistency_lock.write().unwrap();
 
-               writer.write_all(&[SERIALIZATION_VERSION; 1])?;
-               writer.write_all(&[MIN_SERIALIZATION_VERSION; 1])?;
+               write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
 
                self.genesis_hash.write(writer)?;
                {
@@ -4470,6 +4508,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                        pending_payment.write(writer)?;
                }
 
+               let pending_outbound_payments = self.pending_outbound_payments.lock().unwrap();
+               (pending_outbound_payments.len() as u64).write(writer)?;
+               for session_priv in pending_outbound_payments.iter() {
+                       session_priv.write(writer)?;
+               }
+
+               write_tlv_fields!(writer, {});
+
                Ok(())
        }
 }
@@ -4593,11 +4639,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
         L::Target: Logger,
 {
        fn read<R: ::std::io::Read>(reader: &mut R, mut args: ChannelManagerReadArgs<'a, Signer, M, T, K, F, L>) -> Result<Self, DecodeError> {
-               let _ver: u8 = Readable::read(reader)?;
-               let min_ver: u8 = Readable::read(reader)?;
-               if min_ver > SERIALIZATION_VERSION {
-                       return Err(DecodeError::UnknownVersion);
-               }
+               let _ver = read_ver_prefix!(reader, SERIALIZATION_VERSION);
 
                let genesis_hash: BlockHash = Readable::read(reader)?;
                let best_block_height: u32 = Readable::read(reader)?;
@@ -4709,6 +4751,16 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                        }
                }
 
+               let pending_outbound_payments_count: u64 = Readable::read(reader)?;
+               let mut pending_outbound_payments: HashSet<[u8; 32]> = HashSet::with_capacity(cmp::min(pending_outbound_payments_count as usize, MAX_ALLOC_SIZE/32));
+               for _ in 0..pending_outbound_payments_count {
+                       if !pending_outbound_payments.insert(Readable::read(reader)?) {
+                               return Err(DecodeError::InvalidValue);
+                       }
+               }
+
+               read_tlv_fields!(reader, {}, {});
+
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes());
 
@@ -4728,6 +4780,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                                pending_msg_events: Vec::new(),
                        }),
                        pending_inbound_payments: Mutex::new(pending_inbound_payments),
+                       pending_outbound_payments: Mutex::new(pending_outbound_payments),
 
                        our_network_key: args.keys_manager.get_node_secret(),
                        our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &args.keys_manager.get_node_secret()),
@@ -4763,9 +4816,9 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
 mod tests {
        use ln::channelmanager::PersistenceNotifier;
        use std::sync::Arc;
-       use std::sync::atomic::{AtomicBool, Ordering};
+       use core::sync::atomic::{AtomicBool, Ordering};
        use std::thread;
-       use std::time::Duration;
+       use core::time::Duration;
 
        #[test]
        fn test_wait_timeout() {