Onion message payload for async payments
[rust-lightning] / lightning / src / onion_message / packet.rs
index 7483888d137132ec2b0bdd3719a461280c4aa6cb..47b1a0313d791a43c5d52c286b919bba2f151b1a 100644 (file)
 use bitcoin::secp256k1::PublicKey;
 use bitcoin::secp256k1::ecdh::SharedSecret;
 
-use crate::blinded_path::BlindedPath;
-use crate::blinded_path::message::{ForwardTlvs, NextHop, ReceiveTlvs};
+use crate::blinded_path::{BlindedPath, NextMessageHop};
+use crate::blinded_path::message::{ForwardTlvs, ReceiveTlvs};
 use crate::blinded_path::utils::Padding;
 use crate::ln::msgs::DecodeError;
 use crate::ln::onion_utils;
+use super::async_payments::AsyncPaymentsMessage;
 use super::messenger::CustomOnionMessageHandler;
 use super::offers::OffersMessage;
 use crate::crypto::streams::{ChaChaPolyReadAdapter, ChaChaPolyWriteAdapter};
@@ -124,10 +125,12 @@ pub(super) enum Payload<T: OnionMessageContents> {
 /// The contents of an [`OnionMessage`] as read from the wire.
 ///
 /// [`OnionMessage`]: crate::ln::msgs::OnionMessage
-#[derive(Debug)]
+#[derive(Clone, Debug)]
 pub enum ParsedOnionMessageContents<T: OnionMessageContents> {
        /// A message related to BOLT 12 Offers.
        Offers(OffersMessage),
+       /// A message related to async payments.
+       AsyncPayments(AsyncPaymentsMessage),
        /// A custom onion message specified by the user.
        Custom(T),
 }
@@ -139,15 +142,24 @@ impl<T: OnionMessageContents> OnionMessageContents for ParsedOnionMessageContent
        fn tlv_type(&self) -> u64 {
                match self {
                        &ParsedOnionMessageContents::Offers(ref msg) => msg.tlv_type(),
+                       &ParsedOnionMessageContents::AsyncPayments(ref msg) => msg.tlv_type(),
                        &ParsedOnionMessageContents::Custom(ref msg) => msg.tlv_type(),
                }
        }
+       fn msg_type(&self) -> &'static str {
+               match self {
+                       ParsedOnionMessageContents::Offers(ref msg) => msg.msg_type(),
+                       ParsedOnionMessageContents::AsyncPayments(ref msg) => msg.msg_type(),
+                       ParsedOnionMessageContents::Custom(ref msg) => msg.msg_type(),
+               }
+       }
 }
 
 impl<T: OnionMessageContents> Writeable for ParsedOnionMessageContents<T> {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
                match self {
                        ParsedOnionMessageContents::Offers(msg) => Ok(msg.write(w)?),
+                       ParsedOnionMessageContents::AsyncPayments(msg) => Ok(msg.write(w)?),
                        ParsedOnionMessageContents::Custom(msg) => Ok(msg.write(w)?),
                }
        }
@@ -157,6 +169,9 @@ impl<T: OnionMessageContents> Writeable for ParsedOnionMessageContents<T> {
 pub trait OnionMessageContents: Writeable + core::fmt::Debug {
        /// Returns the TLV type identifying the message contents. MUST be >= 64.
        fn tlv_type(&self) -> u64;
+
+       /// Returns the message type
+       fn msg_type(&self) -> &'static str;
 }
 
 /// Forward control TLVs in their blinded and unblinded form.
@@ -246,6 +261,11 @@ for Payload<ParsedOnionMessageContents<<H as CustomOnionMessageHandler>::CustomM
                                        message = Some(ParsedOnionMessageContents::Offers(msg));
                                        Ok(true)
                                },
+                               tlv_type if AsyncPaymentsMessage::is_known_type(tlv_type) => {
+                                       let msg = AsyncPaymentsMessage::read(msg_reader, tlv_type)?;
+                                       message = Some(ParsedOnionMessageContents::AsyncPayments(msg));
+                                       Ok(true)
+                               },
                                _ => match handler.read_custom_message(msg_type, msg_reader)? {
                                        Some(msg) => {
                                                message = Some(ParsedOnionMessageContents::Custom(msg));
@@ -300,8 +320,8 @@ impl Readable for ControlTlvs {
 
                let next_hop = match (short_channel_id, next_node_id) {
                        (Some(_), Some(_)) => return Err(DecodeError::InvalidValue),
-                       (Some(scid), None) => Some(NextHop::ShortChannelId(scid)),
-                       (None, Some(pubkey)) => Some(NextHop::NodeId(pubkey)),
+                       (Some(scid), None) => Some(NextMessageHop::ShortChannelId(scid)),
+                       (None, Some(pubkey)) => Some(NextMessageHop::NodeId(pubkey)),
                        (None, None) => None,
                };