Merge pull request #1731 from valentinewallace/2022-09-router-test-utils-mod
authorvalentinewallace <valentinewallace@users.noreply.github.com>
Thu, 22 Sep 2022 15:51:47 +0000 (11:51 -0400)
committerGitHub <noreply@github.com>
Thu, 22 Sep 2022 15:51:47 +0000 (11:51 -0400)
Move router test utils into their own module

lightning/src/ln/channelmanager.rs
lightning/src/ln/onion_route_tests.rs

index 4b9d67b889be28c9c9c0ab0ad45d10d44f11bc88..72f1f4d596f80597736ec1bb59e9b27673eddf96 100644 (file)
@@ -401,16 +401,6 @@ pub(super) struct ChannelHolder<Signer: Sign> {
        /// SCIDs being added once the funding transaction is confirmed at the channel's required
        /// confirmation depth.
        pub(super) short_to_chan_info: HashMap<u64, (PublicKey, [u8; 32])>,
-       /// SCID/SCID Alias -> forward infos. Key of 0 means payments received.
-       ///
-       /// Note that because we may have an SCID Alias as the key we can have two entries per channel,
-       /// though in practice we probably won't be receiving HTLCs for a channel both via the alias
-       /// and via the classic SCID.
-       ///
-       /// Note that while this is held in the same mutex as the channels themselves, no consistency
-       /// guarantees are made about the existence of a channel with the short id here, nor the short
-       /// ids in the PendingHTLCInfo!
-       pub(super) forward_htlcs: HashMap<u64, Vec<HTLCForwardInfo>>,
        /// Map from payment hash to the payment data and any HTLCs which are to us and can be
        /// failed/claimed by the user.
        ///
@@ -677,6 +667,38 @@ pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L> = ChannelManage
 /// essentially you should default to using a SimpleRefChannelManager, and use a
 /// SimpleArcChannelManager when you require a ChannelManager with a static lifetime, such as when
 /// you're using lightning-net-tokio.
+//
+// Lock order:
+// The tree structure below illustrates the lock order requirements for the different locks of the
+// `ChannelManager`. Locks can be held at the same time if they are on the same branch in the tree,
+// and should then be taken in the order of the lowest to the highest level in the tree.
+// Note that locks on different branches shall not be taken at the same time, as doing so will
+// create a new lock order for those specific locks in the order they were taken.
+//
+// Lock order tree:
+//
+// `total_consistency_lock`
+//  |
+//  |__`forward_htlcs`
+//  |
+//  |__`channel_state`
+//  |   |
+//  |   |__`id_to_peer`
+//  |   |
+//  |   |__`per_peer_state`
+//  |       |
+//  |       |__`outbound_scid_aliases`
+//  |       |
+//  |       |__`pending_inbound_payments`
+//  |           |
+//  |           |__`pending_outbound_payments`
+//  |               |
+//  |               |__`best_block`
+//  |               |
+//  |               |__`pending_events`
+//  |                   |
+//  |                   |__`pending_background_events`
+//
 pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
        where M::Target: chain::Watch<Signer>,
         T::Target: BroadcasterInterface,
@@ -690,12 +712,14 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        chain_monitor: M,
        tx_broadcaster: T,
 
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
        #[cfg(test)]
        pub(super) best_block: RwLock<BestBlock>,
        #[cfg(not(test))]
        best_block: RwLock<BestBlock>,
        secp_ctx: Secp256k1<secp256k1::All>,
 
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
        #[cfg(any(test, feature = "_test_utils"))]
        pub(super) channel_state: Mutex<ChannelHolder<Signer>>,
        #[cfg(not(any(test, feature = "_test_utils")))]
@@ -705,7 +729,8 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// expose them to users via a PaymentReceived event. HTLCs which do not meet the requirements
        /// here are failed when we process them as pending-forwardable-HTLCs, and entries are removed
        /// after we generate a PaymentReceived upon receipt of all MPP parts or when they time out.
-       /// Locked *after* channel_state.
+       ///
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
        pending_inbound_payments: Mutex<HashMap<PaymentHash, PendingInboundPayment>>,
 
        /// The session_priv bytes and retry metadata of outbound payments which are pending resolution.
@@ -719,13 +744,30 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        ///
        /// See `PendingOutboundPayment` documentation for more info.
        ///
-       /// Locked *after* channel_state.
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
        pending_outbound_payments: Mutex<HashMap<PaymentId, PendingOutboundPayment>>,
 
+       /// SCID/SCID Alias -> forward infos. Key of 0 means payments received.
+       ///
+       /// Note that because we may have an SCID Alias as the key we can have two entries per channel,
+       /// though in practice we probably won't be receiving HTLCs for a channel both via the alias
+       /// and via the classic SCID.
+       ///
+       /// Note that no consistency guarantees are made about the existence of a channel with the
+       /// `short_channel_id` here, nor the `short_channel_id` in the `PendingHTLCInfo`!
+       ///
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
+       #[cfg(test)]
+       pub(super) forward_htlcs: Mutex<HashMap<u64, Vec<HTLCForwardInfo>>>,
+       #[cfg(not(test))]
+       forward_htlcs: Mutex<HashMap<u64, Vec<HTLCForwardInfo>>>,
+
        /// The set of outbound SCID aliases across all our channels, including unconfirmed channels
        /// and some closed channels which reached a usable state prior to being closed. This is used
        /// only to avoid duplicates, and is not persisted explicitly to disk, but rebuilt from the
        /// active channel list on load.
+       ///
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
        outbound_scid_aliases: Mutex<HashSet<u64>>,
 
        /// `channel_id` -> `counterparty_node_id`.
@@ -745,6 +787,8 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// We should add `counterparty_node_id`s to `MonitorEvent`s, and eventually rely on it in the
        /// future. That would make this map redundant, as only the `ChannelManager::per_peer_state` is
        /// required to access the channel with the `counterparty_node_id`.
+       ///
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
        id_to_peer: Mutex<HashMap<[u8; 32], PublicKey>>,
 
        our_network_key: SecretKey,
@@ -776,10 +820,12 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// operate on the inner value freely. Sadly, this prevents parallel operation when opening a
        /// new channel.
        ///
-       /// If also holding `channel_state` lock, must lock `channel_state` prior to `per_peer_state`.
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
        per_peer_state: RwLock<HashMap<PublicKey, Mutex<PeerState>>>,
 
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
        pending_events: Mutex<Vec<events::Event>>,
+       /// See `ChannelManager` struct-level documentation for lock order requirements.
        pending_background_events: Mutex<Vec<BackgroundEvent>>,
        /// Used when we have to take a BIG lock to make sure everything is self-consistent.
        /// Essentially just when we're serializing ourselves out.
@@ -1595,13 +1641,13 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        channel_state: Mutex::new(ChannelHolder{
                                by_id: HashMap::new(),
                                short_to_chan_info: HashMap::new(),
-                               forward_htlcs: HashMap::new(),
                                claimable_htlcs: HashMap::new(),
                                pending_msg_events: Vec::new(),
                        }),
                        outbound_scid_aliases: Mutex::new(HashSet::new()),
                        pending_inbound_payments: Mutex::new(HashMap::new()),
                        pending_outbound_payments: Mutex::new(HashMap::new()),
+                       forward_htlcs: Mutex::new(HashMap::new()),
                        id_to_peer: Mutex::new(HashMap::new()),
 
                        our_network_key: keys_manager.get_node_secret(Recipient::Node).unwrap(),
@@ -1884,7 +1930,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                for htlc_source in failed_htlcs.drain(..) {
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(*counterparty_node_id), channel_id: *channel_id };
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                       self.fail_htlc_backwards_internal(htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                }
 
                let _ = handle_error!(self, result, *counterparty_node_id);
@@ -1942,7 +1988,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                for htlc_source in failed_htlcs.drain(..) {
                        let (source, payment_hash, counterparty_node_id, channel_id) = htlc_source;
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id), channel_id: channel_id };
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                       self.fail_htlc_backwards_internal(source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                }
                if let Some((funding_txo, monitor_update)) = monitor_update_option {
                        // There isn't anything we can do if we get an update failure - we're already
@@ -3002,10 +3048,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let mut phantom_receives: Vec<(u64, OutPoint, Vec<(PendingHTLCInfo, u64)>)> = Vec::new();
                let mut handle_errors = Vec::new();
                {
-                       let mut channel_state_lock = self.channel_state.lock().unwrap();
-                       let channel_state = &mut *channel_state_lock;
+                       let mut forward_htlcs = HashMap::new();
+                       mem::swap(&mut forward_htlcs, &mut self.forward_htlcs.lock().unwrap());
 
-                       for (short_chan_id, mut pending_forwards) in channel_state.forward_htlcs.drain() {
+                       for (short_chan_id, mut pending_forwards) in forward_htlcs {
+                               let mut channel_state_lock = self.channel_state.lock().unwrap();
+                               let channel_state = &mut *channel_state_lock;
                                if short_chan_id != 0 {
                                        let forward_chan_id = match channel_state.short_to_chan_info.get(&short_chan_id) {
                                                Some((_cp_id, chan_id)) => chan_id.clone(),
@@ -3403,7 +3451,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                }
 
                for (htlc_source, payment_hash, failure_reason, destination) in failed_forwards.drain(..) {
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source, &payment_hash, failure_reason, destination);
+                       self.fail_htlc_backwards_internal(htlc_source, &payment_hash, failure_reason, destination);
                }
                self.forward_htlcs(&mut phantom_receives);
 
@@ -3627,7 +3675,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                        for htlc_source in timed_out_mpp_htlcs.drain(..) {
                                let receiver = HTLCDestination::FailedPayment { payment_hash: htlc_source.1 };
-                               self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), HTLCSource::PreviousHopData(htlc_source.0.clone()), &htlc_source.1, HTLCFailReason::Reason { failure_code: 23, data: Vec::new() }, receiver );
+                               self.fail_htlc_backwards_internal(HTLCSource::PreviousHopData(htlc_source.0.clone()), &htlc_source.1, HTLCFailReason::Reason { failure_code: 23, data: Vec::new() }, receiver );
                        }
 
                        for (err, counterparty_node_id) in handle_errors.drain(..) {
@@ -3653,15 +3701,16 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash) {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
-               let mut channel_state = Some(self.channel_state.lock().unwrap());
-               let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(payment_hash);
+               let removed_source = {
+                       let mut channel_state = self.channel_state.lock().unwrap();
+                       channel_state.claimable_htlcs.remove(payment_hash)
+               };
                if let Some((_, mut sources)) = removed_source {
                        for htlc in sources.drain(..) {
-                               if channel_state.is_none() { channel_state = Some(self.channel_state.lock().unwrap()); }
                                let mut htlc_msat_height_data = byte_utils::be64_to_array(htlc.value).to_vec();
                                htlc_msat_height_data.extend_from_slice(&byte_utils::be32_to_array(
                                                self.best_block.read().unwrap().height()));
-                               self.fail_htlc_backwards_internal(channel_state.take().unwrap(),
+                               self.fail_htlc_backwards_internal(
                                                HTLCSource::PreviousHopData(htlc.prev_hop), payment_hash,
                                                HTLCFailReason::Reason { failure_code: 0x4000 | 15, data: htlc_msat_height_data },
                                                HTLCDestination::FailedPayment { payment_hash: *payment_hash });
@@ -3724,9 +3773,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                counterparty_node_id: &PublicKey
        ) {
                for (htlc_src, payment_hash) in htlcs_to_fail.drain(..) {
-                       let mut channel_state = self.channel_state.lock().unwrap();
                        let (failure_code, onion_failure_data) =
-                               match channel_state.by_id.entry(channel_id) {
+                               match self.channel_state.lock().unwrap().by_id.entry(channel_id) {
                                        hash_map::Entry::Occupied(chan_entry) => {
                                                self.get_htlc_inbound_temp_fail_err_and_data(0x1000|7, &chan_entry.get())
                                        },
@@ -3734,17 +3782,22 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                };
 
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id.clone()), channel_id };
-                       self.fail_htlc_backwards_internal(channel_state, htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data }, receiver);
+                       self.fail_htlc_backwards_internal(htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data }, receiver);
                }
        }
 
        /// Fails an HTLC backwards to the sender of it to us.
-       /// Note that while we take a channel_state lock as input, we do *not* assume consistency here.
-       /// There are several callsites that do stupid things like loop over a list of payment_hashes
-       /// to fail and take the channel_state lock for each iteration (as we take ownership and may
-       /// drop it). In other words, no assumptions are made that entries in claimable_htlcs point to
-       /// still-available channels.
-       fn fail_htlc_backwards_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason, destination: HTLCDestination) {
+       /// Note that we do not assume that channels corresponding to failed HTLCs are still available.
+       fn fail_htlc_backwards_internal(&self, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason,destination: HTLCDestination) {
+               #[cfg(debug_assertions)]
+               {
+                       // Ensure that the `channel_state` lock is not held when calling this function.
+                       // This ensures that future code doesn't introduce a lock_order requirement for
+                       // `forward_htlcs` to be locked after the `channel_state` lock, which calling this
+                       // function with the `channel_state` locked would.
+                       assert!(self.channel_state.try_lock().is_ok());
+               }
+
                //TODO: There is a timing attack here where if a node fails an HTLC back to us they can
                //identify whether we sent it or not based on the (I presume) very different runtime
                //between the branches here. We should make this async and move it into the forward HTLCs
@@ -3783,7 +3836,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                        return;
                                }
-                               mem::drop(channel_state_lock);
                                let mut retry = if let Some(payment_params_data) = payment_params {
                                        let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
                                        Some(RouteParameters {
@@ -3904,10 +3956,11 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                };
 
                                let mut forward_event = None;
-                               if channel_state_lock.forward_htlcs.is_empty() {
+                               let mut forward_htlcs = self.forward_htlcs.lock().unwrap();
+                               if forward_htlcs.is_empty() {
                                        forward_event = Some(Duration::from_millis(MIN_HTLC_RELAY_HOLDING_CELL_MILLIS));
                                }
-                               match channel_state_lock.forward_htlcs.entry(short_channel_id) {
+                               match forward_htlcs.entry(short_channel_id) {
                                        hash_map::Entry::Occupied(mut entry) => {
                                                entry.get_mut().push(HTLCForwardInfo::FailHTLC { htlc_id, err_packet });
                                        },
@@ -3915,7 +3968,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                entry.insert(vec!(HTLCForwardInfo::FailHTLC { htlc_id, err_packet }));
                                        }
                                }
-                               mem::drop(channel_state_lock);
+                               mem::drop(forward_htlcs);
                                let mut pending_events = self.pending_events.lock().unwrap();
                                if let Some(time) = forward_event {
                                        pending_events.push(events::Event::PendingHTLCsForwardable {
@@ -3953,8 +4006,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
-               let mut channel_state = Some(self.channel_state.lock().unwrap());
-               let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&payment_hash);
+               let removed_source = self.channel_state.lock().unwrap().claimable_htlcs.remove(&payment_hash);
                if let Some((payment_purpose, mut sources)) = removed_source {
                        assert!(!sources.is_empty());
 
@@ -3972,8 +4024,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        let mut claimable_amt_msat = 0;
                        let mut expected_amt_msat = None;
                        let mut valid_mpp = true;
+                       let mut errs = Vec::new();
+                       let mut claimed_any_htlcs = false;
+                       let mut channel_state_lock = self.channel_state.lock().unwrap();
+                       let channel_state = &mut *channel_state_lock;
                        for htlc in sources.iter() {
-                               if let None = channel_state.as_ref().unwrap().short_to_chan_info.get(&htlc.prev_hop.short_channel_id) {
+                               if let None = channel_state.short_to_chan_info.get(&htlc.prev_hop.short_channel_id) {
                                        valid_mpp = false;
                                        break;
                                }
@@ -4006,21 +4062,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        expected_amt_msat.unwrap(), claimable_amt_msat);
                                return;
                        }
-
-                       let mut errs = Vec::new();
-                       let mut claimed_any_htlcs = false;
-                       for htlc in sources.drain(..) {
-                               if !valid_mpp {
-                                       if channel_state.is_none() { channel_state = Some(self.channel_state.lock().unwrap()); }
-                                       let mut htlc_msat_height_data = byte_utils::be64_to_array(htlc.value).to_vec();
-                                       htlc_msat_height_data.extend_from_slice(&byte_utils::be32_to_array(
-                                                       self.best_block.read().unwrap().height()));
-                                       self.fail_htlc_backwards_internal(channel_state.take().unwrap(),
-                                                                        HTLCSource::PreviousHopData(htlc.prev_hop), &payment_hash,
-                                                                        HTLCFailReason::Reason { failure_code: 0x4000|15, data: htlc_msat_height_data },
-                                                                        HTLCDestination::FailedPayment { payment_hash } );
-                               } else {
-                                       match self.claim_funds_from_hop(channel_state.as_mut().unwrap(), htlc.prev_hop, payment_preimage) {
+                       if valid_mpp {
+                               for htlc in sources.drain(..) {
+                                       match self.claim_funds_from_hop(&mut channel_state_lock, htlc.prev_hop, payment_preimage) {
                                                ClaimFundsFromHop::MonitorUpdateFail(pk, err, _) => {
                                                        if let msgs::ErrorAction::IgnoreError = err.err.action {
                                                                // We got a temporary failure updating monitor, but will claim the
@@ -4040,6 +4084,18 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        }
                                }
                        }
+                       mem::drop(channel_state_lock);
+                       if !valid_mpp {
+                               for htlc in sources.drain(..) {
+                                       let mut htlc_msat_height_data = byte_utils::be64_to_array(htlc.value).to_vec();
+                                       htlc_msat_height_data.extend_from_slice(&byte_utils::be32_to_array(
+                                               self.best_block.read().unwrap().height()));
+                                       self.fail_htlc_backwards_internal(
+                                               HTLCSource::PreviousHopData(htlc.prev_hop), &payment_hash,
+                                               HTLCFailReason::Reason { failure_code: 0x4000|15, data: htlc_msat_height_data },
+                                               HTLCDestination::FailedPayment { payment_hash } );
+                               }
+                       }
 
                        if claimed_any_htlcs {
                                self.pending_events.lock().unwrap().push(events::Event::PaymentClaimed {
@@ -4049,10 +4105,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                });
                        }
 
-                       // Now that we've done the entire above loop in one lock, we can handle any errors
-                       // which were generated.
-                       channel_state.take();
-
+                       // Now we can handle any errors which were generated.
                        for (counterparty_node_id, err) in errs.drain(..) {
                                let res: Result<(), _> = Err(err);
                                let _ = handle_error!(self, res, counterparty_node_id);
@@ -4299,7 +4352,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                self.finalize_claims(finalized_claims);
                for failure in pending_failures.drain(..) {
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id), channel_id: funding_txo.to_channel_id() };
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), failure.0, &failure.1, failure.2, receiver);
+                       self.fail_htlc_backwards_internal(failure.0, &failure.1, failure.2, receiver);
                }
        }
 
@@ -4659,7 +4712,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                };
                for htlc_source in dropped_htlcs.drain(..) {
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id.clone()), channel_id: msg.channel_id };
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                       self.fail_htlc_backwards_internal(htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                }
 
                let _ = handle_error!(self, result, *counterparty_node_id);
@@ -4861,12 +4914,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                for &mut (prev_short_channel_id, prev_funding_outpoint, ref mut pending_forwards) in per_source_pending_forwards {
                        let mut forward_event = None;
                        if !pending_forwards.is_empty() {
-                               let mut channel_state = self.channel_state.lock().unwrap();
-                               if channel_state.forward_htlcs.is_empty() {
+                               let mut forward_htlcs = self.forward_htlcs.lock().unwrap();
+                               if forward_htlcs.is_empty() {
                                        forward_event = Some(Duration::from_millis(MIN_HTLC_RELAY_HOLDING_CELL_MILLIS))
                                }
                                for (forward_info, prev_htlc_id) in pending_forwards.drain(..) {
-                                       match channel_state.forward_htlcs.entry(match forward_info.routing {
+                                       match forward_htlcs.entry(match forward_info.routing {
                                                        PendingHTLCRouting::Forward { short_channel_id, .. } => short_channel_id,
                                                        PendingHTLCRouting::Receive { .. } => 0,
                                                        PendingHTLCRouting::ReceiveKeysend { .. } => 0,
@@ -4947,7 +5000,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        {
                                for failure in pending_failures.drain(..) {
                                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(*counterparty_node_id), channel_id: channel_outpoint.to_channel_id() };
-                                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), failure.0, &failure.1, failure.2, receiver);
+                                       self.fail_htlc_backwards_internal(failure.0, &failure.1, failure.2, receiver);
                                }
                                self.forward_htlcs(&mut [(short_channel_id, channel_outpoint, pending_forwards)]);
                                self.finalize_claims(finalized_claim_htlcs);
@@ -5105,7 +5158,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                } else {
                                                        log_trace!(self.logger, "Failing HTLC with hash {} from our monitor", log_bytes!(htlc_update.payment_hash.0));
                                                        let receiver = HTLCDestination::NextHopChannel { node_id: counterparty_node_id, channel_id: funding_outpoint.to_channel_id() };
-                                                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                                                       self.fail_htlc_backwards_internal(htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                                                }
                                        },
                                        MonitorEvent::CommitmentTxConfirmed(funding_outpoint) |
@@ -5839,7 +5892,7 @@ where
                self.handle_init_event_channel_failures(failed_channels);
 
                for (source, payment_hash, reason, destination) in timed_out_htlcs.drain(..) {
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), source, &payment_hash, reason, destination);
+                       self.fail_htlc_backwards_internal(source, &payment_hash, reason, destination);
                }
        }
 
@@ -6538,29 +6591,37 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                        best_block.block_hash().write(writer)?;
                }
 
-               let channel_state = self.channel_state.lock().unwrap();
-               let mut unfunded_channels = 0;
-               for (_, channel) in channel_state.by_id.iter() {
-                       if !channel.is_funding_initiated() {
-                               unfunded_channels += 1;
+               {
+                       // Take `channel_state` lock temporarily to avoid creating a lock order that requires
+                       // that the `forward_htlcs` lock is taken after `channel_state`
+                       let channel_state = self.channel_state.lock().unwrap();
+                       let mut unfunded_channels = 0;
+                       for (_, channel) in channel_state.by_id.iter() {
+                               if !channel.is_funding_initiated() {
+                                       unfunded_channels += 1;
+                               }
                        }
-               }
-               ((channel_state.by_id.len() - unfunded_channels) as u64).write(writer)?;
-               for (_, channel) in channel_state.by_id.iter() {
-                       if channel.is_funding_initiated() {
-                               channel.write(writer)?;
+                       ((channel_state.by_id.len() - unfunded_channels) as u64).write(writer)?;
+                       for (_, channel) in channel_state.by_id.iter() {
+                               if channel.is_funding_initiated() {
+                                       channel.write(writer)?;
+                               }
                        }
                }
 
-               (channel_state.forward_htlcs.len() as u64).write(writer)?;
-               for (short_channel_id, pending_forwards) in channel_state.forward_htlcs.iter() {
-                       short_channel_id.write(writer)?;
-                       (pending_forwards.len() as u64).write(writer)?;
-                       for forward in pending_forwards {
-                               forward.write(writer)?;
+               {
+                       let forward_htlcs = self.forward_htlcs.lock().unwrap();
+                       (forward_htlcs.len() as u64).write(writer)?;
+                       for (short_channel_id, pending_forwards) in forward_htlcs.iter() {
+                               short_channel_id.write(writer)?;
+                               (pending_forwards.len() as u64).write(writer)?;
+                               for forward in pending_forwards {
+                                       forward.write(writer)?;
+                               }
                        }
                }
 
+               let channel_state = self.channel_state.lock().unwrap();
                let mut htlc_purposes: Vec<&events::PaymentPurpose> = Vec::new();
                (channel_state.claimable_htlcs.len() as u64).write(writer)?;
                for (payment_hash, (purpose, previous_hops)) in channel_state.claimable_htlcs.iter() {
@@ -7165,7 +7226,6 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                        channel_state: Mutex::new(ChannelHolder {
                                by_id,
                                short_to_chan_info,
-                               forward_htlcs,
                                claimable_htlcs,
                                pending_msg_events: Vec::new(),
                        }),
@@ -7173,6 +7233,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                        pending_inbound_payments: Mutex::new(pending_inbound_payments),
                        pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()),
 
+                       forward_htlcs: Mutex::new(forward_htlcs),
                        outbound_scid_aliases: Mutex::new(outbound_scid_aliases),
                        id_to_peer: Mutex::new(id_to_peer),
                        fake_scid_rand_bytes: fake_scid_rand_bytes.unwrap(),
@@ -7200,7 +7261,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                for htlc_source in failed_htlcs.drain(..) {
                        let (source, payment_hash, counterparty_node_id, channel_id) = htlc_source;
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id), channel_id };
-                       channel_manager.fail_htlc_backwards_internal(channel_manager.channel_state.lock().unwrap(), source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                       channel_manager.fail_htlc_backwards_internal(source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                }
 
                //TODO: Broadcast channel update for closed channels, but only after we've made a
index a15e06b5d0c74a7f83604b22f80433bc6101ecdc..76d6723466b52b427d1567526c4a46740928bb5e 100644 (file)
@@ -540,7 +540,7 @@ fn test_onion_failure() {
        }, || {}, true, Some(17), None, None);
 
        run_onion_failure_test("final_incorrect_cltv_expiry", 1, &nodes, &route, &payment_hash, &payment_secret, |_| {}, || {
-               for (_, pending_forwards) in nodes[1].node.channel_state.lock().unwrap().forward_htlcs.iter_mut() {
+               for (_, pending_forwards) in nodes[1].node.forward_htlcs.lock().unwrap().iter_mut() {
                        for f in pending_forwards.iter_mut() {
                                match f {
                                        &mut HTLCForwardInfo::AddHTLC { ref mut forward_info, .. } =>
@@ -553,7 +553,7 @@ fn test_onion_failure() {
 
        run_onion_failure_test("final_incorrect_htlc_amount", 1, &nodes, &route, &payment_hash, &payment_secret, |_| {}, || {
                // violate amt_to_forward > msg.amount_msat
-               for (_, pending_forwards) in nodes[1].node.channel_state.lock().unwrap().forward_htlcs.iter_mut() {
+               for (_, pending_forwards) in nodes[1].node.forward_htlcs.lock().unwrap().iter_mut() {
                        for f in pending_forwards.iter_mut() {
                                match f {
                                        &mut HTLCForwardInfo::AddHTLC { ref mut forward_info, .. } =>
@@ -1021,8 +1021,8 @@ fn test_phantom_onion_hmac_failure() {
 
        // Modify the payload so the phantom hop's HMAC is bogus.
        let sha256_of_onion = {
-               let mut channel_state = nodes[1].node.channel_state.lock().unwrap();
-               let mut pending_forward = channel_state.forward_htlcs.get_mut(&phantom_scid).unwrap();
+               let mut forward_htlcs = nodes[1].node.forward_htlcs.lock().unwrap();
+               let mut pending_forward = forward_htlcs.get_mut(&phantom_scid).unwrap();
                match pending_forward[0] {
                        HTLCForwardInfo::AddHTLC {
                                forward_info: PendingHTLCInfo {
@@ -1081,7 +1081,7 @@ fn test_phantom_invalid_onion_payload() {
        commitment_signed_dance!(nodes[1], nodes[0], &update_0.commitment_signed, false, true);
 
        // Modify the onion packet to have an invalid payment amount.
-       for (_, pending_forwards) in nodes[1].node.channel_state.lock().unwrap().forward_htlcs.iter_mut() {
+       for (_, pending_forwards) in nodes[1].node.forward_htlcs.lock().unwrap().iter_mut() {
                for f in pending_forwards.iter_mut() {
                        match f {
                                &mut HTLCForwardInfo::AddHTLC {
@@ -1152,7 +1152,7 @@ fn test_phantom_final_incorrect_cltv_expiry() {
        commitment_signed_dance!(nodes[1], nodes[0], &update_0.commitment_signed, false, true);
 
        // Modify the payload so the phantom hop's HMAC is bogus.
-       for (_, pending_forwards) in nodes[1].node.channel_state.lock().unwrap().forward_htlcs.iter_mut() {
+       for (_, pending_forwards) in nodes[1].node.forward_htlcs.lock().unwrap().iter_mut() {
                for f in pending_forwards.iter_mut() {
                        match f {
                                &mut HTLCForwardInfo::AddHTLC {