]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Use a total lockorder for `NetworkGraph`'s `PartialEq` impl
authorDuncan Dean <git@dunxen.dev>
Tue, 9 May 2023 09:44:48 +0000 (11:44 +0200)
committerDuncan Dean <git@dunxen.dev>
Tue, 23 May 2023 20:51:22 +0000 (22:51 +0200)
`NetworkGraph`'s `PartialEq` impl before this commit was deadlock-prone.
Similarly to `ChannelMonitor`'s, `PartialEq` impl, we use position in
memory for a total lockorder. This uses the assumption that the objects
cannot move within memory while the inner locks are held.

lightning/src/routing/gossip.rs

index d11eaa6bc36d4221c481d544aa31b436c13f39d6..f134bd0569e9905cf7a54b5359215c6b434ae4d0 100644 (file)
@@ -40,7 +40,7 @@ use crate::io_extras::{copy, sink};
 use crate::prelude::*;
 use core::{cmp, fmt};
 use core::convert::TryFrom;
-use crate::sync::{RwLock, RwLockReadGuard};
+use crate::sync::{RwLock, RwLockReadGuard, LockTestExt};
 #[cfg(feature = "std")]
 use core::sync::atomic::{AtomicUsize, Ordering};
 use crate::sync::Mutex;
@@ -1327,9 +1327,14 @@ impl<L: Deref> fmt::Display for NetworkGraph<L> where L::Target: Logger {
 impl<L: Deref> Eq for NetworkGraph<L> where L::Target: Logger {}
 impl<L: Deref> PartialEq for NetworkGraph<L> where L::Target: Logger {
        fn eq(&self, other: &Self) -> bool {
-               self.genesis_hash == other.genesis_hash &&
-                       *self.channels.read().unwrap() == *other.channels.read().unwrap() &&
-                       *self.nodes.read().unwrap() == *other.nodes.read().unwrap()
+               // For a total lockorder, sort by position in memory and take the inner locks in that order.
+               // (Assumes that we can't move within memory while a lock is held).
+               let ord = ((self as *const _) as usize) < ((other as *const _) as usize);
+               let a = if ord { (&self.channels, &self.nodes) } else { (&other.channels, &other.nodes) };
+               let b = if ord { (&other.channels, &other.nodes) } else { (&self.channels, &self.nodes) };
+               let (channels_a, channels_b) = (a.0.unsafe_well_ordered_double_lock_self(), b.0.unsafe_well_ordered_double_lock_self());
+               let (nodes_a, nodes_b) = (a.1.unsafe_well_ordered_double_lock_self(), b.1.unsafe_well_ordered_double_lock_self());
+               self.genesis_hash.eq(&other.genesis_hash) && channels_a.eq(&channels_b) && nodes_a.eq(&nodes_b)
        }
 }