]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Refactor incoming HTLC accept checks out from Channel::update_add_htlc
authorWilmer Paulino <wilmer@wilmerpaulino.com>
Fri, 8 Mar 2024 08:19:40 +0000 (00:19 -0800)
committerWilmer Paulino <wilmer@wilmerpaulino.com>
Wed, 27 Mar 2024 21:28:04 +0000 (14:28 -0700)
In the future, we plan to completely remove
`decode_update_add_htlc_onion` and replace it with a batched variant.
This refactor, while improving readability in its current form, does not
feature any functional changes and allows us to reuse the incoming HTLC
acceptance checks in the batched variant.

lightning/src/ln/channel.rs
lightning/src/ln/channelmanager.rs

index 31b2dab38d1299191305380ab25d6b3c8fc21e55..cddd6a5e533201a976dff5caf41f1b4158304eb7 100644 (file)
@@ -4121,20 +4121,12 @@ impl<SP: Deref> Channel<SP> where
                Ok(self.get_announcement_sigs(node_signer, chain_hash, user_config, best_block.height, logger))
        }
 
-       pub fn update_add_htlc<F, FE: Deref, L: Deref>(
-               &mut self, msg: &msgs::UpdateAddHTLC, mut pending_forward_status: PendingHTLCStatus,
-               create_pending_htlc_status: F, fee_estimator: &LowerBoundedFeeEstimator<FE>, logger: &L
-       ) -> Result<(), ChannelError>
-       where F: for<'a> Fn(&'a Self, PendingHTLCStatus, u16) -> PendingHTLCStatus,
-               FE::Target: FeeEstimator, L::Target: Logger,
-       {
+       pub fn update_add_htlc(
+               &mut self, msg: &msgs::UpdateAddHTLC, pending_forward_status: PendingHTLCStatus,
+       ) -> Result<(), ChannelError> {
                if !matches!(self.context.channel_state, ChannelState::ChannelReady(_)) {
                        return Err(ChannelError::Close("Got add HTLC message when channel was not in an operational state".to_owned()));
                }
-               // We can't accept HTLCs sent after we've sent a shutdown.
-               if self.context.channel_state.is_local_shutdown_sent() {
-                       pending_forward_status = create_pending_htlc_status(self, pending_forward_status, 0x4000|8);
-               }
                // If the remote has sent a shutdown prior to adding this HTLC, then they are in violation of the spec.
                if self.context.channel_state.is_remote_shutdown_sent() {
                        return Err(ChannelError::Close("Got add HTLC message when channel was not in an operational state".to_owned()));
@@ -4153,7 +4145,6 @@ impl<SP: Deref> Channel<SP> where
                }
 
                let inbound_stats = self.context.get_inbound_pending_htlc_stats(None);
-               let outbound_stats = self.context.get_outbound_pending_htlc_stats(None);
                if inbound_stats.pending_htlcs + 1 > self.context.holder_max_accepted_htlcs as u32 {
                        return Err(ChannelError::Close(format!("Remote tried to push more than our max accepted HTLCs ({})", self.context.holder_max_accepted_htlcs)));
                }
@@ -4182,34 +4173,6 @@ impl<SP: Deref> Channel<SP> where
                        }
                }
 
-               let max_dust_htlc_exposure_msat = self.context.get_max_dust_htlc_exposure_msat(fee_estimator);
-               let (htlc_timeout_dust_limit, htlc_success_dust_limit) = if self.context.get_channel_type().supports_anchors_zero_fee_htlc_tx() {
-                       (0, 0)
-               } else {
-                       let dust_buffer_feerate = self.context.get_dust_buffer_feerate(None) as u64;
-                       (dust_buffer_feerate * htlc_timeout_tx_weight(self.context.get_channel_type()) / 1000,
-                               dust_buffer_feerate * htlc_success_tx_weight(self.context.get_channel_type()) / 1000)
-               };
-               let exposure_dust_limit_timeout_sats = htlc_timeout_dust_limit + self.context.counterparty_dust_limit_satoshis;
-               if msg.amount_msat / 1000 < exposure_dust_limit_timeout_sats {
-                       let on_counterparty_tx_dust_htlc_exposure_msat = inbound_stats.on_counterparty_tx_dust_exposure_msat + outbound_stats.on_counterparty_tx_dust_exposure_msat + msg.amount_msat;
-                       if on_counterparty_tx_dust_htlc_exposure_msat > max_dust_htlc_exposure_msat {
-                               log_info!(logger, "Cannot accept value that would put our exposure to dust HTLCs at {} over the limit {} on counterparty commitment tx",
-                                       on_counterparty_tx_dust_htlc_exposure_msat, max_dust_htlc_exposure_msat);
-                               pending_forward_status = create_pending_htlc_status(self, pending_forward_status, 0x1000|7);
-                       }
-               }
-
-               let exposure_dust_limit_success_sats = htlc_success_dust_limit + self.context.holder_dust_limit_satoshis;
-               if msg.amount_msat / 1000 < exposure_dust_limit_success_sats {
-                       let on_holder_tx_dust_htlc_exposure_msat = inbound_stats.on_holder_tx_dust_exposure_msat + outbound_stats.on_holder_tx_dust_exposure_msat + msg.amount_msat;
-                       if on_holder_tx_dust_htlc_exposure_msat > max_dust_htlc_exposure_msat {
-                               log_info!(logger, "Cannot accept value that would put our exposure to dust HTLCs at {} over the limit {} on holder commitment tx",
-                                       on_holder_tx_dust_htlc_exposure_msat, max_dust_htlc_exposure_msat);
-                               pending_forward_status = create_pending_htlc_status(self, pending_forward_status, 0x1000|7);
-                       }
-               }
-
                let pending_value_to_self_msat =
                        self.context.value_to_self_msat + inbound_stats.pending_htlcs_value_msat - removed_outbound_total_msat;
                let pending_remote_value_msat =
@@ -4243,23 +4206,7 @@ impl<SP: Deref> Channel<SP> where
                } else {
                        0
                };
-               if !self.context.is_outbound() {
-                       // `Some(())` is for the fee spike buffer we keep for the remote. This deviates from
-                       // the spec because the fee spike buffer requirement doesn't exist on the receiver's
-                       // side, only on the sender's. Note that with anchor outputs we are no longer as
-                       // sensitive to fee spikes, so we need to account for them.
-                       let htlc_candidate = HTLCCandidate::new(msg.amount_msat, HTLCInitiator::RemoteOffered);
-                       let mut remote_fee_cost_incl_stuck_buffer_msat = self.context.next_remote_commit_tx_fee_msat(htlc_candidate, Some(()));
-                       if !self.context.get_channel_type().supports_anchors_zero_fee_htlc_tx() {
-                               remote_fee_cost_incl_stuck_buffer_msat *= FEE_SPIKE_BUFFER_FEE_INCREASE_MULTIPLE;
-                       }
-                       if pending_remote_value_msat.saturating_sub(msg.amount_msat).saturating_sub(self.context.holder_selected_channel_reserve_satoshis * 1000).saturating_sub(anchor_outputs_value_msat) < remote_fee_cost_incl_stuck_buffer_msat {
-                               // Note that if the pending_forward_status is not updated here, then it's because we're already failing
-                               // the HTLC, i.e. its status is already set to failing.
-                               log_info!(logger, "Attempting to fail HTLC due to fee spike buffer violation in channel {}. Rebalancing is required.", &self.context.channel_id());
-                               pending_forward_status = create_pending_htlc_status(self, pending_forward_status, 0x1000|7);
-                       }
-               } else {
+               if self.context.is_outbound() {
                        // Check that they won't violate our local required channel reserve by adding this HTLC.
                        let htlc_candidate = HTLCCandidate::new(msg.amount_msat, HTLCInitiator::RemoteOffered);
                        let local_commit_tx_fee_msat = self.context.next_local_commit_tx_fee_msat(htlc_candidate, None);
@@ -6148,6 +6095,86 @@ impl<SP: Deref> Channel<SP> where
                        })
        }
 
+       pub fn can_accept_incoming_htlc<F: Deref, L: Deref>(
+               &self, msg: &msgs::UpdateAddHTLC, fee_estimator: &LowerBoundedFeeEstimator<F>, logger: L
+       ) -> Result<(), (&'static str, u16)>
+       where
+               F::Target: FeeEstimator,
+               L::Target: Logger
+       {
+               if self.context.channel_state.is_local_shutdown_sent() {
+                       return Err(("Shutdown was already sent", 0x4000|8))
+               }
+
+               let inbound_stats = self.context.get_inbound_pending_htlc_stats(None);
+               let outbound_stats = self.context.get_outbound_pending_htlc_stats(None);
+               let max_dust_htlc_exposure_msat = self.context.get_max_dust_htlc_exposure_msat(fee_estimator);
+               let (htlc_timeout_dust_limit, htlc_success_dust_limit) = if self.context.get_channel_type().supports_anchors_zero_fee_htlc_tx() {
+                       (0, 0)
+               } else {
+                       let dust_buffer_feerate = self.context.get_dust_buffer_feerate(None) as u64;
+                       (dust_buffer_feerate * htlc_timeout_tx_weight(self.context.get_channel_type()) / 1000,
+                               dust_buffer_feerate * htlc_success_tx_weight(self.context.get_channel_type()) / 1000)
+               };
+               let exposure_dust_limit_timeout_sats = htlc_timeout_dust_limit + self.context.counterparty_dust_limit_satoshis;
+               if msg.amount_msat / 1000 < exposure_dust_limit_timeout_sats {
+                       let on_counterparty_tx_dust_htlc_exposure_msat = inbound_stats.on_counterparty_tx_dust_exposure_msat + outbound_stats.on_counterparty_tx_dust_exposure_msat + msg.amount_msat;
+                       if on_counterparty_tx_dust_htlc_exposure_msat > max_dust_htlc_exposure_msat {
+                               log_info!(logger, "Cannot accept value that would put our exposure to dust HTLCs at {} over the limit {} on counterparty commitment tx",
+                                       on_counterparty_tx_dust_htlc_exposure_msat, max_dust_htlc_exposure_msat);
+                               return Err(("Exceeded our dust exposure limit on counterparty commitment tx", 0x1000|7))
+                       }
+               }
+
+               let exposure_dust_limit_success_sats = htlc_success_dust_limit + self.context.holder_dust_limit_satoshis;
+               if msg.amount_msat / 1000 < exposure_dust_limit_success_sats {
+                       let on_holder_tx_dust_htlc_exposure_msat = inbound_stats.on_holder_tx_dust_exposure_msat + outbound_stats.on_holder_tx_dust_exposure_msat + msg.amount_msat;
+                       if on_holder_tx_dust_htlc_exposure_msat > max_dust_htlc_exposure_msat {
+                               log_info!(logger, "Cannot accept value that would put our exposure to dust HTLCs at {} over the limit {} on holder commitment tx",
+                                       on_holder_tx_dust_htlc_exposure_msat, max_dust_htlc_exposure_msat);
+                               return Err(("Exceeded our dust exposure limit on holder commitment tx", 0x1000|7))
+                       }
+               }
+
+               let anchor_outputs_value_msat = if self.context.get_channel_type().supports_anchors_zero_fee_htlc_tx() {
+                       ANCHOR_OUTPUT_VALUE_SATOSHI * 2 * 1000
+               } else {
+                       0
+               };
+
+               let mut removed_outbound_total_msat = 0;
+               for ref htlc in self.context.pending_outbound_htlcs.iter() {
+                       if let OutboundHTLCState::AwaitingRemoteRevokeToRemove(OutboundHTLCOutcome::Success(_)) = htlc.state {
+                               removed_outbound_total_msat += htlc.amount_msat;
+                       } else if let OutboundHTLCState::AwaitingRemovedRemoteRevoke(OutboundHTLCOutcome::Success(_)) = htlc.state {
+                               removed_outbound_total_msat += htlc.amount_msat;
+                       }
+               }
+
+               let pending_value_to_self_msat =
+                       self.context.value_to_self_msat + inbound_stats.pending_htlcs_value_msat - removed_outbound_total_msat;
+               let pending_remote_value_msat =
+                       self.context.channel_value_satoshis * 1000 - pending_value_to_self_msat;
+
+               if !self.context.is_outbound() {
+                       // `Some(())` is for the fee spike buffer we keep for the remote. This deviates from
+                       // the spec because the fee spike buffer requirement doesn't exist on the receiver's
+                       // side, only on the sender's. Note that with anchor outputs we are no longer as
+                       // sensitive to fee spikes, so we need to account for them.
+                       let htlc_candidate = HTLCCandidate::new(msg.amount_msat, HTLCInitiator::RemoteOffered);
+                       let mut remote_fee_cost_incl_stuck_buffer_msat = self.context.next_remote_commit_tx_fee_msat(htlc_candidate, Some(()));
+                       if !self.context.get_channel_type().supports_anchors_zero_fee_htlc_tx() {
+                               remote_fee_cost_incl_stuck_buffer_msat *= FEE_SPIKE_BUFFER_FEE_INCREASE_MULTIPLE;
+                       }
+                       if pending_remote_value_msat.saturating_sub(msg.amount_msat).saturating_sub(self.context.holder_selected_channel_reserve_satoshis * 1000).saturating_sub(anchor_outputs_value_msat) < remote_fee_cost_incl_stuck_buffer_msat {
+                               log_info!(logger, "Attempting to fail HTLC due to fee spike buffer violation in channel {}. Rebalancing is required.", &self.context.channel_id());
+                               return Err(("Fee spike buffer violation", 0x1000|7));
+                       }
+               }
+
+               Ok(())
+       }
+
        pub fn get_cur_holder_commitment_transaction_number(&self) -> u64 {
                self.context.cur_holder_commitment_transaction_number + 1
        }
index 683b570dae0d2c0b79849843296f518d36fe635e..3ab5911f560c2f789baca46691ea2baa1476993b 100644 (file)
@@ -6814,7 +6814,7 @@ where
                match peer_state.channel_by_id.entry(msg.channel_id) {
                        hash_map::Entry::Occupied(mut chan_phase_entry) => {
                                if let ChannelPhase::Funded(chan) = chan_phase_entry.get_mut() {
-                                       let pending_forward_info = match decoded_hop_res {
+                                       let mut pending_forward_info = match decoded_hop_res {
                                                Ok((next_hop, shared_secret, next_packet_pk_opt)) =>
                                                        self.construct_pending_htlc_status(
                                                                msg, counterparty_node_id, shared_secret, next_hop,
@@ -6822,44 +6822,45 @@ where
                                                        ),
                                                Err(e) => PendingHTLCStatus::Fail(e)
                                        };
-                                       let create_pending_htlc_status = |chan: &Channel<SP>, pending_forward_info: PendingHTLCStatus, error_code: u16| {
+                                       let logger = WithChannelContext::from(&self.logger, &chan.context);
+                                       // If the update_add is completely bogus, the call will Err and we will close,
+                                       // but if we've sent a shutdown and they haven't acknowledged it yet, we just
+                                       // want to reject the new HTLC and fail it backwards instead of forwarding.
+                                       if let Err((_, error_code)) = chan.can_accept_incoming_htlc(&msg, &self.fee_estimator, &logger) {
                                                if msg.blinding_point.is_some() {
-                                                       return PendingHTLCStatus::Fail(HTLCFailureMsg::Malformed(
-                                                                       msgs::UpdateFailMalformedHTLC {
-                                                                               channel_id: msg.channel_id,
-                                                                               htlc_id: msg.htlc_id,
-                                                                               sha256_of_onion: [0; 32],
-                                                                               failure_code: INVALID_ONION_BLINDING,
-                                                                       }
-                                                       ))
-                                               }
-                                               // If the update_add is completely bogus, the call will Err and we will close,
-                                               // but if we've sent a shutdown and they haven't acknowledged it yet, we just
-                                               // want to reject the new HTLC and fail it backwards instead of forwarding.
-                                               match pending_forward_info {
-                                                       PendingHTLCStatus::Forward(PendingHTLCInfo {
-                                                               ref incoming_shared_secret, ref routing, ..
-                                                       }) => {
-                                                               let reason = if routing.blinded_failure().is_some() {
-                                                                       HTLCFailReason::reason(INVALID_ONION_BLINDING, vec![0; 32])
-                                                               } else if (error_code & 0x1000) != 0 {
-                                                                       let (real_code, error_data) = self.get_htlc_inbound_temp_fail_err_and_data(error_code, chan);
-                                                                       HTLCFailReason::reason(real_code, error_data)
-                                                               } else {
-                                                                       HTLCFailReason::from_failure_code(error_code)
-                                                               }.get_encrypted_failure_packet(incoming_shared_secret, &None);
-                                                               let msg = msgs::UpdateFailHTLC {
+                                                       pending_forward_info = PendingHTLCStatus::Fail(HTLCFailureMsg::Malformed(
+                                                               msgs::UpdateFailMalformedHTLC {
                                                                        channel_id: msg.channel_id,
                                                                        htlc_id: msg.htlc_id,
-                                                                       reason
-                                                               };
-                                                               PendingHTLCStatus::Fail(HTLCFailureMsg::Relay(msg))
-                                                       },
-                                                       _ => pending_forward_info
+                                                                       sha256_of_onion: [0; 32],
+                                                                       failure_code: INVALID_ONION_BLINDING,
+                                                               }
+                                                       ))
+                                               } else {
+                                                       match pending_forward_info {
+                                                               PendingHTLCStatus::Forward(PendingHTLCInfo {
+                                                                       ref incoming_shared_secret, ref routing, ..
+                                                               }) => {
+                                                                       let reason = if routing.blinded_failure().is_some() {
+                                                                               HTLCFailReason::reason(INVALID_ONION_BLINDING, vec![0; 32])
+                                                                       } else if (error_code & 0x1000) != 0 {
+                                                                               let (real_code, error_data) = self.get_htlc_inbound_temp_fail_err_and_data(error_code, chan);
+                                                                               HTLCFailReason::reason(real_code, error_data)
+                                                                       } else {
+                                                                               HTLCFailReason::from_failure_code(error_code)
+                                                                       }.get_encrypted_failure_packet(incoming_shared_secret, &None);
+                                                                       let msg = msgs::UpdateFailHTLC {
+                                                                               channel_id: msg.channel_id,
+                                                                               htlc_id: msg.htlc_id,
+                                                                               reason
+                                                                       };
+                                                                       pending_forward_info = PendingHTLCStatus::Fail(HTLCFailureMsg::Relay(msg));
+                                                               },
+                                                               _ => {},
+                                                       }
                                                }
-                                       };
-                                       let logger = WithChannelContext::from(&self.logger, &chan.context);
-                                       try_chan_phase_entry!(self, chan.update_add_htlc(&msg, pending_forward_info, create_pending_htlc_status, &self.fee_estimator, &&logger), chan_phase_entry);
+                                       }
+                                       try_chan_phase_entry!(self, chan.update_add_htlc(&msg, pending_forward_info), chan_phase_entry);
                                } else {
                                        return try_chan_phase_entry!(self, Err(ChannelError::Close(
                                                "Got an update_add_htlc message for an unfunded channel!".into())), chan_phase_entry);