use core::{cmp, mem};
use core::cell::RefCell;
use crate::io::Read;
-use crate::sync::{Arc, Mutex, RwLock, RwLockReadGuard, FairRwLock};
+use crate::sync::{Arc, Mutex, RwLock, RwLockReadGuard, FairRwLock, LockTestExt, LockHeldState};
use core::sync::atomic::{AtomicUsize, Ordering};
use core::time::Duration;
use core::ops::Deref;
match $internal {
Ok(msg) => Ok(msg),
Err(MsgHandleErrInternal { err, chan_id, shutdown_finish }) => {
- #[cfg(any(feature = "_test_utils", test))]
- {
- // In testing, ensure there are no deadlocks where the lock is already held upon
- // entering the macro.
- debug_assert!($self.pending_events.try_lock().is_ok());
- debug_assert!($self.per_peer_state.try_write().is_ok());
- }
+ // In testing, ensure there are no deadlocks where the lock is already held upon
+ // entering the macro.
+ debug_assert_ne!($self.pending_events.held_by_thread(), LockHeldState::HeldByThread);
+ debug_assert_ne!($self.per_peer_state.held_by_thread(), LockHeldState::HeldByThread);
let mut msg_events = Vec::with_capacity(2);
/// Fails an HTLC backwards to the sender of it to us.
/// Note that we do not assume that channels corresponding to failed HTLCs are still available.
fn fail_htlc_backwards_internal(&self, source: &HTLCSource, payment_hash: &PaymentHash, onion_error: &HTLCFailReason, destination: HTLCDestination) {
- #[cfg(any(feature = "_test_utils", test))]
- {
- // Ensure that the peer state channel storage lock is not held when calling this
- // function.
- // This ensures that future code doesn't introduce a lock_order requirement for
- // `forward_htlcs` to be locked after the `per_peer_state` peer locks, which calling
- // this function with any `per_peer_state` peer lock aquired would.
- let per_peer_state = self.per_peer_state.read().unwrap();
- for (_, peer) in per_peer_state.iter() {
- debug_assert!(peer.try_lock().is_ok());
- }
+ // Ensure that no peer state channel storage lock is held when calling this function.
+ // This ensures that future code doesn't introduce a lock-order requirement for
+ // `forward_htlcs` to be locked after the `per_peer_state` peer locks, which calling
+ // this function with any `per_peer_state` peer lock acquired would.
+ for (_, peer) in self.per_peer_state.read().unwrap().iter() {
+ debug_assert_ne!(peer.held_by_thread(), LockHeldState::HeldByThread);
}
//TODO: There is a timing attack here where if a node fails an HTLC back to us they can
use crate::prelude::HashMap;
+use super::{LockTestExt, LockHeldState};
+
#[cfg(feature = "backtrace")]
use {crate::prelude::hash_map, backtrace::Backtrace, std::sync::Once};
fn pre_lock(this: &Arc<LockMetadata>) { Self::_pre_lock(this, false); }
fn pre_read_lock(this: &Arc<LockMetadata>) -> bool { Self::_pre_lock(this, true) }
+ fn held_by_thread(this: &Arc<LockMetadata>) -> LockHeldState {
+ let mut res = LockHeldState::NotHeldByThread;
+ LOCKS_HELD.with(|held| {
+ for (locked_idx, _locked) in held.borrow().iter() {
+ if *locked_idx == this.lock_idx {
+ res = LockHeldState::HeldByThread;
+ }
+ }
+ });
+ res
+ }
+
fn try_locked(this: &Arc<LockMetadata>) {
LOCKS_HELD.with(|held| {
// Since a try-lock will simply fail if the lock is held already, we do not
}
}
+impl <T> LockTestExt for Mutex<T> {
+ #[inline]
+ fn held_by_thread(&self) -> LockHeldState {
+ LockMetadata::held_by_thread(&self.deps)
+ }
+}
+
pub struct RwLock<T: Sized> {
inner: StdRwLock<T>,
deps: Arc<LockMetadata>,
}
}
+impl <T> LockTestExt for RwLock<T> {
+ #[inline]
+ fn held_by_thread(&self) -> LockHeldState {
+ LockMetadata::held_by_thread(&self.deps)
+ }
+}
+
pub type FairRwLock<T> = RwLock<T>;
use std::sync::{LockResult, RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockResult};
use std::sync::atomic::{AtomicUsize, Ordering};
+use super::{LockHeldState, LockTestExt};
/// Rust libstd's RwLock does not provide any fairness guarantees (and, in fact, when used on
/// Linux with pthreads under the hood, readers trivially and completely starve writers).
self.lock.try_write()
}
}
+
+impl<T> LockTestExt for FairRwLock<T> {
+ #[inline]
+ fn held_by_thread(&self) -> LockHeldState {
+ // fairrwlock is only built in non-test modes, so we should never support tests.
+ LockHeldState::Unsupported
+ }
+}
+#[allow(dead_code)] // Depending on the compilation flags some variants are never used
+#[derive(Debug, PartialEq, Eq)]
+pub(crate) enum LockHeldState {
+ HeldByThread,
+ NotHeldByThread,
+ #[cfg(any(feature = "_bench_unstable", not(test)))]
+ Unsupported,
+}
+
+pub(crate) trait LockTestExt {
+ fn held_by_thread(&self) -> LockHeldState;
+}
+
#[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))]
mod debug_sync;
#[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))]
#[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))]
pub use {std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard}, fairrwlock::FairRwLock};
+#[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))]
+mod ext_impl {
+ use super::*;
+ impl<T> LockTestExt for Mutex<T> {
+ #[inline]
+ fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
+ }
+ impl<T> LockTestExt for RwLock<T> {
+ #[inline]
+ fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported }
+ }
+}
+
#[cfg(not(feature = "std"))]
mod nostd_sync;
#[cfg(not(feature = "std"))]
use core::ops::{Deref, DerefMut};
use core::time::Duration;
use core::cell::{RefCell, Ref, RefMut};
+use super::{LockTestExt, LockHeldState};
pub type LockResult<Guard> = Result<Guard, ()>;
}
}
+impl<T> LockTestExt for Mutex<T> {
+ #[inline]
+ fn held_by_thread(&self) -> LockHeldState {
+ if self.lock().is_err() { return LockHeldState::HeldByThread; }
+ else { return LockHeldState::NotHeldByThread; }
+ }
+}
+
pub struct RwLock<T: ?Sized> {
inner: RefCell<T>
}
}
}
+impl<T> LockTestExt for RwLock<T> {
+ #[inline]
+ fn held_by_thread(&self) -> LockHeldState {
+ if self.write().is_err() { return LockHeldState::HeldByThread; }
+ else { return LockHeldState::NotHeldByThread; }
+ }
+}
+
pub type FairRwLock<T> = RwLock<T>;
use crate::sync::debug_sync::{RwLock, Mutex};
+use super::{LockHeldState, LockTestExt};
+
+use std::sync::Arc;
+use std::thread;
+
#[test]
#[should_panic]
#[cfg(not(feature = "backtrace"))]
let _a = a.write().unwrap();
}
}
+
+#[test]
+fn test_thread_locked_state() {
+ let mtx = Arc::new(Mutex::new(()));
+ let mtx_ref = Arc::clone(&mtx);
+ assert_eq!(mtx.held_by_thread(), LockHeldState::NotHeldByThread);
+
+ let lck = mtx.lock().unwrap();
+ assert_eq!(mtx.held_by_thread(), LockHeldState::HeldByThread);
+
+ let thrd = std::thread::spawn(move || {
+ assert_eq!(mtx_ref.held_by_thread(), LockHeldState::NotHeldByThread);
+ });
+ thrd.join().unwrap();
+ assert_eq!(mtx.held_by_thread(), LockHeldState::HeldByThread);
+
+ std::mem::drop(lck);
+ assert_eq!(mtx.held_by_thread(), LockHeldState::NotHeldByThread);
+}