Merge pull request #3132 from jkczyz/2024-06-bolt12-unannounced
[rust-lightning] / lightning / src / blinded_path / message.rs
index 1f3f5a1fa38e70bdac2c7d01923fa06e04513f2e..bdbb4be4541d2ce8f3aa34840263de963b23ec57 100644 (file)
@@ -30,7 +30,7 @@ use core::mem;
 use core::ops::Deref;
 
 /// An intermediate node, and possibly a short channel id leading to the next node.
-#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
 pub struct ForwardNode {
        /// This node's pubkey.
        pub node_id: PublicKey,
@@ -106,6 +106,8 @@ pub(super) fn blinded_hops<T: secp256k1::Signing + secp256k1::Verification>(
 
 // Advance the blinded onion message path by one hop, so make the second hop into the new
 // introduction node.
+//
+// Will only modify `path` when returning `Ok`.
 pub(crate) fn advance_path_by_one<NS: Deref, NL: Deref, T>(
        path: &mut BlindedPath, node_signer: &NS, node_id_lookup: &NL, secp_ctx: &Secp256k1<T>
 ) -> Result<(), ()>
@@ -116,8 +118,8 @@ where
 {
        let control_tlvs_ss = node_signer.ecdh(Recipient::Node, &path.blinding_point, None)?;
        let rho = onion_utils::gen_rho_from_shared_secret(&control_tlvs_ss.secret_bytes());
-       let encrypted_control_tlvs = path.blinded_hops.remove(0).encrypted_payload;
-       let mut s = Cursor::new(&encrypted_control_tlvs);
+       let encrypted_control_tlvs = &path.blinded_hops.get(0).ok_or(())?.encrypted_payload;
+       let mut s = Cursor::new(encrypted_control_tlvs);
        let mut reader = FixedLengthReader::new(&mut s, encrypted_control_tlvs.len() as u64);
        match ChaChaPolyReadAdapter::read(&mut reader, rho) {
                Ok(ChaChaPolyReadAdapter {
@@ -139,6 +141,7 @@ where
                        };
                        mem::swap(&mut path.blinding_point, &mut new_blinding_point);
                        path.introduction_node = IntroductionNode::NodeId(next_node_id);
+                       path.blinded_hops.remove(0);
                        Ok(())
                },
                _ => Err(())