From: Wilmer Paulino <9447167+wpaulino@users.noreply.github.com> Date: Tue, 28 Feb 2023 08:24:16 +0000 (-0800) Subject: Merge pull request #2006 from TheBlueMatt/2023-02-no-recursive-read-locks X-Git-Tag: v0.0.114-beta^0 X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=commitdiff_plain;h=8311581fe110a9ee561a6fda6b55c78a02068d43;hp=b8bea74d6a7c47c548670a851eb9a798a526daae;p=rust-lightning Merge pull request #2006 from TheBlueMatt/2023-02-no-recursive-read-locks Refuse recursive read locks --- diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9fb37f1b..2bb21c2c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -141,8 +141,9 @@ jobs: cargo test --verbose --color always --features esplora-async - name: Test backtrace-debug builds on Rust ${{ matrix.toolchain }} if: "matrix.toolchain == 'stable'" + shell: bash # Default on Winblows is powershell run: | - cd lightning && cargo test --verbose --color always --features backtrace + cd lightning && RUST_BACKTRACE=1 cargo test --verbose --color always --features backtrace - name: Test on Rust ${{ matrix.toolchain }} with net-tokio if: "matrix.build-net-tokio && !matrix.coverage" run: cargo test --verbose --color always diff --git a/lightning/src/chain/channelmonitor.rs b/lightning/src/chain/channelmonitor.rs index 8b281db6..a664c7c7 100644 --- a/lightning/src/chain/channelmonitor.rs +++ b/lightning/src/chain/channelmonitor.rs @@ -60,7 +60,7 @@ use core::{cmp, mem}; use crate::io::{self, Error}; use core::convert::TryInto; use core::ops::Deref; -use crate::sync::Mutex; +use crate::sync::{Mutex, LockTestExt}; /// An update generated by the underlying channel itself which contains some new information the /// [`ChannelMonitor`] should be made aware of. @@ -851,9 +851,13 @@ pub type TransactionOutputs = (Txid, Vec<(u32, TxOut)>); impl PartialEq for ChannelMonitor where Signer: PartialEq { fn eq(&self, other: &Self) -> bool { - let inner = self.inner.lock().unwrap(); - let other = other.inner.lock().unwrap(); - inner.eq(&other) + // We need some kind of total lockorder. Absent a better idea, we sort by position in + // memory and take locks in that order (assuming that we can't move within memory while a + // lock is held). + let ord = ((self as *const _) as usize) < ((other as *const _) as usize); + let a = if ord { self.inner.unsafe_well_ordered_double_lock_self() } else { other.inner.unsafe_well_ordered_double_lock_self() }; + let b = if ord { other.inner.unsafe_well_ordered_double_lock_self() } else { self.inner.unsafe_well_ordered_double_lock_self() }; + a.eq(&b) } } @@ -4066,7 +4070,10 @@ mod tests { fn test_prune_preimages() { let secp_ctx = Secp256k1::new(); let logger = Arc::new(TestLogger::new()); - let broadcaster = Arc::new(TestBroadcaster{txn_broadcasted: Mutex::new(Vec::new()), blocks: Arc::new(Mutex::new(Vec::new()))}); + let broadcaster = Arc::new(TestBroadcaster { + txn_broadcasted: Mutex::new(Vec::new()), + blocks: Arc::new(Mutex::new(Vec::new())) + }); let fee_estimator = TestFeeEstimator { sat_per_kw: Mutex::new(253) }; let dummy_key = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap()); diff --git a/lightning/src/ln/chanmon_update_fail_tests.rs b/lightning/src/ln/chanmon_update_fail_tests.rs index 1532884f..18e461f6 100644 --- a/lightning/src/ln/chanmon_update_fail_tests.rs +++ b/lightning/src/ln/chanmon_update_fail_tests.rs @@ -108,12 +108,13 @@ fn test_monitor_and_persister_update_fail() { blocks: Arc::new(Mutex::new(vec![(genesis_block(Network::Testnet), 200); 200])), }; let chain_mon = { - let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap(); - let mut w = test_utils::TestVecWriter(Vec::new()); - monitor.write(&mut w).unwrap(); - let new_monitor = <(BlockHash, ChannelMonitor)>::read( - &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1; - assert!(new_monitor == *monitor); + let new_monitor = { + let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap(); + let new_monitor = <(BlockHash, ChannelMonitor)>::read( + &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1; + assert!(new_monitor == *monitor); + new_monitor + }; let chain_mon = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager); assert_eq!(chain_mon.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed); chain_mon @@ -1426,9 +1427,11 @@ fn monitor_failed_no_reestablish_response() { { let mut node_0_per_peer_lock; let mut node_0_peer_state_lock; + get_channel_ref!(nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, channel_id).announcement_sigs_state = AnnouncementSigsState::PeerReceived; + } + { let mut node_1_per_peer_lock; let mut node_1_peer_state_lock; - get_channel_ref!(nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, channel_id).announcement_sigs_state = AnnouncementSigsState::PeerReceived; get_channel_ref!(nodes[1], nodes[0], node_1_per_peer_lock, node_1_peer_state_lock, channel_id).announcement_sigs_state = AnnouncementSigsState::PeerReceived; } diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 80cea29f..0757e117 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -1417,7 +1417,7 @@ macro_rules! emit_channel_ready_event { } macro_rules! handle_monitor_update_completion { - ($self: ident, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $chan: expr) => { { + ($self: ident, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $per_peer_state_lock: expr, $chan: expr) => { { let mut updates = $chan.monitor_updating_restored(&$self.logger, &$self.node_signer, $self.genesis_hash, &$self.default_configuration, $self.best_block.read().unwrap().height()); @@ -1450,6 +1450,7 @@ macro_rules! handle_monitor_update_completion { let channel_id = $chan.channel_id(); core::mem::drop($peer_state_lock); + core::mem::drop($per_peer_state_lock); $self.handle_monitor_update_completion_actions(update_actions); @@ -1465,7 +1466,7 @@ macro_rules! handle_monitor_update_completion { } macro_rules! handle_new_monitor_update { - ($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $chan: expr, MANUALLY_REMOVING, $remove: expr) => { { + ($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $per_peer_state_lock: expr, $chan: expr, MANUALLY_REMOVING, $remove: expr) => { { // update_maps_on_chan_removal needs to be able to take id_to_peer, so make sure we can in // any case so that it won't deadlock. debug_assert!($self.id_to_peer.try_lock().is_ok()); @@ -1492,14 +1493,14 @@ macro_rules! handle_new_monitor_update { .update_id == $update_id) && $chan.get_latest_monitor_update_id() == $update_id { - handle_monitor_update_completion!($self, $update_id, $peer_state_lock, $peer_state, $chan); + handle_monitor_update_completion!($self, $update_id, $peer_state_lock, $peer_state, $per_peer_state_lock, $chan); } Ok(()) }, } } }; - ($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $chan_entry: expr) => { - handle_new_monitor_update!($self, $update_res, $update_id, $peer_state_lock, $peer_state, $chan_entry.get_mut(), MANUALLY_REMOVING, $chan_entry.remove_entry()) + ($self: ident, $update_res: expr, $update_id: expr, $peer_state_lock: expr, $peer_state: expr, $per_peer_state_lock: expr, $chan_entry: expr) => { + handle_new_monitor_update!($self, $update_res, $update_id, $peer_state_lock, $peer_state, $per_peer_state_lock, $chan_entry.get_mut(), MANUALLY_REMOVING, $chan_entry.remove_entry()) } } @@ -1835,7 +1836,7 @@ where if let Some(monitor_update) = monitor_update_opt.take() { let update_id = monitor_update.update_id; let update_res = self.chain_monitor.update_channel(funding_txo_opt.unwrap(), monitor_update); - break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, chan_entry); + break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, per_peer_state, chan_entry); } if chan_entry.get().is_shutdown() { @@ -2413,8 +2414,16 @@ where }) } - // Only public for testing, this should otherwise never be called direcly - pub(crate) fn send_payment_along_path(&self, path: &Vec, payment_params: &Option, payment_hash: &PaymentHash, payment_secret: &Option, total_value: u64, cur_height: u32, payment_id: PaymentId, keysend_preimage: &Option, session_priv_bytes: [u8; 32]) -> Result<(), APIError> { + #[cfg(test)] + pub(crate) fn test_send_payment_along_path(&self, path: &Vec, payment_params: &Option, payment_hash: &PaymentHash, payment_secret: &Option, total_value: u64, cur_height: u32, payment_id: PaymentId, keysend_preimage: &Option, session_priv_bytes: [u8; 32]) -> Result<(), APIError> { + let _lck = self.total_consistency_lock.read().unwrap(); + self.send_payment_along_path(path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv_bytes) + } + + fn send_payment_along_path(&self, path: &Vec, payment_params: &Option, payment_hash: &PaymentHash, payment_secret: &Option, total_value: u64, cur_height: u32, payment_id: PaymentId, keysend_preimage: &Option, session_priv_bytes: [u8; 32]) -> Result<(), APIError> { + // The top-level caller should hold the total_consistency_lock read lock. + debug_assert!(self.total_consistency_lock.try_write().is_err()); + log_trace!(self.logger, "Attempting to send payment for path with next hop {}", path.first().unwrap().short_channel_id); let prng_seed = self.entropy_source.get_secure_random_bytes(); let session_priv = SecretKey::from_slice(&session_priv_bytes[..]).expect("RNG is busted"); @@ -2427,8 +2436,6 @@ where } let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash); - let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier); - let err: Result<(), _> = loop { let (counterparty_node_id, id) = match self.short_to_chan_info.read().unwrap().get(&path.first().unwrap().short_channel_id) { None => return Err(APIError::ChannelUnavailable{err: "No channel available with first hop!".to_owned()}), @@ -2458,7 +2465,7 @@ where Some(monitor_update) => { let update_id = monitor_update.update_id; let update_res = self.chain_monitor.update_channel(funding_txo, monitor_update); - if let Err(e) = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, chan) { + if let Err(e) = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, per_peer_state, chan) { break Err(e); } if update_res == ChannelMonitorUpdateStatus::InProgress { @@ -2555,6 +2562,7 @@ where /// [`ChannelMonitorUpdateStatus::InProgress`]: crate::chain::ChannelMonitorUpdateStatus::InProgress pub fn send_payment(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option, payment_id: PaymentId) -> Result<(), PaymentSendFailure> { let best_block_height = self.best_block.read().unwrap().height(); + let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier); self.pending_outbound_payments .send_payment_with_route(route, payment_hash, payment_secret, payment_id, &self.entropy_source, &self.node_signer, best_block_height, |path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv| @@ -2565,6 +2573,7 @@ where /// `route_params` and retry failed payment paths based on `retry_strategy`. pub fn send_payment_with_retry(&self, payment_hash: PaymentHash, payment_secret: &Option, payment_id: PaymentId, route_params: RouteParameters, retry_strategy: Retry) -> Result<(), RetryableSendFailure> { let best_block_height = self.best_block.read().unwrap().height(); + let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier); self.pending_outbound_payments .send_payment(payment_hash, payment_secret, payment_id, retry_strategy, route_params, &self.router, self.list_usable_channels(), || self.compute_inflight_htlcs(), @@ -2577,6 +2586,7 @@ where #[cfg(test)] fn test_send_payment_internal(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option, keysend_preimage: Option, payment_id: PaymentId, recv_value_msat: Option, onion_session_privs: Vec<[u8; 32]>) -> Result<(), PaymentSendFailure> { let best_block_height = self.best_block.read().unwrap().height(); + let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier); self.pending_outbound_payments.test_send_payment_internal(route, payment_hash, payment_secret, keysend_preimage, payment_id, recv_value_msat, onion_session_privs, &self.node_signer, best_block_height, |path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv| self.send_payment_along_path(path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv)) @@ -2627,6 +2637,7 @@ where /// [`send_payment`]: Self::send_payment pub fn send_spontaneous_payment(&self, route: &Route, payment_preimage: Option, payment_id: PaymentId) -> Result { let best_block_height = self.best_block.read().unwrap().height(); + let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier); self.pending_outbound_payments.send_spontaneous_payment_with_route( route, payment_preimage, payment_id, &self.entropy_source, &self.node_signer, best_block_height, @@ -2643,6 +2654,7 @@ where /// [`PaymentParameters::for_keysend`]: crate::routing::router::PaymentParameters::for_keysend pub fn send_spontaneous_payment_with_retry(&self, payment_preimage: Option, payment_id: PaymentId, route_params: RouteParameters, retry_strategy: Retry) -> Result { let best_block_height = self.best_block.read().unwrap().height(); + let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier); self.pending_outbound_payments.send_spontaneous_payment(payment_preimage, payment_id, retry_strategy, route_params, &self.router, self.list_usable_channels(), || self.compute_inflight_htlcs(), &self.entropy_source, &self.node_signer, best_block_height, @@ -2656,6 +2668,7 @@ where /// us to easily discern them from real payments. pub fn send_probe(&self, hops: Vec) -> Result<(PaymentHash, PaymentId), PaymentSendFailure> { let best_block_height = self.best_block.read().unwrap().height(); + let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier); self.pending_outbound_payments.send_probe(hops, self.probing_cookie_secret, &self.entropy_source, &self.node_signer, best_block_height, |path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv| self.send_payment_along_path(path, payment_params, payment_hash, payment_secret, total_value, cur_height, payment_id, keysend_preimage, session_priv)) @@ -3979,7 +3992,8 @@ where ) ).unwrap_or(None); - if let Some(mut peer_state_lock) = peer_state_opt.take() { + if peer_state_opt.is_some() { + let mut peer_state_lock = peer_state_opt.unwrap(); let peer_state = &mut *peer_state_lock; if let hash_map::Entry::Occupied(mut chan) = peer_state.channel_by_id.entry(chan_id) { let counterparty_node_id = chan.get().get_counterparty_node_id(); @@ -3994,7 +4008,7 @@ where let update_id = monitor_update.update_id; let update_res = self.chain_monitor.update_channel(prev_hop.outpoint, monitor_update); let res = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, - peer_state, chan); + peer_state, per_peer_state, chan); if let Err(e) = res { // TODO: This is a *critical* error - we probably updated the outbound edge // of the HTLC's monitor with a preimage. We should retry this monitor @@ -4164,7 +4178,7 @@ where } fn channel_monitor_updated(&self, funding_txo: &OutPoint, highest_applied_update_id: u64, counterparty_node_id: Option<&PublicKey>) { - let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier); + debug_assert!(self.total_consistency_lock.try_write().is_err()); // Caller holds read lock let counterparty_node_id = match counterparty_node_id { Some(cp_id) => cp_id.clone(), @@ -4195,7 +4209,7 @@ where if !channel.get().is_awaiting_monitor_update() || channel.get().get_latest_monitor_update_id() != highest_applied_update_id { return; } - handle_monitor_update_completion!(self, highest_applied_update_id, peer_state_lock, peer_state, channel.get_mut()); + handle_monitor_update_completion!(self, highest_applied_update_id, peer_state_lock, peer_state, per_peer_state, channel.get_mut()); } /// Accepts a request to open a channel after a [`Event::OpenChannelRequest`]. @@ -4501,7 +4515,8 @@ where let monitor_res = self.chain_monitor.watch_channel(monitor.get_funding_txo().0, monitor); let chan = e.insert(chan); - let mut res = handle_new_monitor_update!(self, monitor_res, 0, peer_state_lock, peer_state, chan, MANUALLY_REMOVING, { peer_state.channel_by_id.remove(&new_channel_id) }); + let mut res = handle_new_monitor_update!(self, monitor_res, 0, peer_state_lock, peer_state, + per_peer_state, chan, MANUALLY_REMOVING, { peer_state.channel_by_id.remove(&new_channel_id) }); // Note that we reply with the new channel_id in error messages if we gave up on the // channel, not the temporary_channel_id. This is compatible with ourselves, but the @@ -4534,7 +4549,7 @@ where let monitor = try_chan_entry!(self, chan.get_mut().funding_signed(&msg, best_block, &self.signer_provider, &self.logger), chan); let update_res = self.chain_monitor.watch_channel(chan.get().get_funding_txo().unwrap(), monitor); - let mut res = handle_new_monitor_update!(self, update_res, 0, peer_state_lock, peer_state, chan); + let mut res = handle_new_monitor_update!(self, update_res, 0, peer_state_lock, peer_state, per_peer_state, chan); if let Err(MsgHandleErrInternal { ref mut shutdown_finish, .. }) = res { // We weren't able to watch the channel to begin with, so no updates should be made on // it. Previously, full_stack_target found an (unreachable) panic when the @@ -4630,7 +4645,7 @@ where if let Some(monitor_update) = monitor_update_opt { let update_id = monitor_update.update_id; let update_res = self.chain_monitor.update_channel(funding_txo_opt.unwrap(), monitor_update); - break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, chan_entry); + break handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, peer_state, per_peer_state, chan_entry); } break Ok(()); }, @@ -4822,7 +4837,7 @@ where let update_res = self.chain_monitor.update_channel(funding_txo.unwrap(), monitor_update); let update_id = monitor_update.update_id; handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, - peer_state, chan) + peer_state, per_peer_state, chan) }, hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close(format!("Got a message for a channel from the wrong node! No such channel for the passed counterparty_node_id {}", counterparty_node_id), msg.channel_id)) } @@ -4928,12 +4943,11 @@ where fn internal_revoke_and_ack(&self, counterparty_node_id: &PublicKey, msg: &msgs::RevokeAndACK) -> Result<(), MsgHandleErrInternal> { let (htlcs_to_fail, res) = { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let mut peer_state_lock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) - })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + }).map(|mtx| mtx.lock().unwrap())?; let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan) => { @@ -4941,8 +4955,8 @@ where let (htlcs_to_fail, monitor_update) = try_chan_entry!(self, chan.get_mut().revoke_and_ack(&msg, &self.logger), chan); let update_res = self.chain_monitor.update_channel(funding_txo.unwrap(), monitor_update); let update_id = monitor_update.update_id; - let res = handle_new_monitor_update!(self, update_res, update_id, peer_state_lock, - peer_state, chan); + let res = handle_new_monitor_update!(self, update_res, update_id, + peer_state_lock, peer_state, per_peer_state, chan); (htlcs_to_fail, res) }, hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close(format!("Got a message for a channel from the wrong node! No such channel for the passed counterparty_node_id {}", counterparty_node_id), msg.channel_id)) @@ -5104,6 +5118,8 @@ where /// Process pending events from the `chain::Watch`, returning whether any events were processed. fn process_pending_monitor_events(&self) -> bool { + debug_assert!(self.total_consistency_lock.try_write().is_err()); // Caller holds read lock + let mut failed_channels = Vec::new(); let mut pending_monitor_events = self.chain_monitor.release_pending_monitor_events(); let has_pending_monitor_events = !pending_monitor_events.is_empty(); @@ -5181,7 +5197,13 @@ where /// update events as a separate process method here. #[cfg(fuzzing)] pub fn process_monitor_events(&self) { - self.process_pending_monitor_events(); + PersistenceNotifierGuard::optionally_notify(&self.total_consistency_lock, &self.persistence_notifier, || { + if self.process_pending_monitor_events() { + NotifyOption::DoPersist + } else { + NotifyOption::SkipPersist + } + }); } /// Check the holding cell in each channel and free any pending HTLCs in them if possible. @@ -5191,38 +5213,45 @@ where let mut has_monitor_update = false; let mut failed_htlcs = Vec::new(); let mut handle_errors = Vec::new(); - let per_peer_state = self.per_peer_state.read().unwrap(); - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { - 'chan_loop: loop { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); - let peer_state: &mut PeerState<_> = &mut *peer_state_lock; - for (channel_id, chan) in peer_state.channel_by_id.iter_mut() { - let counterparty_node_id = chan.get_counterparty_node_id(); - let funding_txo = chan.get_funding_txo(); - let (monitor_opt, holding_cell_failed_htlcs) = - chan.maybe_free_holding_cell_htlcs(&self.logger); - if !holding_cell_failed_htlcs.is_empty() { - failed_htlcs.push((holding_cell_failed_htlcs, *channel_id, counterparty_node_id)); - } - if let Some(monitor_update) = monitor_opt { - has_monitor_update = true; - - let update_res = self.chain_monitor.update_channel( - funding_txo.expect("channel is live"), monitor_update); - let update_id = monitor_update.update_id; - let channel_id: [u8; 32] = *channel_id; - let res = handle_new_monitor_update!(self, update_res, update_id, - peer_state_lock, peer_state, chan, MANUALLY_REMOVING, - peer_state.channel_by_id.remove(&channel_id)); - if res.is_err() { - handle_errors.push((counterparty_node_id, res)); + // Walk our list of channels and find any that need to update. Note that when we do find an + // update, if it includes actions that must be taken afterwards, we have to drop the + // per-peer state lock as well as the top level per_peer_state lock. Thus, we loop until we + // manage to go through all our peers without finding a single channel to update. + 'peer_loop: loop { + let per_peer_state = self.per_peer_state.read().unwrap(); + for (_cp_id, peer_state_mutex) in per_peer_state.iter() { + 'chan_loop: loop { + let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let peer_state: &mut PeerState<_> = &mut *peer_state_lock; + for (channel_id, chan) in peer_state.channel_by_id.iter_mut() { + let counterparty_node_id = chan.get_counterparty_node_id(); + let funding_txo = chan.get_funding_txo(); + let (monitor_opt, holding_cell_failed_htlcs) = + chan.maybe_free_holding_cell_htlcs(&self.logger); + if !holding_cell_failed_htlcs.is_empty() { + failed_htlcs.push((holding_cell_failed_htlcs, *channel_id, counterparty_node_id)); + } + if let Some(monitor_update) = monitor_opt { + has_monitor_update = true; + + let update_res = self.chain_monitor.update_channel( + funding_txo.expect("channel is live"), monitor_update); + let update_id = monitor_update.update_id; + let channel_id: [u8; 32] = *channel_id; + let res = handle_new_monitor_update!(self, update_res, update_id, + peer_state_lock, peer_state, per_peer_state, chan, MANUALLY_REMOVING, + peer_state.channel_by_id.remove(&channel_id)); + if res.is_err() { + handle_errors.push((counterparty_node_id, res)); + } + continue 'peer_loop; } - continue 'chan_loop; } + break 'chan_loop; } - break 'chan_loop; } + break 'peer_loop; } let has_update = has_monitor_update || !failed_htlcs.is_empty() || !handle_errors.is_empty(); @@ -6921,7 +6950,10 @@ where let mut monitor_update_blocked_actions_per_peer = None; let mut peer_states = Vec::new(); for (_, peer_state_mutex) in per_peer_state.iter() { - peer_states.push(peer_state_mutex.lock().unwrap()); + // Because we're holding the owning `per_peer_state` write lock here there's no chance + // of a lockorder violation deadlock - no other thread can be holding any + // per_peer_state lock at all. + peer_states.push(peer_state_mutex.unsafe_well_ordered_double_lock_self()); } (serializable_peer_count).write(writer)?; @@ -7509,7 +7541,10 @@ where } } - let pending_outbounds = OutboundPayments { pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), retry_lock: Mutex::new(()) }; + let pending_outbounds = OutboundPayments { + pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), + retry_lock: Mutex::new(()) + }; if !forward_htlcs.is_empty() || pending_outbounds.needs_abandon() { // If we have pending HTLCs to forward, assume we either dropped a // `PendingHTLCsForwardable` or the user received it but never processed it as they @@ -7850,7 +7885,7 @@ mod tests { // indicates there are more HTLCs coming. let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match. let session_privs = nodes[0].node.test_add_new_pending_payment(our_payment_hash, Some(payment_secret), payment_id, &mpp_route).unwrap(); - nodes[0].node.send_payment_along_path(&mpp_route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap(); + nodes[0].node.test_send_payment_along_path(&mpp_route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap(); check_added_monitors!(nodes[0], 1); let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); @@ -7880,7 +7915,7 @@ mod tests { expect_payment_failed!(nodes[0], our_payment_hash, true); // Send the second half of the original MPP payment. - nodes[0].node.send_payment_along_path(&mpp_route.paths[1], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[1]).unwrap(); + nodes[0].node.test_send_payment_along_path(&mpp_route.paths[1], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[1]).unwrap(); check_added_monitors!(nodes[0], 1); let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); @@ -8257,10 +8292,10 @@ mod tests { let nodes_0_lock = nodes[0].node.id_to_peer.lock().unwrap(); assert_eq!(nodes_0_lock.len(), 1); assert!(nodes_0_lock.contains_key(channel_id)); - - assert_eq!(nodes[1].node.id_to_peer.lock().unwrap().len(), 0); } + assert_eq!(nodes[1].node.id_to_peer.lock().unwrap().len(), 0); + let funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id()); nodes[1].node.handle_funding_created(&nodes[0].node.get_our_node_id(), &funding_created_msg); @@ -8268,7 +8303,9 @@ mod tests { let nodes_0_lock = nodes[0].node.id_to_peer.lock().unwrap(); assert_eq!(nodes_0_lock.len(), 1); assert!(nodes_0_lock.contains_key(channel_id)); + } + { // Assert that `nodes[1]`'s `id_to_peer` map is populated with the channel as soon as // as it has the funding transaction. let nodes_1_lock = nodes[1].node.id_to_peer.lock().unwrap(); @@ -8298,7 +8335,9 @@ mod tests { let nodes_0_lock = nodes[0].node.id_to_peer.lock().unwrap(); assert_eq!(nodes_0_lock.len(), 1); assert!(nodes_0_lock.contains_key(channel_id)); + } + { // At this stage, `nodes[1]` has proposed a fee for the closing transaction in the // `handle_closing_signed` call above. As `nodes[1]` has not yet received the signature // from `nodes[0]` for the closing transaction with the proposed fee, the channel is diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index b4249169..96ce9312 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -44,7 +44,7 @@ use crate::io; use crate::prelude::*; use core::cell::RefCell; use alloc::rc::Rc; -use crate::sync::{Arc, Mutex}; +use crate::sync::{Arc, Mutex, LockTestExt}; use core::mem; use core::iter::repeat; use bitcoin::{PackedLockTime, TxMerkleNode}; @@ -466,8 +466,8 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> { panic!(); } } - assert_eq!(*chain_source.watched_txn.lock().unwrap(), *self.chain_source.watched_txn.lock().unwrap()); - assert_eq!(*chain_source.watched_outputs.lock().unwrap(), *self.chain_source.watched_outputs.lock().unwrap()); + assert_eq!(*chain_source.watched_txn.unsafe_well_ordered_double_lock_self(), *self.chain_source.watched_txn.unsafe_well_ordered_double_lock_self()); + assert_eq!(*chain_source.watched_outputs.unsafe_well_ordered_double_lock_self(), *self.chain_source.watched_outputs.unsafe_well_ordered_double_lock_self()); } } } @@ -2151,9 +2151,10 @@ pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_rou assert!(err.contains("Cannot send value that would put us over the max HTLC value in flight our peer will accept"))); } -pub fn send_payment<'a, 'b, 'c>(origin: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) { - let our_payment_preimage = route_payment(&origin, expected_route, recv_value).0; - claim_payment(&origin, expected_route, our_payment_preimage); +pub fn send_payment<'a, 'b, 'c>(origin: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) -> (PaymentPreimage, PaymentHash, PaymentSecret) { + let res = route_payment(&origin, expected_route, recv_value); + claim_payment(&origin, expected_route, res.0); + res } pub fn fail_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_hash: PaymentHash) { diff --git a/lightning/src/ln/functional_tests.rs b/lightning/src/ln/functional_tests.rs index e16b20b8..454fbe2b 100644 --- a/lightning/src/ln/functional_tests.rs +++ b/lightning/src/ln/functional_tests.rs @@ -4083,7 +4083,7 @@ fn do_test_htlc_timeout(send_partial_mpp: bool) { let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match. let payment_id = PaymentId([42; 32]); let session_privs = nodes[0].node.test_add_new_pending_payment(our_payment_hash, Some(payment_secret), payment_id, &route).unwrap(); - nodes[0].node.send_payment_along_path(&route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap(); + nodes[0].node.test_send_payment_along_path(&route.paths[0], &route.payment_params, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None, session_privs[0]).unwrap(); check_added_monitors!(nodes[0], 1); let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); @@ -8150,12 +8150,13 @@ fn test_update_err_monitor_lockdown() { let logger = test_utils::TestLogger::with_id(format!("node {}", 0)); let persister = test_utils::TestPersister::new(); let watchtower = { - let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap(); - let mut w = test_utils::TestVecWriter(Vec::new()); - monitor.write(&mut w).unwrap(); - let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1; - assert!(new_monitor == *monitor); + let new_monitor = { + let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap(); + let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( + &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1; + assert!(new_monitor == *monitor); + new_monitor + }; let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager); assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed); watchtower @@ -8217,12 +8218,13 @@ fn test_concurrent_monitor_claim() { let logger = test_utils::TestLogger::with_id(format!("node {}", "Alice")); let persister = test_utils::TestPersister::new(); let watchtower_alice = { - let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap(); - let mut w = test_utils::TestVecWriter(Vec::new()); - monitor.write(&mut w).unwrap(); - let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1; - assert!(new_monitor == *monitor); + let new_monitor = { + let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap(); + let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( + &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1; + assert!(new_monitor == *monitor); + new_monitor + }; let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager); assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed); watchtower @@ -8246,12 +8248,13 @@ fn test_concurrent_monitor_claim() { let logger = test_utils::TestLogger::with_id(format!("node {}", "Bob")); let persister = test_utils::TestPersister::new(); let watchtower_bob = { - let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap(); - let mut w = test_utils::TestVecWriter(Vec::new()); - monitor.write(&mut w).unwrap(); - let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( - &mut io::Cursor::new(&w.0), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1; - assert!(new_monitor == *monitor); + let new_monitor = { + let monitor = nodes[0].chain_monitor.chain_monitor.get_monitor(outpoint).unwrap(); + let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor)>::read( + &mut io::Cursor::new(&monitor.encode()), (nodes[0].keys_manager, nodes[0].keys_manager)).unwrap().1; + assert!(new_monitor == *monitor); + new_monitor + }; let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister, &node_cfgs[0].keys_manager); assert_eq!(watchtower.watch_channel(outpoint, new_monitor), ChannelMonitorUpdateStatus::Completed); watchtower @@ -9141,20 +9144,20 @@ fn test_inconsistent_mpp_params() { dup_route.paths.push(route.paths[1].clone()); nodes[0].node.test_add_new_pending_payment(our_payment_hash, Some(our_payment_secret), payment_id, &dup_route).unwrap() }; - { - nodes[0].node.send_payment_along_path(&route.paths[0], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[0]).unwrap(); - check_added_monitors!(nodes[0], 1); + nodes[0].node.test_send_payment_along_path(&route.paths[0], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[0]).unwrap(); + check_added_monitors!(nodes[0], 1); + { let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); pass_along_path(&nodes[0], &[&nodes[1], &nodes[3]], 15_000_000, our_payment_hash, Some(our_payment_secret), events.pop().unwrap(), false, None); } assert!(nodes[3].node.get_and_clear_pending_events().is_empty()); - { - nodes[0].node.send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 14_000_000, cur_height, payment_id, &None, session_privs[1]).unwrap(); - check_added_monitors!(nodes[0], 1); + nodes[0].node.test_send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 14_000_000, cur_height, payment_id, &None, session_privs[1]).unwrap(); + check_added_monitors!(nodes[0], 1); + { let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); let payment_event = SendEvent::from_event(events.pop().unwrap()); @@ -9197,7 +9200,7 @@ fn test_inconsistent_mpp_params() { expect_payment_failed_conditions(&nodes[0], our_payment_hash, true, PaymentFailedConditions::new().mpp_parts_remain()); - nodes[0].node.send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[2]).unwrap(); + nodes[0].node.test_send_payment_along_path(&route.paths[1], &payment_params_opt, &our_payment_hash, &Some(our_payment_secret), 15_000_000, cur_height, payment_id, &None, session_privs[2]).unwrap(); check_added_monitors!(nodes[0], 1); let mut events = nodes[0].node.get_and_clear_pending_msg_events(); diff --git a/lightning/src/ln/payment_tests.rs b/lightning/src/ln/payment_tests.rs index 72513087..c4cd0fc1 100644 --- a/lightning/src/ln/payment_tests.rs +++ b/lightning/src/ln/payment_tests.rs @@ -1192,33 +1192,31 @@ fn test_trivial_inflight_htlc_tracking(){ let (_, _, chan_2_id, _) = create_announced_chan_between_nodes(&nodes, 1, 2); // Send and claim the payment. Inflight HTLCs should be empty. - let (route, payment_hash, payment_preimage, payment_secret) = get_route_and_payment_hash!(nodes[0], nodes[2], 500000); - nodes[0].node.send_payment(&route, payment_hash, &Some(payment_secret), PaymentId(payment_hash.0)).unwrap(); - check_added_monitors!(nodes[0], 1); - pass_along_route(&nodes[0], &[&vec!(&nodes[1], &nodes[2])[..]], 500000, payment_hash, payment_secret); - claim_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], payment_preimage); + let payment_hash = send_payment(&nodes[0], &[&nodes[1], &nodes[2]], 500000).1; + let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs(); { - let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs(); - let mut node_0_per_peer_lock; let mut node_0_peer_state_lock; - let mut node_1_per_peer_lock; - let mut node_1_peer_state_lock; let channel_1 = get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id); - let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id); let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat( &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) , &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()), channel_1.get_short_channel_id().unwrap() ); + assert_eq!(chan_1_used_liquidity, None); + } + { + let mut node_1_per_peer_lock; + let mut node_1_peer_state_lock; + let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id); + let chan_2_used_liquidity = inflight_htlcs.used_liquidity_msat( &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()) , &NodeId::from_pubkey(&nodes[2].node.get_our_node_id()), channel_2.get_short_channel_id().unwrap() ); - assert_eq!(chan_1_used_liquidity, None); assert_eq!(chan_2_used_liquidity, None); } let pending_payments = nodes[0].node.list_recent_payments(); @@ -1231,30 +1229,32 @@ fn test_trivial_inflight_htlc_tracking(){ } // Send the payment, but do not claim it. Our inflight HTLCs should contain the pending payment. - let (payment_preimage, payment_hash, _) = route_payment(&nodes[0], &vec!(&nodes[1], &nodes[2])[..], 500000); + let (payment_preimage, payment_hash, _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 500000); + let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs(); { - let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs(); - let mut node_0_per_peer_lock; let mut node_0_peer_state_lock; - let mut node_1_per_peer_lock; - let mut node_1_peer_state_lock; let channel_1 = get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id); - let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id); let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat( &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) , &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()), channel_1.get_short_channel_id().unwrap() ); + // First hop accounts for expected 1000 msat fee + assert_eq!(chan_1_used_liquidity, Some(501000)); + } + { + let mut node_1_per_peer_lock; + let mut node_1_peer_state_lock; + let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id); + let chan_2_used_liquidity = inflight_htlcs.used_liquidity_msat( &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()) , &NodeId::from_pubkey(&nodes[2].node.get_our_node_id()), channel_2.get_short_channel_id().unwrap() ); - // First hop accounts for expected 1000 msat fee - assert_eq!(chan_1_used_liquidity, Some(501000)); assert_eq!(chan_2_used_liquidity, Some(500000)); } let pending_payments = nodes[0].node.list_recent_payments(); @@ -1269,28 +1269,29 @@ fn test_trivial_inflight_htlc_tracking(){ nodes[0].node.timer_tick_occurred(); } + let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs(); { - let inflight_htlcs = node_chanmgrs[0].compute_inflight_htlcs(); - let mut node_0_per_peer_lock; let mut node_0_peer_state_lock; - let mut node_1_per_peer_lock; - let mut node_1_peer_state_lock; let channel_1 = get_channel_ref!(&nodes[0], nodes[1], node_0_per_peer_lock, node_0_peer_state_lock, chan_1_id); - let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id); let chan_1_used_liquidity = inflight_htlcs.used_liquidity_msat( &NodeId::from_pubkey(&nodes[0].node.get_our_node_id()) , &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()), channel_1.get_short_channel_id().unwrap() ); + assert_eq!(chan_1_used_liquidity, None); + } + { + let mut node_1_per_peer_lock; + let mut node_1_peer_state_lock; + let channel_2 = get_channel_ref!(&nodes[1], nodes[2], node_1_per_peer_lock, node_1_peer_state_lock, chan_2_id); + let chan_2_used_liquidity = inflight_htlcs.used_liquidity_msat( &NodeId::from_pubkey(&nodes[1].node.get_our_node_id()) , &NodeId::from_pubkey(&nodes[2].node.get_our_node_id()), channel_2.get_short_channel_id().unwrap() ); - - assert_eq!(chan_1_used_liquidity, None); assert_eq!(chan_2_used_liquidity, None); } diff --git a/lightning/src/routing/utxo.rs b/lightning/src/routing/utxo.rs index 0e5bd8d6..74abd427 100644 --- a/lightning/src/routing/utxo.rs +++ b/lightning/src/routing/utxo.rs @@ -26,7 +26,7 @@ use crate::util::ser::Writeable; use crate::prelude::*; use alloc::sync::{Arc, Weak}; -use crate::sync::Mutex; +use crate::sync::{Mutex, LockTestExt}; use core::ops::Deref; /// An error when accessing the chain via [`UtxoLookup`]. @@ -404,7 +404,10 @@ impl PendingChecks { // lookup if we haven't gotten that far yet). match Weak::upgrade(&e.get()) { Some(pending_msgs) => { - let pending_matches = match &pending_msgs.lock().unwrap().channel_announce { + // This may be called with the mutex held on a different UtxoMessages + // struct, however in that case we have a global lockorder of new messages + // -> old messages, which makes this safe. + let pending_matches = match &pending_msgs.unsafe_well_ordered_double_lock_self().channel_announce { Some(ChannelAnnouncement::Full(pending_msg)) => Some(pending_msg) == full_msg, Some(ChannelAnnouncement::Unsigned(pending_msg)) => pending_msg == msg, None => { @@ -745,10 +748,11 @@ mod tests { Ok(TxOut { value: 1_000_000, script_pubkey: good_script })); assert_eq!(chan_update_a.contents.timestamp, chan_update_b.contents.timestamp); - assert!(network_graph.read_only().channels() + let graph_lock = network_graph.read_only(); + assert!(graph_lock.channels() .get(&valid_announcement.contents.short_channel_id).as_ref().unwrap() .one_to_two.as_ref().unwrap().last_update != - network_graph.read_only().channels() + graph_lock.channels() .get(&valid_announcement.contents.short_channel_id).as_ref().unwrap() .two_to_one.as_ref().unwrap().last_update); } diff --git a/lightning/src/sync/debug_sync.rs b/lightning/src/sync/debug_sync.rs index 56310937..5b6acbca 100644 --- a/lightning/src/sync/debug_sync.rs +++ b/lightning/src/sync/debug_sync.rs @@ -75,7 +75,7 @@ struct LockDep { } #[cfg(feature = "backtrace")] -fn get_construction_location(backtrace: &Backtrace) -> String { +fn get_construction_location(backtrace: &Backtrace) -> (String, Option) { // Find the first frame that is after `debug_sync` (or that is in our tests) and use // that as the mutex construction site. Note that the first few frames may be in // the `backtrace` crate, so we have to ignore those. @@ -86,13 +86,7 @@ fn get_construction_location(backtrace: &Backtrace) -> String { let symbol_name = symbol.name().unwrap().as_str().unwrap(); if !sync_mutex_constr_regex.is_match(symbol_name) { if found_debug_sync { - if let Some(col) = symbol.colno() { - return format!("{}:{}:{}", symbol.filename().unwrap().display(), symbol.lineno().unwrap(), col); - } else { - // Windows debug symbols don't support column numbers, so fall back to - // line numbers only if no `colno` is available - return format!("{}:{}", symbol.filename().unwrap().display(), symbol.lineno().unwrap()); - } + return (format!("{}:{}", symbol.filename().unwrap().display(), symbol.lineno().unwrap()), symbol.colno()); } } else { found_debug_sync = true; } } @@ -113,42 +107,51 @@ impl LockMetadata { #[cfg(feature = "backtrace")] { - let lock_constr_location = get_construction_location(&res._lock_construction_bt); + let (lock_constr_location, lock_constr_colno) = + get_construction_location(&res._lock_construction_bt); LOCKS_INIT.call_once(|| { unsafe { LOCKS = Some(StdMutex::new(HashMap::new())); } }); let mut locks = unsafe { LOCKS.as_ref() }.unwrap().lock().unwrap(); match locks.entry(lock_constr_location) { - hash_map::Entry::Occupied(e) => return Arc::clone(e.get()), + hash_map::Entry::Occupied(e) => { + assert_eq!(lock_constr_colno, + get_construction_location(&e.get()._lock_construction_bt).1, + "Because Windows doesn't support column number results in backtraces, we cannot construct two mutexes on the same line or we risk lockorder detection false positives."); + return Arc::clone(e.get()) + }, hash_map::Entry::Vacant(e) => { e.insert(Arc::clone(&res)); }, } } res } - // Returns whether we were a recursive lock (only relevant for read) - fn _pre_lock(this: &Arc, read: bool) -> bool { - let mut inserted = false; + fn pre_lock(this: &Arc, _double_lock_self_allowed: bool) { LOCKS_HELD.with(|held| { // For each lock which is currently locked, check that no lock's locked-before // set includes the lock we're about to lock, which would imply a lockorder // inversion. - for (locked_idx, _locked) in held.borrow().iter() { - if read && *locked_idx == this.lock_idx { - // Recursive read locks are explicitly allowed - return; + for (locked_idx, locked) in held.borrow().iter() { + if *locked_idx == this.lock_idx { + // Note that with `feature = "backtrace"` set, we may be looking at different + // instances of the same lock. Still, doing so is quite risky, a total order + // must be maintained, and doing so across a set of otherwise-identical mutexes + // is fraught with issues. + #[cfg(feature = "backtrace")] + debug_assert!(_double_lock_self_allowed, + "Tried to acquire a lock while it was held!\nLock constructed at {}", + get_construction_location(&this._lock_construction_bt).0); + #[cfg(not(feature = "backtrace"))] + panic!("Tried to acquire a lock while it was held!"); } } for (locked_idx, locked) in held.borrow().iter() { - if !read && *locked_idx == this.lock_idx { - // With `feature = "backtrace"` set, we may be looking at different instances - // of the same lock. - debug_assert!(cfg!(feature = "backtrace"), "Tried to acquire a lock while it was held!"); - } for (locked_dep_idx, _locked_dep) in locked.locked_before.lock().unwrap().iter() { if *locked_dep_idx == this.lock_idx && *locked_dep_idx != locked.lock_idx { #[cfg(feature = "backtrace")] panic!("Tried to violate existing lockorder.\nMutex that should be locked after the current lock was created at the following backtrace.\nNote that to get a backtrace for the lockorder violation, you should set RUST_BACKTRACE=1\nLock being taken constructed at: {} ({}):\n{:?}\nLock constructed at: {} ({})\n{:?}\n\nLock dep created at:\n{:?}\n\n", - get_construction_location(&this._lock_construction_bt), this.lock_idx, this._lock_construction_bt, - get_construction_location(&locked._lock_construction_bt), locked.lock_idx, locked._lock_construction_bt, + get_construction_location(&this._lock_construction_bt).0, + this.lock_idx, this._lock_construction_bt, + get_construction_location(&locked._lock_construction_bt).0, + locked.lock_idx, locked._lock_construction_bt, _locked_dep._lockdep_trace); #[cfg(not(feature = "backtrace"))] panic!("Tried to violate existing lockorder. Build with the backtrace feature for more info."); @@ -162,14 +165,9 @@ impl LockMetadata { } } held.borrow_mut().insert(this.lock_idx, Arc::clone(this)); - inserted = true; }); - inserted } - fn pre_lock(this: &Arc) { Self::_pre_lock(this, false); } - fn pre_read_lock(this: &Arc) -> bool { Self::_pre_lock(this, true) } - fn held_by_thread(this: &Arc) -> LockHeldState { let mut res = LockHeldState::NotHeldByThread; LOCKS_HELD.with(|held| { @@ -249,7 +247,7 @@ impl Mutex { } pub fn lock<'a>(&'a self) -> LockResult> { - LockMetadata::pre_lock(&self.deps); + LockMetadata::pre_lock(&self.deps, false); self.inner.lock().map(|lock| MutexGuard { mutex: self, lock }).map_err(|_| ()) } @@ -262,11 +260,17 @@ impl Mutex { } } -impl LockTestExt for Mutex { +impl<'a, T: 'a> LockTestExt<'a> for Mutex { #[inline] fn held_by_thread(&self) -> LockHeldState { LockMetadata::held_by_thread(&self.deps) } + type ExclLock = MutexGuard<'a, T>; + #[inline] + fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard { + LockMetadata::pre_lock(&self.deps, true); + self.inner.lock().map(|lock| MutexGuard { mutex: self, lock }).unwrap() + } } pub struct RwLock { @@ -276,7 +280,6 @@ pub struct RwLock { pub struct RwLockReadGuard<'a, T: Sized + 'a> { lock: &'a RwLock, - first_lock: bool, guard: StdRwLockReadGuard<'a, T>, } @@ -295,12 +298,6 @@ impl Deref for RwLockReadGuard<'_, T> { impl Drop for RwLockReadGuard<'_, T> { fn drop(&mut self) { - if !self.first_lock { - // Note that its not strictly true that the first taken read lock will get unlocked - // last, but in practice our locks are always taken as RAII, so it should basically - // always be true. - return; - } LOCKS_HELD.with(|held| { held.borrow_mut().remove(&self.lock.deps.lock_idx); }); @@ -335,12 +332,16 @@ impl RwLock { } pub fn read<'a>(&'a self) -> LockResult> { - let first_lock = LockMetadata::pre_read_lock(&self.deps); - self.inner.read().map(|guard| RwLockReadGuard { lock: self, guard, first_lock }).map_err(|_| ()) + // Note that while we could be taking a recursive read lock here, Rust's `RwLock` may + // deadlock trying to take a second read lock if another thread is waiting on the write + // lock. This behavior is platform dependent, but our in-tree `FairRwLock` guarantees + // such a deadlock. + LockMetadata::pre_lock(&self.deps, false); + self.inner.read().map(|guard| RwLockReadGuard { lock: self, guard }).map_err(|_| ()) } pub fn write<'a>(&'a self) -> LockResult> { - LockMetadata::pre_lock(&self.deps); + LockMetadata::pre_lock(&self.deps, false); self.inner.write().map(|guard| RwLockWriteGuard { lock: self, guard }).map_err(|_| ()) } @@ -353,11 +354,17 @@ impl RwLock { } } -impl LockTestExt for RwLock { +impl<'a, T: 'a> LockTestExt<'a> for RwLock { #[inline] fn held_by_thread(&self) -> LockHeldState { LockMetadata::held_by_thread(&self.deps) } + type ExclLock = RwLockWriteGuard<'a, T>; + #[inline] + fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<'a, T> { + LockMetadata::pre_lock(&self.deps, true); + self.inner.write().map(|guard| RwLockWriteGuard { lock: self, guard }).unwrap() + } } pub type FairRwLock = RwLock; diff --git a/lightning/src/sync/fairrwlock.rs b/lightning/src/sync/fairrwlock.rs index a9519ac2..de609d5b 100644 --- a/lightning/src/sync/fairrwlock.rs +++ b/lightning/src/sync/fairrwlock.rs @@ -50,10 +50,15 @@ impl FairRwLock { } } -impl LockTestExt for FairRwLock { +impl<'a, T: 'a> LockTestExt<'a> for FairRwLock { #[inline] fn held_by_thread(&self) -> LockHeldState { // fairrwlock is only built in non-test modes, so we should never support tests. LockHeldState::Unsupported } + type ExclLock = RwLockWriteGuard<'a, T>; + #[inline] + fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard<'a, T> { + self.write().unwrap() + } } diff --git a/lightning/src/sync/mod.rs b/lightning/src/sync/mod.rs index 50ef40e2..1b2b9a73 100644 --- a/lightning/src/sync/mod.rs +++ b/lightning/src/sync/mod.rs @@ -7,8 +7,17 @@ pub(crate) enum LockHeldState { Unsupported, } -pub(crate) trait LockTestExt { +pub(crate) trait LockTestExt<'a> { fn held_by_thread(&self) -> LockHeldState; + type ExclLock; + /// If two instances of the same mutex are being taken at the same time, it's very easy to have + /// a lockorder inversion and risk deadlock. Thus, we default to disabling such locks. + /// + /// However, sometimes they cannot be avoided. In such cases, this method exists to take a + /// mutex while avoiding a test failure. It is deliberately verbose and includes the term + /// "unsafe" to indicate that special care needs to be taken to ensure no deadlocks are + /// possible. + fn unsafe_well_ordered_double_lock_self(&'a self) -> Self::ExclLock; } #[cfg(all(feature = "std", not(feature = "_bench_unstable"), test))] @@ -27,13 +36,19 @@ pub use {std::sync::{Arc, Mutex, Condvar, MutexGuard, RwLock, RwLockReadGuard, R #[cfg(all(feature = "std", any(feature = "_bench_unstable", not(test))))] mod ext_impl { use super::*; - impl LockTestExt for Mutex { + impl<'a, T: 'a> LockTestExt<'a> for Mutex { #[inline] fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported } + type ExclLock = MutexGuard<'a, T>; + #[inline] + fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard { self.lock().unwrap() } } - impl LockTestExt for RwLock { + impl<'a, T: 'a> LockTestExt<'a> for RwLock { #[inline] fn held_by_thread(&self) -> LockHeldState { LockHeldState::Unsupported } + type ExclLock = RwLockWriteGuard<'a, T>; + #[inline] + fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard { self.write().unwrap() } } } diff --git a/lightning/src/sync/nostd_sync.rs b/lightning/src/sync/nostd_sync.rs index e17aa6ab..858f60db 100644 --- a/lightning/src/sync/nostd_sync.rs +++ b/lightning/src/sync/nostd_sync.rs @@ -62,12 +62,15 @@ impl Mutex { } } -impl LockTestExt for Mutex { +impl<'a, T: 'a> LockTestExt<'a> for Mutex { #[inline] fn held_by_thread(&self) -> LockHeldState { if self.lock().is_err() { return LockHeldState::HeldByThread; } else { return LockHeldState::NotHeldByThread; } } + type ExclLock = MutexGuard<'a, T>; + #[inline] + fn unsafe_well_ordered_double_lock_self(&'a self) -> MutexGuard { self.lock().unwrap() } } pub struct RwLock { @@ -125,12 +128,15 @@ impl RwLock { } } -impl LockTestExt for RwLock { +impl<'a, T: 'a> LockTestExt<'a> for RwLock { #[inline] fn held_by_thread(&self) -> LockHeldState { if self.write().is_err() { return LockHeldState::HeldByThread; } else { return LockHeldState::NotHeldByThread; } } + type ExclLock = RwLockWriteGuard<'a, T>; + #[inline] + fn unsafe_well_ordered_double_lock_self(&'a self) -> RwLockWriteGuard { self.write().unwrap() } } pub type FairRwLock = RwLock; diff --git a/lightning/src/sync/test_lockorder_checks.rs b/lightning/src/sync/test_lockorder_checks.rs index a3f746b1..6d72410b 100644 --- a/lightning/src/sync/test_lockorder_checks.rs +++ b/lightning/src/sync/test_lockorder_checks.rs @@ -15,6 +15,8 @@ fn recursive_lock_fail() { } #[test] +#[should_panic] +#[cfg(not(feature = "backtrace"))] fn recursive_read() { let lock = RwLock::new(()); let _a = lock.read().unwrap(); @@ -66,23 +68,6 @@ fn read_lockorder_fail() { } } -#[test] -fn read_recursive_no_lockorder() { - // Like the above, but note that no lockorder is implied when we recursively read-lock a - // RwLock, causing this to pass just fine. - let a = RwLock::new(()); - let b = RwLock::new(()); - let _outer = a.read().unwrap(); - { - let _a = a.read().unwrap(); - let _b = b.read().unwrap(); - } - { - let _b = b.read().unwrap(); - let _a = a.read().unwrap(); - } -} - #[test] #[should_panic] fn read_write_lockorder_fail() {