Generalize BlindedPath::introduction_node_id field
[rust-lightning] / lightning / src / onion_message / messenger.rs
index 85e19cc56e535bb364117c271871922683292a26..5334d6db25e2e88d4906c896c7549c53a9d50d4e 100644 (file)
@@ -15,7 +15,7 @@ use bitcoin::hashes::hmac::{Hmac, HmacEngine};
 use bitcoin::hashes::sha256::Hash as Sha256;
 use bitcoin::secp256k1::{self, PublicKey, Scalar, Secp256k1, SecretKey};
 
-use crate::blinded_path::BlindedPath;
+use crate::blinded_path::{BlindedPath, IntroductionNode};
 use crate::blinded_path::message::{advance_path_by_one, ForwardTlvs, NextHop, ReceiveTlvs};
 use crate::blinded_path::utils;
 use crate::events::{Event, EventHandler, EventsProvider};
@@ -444,7 +444,12 @@ impl Destination {
        fn first_node(&self) -> PublicKey {
                match self {
                        Destination::Node(node_id) => *node_id,
-                       Destination::BlindedPath(BlindedPath { introduction_node_id: node_id, .. }) => *node_id,
+                       Destination::BlindedPath(BlindedPath { introduction_node, .. }) => {
+                               match introduction_node {
+                                       IntroductionNode::NodeId(pubkey) => *pubkey,
+                                       IntroductionNode::DirectedShortChannelId(..) => todo!(),
+                               }
+                       },
                }
        }
 }
@@ -569,9 +574,13 @@ where
        // advance the blinded path by 1 hop so the second hop is the new introduction node.
        if intermediate_nodes.len() == 0 {
                if let Destination::BlindedPath(ref mut blinded_path) = destination {
+                       let introduction_node_id = match blinded_path.introduction_node {
+                               IntroductionNode::NodeId(pubkey) => pubkey,
+                               IntroductionNode::DirectedShortChannelId(..) => todo!(),
+                       };
                        let our_node_id = node_signer.get_node_id(Recipient::Node)
                                .map_err(|()| SendError::GetNodeIdFailed)?;
-                       if blinded_path.introduction_node_id == our_node_id {
+                       if introduction_node_id == our_node_id {
                                advance_path_by_one(blinded_path, node_signer, &secp_ctx)
                                        .map_err(|()| SendError::BlindedPathAdvanceFailed)?;
                        }
@@ -583,10 +592,14 @@ where
        let (first_node_id, blinding_point) = if let Some(first_node_id) = intermediate_nodes.first() {
                (*first_node_id, PublicKey::from_secret_key(&secp_ctx, &blinding_secret))
        } else {
-               match destination {
-                       Destination::Node(pk) => (pk, PublicKey::from_secret_key(&secp_ctx, &blinding_secret)),
-                       Destination::BlindedPath(BlindedPath { introduction_node_id, blinding_point, .. }) =>
-                               (introduction_node_id, blinding_point),
+               match &destination {
+                       Destination::Node(pk) => (*pk, PublicKey::from_secret_key(&secp_ctx, &blinding_secret)),
+                       Destination::BlindedPath(BlindedPath { introduction_node, blinding_point, .. }) => {
+                               match introduction_node {
+                                       IntroductionNode::NodeId(pubkey) => (*pubkey, *blinding_point),
+                                       IntroductionNode::DirectedShortChannelId(..) => todo!(),
+                               }
+                       }
                }
        };
        let (packet_payloads, packet_keys) = packet_payloads_and_keys(
@@ -1136,9 +1149,16 @@ fn packet_payloads_and_keys<T: OnionMessageContents, S: secp256k1::Signing + sec
        let mut payloads = Vec::with_capacity(num_hops);
        let mut onion_packet_keys = Vec::with_capacity(num_hops);
 
-       let (mut intro_node_id_blinding_pt, num_blinded_hops) = if let Destination::BlindedPath(BlindedPath {
-               introduction_node_id, blinding_point, blinded_hops }) = &destination {
-               (Some((*introduction_node_id, *blinding_point)), blinded_hops.len()) } else { (None, 0) };
+       let (mut intro_node_id_blinding_pt, num_blinded_hops) = match &destination {
+               Destination::Node(_) => (None, 0),
+               Destination::BlindedPath(BlindedPath { introduction_node, blinding_point, blinded_hops }) => {
+                       let introduction_node_id = match introduction_node {
+                               IntroductionNode::NodeId(pubkey) => pubkey,
+                               IntroductionNode::DirectedShortChannelId(..) => todo!(),
+                       };
+                       (Some((*introduction_node_id, *blinding_point)), blinded_hops.len())
+               },
+       };
        let num_unblinded_hops = num_hops - num_blinded_hops;
 
        let mut unblinded_path_idx = 0;