Extract onion_utils::build_onion_payloads_callback helper.
[rust-lightning] / lightning / src / ln / onion_utils.rs
index a946dae6779077396f03aeedc578fa0cadde2f90..2aba2c75748ed4ad74813206b7e302964d594075 100644 (file)
@@ -7,14 +7,15 @@
 // You may not use this file except in accordance with one or both of these
 // licenses.
 
+use crate::blinded_path::BlindedHop;
 use crate::crypto::chacha20::ChaCha20;
 use crate::crypto::streams::ChaChaReader;
 use crate::ln::channelmanager::{HTLCSource, RecipientOnionFields};
 use crate::ln::msgs;
+use crate::ln::types::{PaymentHash, PaymentPreimage};
 use crate::ln::wire::Encode;
-use crate::ln::{PaymentHash, PaymentPreimage};
 use crate::routing::gossip::NetworkUpdate;
-use crate::routing::router::{BlindedTail, Path, RouteHop};
+use crate::routing::router::{Path, RouteHop};
 use crate::sign::NodeSigner;
 use crate::util::errors::{self, APIError};
 use crate::util::logger::Logger;
@@ -30,10 +31,11 @@ use bitcoin::secp256k1::ecdh::SharedSecret;
 use bitcoin::secp256k1::{PublicKey, Scalar, Secp256k1, SecretKey};
 
 use crate::io::{Cursor, Read};
-use crate::prelude::*;
-use core::convert::{AsMut, TryInto};
 use core::ops::Deref;
 
+#[allow(unused_imports)]
+use crate::prelude::*;
+
 pub(crate) struct OnionKeys {
        #[cfg(test)]
        pub(crate) shared_secret: SharedSecret,
@@ -173,18 +175,60 @@ pub(super) fn construct_onion_keys<T: secp256k1::Signing>(
 }
 
 /// returns the hop data, as well as the first-hop value_msat and CLTV value we should send.
-pub(super) fn build_onion_payloads(
-       path: &Path, total_msat: u64, mut recipient_onion: RecipientOnionFields,
+pub(super) fn build_onion_payloads<'a>(
+       path: &'a Path, total_msat: u64, recipient_onion: &'a RecipientOnionFields,
        starting_htlc_offset: u32, keysend_preimage: &Option<PaymentPreimage>,
-) -> Result<(Vec<msgs::OutboundOnionPayload>, u64, u32), APIError> {
-       let mut cur_value_msat = 0u64;
-       let mut cur_cltv = starting_htlc_offset;
-       let mut last_short_channel_id = 0;
+) -> Result<(Vec<msgs::OutboundOnionPayload<'a>>, u64, u32), APIError> {
        let mut res: Vec<msgs::OutboundOnionPayload> = Vec::with_capacity(
                path.hops.len() + path.blinded_tail.as_ref().map_or(0, |t| t.hops.len()),
        );
+       let blinded_tail_with_hop_iter = path.blinded_tail.as_ref().map(|bt| BlindedTailHopIter {
+               hops: bt.hops.iter(),
+               blinding_point: bt.blinding_point,
+               final_value_msat: bt.final_value_msat,
+               excess_final_cltv_expiry_delta: bt.excess_final_cltv_expiry_delta,
+       });
+
+       let (value_msat, cltv) = build_onion_payloads_callback(
+               path.hops.iter(),
+               blinded_tail_with_hop_iter,
+               total_msat,
+               recipient_onion,
+               starting_htlc_offset,
+               keysend_preimage,
+               |action, payload| match action {
+                       PayloadCallbackAction::PushBack => res.push(payload),
+                       PayloadCallbackAction::PushFront => res.insert(0, payload),
+               },
+       )?;
+       Ok((res, value_msat, cltv))
+}
 
-       for (idx, hop) in path.hops.iter().rev().enumerate() {
+struct BlindedTailHopIter<'a, I: Iterator<Item = &'a BlindedHop>> {
+       hops: I,
+       blinding_point: PublicKey,
+       final_value_msat: u64,
+       excess_final_cltv_expiry_delta: u32,
+}
+enum PayloadCallbackAction {
+       PushBack,
+       PushFront,
+}
+fn build_onion_payloads_callback<'a, H, B, F>(
+       hops: H, mut blinded_tail: Option<BlindedTailHopIter<'a, B>>, total_msat: u64,
+       recipient_onion: &'a RecipientOnionFields, starting_htlc_offset: u32,
+       keysend_preimage: &Option<PaymentPreimage>, mut callback: F,
+) -> Result<(u64, u32), APIError>
+where
+       H: DoubleEndedIterator<Item = &'a RouteHop>,
+       B: ExactSizeIterator<Item = &'a BlindedHop>,
+       F: FnMut(PayloadCallbackAction, msgs::OutboundOnionPayload<'a>),
+{
+       let mut cur_value_msat = 0u64;
+       let mut cur_cltv = starting_htlc_offset;
+       let mut last_short_channel_id = 0;
+
+       for (idx, hop) in hops.rev().enumerate() {
                // First hop gets special values so that it can check, on receipt, that everything is
                // exactly as it should be (and the next hop isn't trying to probe to find out if we're
                // the intended recipient).
@@ -195,47 +239,55 @@ pub(super) fn build_onion_payloads(
                        cur_cltv
                };
                if idx == 0 {
-                       if let Some(BlindedTail {
+                       if let Some(BlindedTailHopIter {
                                blinding_point,
                                hops,
                                final_value_msat,
                                excess_final_cltv_expiry_delta,
                                ..
-                       }) = &path.blinded_tail
+                       }) = blinded_tail.take()
                        {
-                               let mut blinding_point = Some(*blinding_point);
-                               for (i, blinded_hop) in hops.iter().enumerate() {
-                                       if i == hops.len() - 1 {
+                               let mut blinding_point = Some(blinding_point);
+                               let hops_len = hops.len();
+                               for (i, blinded_hop) in hops.enumerate() {
+                                       if i == hops_len - 1 {
                                                cur_value_msat += final_value_msat;
-                                               res.push(msgs::OutboundOnionPayload::BlindedReceive {
-                                                       sender_intended_htlc_amt_msat: *final_value_msat,
-                                                       total_msat,
-                                                       cltv_expiry_height: cur_cltv + excess_final_cltv_expiry_delta,
-                                                       encrypted_tlvs: blinded_hop.encrypted_payload.clone(),
-                                                       intro_node_blinding_point: blinding_point.take(),
-                                                       keysend_preimage: *keysend_preimage,
-                                                       custom_tlvs: recipient_onion.custom_tlvs.clone(),
-                                               });
+                                               callback(
+                                                       PayloadCallbackAction::PushBack,
+                                                       msgs::OutboundOnionPayload::BlindedReceive {
+                                                               sender_intended_htlc_amt_msat: final_value_msat,
+                                                               total_msat,
+                                                               cltv_expiry_height: cur_cltv + excess_final_cltv_expiry_delta,
+                                                               encrypted_tlvs: &blinded_hop.encrypted_payload,
+                                                               intro_node_blinding_point: blinding_point.take(),
+                                                               keysend_preimage: *keysend_preimage,
+                                                               custom_tlvs: &recipient_onion.custom_tlvs,
+                                                       },
+                                               );
                                        } else {
-                                               res.push(msgs::OutboundOnionPayload::BlindedForward {
-                                                       encrypted_tlvs: blinded_hop.encrypted_payload.clone(),
-                                                       intro_node_blinding_point: blinding_point.take(),
-                                               });
+                                               callback(
+                                                       PayloadCallbackAction::PushBack,
+                                                       msgs::OutboundOnionPayload::BlindedForward {
+                                                               encrypted_tlvs: &blinded_hop.encrypted_payload,
+                                                               intro_node_blinding_point: blinding_point.take(),
+                                                       },
+                                               );
                                        }
                                }
                        } else {
-                               res.push(msgs::OutboundOnionPayload::Receive {
-                                       payment_data: if let Some(secret) = recipient_onion.payment_secret.take() {
-                                               Some(msgs::FinalOnionHopData { payment_secret: secret, total_msat })
-                                       } else {
-                                               None
+                               callback(
+                                       PayloadCallbackAction::PushBack,
+                                       msgs::OutboundOnionPayload::Receive {
+                                               payment_data: recipient_onion.payment_secret.map(|payment_secret| {
+                                                       msgs::FinalOnionHopData { payment_secret, total_msat }
+                                               }),
+                                               payment_metadata: recipient_onion.payment_metadata.as_ref(),
+                                               keysend_preimage: *keysend_preimage,
+                                               custom_tlvs: &recipient_onion.custom_tlvs,
+                                               sender_intended_htlc_amt_msat: value_msat,
+                                               cltv_expiry_height: cltv,
                                        },
-                                       payment_metadata: recipient_onion.payment_metadata.take(),
-                                       keysend_preimage: *keysend_preimage,
-                                       custom_tlvs: recipient_onion.custom_tlvs.clone(),
-                                       sender_intended_htlc_amt_msat: value_msat,
-                                       cltv_expiry_height: cltv,
-                               });
+                               );
                        }
                } else {
                        let payload = msgs::OutboundOnionPayload::Forward {
@@ -243,7 +295,7 @@ pub(super) fn build_onion_payloads(
                                amt_to_forward: value_msat,
                                outgoing_cltv_value: cltv,
                        };
-                       res.insert(0, payload);
+                       callback(PayloadCallbackAction::PushFront, payload);
                }
                cur_value_msat += hop.fee_msat;
                if cur_value_msat >= 21000000 * 100000000 * 1000 {
@@ -255,7 +307,7 @@ pub(super) fn build_onion_payloads(
                }
                last_short_channel_id = hop.short_channel_id;
        }
-       Ok((res, cur_value_msat, cur_cltv))
+       Ok((cur_value_msat, cur_cltv))
 }
 
 /// Length of the onion data packet. Before TLV-based onions this was 20 65-byte hops, though now
@@ -1123,7 +1175,7 @@ where
 /// `cur_block_height` should be set to the best known block height + 1.
 pub fn create_payment_onion<T: secp256k1::Signing>(
        secp_ctx: &Secp256k1<T>, path: &Path, session_priv: &SecretKey, total_msat: u64,
-       recipient_onion: RecipientOnionFields, cur_block_height: u32, payment_hash: &PaymentHash,
+       recipient_onion: &RecipientOnionFields, cur_block_height: u32, payment_hash: &PaymentHash,
        keysend_preimage: &Option<PaymentPreimage>, prng_seed: [u8; 32],
 ) -> Result<(msgs::OnionPacket, u64, u32), APIError> {
        let onion_keys = construct_onion_keys(&secp_ctx, &path, &session_priv).map_err(|_| {
@@ -1239,11 +1291,13 @@ mod tests {
        use crate::io;
        use crate::ln::features::{ChannelFeatures, NodeFeatures};
        use crate::ln::msgs;
-       use crate::ln::PaymentHash;
-       use crate::prelude::*;
+       use crate::ln::types::PaymentHash;
        use crate::routing::router::{Path, Route, RouteHop};
        use crate::util::ser::{VecWriter, Writeable, Writer};
 
+       #[allow(unused_imports)]
+       use crate::prelude::*;
+
        use bitcoin::hashes::hex::FromHex;
        use bitcoin::secp256k1::Secp256k1;
        use bitcoin::secp256k1::{PublicKey, SecretKey};