]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Implement lockorder checking on RwLocks in debug_sync
authorMatt Corallo <git@bluematt.me>
Wed, 9 Mar 2022 06:24:06 +0000 (06:24 +0000)
committerMatt Corallo <git@bluematt.me>
Fri, 18 Mar 2022 18:54:27 +0000 (18:54 +0000)
lightning/src/debug_sync.rs

index 3eb342af2288c20a32081263d39262346fd81a23..04fc86bbbc758fa54366ef3217103d901fe0d84c 100644 (file)
@@ -74,12 +74,23 @@ impl MutexMetadata {
                }
        }
 
-       fn pre_lock(this: &Arc<MutexMetadata>) {
+       // Returns whether we were a recursive lock (only relevant for read)
+       fn _pre_lock(this: &Arc<MutexMetadata>, read: bool) -> bool {
+               let mut inserted = false;
                MUTEXES_HELD.with(|held| {
                        // For each mutex which is currently locked, check that no mutex's locked-before
                        // set includes the mutex we're about to lock, which would imply a lockorder
                        // inversion.
                        for locked in held.borrow().iter() {
+                               if read && *locked == *this {
+                                       // Recursive read locks are explicitly allowed
+                                       return;
+                               }
+                       }
+                       for locked in held.borrow().iter() {
+                               if !read && *locked == *this {
+                                       panic!("Tried to lock a mutex while it was held!");
+                               }
                                for locked_dep in locked.locked_before.lock().unwrap().iter() {
                                        if *locked_dep == *this {
                                                #[cfg(feature = "backtrace")]
@@ -92,9 +103,14 @@ impl MutexMetadata {
                                this.locked_before.lock().unwrap().insert(Arc::clone(locked));
                        }
                        held.borrow_mut().insert(Arc::clone(this));
+                       inserted = true;
                });
+               inserted
        }
 
+       fn pre_lock(this: &Arc<MutexMetadata>) { Self::_pre_lock(this, false); }
+       fn pre_read_lock(this: &Arc<MutexMetadata>) -> bool { Self::_pre_lock(this, true) }
+
        fn try_locked(this: &Arc<MutexMetadata>) {
                MUTEXES_HELD.with(|held| {
                        // Since a try-lock will simply fail if the lock is held already, we do not
@@ -171,19 +187,23 @@ impl<T> Mutex<T> {
        }
 }
 
-pub struct RwLock<T: ?Sized> {
-       inner: StdRwLock<T>
+pub struct RwLock<T: Sized> {
+       inner: StdRwLock<T>,
+       deps: Arc<MutexMetadata>,
 }
 
-pub struct RwLockReadGuard<'a, T: ?Sized + 'a> {
+pub struct RwLockReadGuard<'a, T: Sized + 'a> {
+       mutex: &'a RwLock<T>,
+       first_lock: bool,
        lock: StdRwLockReadGuard<'a, T>,
 }
 
-pub struct RwLockWriteGuard<'a, T: ?Sized + 'a> {
+pub struct RwLockWriteGuard<'a, T: Sized + 'a> {
+       mutex: &'a RwLock<T>,
        lock: StdRwLockWriteGuard<'a, T>,
 }
 
-impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
+impl<T: Sized> Deref for RwLockReadGuard<'_, T> {
        type Target = T;
 
        fn deref(&self) -> &T {
@@ -191,7 +211,21 @@ impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
        }
 }
 
-impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
+impl<T: Sized> Drop for RwLockReadGuard<'_, T> {
+       fn drop(&mut self) {
+               if !self.first_lock {
+                       // Note that its not strictly true that the first taken read lock will get unlocked
+                       // last, but in practice our locks are always taken as RAII, so it should basically
+                       // always be true.
+                       return;
+               }
+               MUTEXES_HELD.with(|held| {
+                       held.borrow_mut().remove(&self.mutex.deps);
+               });
+       }
+}
+
+impl<T: Sized> Deref for RwLockWriteGuard<'_, T> {
        type Target = T;
 
        fn deref(&self) -> &T {
@@ -199,7 +233,15 @@ impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
        }
 }
 
-impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
+impl<T: Sized> Drop for RwLockWriteGuard<'_, T> {
+       fn drop(&mut self) {
+               MUTEXES_HELD.with(|held| {
+                       held.borrow_mut().remove(&self.mutex.deps);
+               });
+       }
+}
+
+impl<T: Sized> DerefMut for RwLockWriteGuard<'_, T> {
        fn deref_mut(&mut self) -> &mut T {
                self.lock.deref_mut()
        }
@@ -207,18 +249,116 @@ impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
 
 impl<T> RwLock<T> {
        pub fn new(inner: T) -> RwLock<T> {
-               RwLock { inner: StdRwLock::new(inner) }
+               RwLock { inner: StdRwLock::new(inner), deps: Arc::new(MutexMetadata::new()) }
        }
 
        pub fn read<'a>(&'a self) -> LockResult<RwLockReadGuard<'a, T>> {
-               self.inner.read().map(|lock| RwLockReadGuard { lock }).map_err(|_| ())
+               let first_lock = MutexMetadata::pre_read_lock(&self.deps);
+               self.inner.read().map(|lock| RwLockReadGuard { mutex: self, lock, first_lock }).map_err(|_| ())
        }
 
        pub fn write<'a>(&'a self) -> LockResult<RwLockWriteGuard<'a, T>> {
-               self.inner.write().map(|lock| RwLockWriteGuard { lock }).map_err(|_| ())
+               MutexMetadata::pre_lock(&self.deps);
+               self.inner.write().map(|lock| RwLockWriteGuard { mutex: self, lock }).map_err(|_| ())
        }
 
        pub fn try_write<'a>(&'a self) -> LockResult<RwLockWriteGuard<'a, T>> {
-               self.inner.try_write().map(|lock| RwLockWriteGuard { lock }).map_err(|_| ())
+               let res = self.inner.try_write().map(|lock| RwLockWriteGuard { mutex: self, lock }).map_err(|_| ());
+               if res.is_ok() {
+                       MutexMetadata::try_locked(&self.deps);
+               }
+               res
+       }
+}
+
+#[test]
+#[should_panic]
+fn recursive_lock_fail() {
+       let mutex = Mutex::new(());
+       let _a = mutex.lock().unwrap();
+       let _b = mutex.lock().unwrap();
+}
+
+#[test]
+fn recursive_read() {
+       let lock = RwLock::new(());
+       let _a = lock.read().unwrap();
+       let _b = lock.read().unwrap();
+}
+
+#[test]
+#[should_panic]
+fn lockorder_fail() {
+       let a = Mutex::new(());
+       let b = Mutex::new(());
+       {
+               let _a = a.lock().unwrap();
+               let _b = b.lock().unwrap();
+       }
+       {
+               let _b = b.lock().unwrap();
+               let _a = a.lock().unwrap();
+       }
+}
+
+#[test]
+#[should_panic]
+fn write_lockorder_fail() {
+       let a = RwLock::new(());
+       let b = RwLock::new(());
+       {
+               let _a = a.write().unwrap();
+               let _b = b.write().unwrap();
+       }
+       {
+               let _b = b.write().unwrap();
+               let _a = a.write().unwrap();
+       }
+}
+
+#[test]
+#[should_panic]
+fn read_lockorder_fail() {
+       let a = RwLock::new(());
+       let b = RwLock::new(());
+       {
+               let _a = a.read().unwrap();
+               let _b = b.read().unwrap();
+       }
+       {
+               let _b = b.read().unwrap();
+               let _a = a.read().unwrap();
+       }
+}
+
+#[test]
+fn read_recurisve_no_lockorder() {
+       // Like the above, but note that no lockorder is implied when we recursively read-lock a
+       // RwLock, causing this to pass just fine.
+       let a = RwLock::new(());
+       let b = RwLock::new(());
+       let _outer = a.read().unwrap();
+       {
+               let _a = a.read().unwrap();
+               let _b = b.read().unwrap();
+       }
+       {
+               let _b = b.read().unwrap();
+               let _a = a.read().unwrap();
+       }
+}
+
+#[test]
+#[should_panic]
+fn read_write_lockorder_fail() {
+       let a = RwLock::new(());
+       let b = RwLock::new(());
+       {
+               let _a = a.write().unwrap();
+               let _b = b.read().unwrap();
+       }
+       {
+               let _b = b.read().unwrap();
+               let _a = a.write().unwrap();
        }
 }