Support NextHop::ShortChannelId in BlindedPath
[rust-lightning] / lightning / src / blinded_path / message.rs
index 2549673b00f1a390306f2e9ca7ae891450a35274..1a0f63d46000401cd5ffbec7a87c1dd42bf7e786 100644 (file)
@@ -1,24 +1,41 @@
+//! Data structures and methods for constructing [`BlindedPath`]s to send a message over.
+//!
+//! [`BlindedPath`]: crate::blinded_path::BlindedPath
+
 use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey};
 
-use crate::blinded_path::{BlindedHop, BlindedPath};
+#[allow(unused_imports)]
+use crate::prelude::*;
+
+use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode, NextMessageHop, NodeIdLookUp};
 use crate::blinded_path::utils;
 use crate::io;
 use crate::io::Cursor;
 use crate::ln::onion_utils;
-use crate::onion_message::ControlTlvs;
-use crate::prelude::*;
+use crate::onion_message::packet::ControlTlvs;
 use crate::sign::{NodeSigner, Recipient};
-use crate::util::chacha20poly1305rfc::ChaChaPolyReadAdapter;
+use crate::crypto::streams::ChaChaPolyReadAdapter;
 use crate::util::ser::{FixedLengthReader, LengthReadableArgs, Writeable, Writer};
 
 use core::mem;
 use core::ops::Deref;
 
+/// An intermediate node, and possibly a short channel id leading to the next node.
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+pub struct ForwardNode {
+       /// This node's pubkey.
+       pub node_id: PublicKey,
+       /// The channel between `node_id` and the next hop. If set, the constructed [`BlindedHop`]'s
+       /// `encrypted_payload` will use this instead of the next [`ForwardNode::node_id`] for a more
+       /// compact representation.
+       pub short_channel_id: Option<u64>,
+}
+
 /// TLVs to encode in an intermediate onion message packet's hop data. When provided in a blinded
 /// route, they are encoded into [`BlindedHop::encrypted_payload`].
 pub(crate) struct ForwardTlvs {
-       /// The node id of the next hop in the onion message's path.
-       pub(crate) next_node_id: PublicKey,
+       /// The next hop in the onion message's path.
+       pub(crate) next_hop: NextMessageHop,
        /// Senders to a blinded path use this value to concatenate the route they find to the
        /// introduction node with the blinded path.
        pub(crate) next_blinding_override: Option<PublicKey>,
@@ -34,9 +51,14 @@ pub(crate) struct ReceiveTlvs {
 
 impl Writeable for ForwardTlvs {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               let (next_node_id, short_channel_id) = match self.next_hop {
+                       NextMessageHop::NodeId(pubkey) => (Some(pubkey), None),
+                       NextMessageHop::ShortChannelId(scid) => (None, Some(scid)),
+               };
                // TODO: write padding
                encode_tlv_stream!(writer, {
-                       (4, self.next_node_id, required),
+                       (2, short_channel_id, option),
+                       (4, next_node_id, option),
                        (8, self.next_blinding_override, option)
                });
                Ok(())
@@ -53,54 +75,52 @@ impl Writeable for ReceiveTlvs {
        }
 }
 
-/// Construct blinded onion message hops for the given `unblinded_path`.
+/// Construct blinded onion message hops for the given `intermediate_nodes` and `recipient_node_id`.
 pub(super) fn blinded_hops<T: secp256k1::Signing + secp256k1::Verification>(
-       secp_ctx: &Secp256k1<T>, unblinded_path: &[PublicKey], session_priv: &SecretKey
+       secp_ctx: &Secp256k1<T>, intermediate_nodes: &[ForwardNode], recipient_node_id: PublicKey,
+       session_priv: &SecretKey
 ) -> Result<Vec<BlindedHop>, secp256k1::Error> {
-       let mut blinded_hops = Vec::with_capacity(unblinded_path.len());
-
-       let mut prev_ss_and_blinded_node_id = None;
-       utils::construct_keys_callback(secp_ctx, unblinded_path, None, session_priv, |blinded_node_id, _, _, encrypted_payload_ss, unblinded_pk, _| {
-               if let Some((prev_ss, prev_blinded_node_id)) = prev_ss_and_blinded_node_id {
-                       if let Some(pk) = unblinded_pk {
-                               let payload = ForwardTlvs {
-                                       next_node_id: pk,
-                                       next_blinding_override: None,
-                               };
-                               blinded_hops.push(BlindedHop {
-                                       blinded_node_id: prev_blinded_node_id,
-                                       encrypted_payload: utils::encrypt_payload(payload, prev_ss),
-                               });
-                       } else { debug_assert!(false); }
-               }
-               prev_ss_and_blinded_node_id = Some((encrypted_payload_ss, blinded_node_id));
-       })?;
+       let pks = intermediate_nodes.iter().map(|node| &node.node_id)
+               .chain(core::iter::once(&recipient_node_id));
+       let tlvs = pks.clone()
+               .skip(1) // The first node's TLVs contains the next node's pubkey
+               .zip(intermediate_nodes.iter().map(|node| node.short_channel_id))
+               .map(|(pubkey, scid)| match scid {
+                       Some(scid) => NextMessageHop::ShortChannelId(scid),
+                       None => NextMessageHop::NodeId(*pubkey),
+               })
+               .map(|next_hop| ControlTlvs::Forward(ForwardTlvs { next_hop, next_blinding_override: None }))
+               .chain(core::iter::once(ControlTlvs::Receive(ReceiveTlvs { path_id: None })));
 
-       if let Some((final_ss, final_blinded_node_id)) = prev_ss_and_blinded_node_id {
-               let final_payload = ReceiveTlvs { path_id: None };
-               blinded_hops.push(BlindedHop {
-                       blinded_node_id: final_blinded_node_id,
-                       encrypted_payload: utils::encrypt_payload(final_payload, final_ss),
-               });
-       } else { debug_assert!(false) }
-
-       Ok(blinded_hops)
+       utils::construct_blinded_hops(secp_ctx, pks, tlvs, session_priv)
 }
 
 // Advance the blinded onion message path by one hop, so make the second hop into the new
 // introduction node.
-pub(crate) fn advance_path_by_one<NS: Deref, T: secp256k1::Signing + secp256k1::Verification>(
-       path: &mut BlindedPath, node_signer: &NS, secp_ctx: &Secp256k1<T>
-) -> Result<(), ()> where NS::Target: NodeSigner {
+pub(crate) fn advance_path_by_one<NS: Deref, NL: Deref, T>(
+       path: &mut BlindedPath, node_signer: &NS, node_id_lookup: &NL, secp_ctx: &Secp256k1<T>
+) -> Result<(), ()>
+where
+       NS::Target: NodeSigner,
+       NL::Target: NodeIdLookUp,
+       T: secp256k1::Signing + secp256k1::Verification,
+{
        let control_tlvs_ss = node_signer.ecdh(Recipient::Node, &path.blinding_point, None)?;
        let rho = onion_utils::gen_rho_from_shared_secret(&control_tlvs_ss.secret_bytes());
        let encrypted_control_tlvs = path.blinded_hops.remove(0).encrypted_payload;
        let mut s = Cursor::new(&encrypted_control_tlvs);
        let mut reader = FixedLengthReader::new(&mut s, encrypted_control_tlvs.len() as u64);
        match ChaChaPolyReadAdapter::read(&mut reader, rho) {
-               Ok(ChaChaPolyReadAdapter { readable: ControlTlvs::Forward(ForwardTlvs {
-                       mut next_node_id, next_blinding_override,
-               })}) => {
+               Ok(ChaChaPolyReadAdapter {
+                       readable: ControlTlvs::Forward(ForwardTlvs { next_hop, next_blinding_override })
+               }) => {
+                       let next_node_id = match next_hop {
+                               NextMessageHop::NodeId(pubkey) => pubkey,
+                               NextMessageHop::ShortChannelId(scid) => match node_id_lookup.next_node_id(scid) {
+                                       Some(pubkey) => pubkey,
+                                       None => return Err(()),
+                               },
+                       };
                        let mut new_blinding_point = match next_blinding_override {
                                Some(blinding_point) => blinding_point,
                                None => {
@@ -109,7 +129,7 @@ pub(crate) fn advance_path_by_one<NS: Deref, T: secp256k1::Signing + secp256k1::
                                }
                        };
                        mem::swap(&mut path.blinding_point, &mut new_blinding_point);
-                       mem::swap(&mut path.introduction_node_id, &mut next_node_id);
+                       path.introduction_node = IntroductionNode::NodeId(next_node_id);
                        Ok(())
                },
                _ => Err(())