Extract onion_utils::build_onion_payloads_callback helper.
authorValentine Wallace <vwallace@protonmail.com>
Fri, 10 May 2024 21:31:51 +0000 (17:31 -0400)
committerValentine Wallace <vwallace@protonmail.com>
Thu, 16 May 2024 22:08:44 +0000 (15:08 -0700)
Will be useful when we want to calculate the total size of the payloads without
actually allocating for them.

lightning/src/ln/onion_utils.rs

index a90dd837e56792ca329a582252d85b7ec1e60ef9..2aba2c75748ed4ad74813206b7e302964d594075 100644 (file)
@@ -7,6 +7,7 @@
 // You may not use this file except in accordance with one or both of these
 // licenses.
 
+use crate::blinded_path::BlindedHop;
 use crate::crypto::chacha20::ChaCha20;
 use crate::crypto::streams::ChaChaReader;
 use crate::ln::channelmanager::{HTLCSource, RecipientOnionFields};
@@ -14,7 +15,7 @@ use crate::ln::msgs;
 use crate::ln::types::{PaymentHash, PaymentPreimage};
 use crate::ln::wire::Encode;
 use crate::routing::gossip::NetworkUpdate;
-use crate::routing::router::{BlindedTail, Path, RouteHop};
+use crate::routing::router::{Path, RouteHop};
 use crate::sign::NodeSigner;
 use crate::util::errors::{self, APIError};
 use crate::util::logger::Logger;
@@ -178,14 +179,56 @@ pub(super) fn build_onion_payloads<'a>(
        path: &'a Path, total_msat: u64, recipient_onion: &'a RecipientOnionFields,
        starting_htlc_offset: u32, keysend_preimage: &Option<PaymentPreimage>,
 ) -> Result<(Vec<msgs::OutboundOnionPayload<'a>>, u64, u32), APIError> {
-       let mut cur_value_msat = 0u64;
-       let mut cur_cltv = starting_htlc_offset;
-       let mut last_short_channel_id = 0;
        let mut res: Vec<msgs::OutboundOnionPayload> = Vec::with_capacity(
                path.hops.len() + path.blinded_tail.as_ref().map_or(0, |t| t.hops.len()),
        );
+       let blinded_tail_with_hop_iter = path.blinded_tail.as_ref().map(|bt| BlindedTailHopIter {
+               hops: bt.hops.iter(),
+               blinding_point: bt.blinding_point,
+               final_value_msat: bt.final_value_msat,
+               excess_final_cltv_expiry_delta: bt.excess_final_cltv_expiry_delta,
+       });
+
+       let (value_msat, cltv) = build_onion_payloads_callback(
+               path.hops.iter(),
+               blinded_tail_with_hop_iter,
+               total_msat,
+               recipient_onion,
+               starting_htlc_offset,
+               keysend_preimage,
+               |action, payload| match action {
+                       PayloadCallbackAction::PushBack => res.push(payload),
+                       PayloadCallbackAction::PushFront => res.insert(0, payload),
+               },
+       )?;
+       Ok((res, value_msat, cltv))
+}
 
-       for (idx, hop) in path.hops.iter().rev().enumerate() {
+struct BlindedTailHopIter<'a, I: Iterator<Item = &'a BlindedHop>> {
+       hops: I,
+       blinding_point: PublicKey,
+       final_value_msat: u64,
+       excess_final_cltv_expiry_delta: u32,
+}
+enum PayloadCallbackAction {
+       PushBack,
+       PushFront,
+}
+fn build_onion_payloads_callback<'a, H, B, F>(
+       hops: H, mut blinded_tail: Option<BlindedTailHopIter<'a, B>>, total_msat: u64,
+       recipient_onion: &'a RecipientOnionFields, starting_htlc_offset: u32,
+       keysend_preimage: &Option<PaymentPreimage>, mut callback: F,
+) -> Result<(u64, u32), APIError>
+where
+       H: DoubleEndedIterator<Item = &'a RouteHop>,
+       B: ExactSizeIterator<Item = &'a BlindedHop>,
+       F: FnMut(PayloadCallbackAction, msgs::OutboundOnionPayload<'a>),
+{
+       let mut cur_value_msat = 0u64;
+       let mut cur_cltv = starting_htlc_offset;
+       let mut last_short_channel_id = 0;
+
+       for (idx, hop) in hops.rev().enumerate() {
                // First hop gets special values so that it can check, on receipt, that everything is
                // exactly as it should be (and the next hop isn't trying to probe to find out if we're
                // the intended recipient).
@@ -196,45 +239,55 @@ pub(super) fn build_onion_payloads<'a>(
                        cur_cltv
                };
                if idx == 0 {
-                       if let Some(BlindedTail {
+                       if let Some(BlindedTailHopIter {
                                blinding_point,
                                hops,
                                final_value_msat,
                                excess_final_cltv_expiry_delta,
                                ..
-                       }) = &path.blinded_tail
+                       }) = blinded_tail.take()
                        {
-                               let mut blinding_point = Some(*blinding_point);
-                               for (i, blinded_hop) in hops.iter().enumerate() {
-                                       if i == hops.len() - 1 {
+                               let mut blinding_point = Some(blinding_point);
+                               let hops_len = hops.len();
+                               for (i, blinded_hop) in hops.enumerate() {
+                                       if i == hops_len - 1 {
                                                cur_value_msat += final_value_msat;
-                                               res.push(msgs::OutboundOnionPayload::BlindedReceive {
-                                                       sender_intended_htlc_amt_msat: *final_value_msat,
-                                                       total_msat,
-                                                       cltv_expiry_height: cur_cltv + excess_final_cltv_expiry_delta,
-                                                       encrypted_tlvs: &blinded_hop.encrypted_payload,
-                                                       intro_node_blinding_point: blinding_point.take(),
-                                                       keysend_preimage: *keysend_preimage,
-                                                       custom_tlvs: &recipient_onion.custom_tlvs,
-                                               });
+                                               callback(
+                                                       PayloadCallbackAction::PushBack,
+                                                       msgs::OutboundOnionPayload::BlindedReceive {
+                                                               sender_intended_htlc_amt_msat: final_value_msat,
+                                                               total_msat,
+                                                               cltv_expiry_height: cur_cltv + excess_final_cltv_expiry_delta,
+                                                               encrypted_tlvs: &blinded_hop.encrypted_payload,
+                                                               intro_node_blinding_point: blinding_point.take(),
+                                                               keysend_preimage: *keysend_preimage,
+                                                               custom_tlvs: &recipient_onion.custom_tlvs,
+                                                       },
+                                               );
                                        } else {
-                                               res.push(msgs::OutboundOnionPayload::BlindedForward {
-                                                       encrypted_tlvs: &blinded_hop.encrypted_payload,
-                                                       intro_node_blinding_point: blinding_point.take(),
-                                               });
+                                               callback(
+                                                       PayloadCallbackAction::PushBack,
+                                                       msgs::OutboundOnionPayload::BlindedForward {
+                                                               encrypted_tlvs: &blinded_hop.encrypted_payload,
+                                                               intro_node_blinding_point: blinding_point.take(),
+                                                       },
+                                               );
                                        }
                                }
                        } else {
-                               res.push(msgs::OutboundOnionPayload::Receive {
-                                       payment_data: recipient_onion.payment_secret.map(|payment_secret| {
-                                               msgs::FinalOnionHopData { payment_secret, total_msat }
-                                       }),
-                                       payment_metadata: recipient_onion.payment_metadata.as_ref(),
-                                       keysend_preimage: *keysend_preimage,
-                                       custom_tlvs: &recipient_onion.custom_tlvs,
-                                       sender_intended_htlc_amt_msat: value_msat,
-                                       cltv_expiry_height: cltv,
-                               });
+                               callback(
+                                       PayloadCallbackAction::PushBack,
+                                       msgs::OutboundOnionPayload::Receive {
+                                               payment_data: recipient_onion.payment_secret.map(|payment_secret| {
+                                                       msgs::FinalOnionHopData { payment_secret, total_msat }
+                                               }),
+                                               payment_metadata: recipient_onion.payment_metadata.as_ref(),
+                                               keysend_preimage: *keysend_preimage,
+                                               custom_tlvs: &recipient_onion.custom_tlvs,
+                                               sender_intended_htlc_amt_msat: value_msat,
+                                               cltv_expiry_height: cltv,
+                                       },
+                               );
                        }
                } else {
                        let payload = msgs::OutboundOnionPayload::Forward {
@@ -242,7 +295,7 @@ pub(super) fn build_onion_payloads<'a>(
                                amt_to_forward: value_msat,
                                outgoing_cltv_value: cltv,
                        };
-                       res.insert(0, payload);
+                       callback(PayloadCallbackAction::PushFront, payload);
                }
                cur_value_msat += hop.fee_msat;
                if cur_value_msat >= 21000000 * 100000000 * 1000 {
@@ -254,7 +307,7 @@ pub(super) fn build_onion_payloads<'a>(
                }
                last_short_channel_id = hop.short_channel_id;
        }
-       Ok((res, cur_value_msat, cur_cltv))
+       Ok((cur_value_msat, cur_cltv))
 }
 
 /// Length of the onion data packet. Before TLV-based onions this was 20 65-byte hops, though now