]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Test if a given mutex is locked by the current thread in tests
authorMatt Corallo <git@bluematt.me>
Tue, 7 Feb 2023 19:46:08 +0000 (19:46 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 16 Feb 2023 21:35:23 +0000 (21:35 +0000)
In anticipation of the next commit(s) adding threaded tests, we
need to ensure our lockorder checks work fine with multiple
threads. Sadly, currently we have tests in the form
`assert!(mutex.try_lock().is_ok())` to assert that a given mutex is
not locked by the caller to a function.

The fix is rather simple given we already track mutexes locked by a
thread in our `debug_sync` logic - simply replace the check with a
new extension trait which (for test builds) checks the locked state
by only looking at what was locked by the current thread.

lightning/src/ln/channelmanager.rs
lightning/src/sync/debug_sync.rs
lightning/src/sync/fairrwlock.rs
lightning/src/sync/mod.rs
lightning/src/sync/nostd_sync.rs
lightning/src/sync/test_lockorder_checks.rs

index a042d2f95eedaf7ed37579f67a4de29904f3b47d..2ced6c9486f594b8e8e21ab11892b3aed13accf5 100644 (file)
@@ -70,7 +70,7 @@ use crate::prelude::*;
 use core::{cmp, mem};
 use core::cell::RefCell;
 use crate::io::Read;
-use crate::sync::{Arc, Mutex, RwLock, RwLockReadGuard, FairRwLock};
+use crate::sync::{Arc, Mutex, RwLock, RwLockReadGuard, FairRwLock, LockTestExt, LockHeldState};
 use core::sync::atomic::{AtomicUsize, Ordering};
 use core::time::Duration;
 use core::ops::Deref;
@@ -1218,13 +1218,10 @@ macro_rules! handle_error {
                match $internal {
                        Ok(msg) => Ok(msg),
                        Err(MsgHandleErrInternal { err, chan_id, shutdown_finish }) => {
-                               #[cfg(any(feature = "_test_utils", test))]
-                               {
-                                       // In testing, ensure there are no deadlocks where the lock is already held upon
-                                       // entering the macro.
-                                       debug_assert!($self.pending_events.try_lock().is_ok());
-                                       debug_assert!($self.per_peer_state.try_write().is_ok());
-                               }
+                               // In testing, ensure there are no deadlocks where the lock is already held upon
+                               // entering the macro.
+                               debug_assert_ne!($self.pending_events.held_by_thread(), LockHeldState::HeldByThread);
+                               debug_assert_ne!($self.per_peer_state.held_by_thread(), LockHeldState::HeldByThread);
 
                                let mut msg_events = Vec::with_capacity(2);
 
@@ -3743,17 +3740,12 @@ where
        /// Fails an HTLC backwards to the sender of it to us.
        /// Note that we do not assume that channels corresponding to failed HTLCs are still available.
        fn fail_htlc_backwards_internal(&self, source: &HTLCSource, payment_hash: &PaymentHash, onion_error: &HTLCFailReason, destination: HTLCDestination) {
-               #[cfg(any(feature = "_test_utils", test))]
-               {
-                       // Ensure that the peer state channel storage lock is not held when calling this
-                       // function.
-                       // This ensures that future code doesn't introduce a lock_order requirement for
-                       // `forward_htlcs` to be locked after the `per_peer_state` peer locks, which calling
-                       // this function with any `per_peer_state` peer lock aquired would.
-                       let per_peer_state = self.per_peer_state.read().unwrap();
-                       for (_, peer) in per_peer_state.iter() {
-                               debug_assert!(peer.try_lock().is_ok());
-                       }
+               // Ensure that no peer state channel storage lock is held when calling this function.
+               // This ensures that future code doesn't introduce a lock-order requirement for
+               // `forward_htlcs` to be locked after the `per_peer_state` peer locks, which calling
+               // this function with any `per_peer_state` peer lock acquired would.
+               for (_, peer) in self.per_peer_state.read().unwrap().iter() {
+                       debug_assert_ne!(peer.held_by_thread(), LockHeldState::HeldByThread);
                }
 
                //TODO: There is a timing attack here where if a node fails an HTLC back to us they can
index 9f7caa2c1804350818a5804f49fb0896c6d33f0f..5631093723733f16f4d7447c0865f58219d1cf40 100644 (file)
@@ -14,6 +14,8 @@ use std::sync::Condvar as StdCondvar;
 
 use crate::prelude::HashMap;
 
+use super::{LockTestExt, LockHeldState};
+
 #[cfg(feature = "backtrace")]
 use {crate::prelude::hash_map, backtrace::Backtrace, std::sync::Once};
 
@@ -168,6 +170,18 @@ impl LockMetadata {
        fn pre_lock(this: &Arc<LockMetadata>) { Self::_pre_lock(this, false); }
        fn pre_read_lock(this: &Arc<LockMetadata>) -> bool { Self::_pre_lock(this, true) }
 
+       fn held_by_thread(this: &Arc<LockMetadata>) -> LockHeldState {
+               let mut res = LockHeldState::NotHeldByThread;
+               LOCKS_HELD.with(|held| {
+                       for (locked_idx, _locked) in held.borrow().iter() {
+                               if *locked_idx == this.lock_idx {
+                                       res = LockHeldState::HeldByThread;
+                               }
+                       }
+               });
+               res
+       }
+
        fn try_locked(this: &Arc<LockMetadata>) {
                LOCKS_HELD.with(|held| {
                        // Since a try-lock will simply fail if the lock is held already, we do not
@@ -248,6 +262,13 @@ impl<T> Mutex<T> {
        }
 }
 
+impl <T> LockTestExt for Mutex<T> {
+       #[inline]
+       fn held_by_thread(&self) -> LockHeldState {
+               LockMetadata::held_by_thread(&self.deps)
+       }
+}
+
 pub struct RwLock<T: Sized> {
        inner: StdRwLock<T>,
        deps: Arc<LockMetadata>,
@@ -332,4 +353,11 @@ impl<T> RwLock<T> {
        }
 }
 
+impl <T> LockTestExt for RwLock<T> {
+       #[inline]
+       fn held_by_thread(&self) -> LockHeldState {
+               LockMetadata::held_by_thread(&self.deps)
+       }
+}
+
 pub type FairRwLock<T> = RwLock<T>;
index 5715a8cf646cd67e29b7aa5c21c8722d89a45964..a9519ac240cde1e24fb3dd09e37e36eddc552216 100644 (file)
@@ -1,5 +1,6 @@
 use std::sync::{LockResult, RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockResult};
 use std::sync::atomic::{AtomicUsize, Ordering};
+use super::{LockHeldState, LockTestExt};
 
 /// Rust libstd's RwLock does not provide any fairness guarantees (and, in fact, when used on
 /// Linux with pthreads under the hood, readers trivially and completely starve writers).
@@ -48,3 +49,11 @@ impl<T> FairRwLock<T> {
                self.lock.try_write()
        }
 }
+
+impl<T> LockTestExt for FairRwLock<T> {
+       #[inline]
+       fn held_by_thread(&self) -> LockHeldState {
+               // fairrwlock is only built in non-test modes, so we should never support tests.
+               LockHeldState::Unsupported
+       }
+}
index 5b0cc9c1886611bd0dc47a8d39188f36363b825a..50ef40e295f50d0a0d9095a0f3dfe4c73916a340 100644 (file)
@@ -1,3 +1,16 @@
+#[allow(dead_code)] // Depending on the compilation flags some variants are never used
+#[derive(Debug, PartialEq, Eq)]
+pub(crate) enum LockHeldState {
+       HeldByThread,
+       NotHeldByThread,
+       #[cfg(any(feature = "_bench_unstable", not(test)))]
+       Unsupported,
+}
+
+pub(crate) trait LockTestExt {
+       fn held_by_thread(&self) -> LockHeldState;
+}
+
 #[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))]
 mod debug_sync;
 #[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))]
@@ -11,6 +24,19 @@ pub(crate) mod fairrwlock;
 #[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))]
 pub use {std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}, fairrwlock::FairRwLock};
 
+#[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))]
+mod ext_impl {
+       use super::*;
+       impl<T> LockTestExt for Mutex<T> {
+               #[inline]
+               fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
+       }
+       impl<T> LockTestExt for RwLock<T> {
+               #[inline]
+               fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
+       }
+}
+
 #[cfg(not(feature = "std"))]
 mod nostd_sync;
 #[cfg(not(feature = "std"))]
index caf88a7cc04a8617a86e63988a7cc7a291fe96bc..e17aa6ab15faa5dc33b66dbc1b80ab79fd36cb18 100644 (file)
@@ -2,6 +2,7 @@ pub use ::alloc::sync::Arc;
 use core::ops::{Deref, DerefMut};
 use core::time::Duration;
 use core::cell::{RefCell, Ref, RefMut};
+use super::{LockTestExt, LockHeldState};
 
 pub type LockResult<Guard> = Result<Guard, ()>;
 
@@ -61,6 +62,14 @@ impl<T> Mutex<T> {
        }
 }
 
+impl<T> LockTestExt for Mutex<T> {
+       #[inline]
+       fn held_by_thread(&self) -> LockHeldState {
+               if self.lock().is_err() { return LockHeldState::HeldByThread; }
+               else { return LockHeldState::NotHeldByThread; }
+       }
+}
+
 pub struct RwLock<T: ?Sized> {
        inner: RefCell<T>
 }
@@ -116,4 +125,12 @@ impl<T> RwLock<T> {
        }
 }
 
+impl<T> LockTestExt for RwLock<T> {
+       #[inline]
+       fn held_by_thread(&self) -> LockHeldState {
+               if self.write().is_err() { return LockHeldState::HeldByThread; }
+               else { return LockHeldState::NotHeldByThread; }
+       }
+}
+
 pub type FairRwLock<T> = RwLock<T>;
index f9f30e2cfa28efb4d42960d3fb4c1afc713b8933..a3f746b11dc80dd32bd7badf5f9786acf6426a89 100644 (file)
@@ -1,5 +1,10 @@
 use crate::sync::debug_sync::{RwLock, Mutex};
 
+use super::{LockHeldState, LockTestExt};
+
+use std::sync::Arc;
+use std::thread;
+
 #[test]
 #[should_panic]
 #[cfg(not(feature = "backtrace"))]
@@ -92,3 +97,22 @@ fn read_write_lockorder_fail() {
                let _a = a.write().unwrap();
        }
 }
+
+#[test]
+fn test_thread_locked_state() {
+       let mtx = Arc::new(Mutex::new(()));
+       let mtx_ref = Arc::clone(&mtx);
+       assert_eq!(mtx.held_by_thread(), LockHeldState::NotHeldByThread);
+
+       let lck = mtx.lock().unwrap();
+       assert_eq!(mtx.held_by_thread(), LockHeldState::HeldByThread);
+
+       let thrd = std::thread::spawn(move || {
+               assert_eq!(mtx_ref.held_by_thread(), LockHeldState::NotHeldByThread);
+       });
+       thrd.join().unwrap();
+       assert_eq!(mtx.held_by_thread(), LockHeldState::HeldByThread);
+
+       std::mem::drop(lck);
+       assert_eq!(mtx.held_by_thread(), LockHeldState::NotHeldByThread);
+}