Merge pull request #2387 from vladimirfomene/add_extra_fields_to_ChannelClosed_event
[rust-lightning] / lightning / src / ln / onion_utils.rs
index 3b62c856334596b85bf58b8ee96b1a0eed8f36f6..7e0ccbe9652e960cf0087c238adae8b6f4ce6602 100644 (file)
@@ -149,11 +149,11 @@ pub(super) fn construct_onion_keys<T: secp256k1::Signing>(secp_ctx: &Secp256k1<T
 }
 
 /// returns the hop data, as well as the first-hop value_msat and CLTV value we should send.
-pub(super) fn build_onion_payloads(path: &Path, total_msat: u64, mut recipient_onion: RecipientOnionFields, starting_htlc_offset: u32, keysend_preimage: &Option<PaymentPreimage>) -> Result<(Vec<msgs::OnionHopData>, u64, u32), APIError> {
+pub(super) fn build_onion_payloads(path: &Path, total_msat: u64, mut recipient_onion: RecipientOnionFields, starting_htlc_offset: u32, keysend_preimage: &Option<PaymentPreimage>) -> Result<(Vec<msgs::OutboundOnionPayload>, u64, u32), APIError> {
        let mut cur_value_msat = 0u64;
        let mut cur_cltv = starting_htlc_offset;
        let mut last_short_channel_id = 0;
-       let mut res: Vec<msgs::OnionHopData> = Vec::with_capacity(path.hops.len());
+       let mut res: Vec<msgs::OutboundOnionPayload> = Vec::with_capacity(path.hops.len());
 
        for (idx, hop) in path.hops.iter().rev().enumerate() {
                // First hop gets special values so that it can check, on receipt, that everything is
@@ -161,25 +161,26 @@ pub(super) fn build_onion_payloads(path: &Path, total_msat: u64, mut recipient_o
                // the intended recipient).
                let value_msat = if cur_value_msat == 0 { hop.fee_msat } else { cur_value_msat };
                let cltv = if cur_cltv == starting_htlc_offset { hop.cltv_expiry_delta + starting_htlc_offset } else { cur_cltv };
-               res.insert(0, msgs::OnionHopData {
-                       format: if idx == 0 {
-                               msgs::OnionHopDataFormat::FinalNode {
-                                       payment_data: if let Some(secret) = recipient_onion.payment_secret.take() {
-                                               Some(msgs::FinalOnionHopData {
-                                                       payment_secret: secret,
-                                                       total_msat,
-                                               })
-                                       } else { None },
-                                       payment_metadata: recipient_onion.payment_metadata.take(),
-                                       keysend_preimage: *keysend_preimage,
-                               }
-                       } else {
-                               msgs::OnionHopDataFormat::NonFinalNode {
-                                       short_channel_id: last_short_channel_id,
-                               }
-                       },
-                       amt_to_forward: value_msat,
-                       outgoing_cltv_value: cltv,
+               res.insert(0, if idx == 0 {
+                       msgs::OutboundOnionPayload::Receive {
+                               payment_data: if let Some(secret) = recipient_onion.payment_secret.take() {
+                                       Some(msgs::FinalOnionHopData {
+                                               payment_secret: secret,
+                                               total_msat,
+                                       })
+                               } else { None },
+                               payment_metadata: recipient_onion.payment_metadata.take(),
+                               keysend_preimage: *keysend_preimage,
+                               custom_tlvs: recipient_onion.custom_tlvs.clone(),
+                               amt_msat: value_msat,
+                               outgoing_cltv_value: cltv,
+                       }
+               } else {
+                       msgs::OutboundOnionPayload::Forward {
+                               short_channel_id: last_short_channel_id,
+                               amt_to_forward: value_msat,
+                               outgoing_cltv_value: cltv,
+                       }
                });
                cur_value_msat += hop.fee_msat;
                if cur_value_msat >= 21000000 * 100000000 * 1000 {
@@ -208,7 +209,10 @@ fn shift_slice_right(arr: &mut [u8], amt: usize) {
        }
 }
 
-pub(super) fn construct_onion_packet(payloads: Vec<msgs::OnionHopData>, onion_keys: Vec<OnionKeys>, prng_seed: [u8; 32], associated_data: &PaymentHash) -> Result<msgs::OnionPacket, ()> {
+pub(super) fn construct_onion_packet(
+       payloads: Vec<msgs::OutboundOnionPayload>, onion_keys: Vec<OnionKeys>, prng_seed: [u8; 32],
+       associated_data: &PaymentHash
+) -> Result<msgs::OnionPacket, ()> {
        let mut packet_data = [0; ONION_DATA_LEN];
 
        let mut chacha = ChaCha20::new(&prng_seed, &[0; 8]);
@@ -645,7 +649,7 @@ impl_writeable_tlv_based_enum!(HTLCFailReasonRepr,
        },
        (1, Reason) => {
                (0, failure_code, required),
-               (2, data, vec_type),
+               (2, data, required_vec),
        },
 ;);
 
@@ -763,11 +767,11 @@ impl NextPacketBytes for Vec<u8> {
 pub(crate) enum Hop {
        /// This onion payload was for us, not for forwarding to a next-hop. Contains information for
        /// verifying the incoming payment.
-       Receive(msgs::OnionHopData),
+       Receive(msgs::InboundOnionPayload),
        /// This onion payload needs to be forwarded to a next-hop.
        Forward {
                /// Onion payload data used in forwarding the payment.
-               next_hop_data: msgs::OnionHopData,
+               next_hop_data: msgs::InboundOnionPayload,
                /// HMAC of the next hop's onion packet.
                next_hop_hmac: [u8; 32],
                /// Bytes of the onion packet we're forwarding.
@@ -988,10 +992,8 @@ mod tests {
                // with raw hex instead of our in-memory enums, as the payloads contains custom types, and
                // we have no way of representing that with our enums.
                let payloads = vec!(
-                       RawOnionHopData::new(msgs::OnionHopData {
-                               format: msgs::OnionHopDataFormat::NonFinalNode {
-                                       short_channel_id: 1,
-                               },
+                       RawOnionHopData::new(msgs::OutboundOnionPayload::Forward {
+                               short_channel_id: 1,
                                amt_to_forward: 15000,
                                outgoing_cltv_value: 1500,
                        }),
@@ -1013,17 +1015,13 @@ mod tests {
                        RawOnionHopData {
                                data: hex::decode("52020236b00402057806080000000000000002fd02013c0102030405060708090a0b0c0d0e0f0102030405060708090a0b0c0d0e0f0102030405060708090a0b0c0d0e0f0102030405060708090a0b0c0d0e0f").unwrap(),
                        },
-                       RawOnionHopData::new(msgs::OnionHopData {
-                               format: msgs::OnionHopDataFormat::NonFinalNode {
-                                       short_channel_id: 3,
-                               },
+                       RawOnionHopData::new(msgs::OutboundOnionPayload::Forward {
+                               short_channel_id: 3,
                                amt_to_forward: 12500,
                                outgoing_cltv_value: 1250,
                        }),
-                       RawOnionHopData::new(msgs::OnionHopData {
-                               format: msgs::OnionHopDataFormat::NonFinalNode {
-                                       short_channel_id: 4,
-                               },
+                       RawOnionHopData::new(msgs::OutboundOnionPayload::Forward {
+                               short_channel_id: 4,
                                amt_to_forward: 10000,
                                outgoing_cltv_value: 1000,
                        }),
@@ -1101,7 +1099,7 @@ mod tests {
                data: Vec<u8>
        }
        impl RawOnionHopData {
-               fn new(orig: msgs::OnionHopData) -> Self {
+               fn new(orig: msgs::OutboundOnionPayload) -> Self {
                        Self { data: orig.encode() }
                }
        }