Hold ChannelManager locks independently
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 0940cda2e2d66017a1eb84f78eb7b7725da96bcf..450d014033a895656d39cec81f7daf0caeb142ab 100644 (file)
@@ -423,7 +423,7 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        pub(super) latest_block_height: AtomicUsize,
        #[cfg(not(test))]
        latest_block_height: AtomicUsize,
-       last_block_hash: Mutex<BlockHash>,
+       last_block_hash: RwLock<BlockHash>,
        secp_ctx: Secp256k1<secp256k1::All>,
 
        #[cfg(any(test, feature = "_test_utils"))]
@@ -803,7 +803,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        tx_broadcaster,
 
                        latest_block_height: AtomicUsize::new(params.latest_height),
-                       last_block_hash: Mutex::new(params.latest_hash),
+                       last_block_hash: RwLock::new(params.latest_hash),
                        secp_ctx,
 
                        channel_state: Mutex::new(ChannelHolder{
@@ -2454,6 +2454,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
        fn internal_funding_created(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingCreated) -> Result<(), MsgHandleErrInternal> {
                let ((funding_msg, monitor), mut chan) = {
+                       let last_block_hash = *self.last_block_hash.read().unwrap();
                        let mut channel_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_lock;
                        match channel_state.by_id.entry(msg.temporary_channel_id.clone()) {
@@ -2461,7 +2462,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        if chan.get().get_counterparty_node_id() != *counterparty_node_id {
                                                return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a message for a channel from the wrong node!".to_owned(), msg.temporary_channel_id));
                                        }
-                                       (try_chan_entry!(self, chan.get_mut().funding_created(msg, &self.logger), channel_state, chan), chan.remove())
+                                       (try_chan_entry!(self, chan.get_mut().funding_created(msg, last_block_hash, &self.logger), channel_state, chan), chan.remove())
                                },
                                hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close("Failed to find corresponding channel".to_owned(), msg.temporary_channel_id))
                        }
@@ -2510,6 +2511,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
        fn internal_funding_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingSigned) -> Result<(), MsgHandleErrInternal> {
                let (funding_txo, user_id) = {
+                       let last_block_hash = *self.last_block_hash.read().unwrap();
                        let mut channel_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_lock;
                        match channel_state.by_id.entry(msg.channel_id) {
@@ -2517,7 +2519,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        if chan.get().get_counterparty_node_id() != *counterparty_node_id {
                                                return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a message for a channel from the wrong node!".to_owned(), msg.channel_id));
                                        }
-                                       let monitor = match chan.get_mut().funding_signed(&msg, &self.logger) {
+                                       let monitor = match chan.get_mut().funding_signed(&msg, last_block_hash, &self.logger) {
                                                Ok(update) => update,
                                                Err(e) => try_chan_entry!(self, Err(e), channel_state, chan),
                                        };
@@ -3253,9 +3255,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                // Note that we MUST NOT end up calling methods on self.chain_monitor here - we're called
                // during initialization prior to the chain_monitor being fully configured in some cases.
                // See the docs for `ChannelManagerReadArgs` for more.
-               let header_hash = header.block_hash();
-               log_trace!(self.logger, "Block {} at height {} connected", header_hash, height);
+               let block_hash = header.block_hash();
+               log_trace!(self.logger, "Block {} at height {} connected", block_hash, height);
+
                let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier);
+
+               self.latest_block_height.store(height as usize, Ordering::Release);
+               *self.last_block_hash.write().unwrap() = block_hash;
+
                let mut failed_channels = Vec::new();
                let mut timed_out_htlcs = Vec::new();
                {
@@ -3344,8 +3351,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                for (source, payment_hash, reason) in timed_out_htlcs.drain(..) {
                        self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), source, &payment_hash, reason);
                }
-               self.latest_block_height.store(height as usize, Ordering::Release);
-               *self.last_block_hash.try_lock().expect("block_(dis)connected must not be called in parallel") = header_hash;
+
                loop {
                        // Update last_node_announcement_serial to be the max of its current value and the
                        // block timestamp. This should keep us close to the current time without relying on
@@ -3369,6 +3375,10 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                // during initialization prior to the chain_monitor being fully configured in some cases.
                // See the docs for `ChannelManagerReadArgs` for more.
                let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier);
+
+               self.latest_block_height.fetch_sub(1, Ordering::AcqRel);
+               *self.last_block_hash.write().unwrap() = header.block_hash();
+
                let mut failed_channels = Vec::new();
                {
                        let mut channel_lock = self.channel_state.lock().unwrap();
@@ -3392,9 +3402,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                }
                        });
                }
+
                self.handle_init_event_channel_failures(failed_channels);
-               self.latest_block_height.fetch_sub(1, Ordering::AcqRel);
-               *self.last_block_hash.try_lock().expect("block_(dis)connected must not be called in parallel") = header.block_hash();
        }
 
        /// Blocks until ChannelManager needs to be persisted or a timeout is reached. It returns a bool
@@ -3950,7 +3959,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
 
                self.genesis_hash.write(writer)?;
                (self.latest_block_height.load(Ordering::Acquire) as u32).write(writer)?;
-               self.last_block_hash.lock().unwrap().write(writer)?;
+               self.last_block_hash.read().unwrap().write(writer)?;
 
                let channel_state = self.channel_state.lock().unwrap();
                let mut unfunded_channels = 0;
@@ -4153,10 +4162,6 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                let mut short_to_id = HashMap::with_capacity(cmp::min(channel_count as usize, 128));
                for _ in 0..channel_count {
                        let mut channel: Channel<Signer> = Channel::read(reader, &args.keys_manager)?;
-                       if channel.last_block_connected != Default::default() && channel.last_block_connected != last_block_hash {
-                               return Err(DecodeError::InvalidValue);
-                       }
-
                        let funding_txo = channel.get_funding_txo().ok_or(DecodeError::InvalidValue)?;
                        funding_txo_set.insert(funding_txo.clone());
                        if let Some(ref mut monitor) = args.channel_monitors.get_mut(&funding_txo) {
@@ -4256,7 +4261,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                        tx_broadcaster: args.tx_broadcaster,
 
                        latest_block_height: AtomicUsize::new(latest_block_height as usize),
-                       last_block_hash: Mutex::new(last_block_hash),
+                       last_block_hash: RwLock::new(last_block_hash),
                        secp_ctx,
 
                        channel_state: Mutex::new(ChannelHolder {