Merge pull request #2634 from TheBlueMatt/2023-09-claimable-unwrap
[rust-lightning] / lightning / src / onion_message / messenger.rs
index a3613605cdebbd1ff71014d827b70782e043a160..03dc6a5b002f01be6e706396190f1eddeb84f839 100644 (file)
@@ -15,7 +15,9 @@ use bitcoin::hashes::hmac::{Hmac, HmacEngine};
 use bitcoin::hashes::sha256::Hash as Sha256;
 use bitcoin::secp256k1::{self, PublicKey, Scalar, Secp256k1, SecretKey};
 
-use crate::blinded_path::{BlindedPath, ForwardTlvs, ReceiveTlvs, utils};
+use crate::blinded_path::BlindedPath;
+use crate::blinded_path::message::{advance_path_by_one, ForwardTlvs, ReceiveTlvs};
+use crate::blinded_path::utils;
 use crate::sign::{EntropySource, KeysManager, NodeSigner, Recipient};
 use crate::events::OnionMessageProvider;
 use crate::ln::features::{InitFeatures, NodeFeatures};
@@ -244,6 +246,64 @@ pub trait CustomOnionMessageHandler {
        fn read_custom_message<R: io::Read>(&self, message_type: u64, buffer: &mut R) -> Result<Option<Self::CustomMessage>, msgs::DecodeError>;
 }
 
+
+/// Create an onion message with contents `message` to the destination of `path`.
+/// Returns (introduction_node_id, onion_msg)
+pub fn create_onion_message<ES: Deref, NS: Deref, T: CustomOnionMessageContents>(
+       entropy_source: &ES, node_signer: &NS, secp_ctx: &Secp256k1<secp256k1::All>,
+       path: OnionMessagePath, message: OnionMessageContents<T>, reply_path: Option<BlindedPath>,
+) -> Result<(PublicKey, msgs::OnionMessage), SendError>
+where
+       ES::Target: EntropySource,
+       NS::Target: NodeSigner,
+{
+       let OnionMessagePath { intermediate_nodes, mut destination } = path;
+       if let Destination::BlindedPath(BlindedPath { ref blinded_hops, .. }) = destination {
+               if blinded_hops.len() < 2 {
+                       return Err(SendError::TooFewBlindedHops);
+               }
+       }
+
+       if message.tlv_type() < 64 { return Err(SendError::InvalidMessage) }
+
+       // If we are sending straight to a blinded path and we are the introduction node, we need to
+       // advance the blinded path by 1 hop so the second hop is the new introduction node.
+       if intermediate_nodes.len() == 0 {
+               if let Destination::BlindedPath(ref mut blinded_path) = destination {
+                       let our_node_id = node_signer.get_node_id(Recipient::Node)
+                               .map_err(|()| SendError::GetNodeIdFailed)?;
+                       if blinded_path.introduction_node_id == our_node_id {
+                               advance_path_by_one(blinded_path, node_signer, &secp_ctx)
+                                       .map_err(|()| SendError::BlindedPathAdvanceFailed)?;
+                       }
+               }
+       }
+
+       let blinding_secret_bytes = entropy_source.get_secure_random_bytes();
+       let blinding_secret = SecretKey::from_slice(&blinding_secret_bytes[..]).expect("RNG is busted");
+       let (introduction_node_id, blinding_point) = if intermediate_nodes.len() != 0 {
+               (intermediate_nodes[0], PublicKey::from_secret_key(&secp_ctx, &blinding_secret))
+       } else {
+               match destination {
+                       Destination::Node(pk) => (pk, PublicKey::from_secret_key(&secp_ctx, &blinding_secret)),
+                       Destination::BlindedPath(BlindedPath { introduction_node_id, blinding_point, .. }) =>
+                               (introduction_node_id, blinding_point),
+               }
+       };
+       let (packet_payloads, packet_keys) = packet_payloads_and_keys(
+               &secp_ctx, &intermediate_nodes, destination, message, reply_path, &blinding_secret)
+               .map_err(|e| SendError::Secp256k1(e))?;
+
+       let prng_seed = entropy_source.get_secure_random_bytes();
+       let onion_routing_packet = construct_onion_message_packet(
+               packet_payloads, packet_keys, prng_seed).map_err(|()| SendError::TooBigPacket)?;
+
+       Ok((introduction_node_id, msgs::OnionMessage {
+               blinding_point,
+               onion_routing_packet
+       }))
+}
+
 impl<ES: Deref, NS: Deref, L: Deref, MR: Deref, OMH: Deref, CMH: Deref>
 OnionMessenger<ES, NS, L, MR, OMH, CMH>
 where
@@ -281,53 +341,17 @@ where
                &self, path: OnionMessagePath, message: OnionMessageContents<T>,
                reply_path: Option<BlindedPath>
        ) -> Result<(), SendError> {
-               let OnionMessagePath { intermediate_nodes, mut destination } = path;
-               if let Destination::BlindedPath(BlindedPath { ref blinded_hops, .. }) = destination {
-                       if blinded_hops.len() < 2 {
-                               return Err(SendError::TooFewBlindedHops);
-                       }
-               }
-
-               if message.tlv_type() < 64 { return Err(SendError::InvalidMessage) }
-
-               // If we are sending straight to a blinded path and we are the introduction node, we need to
-               // advance the blinded path by 1 hop so the second hop is the new introduction node.
-               if intermediate_nodes.len() == 0 {
-                       if let Destination::BlindedPath(ref mut blinded_path) = destination {
-                               let our_node_id = self.node_signer.get_node_id(Recipient::Node)
-                                       .map_err(|()| SendError::GetNodeIdFailed)?;
-                               if blinded_path.introduction_node_id == our_node_id {
-                                       blinded_path.advance_message_path_by_one(&self.node_signer, &self.secp_ctx)
-                                               .map_err(|()| SendError::BlindedPathAdvanceFailed)?;
-                               }
-                       }
-               }
-
-               let blinding_secret_bytes = self.entropy_source.get_secure_random_bytes();
-               let blinding_secret = SecretKey::from_slice(&blinding_secret_bytes[..]).expect("RNG is busted");
-               let (introduction_node_id, blinding_point) = if intermediate_nodes.len() != 0 {
-                       (intermediate_nodes[0], PublicKey::from_secret_key(&self.secp_ctx, &blinding_secret))
-               } else {
-                       match destination {
-                               Destination::Node(pk) => (pk, PublicKey::from_secret_key(&self.secp_ctx, &blinding_secret)),
-                               Destination::BlindedPath(BlindedPath { introduction_node_id, blinding_point, .. }) =>
-                                       (introduction_node_id, blinding_point),
-                       }
-               };
-               let (packet_payloads, packet_keys) = packet_payloads_and_keys(
-                       &self.secp_ctx, &intermediate_nodes, destination, message, reply_path, &blinding_secret)
-                       .map_err(|e| SendError::Secp256k1(e))?;
-
-               let prng_seed = self.entropy_source.get_secure_random_bytes();
-               let onion_routing_packet = construct_onion_message_packet(
-                       packet_payloads, packet_keys, prng_seed).map_err(|()| SendError::TooBigPacket)?;
+               let (introduction_node_id, onion_msg) = create_onion_message(
+                       &self.entropy_source, &self.node_signer, &self.secp_ctx,
+                       path, message, reply_path
+               )?;
 
                let mut pending_per_peer_msgs = self.pending_messages.lock().unwrap();
                if outbound_buffer_full(&introduction_node_id, &pending_per_peer_msgs) { return Err(SendError::BufferFull) }
                match pending_per_peer_msgs.entry(introduction_node_id) {
                        hash_map::Entry::Vacant(_) => Err(SendError::InvalidFirstHop),
                        hash_map::Entry::Occupied(mut e) => {
-                               e.get_mut().push_back(msgs::OnionMessage { blinding_point, onion_routing_packet });
+                               e.get_mut().push_back(onion_msg);
                                Ok(())
                        }
                }
@@ -490,7 +514,7 @@ where
                                // unwrapping the onion layers to get to the final payload. Since we don't have the option
                                // of creating blinded paths with dummy hops currently, we should be ok to not handle this
                                // for now.
-                               let new_pubkey = match onion_utils::next_hop_packet_pubkey(&self.secp_ctx, msg.onion_routing_packet.public_key, &onion_decode_ss) {
+                               let new_pubkey = match onion_utils::next_hop_pubkey(&self.secp_ctx, msg.onion_routing_packet.public_key, &onion_decode_ss) {
                                        Ok(pk) => pk,
                                        Err(e) => {
                                                log_trace!(self.logger, "Failed to compute next hop packet pubkey: {}", e);
@@ -507,21 +531,16 @@ where
                                        blinding_point: match next_blinding_override {
                                                Some(blinding_point) => blinding_point,
                                                None => {
-                                                       let blinding_factor = {
-                                                               let mut sha = Sha256::engine();
-                                                               sha.input(&msg.blinding_point.serialize()[..]);
-                                                               sha.input(control_tlvs_ss.as_ref());
-                                                               Sha256::from_engine(sha).into_inner()
-                                                       };
-                                                       let next_blinding_point = msg.blinding_point;
-                                                       match next_blinding_point.mul_tweak(&self.secp_ctx, &Scalar::from_be_bytes(blinding_factor).unwrap()) {
+                                                       match onion_utils::next_hop_pubkey(
+                                                               &self.secp_ctx, msg.blinding_point, control_tlvs_ss.as_ref()
+                                                       ) {
                                                                Ok(bp) => bp,
                                                                Err(e) => {
                                                                        log_trace!(self.logger, "Failed to compute next blinding point: {}", e);
                                                                        return
                                                                }
                                                        }
-                                               },
+                                               }
                                        },
                                        onion_routing_packet: outgoing_packet,
                                };
@@ -653,46 +672,48 @@ fn packet_payloads_and_keys<T: CustomOnionMessageContents, S: secp256k1::Signing
        let mut blinded_path_idx = 0;
        let mut prev_control_tlvs_ss = None;
        let mut final_control_tlvs = None;
-       utils::construct_keys_callback(secp_ctx, unblinded_path, Some(destination), session_priv, |_, onion_packet_ss, ephemeral_pubkey, control_tlvs_ss, unblinded_pk_opt, enc_payload_opt| {
-               if num_unblinded_hops != 0 && unblinded_path_idx < num_unblinded_hops {
-                       if let Some(ss) = prev_control_tlvs_ss.take() {
-                               payloads.push((Payload::Forward(ForwardControlTlvs::Unblinded(
-                                       ForwardTlvs {
-                                               next_node_id: unblinded_pk_opt.unwrap(),
-                                               next_blinding_override: None,
-                                       }
-                               )), ss));
+       utils::construct_keys_callback(secp_ctx, unblinded_path.iter(), Some(destination), session_priv,
+               |_, onion_packet_ss, ephemeral_pubkey, control_tlvs_ss, unblinded_pk_opt, enc_payload_opt| {
+                       if num_unblinded_hops != 0 && unblinded_path_idx < num_unblinded_hops {
+                               if let Some(ss) = prev_control_tlvs_ss.take() {
+                                       payloads.push((Payload::Forward(ForwardControlTlvs::Unblinded(
+                                               ForwardTlvs {
+                                                       next_node_id: unblinded_pk_opt.unwrap(),
+                                                       next_blinding_override: None,
+                                               }
+                                       )), ss));
+                               }
+                               prev_control_tlvs_ss = Some(control_tlvs_ss);
+                               unblinded_path_idx += 1;
+                       } else if let Some((intro_node_id, blinding_pt)) = intro_node_id_blinding_pt.take() {
+                               if let Some(control_tlvs_ss) = prev_control_tlvs_ss.take() {
+                                       payloads.push((Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs {
+                                               next_node_id: intro_node_id,
+                                               next_blinding_override: Some(blinding_pt),
+                                       })), control_tlvs_ss));
+                               }
                        }
-                       prev_control_tlvs_ss = Some(control_tlvs_ss);
-                       unblinded_path_idx += 1;
-               } else if let Some((intro_node_id, blinding_pt)) = intro_node_id_blinding_pt.take() {
-                       if let Some(control_tlvs_ss) = prev_control_tlvs_ss.take() {
-                               payloads.push((Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs {
-                                       next_node_id: intro_node_id,
-                                       next_blinding_override: Some(blinding_pt),
-                               })), control_tlvs_ss));
+                       if blinded_path_idx < num_blinded_hops.saturating_sub(1) && enc_payload_opt.is_some() {
+                               payloads.push((Payload::Forward(ForwardControlTlvs::Blinded(enc_payload_opt.unwrap())),
+                                       control_tlvs_ss));
+                               blinded_path_idx += 1;
+                       } else if let Some(encrypted_payload) = enc_payload_opt {
+                               final_control_tlvs = Some(ReceiveControlTlvs::Blinded(encrypted_payload));
+                               prev_control_tlvs_ss = Some(control_tlvs_ss);
                        }
-               }
-               if blinded_path_idx < num_blinded_hops.saturating_sub(1) && enc_payload_opt.is_some() {
-                       payloads.push((Payload::Forward(ForwardControlTlvs::Blinded(enc_payload_opt.unwrap())),
-                               control_tlvs_ss));
-                       blinded_path_idx += 1;
-               } else if let Some(encrypted_payload) = enc_payload_opt {
-                       final_control_tlvs = Some(ReceiveControlTlvs::Blinded(encrypted_payload));
-                       prev_control_tlvs_ss = Some(control_tlvs_ss);
-               }
 
-               let (rho, mu) = onion_utils::gen_rho_mu_from_shared_secret(onion_packet_ss.as_ref());
-               onion_packet_keys.push(onion_utils::OnionKeys {
-                       #[cfg(test)]
-                       shared_secret: onion_packet_ss,
-                       #[cfg(test)]
-                       blinding_factor: [0; 32],
-                       ephemeral_pubkey,
-                       rho,
-                       mu,
-               });
-       })?;
+                       let (rho, mu) = onion_utils::gen_rho_mu_from_shared_secret(onion_packet_ss.as_ref());
+                       onion_packet_keys.push(onion_utils::OnionKeys {
+                               #[cfg(test)]
+                               shared_secret: onion_packet_ss,
+                               #[cfg(test)]
+                               blinding_factor: [0; 32],
+                               ephemeral_pubkey,
+                               rho,
+                               mu,
+                       });
+               }
+       )?;
 
        if let Some(control_tlvs) = final_control_tlvs {
                payloads.push((Payload::Receive {