Change ChannelManager::wait to be more descriptive
[rust-lightning] / lightning / src / ln / channelmanager.rs
index b356641a0f2eb03b01f1a09a3a95d0cc5bed03b4..de0ce53590dcebfad11af1ca78e11c93547a4ad2 100644 (file)
@@ -423,7 +423,7 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        pub(super) latest_block_height: AtomicUsize,
        #[cfg(not(test))]
        latest_block_height: AtomicUsize,
-       last_block_hash: Mutex<BlockHash>,
+       last_block_hash: RwLock<BlockHash>,
        secp_ctx: Secp256k1<secp256k1::All>,
 
        #[cfg(any(test, feature = "_test_utils"))]
@@ -480,10 +480,11 @@ pub struct ChainParameters {
 }
 
 /// Whenever we release the `ChannelManager`'s `total_consistency_lock`, from read mode, it is
-/// desirable to notify any listeners on `wait_timeout`/`wait` that new updates are available for
-/// persistence. Therefore, this struct is responsible for locking the total consistency lock and,
-/// upon going out of scope, sending the aforementioned notification (since the lock being released
-/// indicates that the updates are ready for persistence).
+/// desirable to notify any listeners on `await_persistable_update_timeout`/
+/// `await_persistable_update` that new updates are available for persistence. Therefore, this
+/// struct is responsible for locking the total consistency lock and, upon going out of scope,
+/// sending the aforementioned notification (since the lock being released indicates that the
+/// updates are ready for persistence).
 struct PersistenceNotifierGuard<'a> {
        persistence_notifier: &'a PersistenceNotifier,
        // We hold onto this result so the lock doesn't get released immediately.
@@ -803,7 +804,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        tx_broadcaster,
 
                        latest_block_height: AtomicUsize::new(params.latest_height),
-                       last_block_hash: Mutex::new(params.latest_hash),
+                       last_block_hash: RwLock::new(params.latest_hash),
                        secp_ctx,
 
                        channel_state: Mutex::new(ChannelHolder{
@@ -2454,6 +2455,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
        fn internal_funding_created(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingCreated) -> Result<(), MsgHandleErrInternal> {
                let ((funding_msg, monitor), mut chan) = {
+                       let last_block_hash = *self.last_block_hash.read().unwrap();
                        let mut channel_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_lock;
                        match channel_state.by_id.entry(msg.temporary_channel_id.clone()) {
@@ -2461,7 +2463,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        if chan.get().get_counterparty_node_id() != *counterparty_node_id {
                                                return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a message for a channel from the wrong node!".to_owned(), msg.temporary_channel_id));
                                        }
-                                       let last_block_hash = *self.last_block_hash.lock().unwrap();
                                        (try_chan_entry!(self, chan.get_mut().funding_created(msg, last_block_hash, &self.logger), channel_state, chan), chan.remove())
                                },
                                hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close("Failed to find corresponding channel".to_owned(), msg.temporary_channel_id))
@@ -2511,6 +2512,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
        fn internal_funding_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingSigned) -> Result<(), MsgHandleErrInternal> {
                let (funding_txo, user_id) = {
+                       let last_block_hash = *self.last_block_hash.read().unwrap();
                        let mut channel_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_lock;
                        match channel_state.by_id.entry(msg.channel_id) {
@@ -2518,7 +2520,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        if chan.get().get_counterparty_node_id() != *counterparty_node_id {
                                                return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a message for a channel from the wrong node!".to_owned(), msg.channel_id));
                                        }
-                                       let last_block_hash = *self.last_block_hash.lock().unwrap();
                                        let monitor = match chan.get_mut().funding_signed(&msg, last_block_hash, &self.logger) {
                                                Ok(update) => update,
                                                Err(e) => try_chan_entry!(self, Err(e), channel_state, chan),
@@ -3255,9 +3256,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                // Note that we MUST NOT end up calling methods on self.chain_monitor here - we're called
                // during initialization prior to the chain_monitor being fully configured in some cases.
                // See the docs for `ChannelManagerReadArgs` for more.
-               let header_hash = header.block_hash();
-               log_trace!(self.logger, "Block {} at height {} connected", header_hash, height);
+               let block_hash = header.block_hash();
+               log_trace!(self.logger, "Block {} at height {} connected", block_hash, height);
+
                let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier);
+
+               self.latest_block_height.store(height as usize, Ordering::Release);
+               *self.last_block_hash.write().unwrap() = block_hash;
+
                let mut failed_channels = Vec::new();
                let mut timed_out_htlcs = Vec::new();
                {
@@ -3346,8 +3352,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                for (source, payment_hash, reason) in timed_out_htlcs.drain(..) {
                        self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), source, &payment_hash, reason);
                }
-               self.latest_block_height.store(height as usize, Ordering::Release);
-               *self.last_block_hash.try_lock().expect("block_(dis)connected must not be called in parallel") = header_hash;
+
                loop {
                        // Update last_node_announcement_serial to be the max of its current value and the
                        // block timestamp. This should keep us close to the current time without relying on
@@ -3371,6 +3376,10 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                // during initialization prior to the chain_monitor being fully configured in some cases.
                // See the docs for `ChannelManagerReadArgs` for more.
                let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier);
+
+               self.latest_block_height.fetch_sub(1, Ordering::AcqRel);
+               *self.last_block_hash.write().unwrap() = header.prev_blockhash;
+
                let mut failed_channels = Vec::new();
                {
                        let mut channel_lock = self.channel_state.lock().unwrap();
@@ -3394,23 +3403,24 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                }
                        });
                }
+
                self.handle_init_event_channel_failures(failed_channels);
-               self.latest_block_height.fetch_sub(1, Ordering::AcqRel);
-               *self.last_block_hash.try_lock().expect("block_(dis)connected must not be called in parallel") = header.block_hash();
        }
 
        /// Blocks until ChannelManager needs to be persisted or a timeout is reached. It returns a bool
-       /// indicating whether persistence is necessary. Only one listener on `wait_timeout` is
-       /// guaranteed to be woken up.
+       /// indicating whether persistence is necessary. Only one listener on
+       /// `await_persistable_update` or `await_persistable_update_timeout` is guaranteed to be woken
+       /// up.
        /// Note that the feature `allow_wallclock_use` must be enabled to use this function.
        #[cfg(any(test, feature = "allow_wallclock_use"))]
-       pub fn wait_timeout(&self, max_wait: Duration) -> bool {
+       pub fn await_persistable_update_timeout(&self, max_wait: Duration) -> bool {
                self.persistence_notifier.wait_timeout(max_wait)
        }
 
-       /// Blocks until ChannelManager needs to be persisted. Only one listener on `wait` is
-       /// guaranteed to be woken up.
-       pub fn wait(&self) {
+       /// Blocks until ChannelManager needs to be persisted. Only one listener on
+       /// `await_persistable_update` or `await_persistable_update_timeout` is guaranteed to be woken
+       /// up.
+       pub fn await_persistable_update(&self) {
                self.persistence_notifier.wait()
        }
 
@@ -3662,7 +3672,7 @@ impl<Signer: Sign, M: Deref + Sync + Send, T: Deref + Sync + Send, K: Deref + Sy
 }
 
 /// Used to signal to the ChannelManager persister that the manager needs to be re-persisted to
-/// disk/backups, through `wait_timeout` and `wait`.
+/// disk/backups, through `await_persistable_update_timeout` and `await_persistable_update`.
 struct PersistenceNotifier {
        /// Users won't access the persistence_lock directly, but rather wait on its bool using
        /// `wait_timeout` and `wait`.
@@ -3952,7 +3962,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
 
                self.genesis_hash.write(writer)?;
                (self.latest_block_height.load(Ordering::Acquire) as u32).write(writer)?;
-               self.last_block_hash.lock().unwrap().write(writer)?;
+               self.last_block_hash.read().unwrap().write(writer)?;
 
                let channel_state = self.channel_state.lock().unwrap();
                let mut unfunded_channels = 0;
@@ -4254,7 +4264,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                        tx_broadcaster: args.tx_broadcaster,
 
                        latest_block_height: AtomicUsize::new(latest_block_height as usize),
-                       last_block_hash: Mutex::new(last_block_hash),
+                       last_block_hash: RwLock::new(last_block_hash),
                        secp_ctx,
 
                        channel_state: Mutex::new(ChannelHolder {