Disallow taking two instances of the same mutex at the same time
authorMatt Corallo <git@bluematt.me>
Wed, 22 Feb 2023 22:54:38 +0000 (22:54 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 28 Feb 2023 01:06:35 +0000 (01:06 +0000)
Taking two instances of the same mutex may be totally fine, but it
requires a total lockorder that we cannot (trivially) check. Thus,
its generally unsafe to do if we can avoid it.

To discourage doing this, here we default to panicing on such locks
in our lockorder tests, with a separate lock function added that is
clearly labeled "unsafe" to allow doing so when we can guarantee a
total lockorder.

This requires adapting a number of sites to the new API, including
fixing a bug this turned up in `ChannelMonitor`'s `PartialEq` where
no lockorder was guaranteed.

lightning/src/chain/channelmonitor.rs
lightning/src/ln/chanmon_update_fail_tests.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs
lightning/src/ln/functional_tests.rs
lightning/src/routing/utxo.rs
lightning/src/sync/debug_sync.rs
lightning/src/sync/fairrwlock.rs
lightning/src/sync/mod.rs
lightning/src/sync/nostd_sync.rs

index 8b281db6030dba2f681720a91eba819918d20d3f..aaf78fdf1228a2d071b3d2e8f2ec01c5fe8f9fb6 100644 (file)
@@ -60,7 +60,7 @@ use core::{cmp, mem};
 use crate::io::{self, Error};
 use core::convert::TryInto;
 use core::ops::Deref;
-use crate::sync::Mutex;
+use crate::sync::{Mutex, LockTestExt};
 
 /// An update generated by the underlying channel itself which contains some new information the
 /// [`ChannelMonitor`] should be made aware of.
@@ -851,9 +851,13 @@ pub type TransactionOutputs = (Txid, Vec<(u32, TxOut)>);
 
 impl<Signer: WriteableEcdsaChannelSigner> PartialEq for ChannelMonitor<Signer> where Signer: PartialEq {
        fn eq(&self, other: &Self) -> bool {
-               let inner = self.inner.lock().unwrap();
-               let other = other.inner.lock().unwrap();
-               inner.eq(&other)
+               // We need some kind of total lockorder. Absent a better idea, we sort by position in
+               // memory and take locks in that order (assuming that we can't move within memory while a
+               // lock is held).
+               let ord = ((self as *const _) as usize) < ((other as *const _) as usize);
+               let a = if ord { self.inner.unsafe_well_ordered_double_lock_self() } else { other.inner.unsafe_well_ordered_double_lock_self() };
+               let b = if ord { other.inner.unsafe_well_ordered_double_lock_self() } else { self.inner.unsafe_well_ordered_double_lock_self() };
+               a.eq(&b)
        }
 }
 
index 4f174c226b313390847ba8ad4de6210c31af8647..18e461f61a3c0da0ca74c8579542e7571bc39bf2 100644 (file)
@@ -108,12 +108,13 @@ fn test_monitor_and_persister_update_fail() {
                blocks: Arc::new(Mutex::new(vec![(genesis_block(Network::Testnet), 200); 200])),
        };
        let chain_mon = {
-               let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
-               let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write(&mut w).unwrap();
-               let new_monitor = <(BlockHash, ChannelMonitor<EnforcingSigner>)>::read(
-                       &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
-               assert!(new_monitor == *monitor);
+               let new_monitor = {
+                       let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
+                       let new_monitor = <(BlockHash, ChannelMonitor<EnforcingSigner>)>::read(
+                               &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
+                       assert!(new_monitor == *monitor);
+                       new_monitor
+               };
                let chain_mon = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
                assert_eq!(chain_mon.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
                chain_mon
index 5e0b56d05ab51217c241a9a3d887adae481e5b16..e7e3acddae1b213ecf7bced34e6c5fd4b9e7b0d1 100644 (file)
@@ -6941,7 +6941,10 @@ where
                let mut monitor_update_blocked_actions_per_peer = None;
                let mut peer_states = Vec::new();
                for (_, peer_state_mutex) in per_peer_state.iter() {
-                       peer_states.push(peer_state_mutex.lock().unwrap());
+                       // Because we're holding the owning `per_peer_state` write lock here there's no chance
+                       // of a lockorder violation deadlock - no other thread can be holding any
+                       // per_peer_state lock at all.
+                       peer_states.push(peer_state_mutex.unsafe_well_ordered_double_lock_self());
                }
 
                (serializable_peer_count).write(writer)?;
@@ -8280,10 +8283,10 @@ mod tests {
                        let nodes_0_lock = nodes[0].node.id_to_peer.lock().unwrap();
                        assert_eq!(nodes_0_lock.len(), 1);
                        assert!(nodes_0_lock.contains_key(channel_id));
-
-                       assert_eq!(nodes[1].node.id_to_peer.lock().unwrap().len(), 0);
                }
 
+               assert_eq!(nodes[1].node.id_to_peer.lock().unwrap().len(), 0);
+
                let funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id());
 
                nodes[1].node.handle_funding_created(&nodes[0].node.get_our_node_id(), &funding_created_msg);
@@ -8291,7 +8294,9 @@ mod tests {
                        let nodes_0_lock = nodes[0].node.id_to_peer.lock().unwrap();
                        assert_eq!(nodes_0_lock.len(), 1);
                        assert!(nodes_0_lock.contains_key(channel_id));
+               }
 
+               {
                        // Assert that `nodes[1]`'s `id_to_peer` map is populated with the channel as soon as
                        // as it has the funding transaction.
                        let nodes_1_lock = nodes[1].node.id_to_peer.lock().unwrap();
@@ -8321,7 +8326,9 @@ mod tests {
                        let nodes_0_lock = nodes[0].node.id_to_peer.lock().unwrap();
                        assert_eq!(nodes_0_lock.len(), 1);
                        assert!(nodes_0_lock.contains_key(channel_id));
+               }
 
+               {
                        // At this stage, `nodes[1]` has proposed a fee for the closing transaction in the
                        // `handle_closing_signed` call above. As `nodes[1]` has not yet received the signature
                        // from `nodes[0]` for the closing transaction with the proposed fee, the channel is
index 553cec5bedfa64169ff508623a321422f039fe53..96ce9312eed3acec30cb9c7e4db7a8114fa804eb 100644 (file)
@@ -44,7 +44,7 @@ use crate::io;
 use crate::prelude::*;
 use core::cell::RefCell;
 use alloc::rc::Rc;
-use crate::sync::{Arc, Mutex};
+use crate::sync::{Arc, Mutex, LockTestExt};
 use core::mem;
 use core::iter::repeat;
 use bitcoin::{PackedLockTime, TxMerkleNode};
@@ -466,8 +466,8 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                        panic!();
                                }
                        }
-                       assert_eq!(*chain_source.watched_txn.lock().unwrap(), *self.chain_source.watched_txn.lock().unwrap());
-                       assert_eq!(*chain_source.watched_outputs.lock().unwrap(), *self.chain_source.watched_outputs.lock().unwrap());
+                       assert_eq!(*chain_source.watched_txn.unsafe_well_ordered_double_lock_self(), *self.chain_source.watched_txn.unsafe_well_ordered_double_lock_self());
+                       assert_eq!(*chain_source.watched_outputs.unsafe_well_ordered_double_lock_self(), *self.chain_source.watched_outputs.unsafe_well_ordered_double_lock_self());
                }
        }
 }
index f48d6f90d1fd4d8fd5b426a5642cbb24272c6848..36f10f742be8c4c808e868d445a39f5b2f88994e 100644 (file)
@@ -8150,12 +8150,13 @@ fn test_update_err_monitor_lockdown() {
        let logger = test_utils::TestLogger::with_id(format!("node {}", 0));
        let persister = test_utils::TestPersister::new();
        let watchtower = {
-               let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
-               let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write(&mut w).unwrap();
-               let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
-                               &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
-               assert!(new_monitor == *monitor);
+               let new_monitor = {
+                       let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
+                       let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
+                                       &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
+                       assert!(new_monitor == *monitor);
+                       new_monitor
+               };
                let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
                assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
                watchtower
@@ -8217,12 +8218,13 @@ fn test_concurrent_monitor_claim() {
        let logger = test_utils::TestLogger::with_id(format!("node {}", "Alice"));
        let persister = test_utils::TestPersister::new();
        let watchtower_alice = {
-               let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
-               let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write(&mut w).unwrap();
-               let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
-                               &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
-               assert!(new_monitor == *monitor);
+               let new_monitor = {
+                       let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
+                       let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
+                                       &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
+                       assert!(new_monitor == *monitor);
+                       new_monitor
+               };
                let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
                assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
                watchtower
@@ -8246,12 +8248,13 @@ fn test_concurrent_monitor_claim() {
        let logger = test_utils::TestLogger::with_id(format!("node {}", "Bob"));
        let persister = test_utils::TestPersister::new();
        let watchtower_bob = {
-               let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
-               let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write(&mut w).unwrap();
-               let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
-                               &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
-               assert!(new_monitor == *monitor);
+               let new_monitor = {
+                       let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap();
+                       let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingSigner>)>::read(
+                                       &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1;
+                       assert!(new_monitor == *monitor);
+                       new_monitor
+               };
                let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager);
                assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed);
                watchtower
index 09e110c2dfeb853c7d10caee35032c3340b258a7..74abd4276432b0d0d048d6791fe0ffb648b4f54d 100644 (file)
@@ -26,7 +26,7 @@ use crate::util::ser::Writeable;
 use crate::prelude::*;
 
 use alloc::sync::{Arc, Weak};
-use crate::sync::Mutex;
+use crate::sync::{Mutex, LockTestExt};
 use core::ops::Deref;
 
 /// An error when accessing the chain via [`UtxoLookup`].
@@ -404,7 +404,10 @@ impl PendingChecks {
                                // lookup if we haven't gotten that far yet).
                                match Weak::upgrade(&e.get()) {
                                        Some(pending_msgs) => {
-                                               let pending_matches = match &pending_msgs.lock().unwrap().channel_announce {
+                                               // This may be called with the mutex held on a different UtxoMessages
+                                               // struct, however in that case we have a global lockorder of new messages
+                                               // -> old messages, which makes this safe.
+                                               let pending_matches = match &pending_msgs.unsafe_well_ordered_double_lock_self().channel_announce {
                                                        Some(ChannelAnnouncement::Full(pending_msg)) => Some(pending_msg) == full_msg,
                                                        Some(ChannelAnnouncement::Unsigned(pending_msg)) => pending_msg == msg,
                                                        None => {
index aa9f5fe9c17d19ca4ffe778faadd1b128123e6f0..721245811771fa56659aa0e9ba221ab880db9ca2 100644 (file)
@@ -124,17 +124,26 @@ impl LockMetadata {
                res
        }
 
-       fn pre_lock(this: &Arc<LockMetadata>) {
+       fn pre_lock(this: &Arc<LockMetadata>, _double_lock_self_allowed: bool) {
                LOCKS_HELD.with(|held| {
                        // For each lock which is currently locked, check that no lock's locked-before
                        // set includes the lock we're about to lock, which would imply a lockorder
                        // inversion.
                        for (locked_idx, locked) in held.borrow().iter() {
                                if *locked_idx == this.lock_idx {
-                                       // With `feature = "backtrace"` set, we may be looking at different instances
-                                       // of the same lock.
-                                       debug_assert!(cfg!(feature = "backtrace"), "Tried to acquire a lock while it was held!");
+                                       // Note that with `feature = "backtrace"` set, we may be looking at different
+                                       // instances of the same lock. Still, doing so is quite risky, a total order
+                                       // must be maintained, and doing so across a set of otherwise-identical mutexes
+                                       // is fraught with issues.
+                                       #[cfg(feature = "backtrace")]
+                                       debug_assert!(_double_lock_self_allowed,
+                                               "Tried to acquire a lock while it was held!\nLock constructed at {}",
+                                               get_construction_location(&this._lock_construction_bt));
+                                       #[cfg(not(feature = "backtrace"))]
+                                       panic!("Tried to acquire a lock while it was held!");
                                }
+                       }
+                       for (locked_idx, locked) in held.borrow().iter() {
                                for (locked_dep_idx, _locked_dep) in locked.locked_before.lock().unwrap().iter() {
                                        if *locked_dep_idx == this.lock_idx && *locked_dep_idx != locked.lock_idx {
                                                #[cfg(feature = "backtrace")]
@@ -236,7 +245,7 @@ impl<T> Mutex<T> {
        }
 
        pub fn lock<'a>(&'a self) -> LockResult<MutexGuard<'a, T>> {
-               LockMetadata::pre_lock(&self.deps);
+               LockMetadata::pre_lock(&self.deps, false);
                self.inner.lock().map(|lock| MutexGuard { mutex: self, lock }).map_err(|_| ())
        }
 
@@ -249,11 +258,17 @@ impl<T> Mutex<T> {
        }
 }
 
-impl <T> LockTestExt for Mutex<T> {
+impl<'a, T: 'a> LockTestExt<'a> for Mutex<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                LockMetadata::held_by_thread(&self.deps)
        }
+       type ExclLock = MutexGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard<T> {
+               LockMetadata::pre_lock(&self.deps, true);
+               self.inner.lock().map(|lock| MutexGuard { mutex: self, lock }).unwrap()
+       }
 }
 
 pub struct RwLock<T: Sized> {
@@ -317,13 +332,14 @@ impl<T> RwLock<T> {
        pub fn read<'a>(&'a self) -> LockResult<RwLockReadGuard<'a, T>> {
                // Note that while we could be taking a recursive read lock here, Rust's `RwLock` may
                // deadlock trying to take a second read lock if another thread is waiting on the write
-               // lock. Its platform dependent (but our in-tree `FairRwLock` guarantees this behavior).
-               LockMetadata::pre_lock(&self.deps);
+               // lock. This behavior is platform dependent, but our in-tree `FairRwLock` guarantees
+               // such a deadlock.
+               LockMetadata::pre_lock(&self.deps, false);
                self.inner.read().map(|guard| RwLockReadGuard { lock: self, guard }).map_err(|_| ())
        }
 
        pub fn write<'a>(&'a self) -> LockResult<RwLockWriteGuard<'a, T>> {
-               LockMetadata::pre_lock(&self.deps);
+               LockMetadata::pre_lock(&self.deps, false);
                self.inner.write().map(|guard| RwLockWriteGuard { lock: self, guard }).map_err(|_| ())
        }
 
@@ -336,11 +352,17 @@ impl<T> RwLock<T> {
        }
 }
 
-impl <T> LockTestExt for RwLock<T> {
+impl<'a, T: 'a> LockTestExt<'a> for RwLock<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                LockMetadata::held_by_thread(&self.deps)
        }
+       type ExclLock = RwLockWriteGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<'a, T> {
+               LockMetadata::pre_lock(&self.deps, true);
+               self.inner.write().map(|guard| RwLockWriteGuard { lock: self, guard }).unwrap()
+       }
 }
 
 pub type FairRwLock<T> = RwLock<T>;
index a9519ac240cde1e24fb3dd09e37e36eddc552216..de609d5b3d711059568daca1e9d408be80891321 100644 (file)
@@ -50,10 +50,15 @@ impl<T> FairRwLock<T> {
        }
 }
 
-impl<T> LockTestExt for FairRwLock<T> {
+impl<'a, T: 'a> LockTestExt<'a> for FairRwLock<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                // fairrwlock is only built in non-test modes, so we should never support tests.
                LockHeldState::Unsupported
        }
+       type ExclLock = RwLockWriteGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<'a, T> {
+               self.write().unwrap()
+       }
 }
index 50ef40e295f50d0a0d9095a0f3dfe4c73916a340..1b2b9a739b8c5d3078204917f55b0fef86e87f13 100644 (file)
@@ -7,8 +7,17 @@ pub(crate) enum LockHeldState {
        Unsupported,
 }
 
-pub(crate) trait LockTestExt {
+pub(crate) trait LockTestExt<'a> {
        fn held_by_thread(&self) -> LockHeldState;
+       type ExclLock;
+       /// If two instances of the same mutex are being taken at the same time, it's very easy to have
+       /// a lockorder inversion and risk deadlock. Thus, we default to disabling such locks.
+       ///
+       /// However, sometimes they cannot be avoided. In such cases, this method exists to take a
+       /// mutex while avoiding a test failure. It is deliberately verbose and includes the term
+       /// "unsafe" to indicate that special care needs to be taken to ensure no deadlocks are
+       /// possible.
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> Self::ExclLock;
 }
 
 #[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))]
@@ -27,13 +36,19 @@ pub use {std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, R
 #[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))]
 mod ext_impl {
        use super::*;
-       impl<T> LockTestExt for Mutex<T> {
+       impl<'a, T: 'a> LockTestExt<'a> for Mutex<T> {
                #[inline]
                fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
+               type ExclLock = MutexGuard<'a, T>;
+               #[inline]
+               fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard<T> { self.lock().unwrap() }
        }
-       impl<T> LockTestExt for RwLock<T> {
+       impl<'a, T: 'a> LockTestExt<'a> for RwLock<T> {
                #[inline]
                fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
+               type ExclLock = RwLockWriteGuard<'a, T>;
+               #[inline]
+               fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<T> { self.write().unwrap() }
        }
 }
 
index e17aa6ab15faa5dc33b66dbc1b80ab79fd36cb18..858f60db5b5b46a5bb703f2d454d8d2b2c1f2934 100644 (file)
@@ -62,12 +62,15 @@ impl<T> Mutex<T> {
        }
 }
 
-impl<T> LockTestExt for Mutex<T> {
+impl<'a, T: 'a> LockTestExt<'a> for Mutex<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                if self.lock().is_err() { return LockHeldState::HeldByThread; }
                else { return LockHeldState::NotHeldByThread; }
        }
+       type ExclLock = MutexGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard<T> { self.lock().unwrap() }
 }
 
 pub struct RwLock<T: ?Sized> {
@@ -125,12 +128,15 @@ impl<T> RwLock<T> {
        }
 }
 
-impl<T> LockTestExt for RwLock<T> {
+impl<'a, T: 'a> LockTestExt<'a> for RwLock<T> {
        #[inline]
        fn held_by_thread(&self) -> LockHeldState {
                if self.write().is_err() { return LockHeldState::HeldByThread; }
                else { return LockHeldState::NotHeldByThread; }
        }
+       type ExclLock = RwLockWriteGuard<'a, T>;
+       #[inline]
+       fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<T> { self.write().unwrap() }
 }
 
 pub type FairRwLock<T> = RwLock<T>;