Generalize onion message ForwardTlvs::next_node_id
[rust-lightning] / lightning / src / onion_message / packet.rs
index d19bd30edd212be4b19286d5ef5f0bf0f1870a8f..510f0ea025a0d615b0f54292d865602ac7e103a6 100644 (file)
@@ -13,13 +13,13 @@ use bitcoin::secp256k1::PublicKey;
 use bitcoin::secp256k1::ecdh::SharedSecret;
 
 use crate::blinded_path::BlindedPath;
-use crate::blinded_path::message::{ForwardTlvs, ReceiveTlvs};
+use crate::blinded_path::message::{ForwardTlvs, NextHop, ReceiveTlvs};
 use crate::blinded_path::utils::Padding;
 use crate::ln::msgs::DecodeError;
 use crate::ln::onion_utils;
 use super::messenger::CustomOnionMessageHandler;
 use super::offers::OffersMessage;
-use crate::util::chacha20poly1305rfc::{ChaChaPolyReadAdapter, ChaChaPolyWriteAdapter};
+use crate::crypto::streams::{ChaChaPolyReadAdapter, ChaChaPolyWriteAdapter};
 use crate::util::logger::Logger;
 use crate::util::ser::{BigSize, FixedLengthReader, LengthRead, LengthReadable, LengthReadableArgs, Readable, ReadableArgs, Writeable, Writer};
 
@@ -33,7 +33,7 @@ pub(super) const SMALL_PACKET_HOP_DATA_LEN: usize = 1300;
 pub(super) const BIG_PACKET_HOP_DATA_LEN: usize = 32768;
 
 /// Packet of hop data for next peer
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
 pub struct Packet {
        /// Bolt 04 version number
        pub version: u8,
@@ -103,51 +103,51 @@ impl LengthReadable for Packet {
 /// Onion message payloads contain "control" TLVs and "data" TLVs. Control TLVs are used to route
 /// the onion message from hop to hop and for path verification, whereas data TLVs contain the onion
 /// message content itself, such as an invoice request.
-pub(super) enum Payload<T: CustomOnionMessageContents> {
+pub(super) enum Payload<T: OnionMessageContents> {
        /// This payload is for an intermediate hop.
        Forward(ForwardControlTlvs),
        /// This payload is for the final hop.
        Receive {
                control_tlvs: ReceiveControlTlvs,
                reply_path: Option<BlindedPath>,
-               message: OnionMessageContents<T>,
+               message: T,
        }
 }
 
+/// The contents of an [`OnionMessage`] as read from the wire.
+///
+/// [`OnionMessage`]: crate::ln::msgs::OnionMessage
 #[derive(Debug)]
-/// The contents of an onion message. In the context of offers, this would be the invoice, invoice
-/// request, or invoice error.
-pub enum OnionMessageContents<T: CustomOnionMessageContents> {
+pub enum ParsedOnionMessageContents<T: OnionMessageContents> {
        /// A message related to BOLT 12 Offers.
        Offers(OffersMessage),
        /// A custom onion message specified by the user.
        Custom(T),
 }
 
-impl<T: CustomOnionMessageContents> OnionMessageContents<T> {
+impl<T: OnionMessageContents> OnionMessageContents for ParsedOnionMessageContents<T> {
        /// Returns the type that was used to decode the message payload.
        ///
        /// This is not exported to bindings users as methods on non-cloneable enums are not currently exportable
-       pub fn tlv_type(&self) -> u64 {
+       fn tlv_type(&self) -> u64 {
                match self {
-                       &OnionMessageContents::Offers(ref msg) => msg.tlv_type(),
-                       &OnionMessageContents::Custom(ref msg) => msg.tlv_type(),
+                       &ParsedOnionMessageContents::Offers(ref msg) => msg.tlv_type(),
+                       &ParsedOnionMessageContents::Custom(ref msg) => msg.tlv_type(),
                }
        }
 }
 
-/// This is not exported to bindings users as methods on non-cloneable enums are not currently exportable
-impl<T: CustomOnionMessageContents> Writeable for OnionMessageContents<T> {
+impl<T: OnionMessageContents> Writeable for ParsedOnionMessageContents<T> {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
                match self {
-                       OnionMessageContents::Offers(msg) => Ok(msg.write(w)?),
-                       OnionMessageContents::Custom(msg) => Ok(msg.write(w)?),
+                       ParsedOnionMessageContents::Offers(msg) => Ok(msg.write(w)?),
+                       ParsedOnionMessageContents::Custom(msg) => Ok(msg.write(w)?),
                }
        }
 }
 
-/// The contents of a custom onion message.
-pub trait CustomOnionMessageContents: Writeable {
+/// The contents of an onion message.
+pub trait OnionMessageContents: Writeable + core::fmt::Debug {
        /// Returns the TLV type identifying the message contents. MUST be >= 64.
        fn tlv_type(&self) -> u64;
 }
@@ -173,7 +173,7 @@ pub(super) enum ReceiveControlTlvs {
 }
 
 // Uses the provided secret to simultaneously encode and encrypt the unblinded control TLVs.
-impl<T: CustomOnionMessageContents> Writeable for (Payload<T>, [u8; 32]) {
+impl<T: OnionMessageContents> Writeable for (Payload<T>, [u8; 32]) {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
                match &self.0 {
                        Payload::Forward(ForwardControlTlvs::Blinded(encrypted_bytes)) => {
@@ -212,8 +212,8 @@ impl<T: CustomOnionMessageContents> Writeable for (Payload<T>, [u8; 32]) {
 }
 
 // Uses the provided secret to simultaneously decode and decrypt the control TLVs and data TLV.
-impl<H: CustomOnionMessageHandler + ?Sized, L: Logger + ?Sized>
-ReadableArgs<(SharedSecret, &H, &L)> for Payload<<H as CustomOnionMessageHandler>::CustomMessage> {
+impl<H: CustomOnionMessageHandler + ?Sized, L: Logger + ?Sized> ReadableArgs<(SharedSecret, &H, &L)>
+for Payload<ParsedOnionMessageContents<<H as CustomOnionMessageHandler>::CustomMessage>> {
        fn read<R: Read>(r: &mut R, args: (SharedSecret, &H, &L)) -> Result<Self, DecodeError> {
                let (encrypted_tlvs_ss, handler, logger) = args;
 
@@ -236,12 +236,12 @@ ReadableArgs<(SharedSecret, &H, &L)> for Payload<<H as CustomOnionMessageHandler
                        match msg_type {
                                tlv_type if OffersMessage::is_known_type(tlv_type) => {
                                        let msg = OffersMessage::read(msg_reader, (tlv_type, logger))?;
-                                       message = Some(OnionMessageContents::Offers(msg));
+                                       message = Some(ParsedOnionMessageContents::Offers(msg));
                                        Ok(true)
                                },
                                _ => match handler.read_custom_message(msg_type, msg_reader)? {
                                        Some(msg) => {
-                                               message = Some(OnionMessageContents::Custom(msg));
+                                               message = Some(ParsedOnionMessageContents::Custom(msg));
                                                Ok(true)
                                        },
                                        None => Ok(false),
@@ -284,20 +284,26 @@ impl Readable for ControlTlvs {
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
                _init_and_read_tlv_stream!(r, {
                        (1, _padding, option),
-                       (2, _short_channel_id, option),
+                       (2, short_channel_id, option),
                        (4, next_node_id, option),
                        (6, path_id, option),
                        (8, next_blinding_override, option),
                });
                let _padding: Option<Padding> = _padding;
-               let _short_channel_id: Option<u64> = _short_channel_id;
 
-               let valid_fwd_fmt  = next_node_id.is_some() && path_id.is_none();
-               let valid_recv_fmt = next_node_id.is_none() && next_blinding_override.is_none();
+               let next_hop = match (short_channel_id, next_node_id) {
+                       (Some(_), Some(_)) => return Err(DecodeError::InvalidValue),
+                       (Some(scid), None) => Some(NextHop::ShortChannelId(scid)),
+                       (None, Some(pubkey)) => Some(NextHop::NodeId(pubkey)),
+                       (None, None) => None,
+               };
+
+               let valid_fwd_fmt = next_hop.is_some() && path_id.is_none();
+               let valid_recv_fmt = next_hop.is_none() && next_blinding_override.is_none();
 
                let payload_fmt = if valid_fwd_fmt {
                        ControlTlvs::Forward(ForwardTlvs {
-                               next_node_id: next_node_id.unwrap(),
+                               next_hop: next_hop.unwrap(),
                                next_blinding_override,
                        })
                } else if valid_recv_fmt {