Remove unnecessary aquiring of the `channel_state` lock
authorViktor Tigerström <11711198+ViktorTigerstrom@users.noreply.github.com>
Mon, 8 Aug 2022 19:54:06 +0000 (21:54 +0200)
committerViktor Tigerström <11711198+ViktorTigerstrom@users.noreply.github.com>
Sun, 18 Sep 2022 21:13:56 +0000 (23:13 +0200)
lightning/src/ln/channelmanager.rs

index 39866d3d8858e474f00d2940cca9c8d9c67fd4f5..cce6a7f97c8b0ae376fcd05871961538427730a0 100644 (file)
@@ -1887,7 +1887,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                for htlc_source in failed_htlcs.drain(..) {
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(*counterparty_node_id), channel_id: *channel_id };
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                       self.fail_htlc_backwards_internal(htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                }
 
                let _ = handle_error!(self, result, *counterparty_node_id);
@@ -1945,7 +1945,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                for htlc_source in failed_htlcs.drain(..) {
                        let (source, payment_hash, counterparty_node_id, channel_id) = htlc_source;
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id), channel_id: channel_id };
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                       self.fail_htlc_backwards_internal(source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                }
                if let Some((funding_txo, monitor_update)) = monitor_update_option {
                        // There isn't anything we can do if we get an update failure - we're already
@@ -3005,13 +3005,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let mut phantom_receives: Vec<(u64, OutPoint, Vec<(PendingHTLCInfo, u64)>)> = Vec::new();
                let mut handle_errors = Vec::new();
                {
-                       let mut channel_state_lock = self.channel_state.lock().unwrap();
-                       let channel_state = &mut *channel_state_lock;
-
                        let mut forward_htlcs = HashMap::new();
                        mem::swap(&mut forward_htlcs, &mut self.forward_htlcs.lock().unwrap());
 
                        for (short_chan_id, mut pending_forwards) in forward_htlcs {
+                               let mut channel_state_lock = self.channel_state.lock().unwrap();
+                               let channel_state = &mut *channel_state_lock;
                                if short_chan_id != 0 {
                                        let forward_chan_id = match channel_state.short_to_chan_info.get(&short_chan_id) {
                                                Some((_cp_id, chan_id)) => chan_id.clone(),
@@ -3409,7 +3408,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                }
 
                for (htlc_source, payment_hash, failure_reason, destination) in failed_forwards.drain(..) {
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source, &payment_hash, failure_reason, destination);
+                       self.fail_htlc_backwards_internal(htlc_source, &payment_hash, failure_reason, destination);
                }
                self.forward_htlcs(&mut phantom_receives);
 
@@ -3633,7 +3632,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                        for htlc_source in timed_out_mpp_htlcs.drain(..) {
                                let receiver = HTLCDestination::FailedPayment { payment_hash: htlc_source.1 };
-                               self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), HTLCSource::PreviousHopData(htlc_source.0.clone()), &htlc_source.1, HTLCFailReason::Reason { failure_code: 23, data: Vec::new() }, receiver );
+                               self.fail_htlc_backwards_internal(HTLCSource::PreviousHopData(htlc_source.0.clone()), &htlc_source.1, HTLCFailReason::Reason { failure_code: 23, data: Vec::new() }, receiver );
                        }
 
                        for (err, counterparty_node_id) in handle_errors.drain(..) {
@@ -3659,15 +3658,16 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash) {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
-               let mut channel_state = Some(self.channel_state.lock().unwrap());
-               let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(payment_hash);
+               let removed_source = {
+                       let mut channel_state = self.channel_state.lock().unwrap();
+                       channel_state.claimable_htlcs.remove(payment_hash)
+               };
                if let Some((_, mut sources)) = removed_source {
                        for htlc in sources.drain(..) {
-                               if channel_state.is_none() { channel_state = Some(self.channel_state.lock().unwrap()); }
                                let mut htlc_msat_height_data = byte_utils::be64_to_array(htlc.value).to_vec();
                                htlc_msat_height_data.extend_from_slice(&byte_utils::be32_to_array(
                                                self.best_block.read().unwrap().height()));
-                               self.fail_htlc_backwards_internal(channel_state.take().unwrap(),
+                               self.fail_htlc_backwards_internal(
                                                HTLCSource::PreviousHopData(htlc.prev_hop), payment_hash,
                                                HTLCFailReason::Reason { failure_code: 0x4000 | 15, data: htlc_msat_height_data },
                                                HTLCDestination::FailedPayment { payment_hash: *payment_hash });
@@ -3730,9 +3730,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                counterparty_node_id: &PublicKey
        ) {
                for (htlc_src, payment_hash) in htlcs_to_fail.drain(..) {
-                       let mut channel_state = self.channel_state.lock().unwrap();
                        let (failure_code, onion_failure_data) =
-                               match channel_state.by_id.entry(channel_id) {
+                               match self.channel_state.lock().unwrap().by_id.entry(channel_id) {
                                        hash_map::Entry::Occupied(chan_entry) => {
                                                self.get_htlc_inbound_temp_fail_err_and_data(0x1000|7, &chan_entry.get())
                                        },
@@ -3740,17 +3739,22 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                };
 
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id.clone()), channel_id };
-                       self.fail_htlc_backwards_internal(channel_state, htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data }, receiver);
+                       self.fail_htlc_backwards_internal(htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data }, receiver);
                }
        }
 
        /// Fails an HTLC backwards to the sender of it to us.
-       /// Note that while we take a channel_state lock as input, we do *not* assume consistency here.
-       /// There are several callsites that do stupid things like loop over a list of payment_hashes
-       /// to fail and take the channel_state lock for each iteration (as we take ownership and may
-       /// drop it). In other words, no assumptions are made that entries in claimable_htlcs point to
-       /// still-available channels.
-       fn fail_htlc_backwards_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason, destination: HTLCDestination) {
+       /// Note that we do not assume that channels corresponding to failed HTLCs are still available.
+       fn fail_htlc_backwards_internal(&self, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason,destination: HTLCDestination) {
+               #[cfg(debug_assertions)]
+               {
+                       // Ensure that the `channel_state` lock is not held when calling this function.
+                       // This ensures that future code doesn't introduce a lock_order requirement for
+                       // `forward_htlcs` to be locked after the `channel_state` lock, which calling this
+                       // function with the `channel_state` locked would.
+                       assert!(self.channel_state.try_lock().is_ok());
+               }
+
                //TODO: There is a timing attack here where if a node fails an HTLC back to us they can
                //identify whether we sent it or not based on the (I presume) very different runtime
                //between the branches here. We should make this async and move it into the forward HTLCs
@@ -3789,7 +3793,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                        return;
                                }
-                               mem::drop(channel_state_lock);
                                let mut retry = if let Some(payment_params_data) = payment_params {
                                        let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
                                        Some(RouteParameters {
@@ -3923,7 +3926,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        }
                                }
                                mem::drop(forward_htlcs);
-                               mem::drop(channel_state_lock);
                                let mut pending_events = self.pending_events.lock().unwrap();
                                if let Some(time) = forward_event {
                                        pending_events.push(events::Event::PendingHTLCsForwardable {
@@ -3961,8 +3963,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
-               let mut channel_state = Some(self.channel_state.lock().unwrap());
-               let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&payment_hash);
+               let removed_source = self.channel_state.lock().unwrap().claimable_htlcs.remove(&payment_hash);
                if let Some((payment_purpose, mut sources)) = removed_source {
                        assert!(!sources.is_empty());
 
@@ -3980,8 +3981,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        let mut claimable_amt_msat = 0;
                        let mut expected_amt_msat = None;
                        let mut valid_mpp = true;
+                       let mut errs = Vec::new();
+                       let mut claimed_any_htlcs = false;
+                       let mut channel_state_lock = self.channel_state.lock().unwrap();
+                       let channel_state = &mut *channel_state_lock;
                        for htlc in sources.iter() {
-                               if let None = channel_state.as_ref().unwrap().short_to_chan_info.get(&htlc.prev_hop.short_channel_id) {
+                               if let None = channel_state.short_to_chan_info.get(&htlc.prev_hop.short_channel_id) {
                                        valid_mpp = false;
                                        break;
                                }
@@ -4014,21 +4019,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        expected_amt_msat.unwrap(), claimable_amt_msat);
                                return;
                        }
-
-                       let mut errs = Vec::new();
-                       let mut claimed_any_htlcs = false;
-                       for htlc in sources.drain(..) {
-                               if !valid_mpp {
-                                       if channel_state.is_none() { channel_state = Some(self.channel_state.lock().unwrap()); }
-                                       let mut htlc_msat_height_data = byte_utils::be64_to_array(htlc.value).to_vec();
-                                       htlc_msat_height_data.extend_from_slice(&byte_utils::be32_to_array(
-                                                       self.best_block.read().unwrap().height()));
-                                       self.fail_htlc_backwards_internal(channel_state.take().unwrap(),
-                                                                        HTLCSource::PreviousHopData(htlc.prev_hop), &payment_hash,
-                                                                        HTLCFailReason::Reason { failure_code: 0x4000|15, data: htlc_msat_height_data },
-                                                                        HTLCDestination::FailedPayment { payment_hash } );
-                               } else {
-                                       match self.claim_funds_from_hop(channel_state.as_mut().unwrap(), htlc.prev_hop, payment_preimage) {
+                       if valid_mpp {
+                               for htlc in sources.drain(..) {
+                                       match self.claim_funds_from_hop(&mut channel_state_lock, htlc.prev_hop, payment_preimage) {
                                                ClaimFundsFromHop::MonitorUpdateFail(pk, err, _) => {
                                                        if let msgs::ErrorAction::IgnoreError = err.err.action {
                                                                // We got a temporary failure updating monitor, but will claim the
@@ -4048,6 +4041,18 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        }
                                }
                        }
+                       mem::drop(channel_state_lock);
+                       if !valid_mpp {
+                               for htlc in sources.drain(..) {
+                                       let mut htlc_msat_height_data = byte_utils::be64_to_array(htlc.value).to_vec();
+                                       htlc_msat_height_data.extend_from_slice(&byte_utils::be32_to_array(
+                                               self.best_block.read().unwrap().height()));
+                                       self.fail_htlc_backwards_internal(
+                                               HTLCSource::PreviousHopData(htlc.prev_hop), &payment_hash,
+                                               HTLCFailReason::Reason { failure_code: 0x4000|15, data: htlc_msat_height_data },
+                                               HTLCDestination::FailedPayment { payment_hash } );
+                               }
+                       }
 
                        if claimed_any_htlcs {
                                self.pending_events.lock().unwrap().push(events::Event::PaymentClaimed {
@@ -4057,10 +4062,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                });
                        }
 
-                       // Now that we've done the entire above loop in one lock, we can handle any errors
-                       // which were generated.
-                       channel_state.take();
-
+                       // Now we can handle any errors which were generated.
                        for (counterparty_node_id, err) in errs.drain(..) {
                                let res: Result<(), _> = Err(err);
                                let _ = handle_error!(self, res, counterparty_node_id);
@@ -4307,7 +4309,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                self.finalize_claims(finalized_claims);
                for failure in pending_failures.drain(..) {
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id), channel_id: funding_txo.to_channel_id() };
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), failure.0, &failure.1, failure.2, receiver);
+                       self.fail_htlc_backwards_internal(failure.0, &failure.1, failure.2, receiver);
                }
        }
 
@@ -4667,7 +4669,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                };
                for htlc_source in dropped_htlcs.drain(..) {
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id.clone()), channel_id: msg.channel_id };
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                       self.fail_htlc_backwards_internal(htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                }
 
                let _ = handle_error!(self, result, *counterparty_node_id);
@@ -4869,7 +4871,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                for &mut (prev_short_channel_id, prev_funding_outpoint, ref mut pending_forwards) in per_source_pending_forwards {
                        let mut forward_event = None;
                        if !pending_forwards.is_empty() {
-                               let mut channel_state = self.channel_state.lock().unwrap();
                                let mut forward_htlcs = self.forward_htlcs.lock().unwrap();
                                if forward_htlcs.is_empty() {
                                        forward_event = Some(Duration::from_millis(MIN_HTLC_RELAY_HOLDING_CELL_MILLIS))
@@ -4956,7 +4957,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        {
                                for failure in pending_failures.drain(..) {
                                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(*counterparty_node_id), channel_id: channel_outpoint.to_channel_id() };
-                                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), failure.0, &failure.1, failure.2, receiver);
+                                       self.fail_htlc_backwards_internal(failure.0, &failure.1, failure.2, receiver);
                                }
                                self.forward_htlcs(&mut [(short_channel_id, channel_outpoint, pending_forwards)]);
                                self.finalize_claims(finalized_claim_htlcs);
@@ -5114,7 +5115,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                } else {
                                                        log_trace!(self.logger, "Failing HTLC with hash {} from our monitor", log_bytes!(htlc_update.payment_hash.0));
                                                        let receiver = HTLCDestination::NextHopChannel { node_id: counterparty_node_id, channel_id: funding_outpoint.to_channel_id() };
-                                                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                                                       self.fail_htlc_backwards_internal(htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                                                }
                                        },
                                        MonitorEvent::CommitmentTxConfirmed(funding_outpoint) |
@@ -5848,7 +5849,7 @@ where
                self.handle_init_event_channel_failures(failed_channels);
 
                for (source, payment_hash, reason, destination) in timed_out_htlcs.drain(..) {
-                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), source, &payment_hash, reason, destination);
+                       self.fail_htlc_backwards_internal(source, &payment_hash, reason, destination);
                }
        }
 
@@ -7210,7 +7211,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                for htlc_source in failed_htlcs.drain(..) {
                        let (source, payment_hash, counterparty_node_id, channel_id) = htlc_source;
                        let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id), channel_id };
-                       channel_manager.fail_htlc_backwards_internal(channel_manager.channel_state.lock().unwrap(), source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
+                       channel_manager.fail_htlc_backwards_internal(source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
                }
 
                //TODO: Broadcast channel update for closed channels, but only after we've made a