Merge pull request #3028 from jkczyz/2024-04-offer-id-followups
[rust-lightning] / lightning / src / ln / wire.rs
index e7db446a86f17f036ceeccba9aca52c4aa7ea652..55e31399ae10074d2e72f1b04a4c1c71ed1808df 100644 (file)
@@ -9,12 +9,12 @@
 
 //! Wire encoding/decoding for Lightning messages according to [BOLT #1], and for
 //! custom message through the [`CustomMessageReader`] trait.
-//! 
-//! [BOLT #1]: https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md
+//!
+//! [BOLT #1]: https://github.com/lightning/bolts/blob/master/01-messaging.md
 
-use io;
-use ln::msgs;
-use util::ser::{Readable, Writeable, Writer};
+use crate::io;
+use crate::ln::msgs;
+use crate::util::ser::{Readable, Writeable, Writer};
 
 /// Trait to be implemented by custom message (unrelated to the channel/gossip LN layers)
 /// decoders.
@@ -28,23 +28,57 @@ pub trait CustomMessageReader {
        fn read<R: io::Read>(&self, message_type: u16, buffer: &mut R) -> Result<Option<Self::CustomMessage>, msgs::DecodeError>;
 }
 
-/// A Lightning message returned by [`read()`] when decoding bytes received over the wire. Each
+// TestEq is a dummy trait which requires PartialEq when built in testing, and otherwise is
+// blanket-implemented for all types.
+
+#[cfg(test)]
+pub trait TestEq : PartialEq {}
+#[cfg(test)]
+impl<T: PartialEq> TestEq for T {}
+
+#[cfg(not(test))]
+pub(crate) trait TestEq {}
+#[cfg(not(test))]
+impl<T> TestEq for T {}
+
+
+/// A Lightning message returned by [`read`] when decoding bytes received over the wire. Each
 /// variant contains a message from [`msgs`] or otherwise the message type if unknown.
 #[allow(missing_docs)]
 #[derive(Debug)]
-pub(crate) enum Message<T> where T: core::fmt::Debug + Type {
+#[cfg_attr(test, derive(PartialEq))]
+pub(crate) enum Message<T> where T: core::fmt::Debug + Type + TestEq {
        Init(msgs::Init),
        Error(msgs::ErrorMessage),
        Warning(msgs::WarningMessage),
        Ping(msgs::Ping),
        Pong(msgs::Pong),
        OpenChannel(msgs::OpenChannel),
+       OpenChannelV2(msgs::OpenChannelV2),
        AcceptChannel(msgs::AcceptChannel),
+       AcceptChannelV2(msgs::AcceptChannelV2),
        FundingCreated(msgs::FundingCreated),
        FundingSigned(msgs::FundingSigned),
-       FundingLocked(msgs::FundingLocked),
+       Stfu(msgs::Stfu),
+       #[cfg(splicing)]
+       Splice(msgs::Splice),
+       #[cfg(splicing)]
+       SpliceAck(msgs::SpliceAck),
+       #[cfg(splicing)]
+       SpliceLocked(msgs::SpliceLocked),
+       TxAddInput(msgs::TxAddInput),
+       TxAddOutput(msgs::TxAddOutput),
+       TxRemoveInput(msgs::TxRemoveInput),
+       TxRemoveOutput(msgs::TxRemoveOutput),
+       TxComplete(msgs::TxComplete),
+       TxSignatures(msgs::TxSignatures),
+       TxInitRbf(msgs::TxInitRbf),
+       TxAckRbf(msgs::TxAckRbf),
+       TxAbort(msgs::TxAbort),
+       ChannelReady(msgs::ChannelReady),
        Shutdown(msgs::Shutdown),
        ClosingSigned(msgs::ClosingSigned),
+       OnionMessage(msgs::OnionMessage),
        UpdateAddHTLC(msgs::UpdateAddHTLC),
        UpdateFulfillHTLC(msgs::UpdateFulfillHTLC),
        UpdateFailHTLC(msgs::UpdateFailHTLC),
@@ -65,13 +99,70 @@ pub(crate) enum Message<T> where T: core::fmt::Debug + Type {
        /// A message that could not be decoded because its type is unknown.
        Unknown(u16),
        /// A message that was produced by a [`CustomMessageReader`] and is to be handled by a
-       /// [`::ln::peer_handler::CustomMessageHandler`].
+       /// [`crate::ln::peer_handler::CustomMessageHandler`].
        Custom(T),
 }
 
-impl<T> Message<T> where T: core::fmt::Debug + Type {
+impl<T> Writeable for Message<T> where T: core::fmt::Debug + Type + TestEq {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               match self {
+                       &Message::Init(ref msg) => msg.write(writer),
+                       &Message::Error(ref msg) => msg.write(writer),
+                       &Message::Warning(ref msg) => msg.write(writer),
+                       &Message::Ping(ref msg) => msg.write(writer),
+                       &Message::Pong(ref msg) => msg.write(writer),
+                       &Message::OpenChannel(ref msg) => msg.write(writer),
+                       &Message::OpenChannelV2(ref msg) => msg.write(writer),
+                       &Message::AcceptChannel(ref msg) => msg.write(writer),
+                       &Message::AcceptChannelV2(ref msg) => msg.write(writer),
+                       &Message::FundingCreated(ref msg) => msg.write(writer),
+                       &Message::FundingSigned(ref msg) => msg.write(writer),
+                       &Message::Stfu(ref msg) => msg.write(writer),
+                       #[cfg(splicing)]
+                       &Message::Splice(ref msg) => msg.write(writer),
+                       #[cfg(splicing)]
+                       &Message::SpliceAck(ref msg) => msg.write(writer),
+                       #[cfg(splicing)]
+                       &Message::SpliceLocked(ref msg) => msg.write(writer),
+                       &Message::TxAddInput(ref msg) => msg.write(writer),
+                       &Message::TxAddOutput(ref msg) => msg.write(writer),
+                       &Message::TxRemoveInput(ref msg) => msg.write(writer),
+                       &Message::TxRemoveOutput(ref msg) => msg.write(writer),
+                       &Message::TxComplete(ref msg) => msg.write(writer),
+                       &Message::TxSignatures(ref msg) => msg.write(writer),
+                       &Message::TxInitRbf(ref msg) => msg.write(writer),
+                       &Message::TxAckRbf(ref msg) => msg.write(writer),
+                       &Message::TxAbort(ref msg) => msg.write(writer),
+                       &Message::ChannelReady(ref msg) => msg.write(writer),
+                       &Message::Shutdown(ref msg) => msg.write(writer),
+                       &Message::ClosingSigned(ref msg) => msg.write(writer),
+                       &Message::OnionMessage(ref msg) => msg.write(writer),
+                       &Message::UpdateAddHTLC(ref msg) => msg.write(writer),
+                       &Message::UpdateFulfillHTLC(ref msg) => msg.write(writer),
+                       &Message::UpdateFailHTLC(ref msg) => msg.write(writer),
+                       &Message::UpdateFailMalformedHTLC(ref msg) => msg.write(writer),
+                       &Message::CommitmentSigned(ref msg) => msg.write(writer),
+                       &Message::RevokeAndACK(ref msg) => msg.write(writer),
+                       &Message::UpdateFee(ref msg) => msg.write(writer),
+                       &Message::ChannelReestablish(ref msg) => msg.write(writer),
+                       &Message::AnnouncementSignatures(ref msg) => msg.write(writer),
+                       &Message::ChannelAnnouncement(ref msg) => msg.write(writer),
+                       &Message::NodeAnnouncement(ref msg) => msg.write(writer),
+                       &Message::ChannelUpdate(ref msg) => msg.write(writer),
+                       &Message::QueryShortChannelIds(ref msg) => msg.write(writer),
+                       &Message::ReplyShortChannelIdsEnd(ref msg) => msg.write(writer),
+                       &Message::QueryChannelRange(ref msg) => msg.write(writer),
+                       &Message::ReplyChannelRange(ref msg) => msg.write(writer),
+                       &Message::GossipTimestampFilter(ref msg) => msg.write(writer),
+                       &Message::Unknown(_) => { Ok(()) },
+                       &Message::Custom(ref msg) => msg.write(writer),
+               }
+       }
+}
+
+impl<T> Type for Message<T> where T: core::fmt::Debug + Type + TestEq {
        /// Returns the type that was used to decode the message payload.
-       pub fn type_id(&self) -> u16 {
+       fn type_id(&self) -> u16 {
                match self {
                        &Message::Init(ref msg) => msg.type_id(),
                        &Message::Error(ref msg) => msg.type_id(),
@@ -79,12 +170,31 @@ impl<T> Message<T> where T: core::fmt::Debug + Type {
                        &Message::Ping(ref msg) => msg.type_id(),
                        &Message::Pong(ref msg) => msg.type_id(),
                        &Message::OpenChannel(ref msg) => msg.type_id(),
+                       &Message::OpenChannelV2(ref msg) => msg.type_id(),
                        &Message::AcceptChannel(ref msg) => msg.type_id(),
+                       &Message::AcceptChannelV2(ref msg) => msg.type_id(),
                        &Message::FundingCreated(ref msg) => msg.type_id(),
                        &Message::FundingSigned(ref msg) => msg.type_id(),
-                       &Message::FundingLocked(ref msg) => msg.type_id(),
+                       &Message::Stfu(ref msg) => msg.type_id(),
+                       #[cfg(splicing)]
+                       &Message::Splice(ref msg) => msg.type_id(),
+                       #[cfg(splicing)]
+                       &Message::SpliceAck(ref msg) => msg.type_id(),
+                       #[cfg(splicing)]
+                       &Message::SpliceLocked(ref msg) => msg.type_id(),
+                       &Message::TxAddInput(ref msg) => msg.type_id(),
+                       &Message::TxAddOutput(ref msg) => msg.type_id(),
+                       &Message::TxRemoveInput(ref msg) => msg.type_id(),
+                       &Message::TxRemoveOutput(ref msg) => msg.type_id(),
+                       &Message::TxComplete(ref msg) => msg.type_id(),
+                       &Message::TxSignatures(ref msg) => msg.type_id(),
+                       &Message::TxInitRbf(ref msg) => msg.type_id(),
+                       &Message::TxAckRbf(ref msg) => msg.type_id(),
+                       &Message::TxAbort(ref msg) => msg.type_id(),
+                       &Message::ChannelReady(ref msg) => msg.type_id(),
                        &Message::Shutdown(ref msg) => msg.type_id(),
                        &Message::ClosingSigned(ref msg) => msg.type_id(),
+                       &Message::OnionMessage(ref msg) => msg.type_id(),
                        &Message::UpdateAddHTLC(ref msg) => msg.type_id(),
                        &Message::UpdateFulfillHTLC(ref msg) => msg.type_id(),
                        &Message::UpdateFailHTLC(ref msg) => msg.type_id(),
@@ -106,7 +216,9 @@ impl<T> Message<T> where T: core::fmt::Debug + Type {
                        &Message::Custom(ref msg) => msg.type_id(),
                }
        }
+}
 
+impl<T> Message<T> where T: core::fmt::Debug + Type + TestEq {
        /// Returns whether the message's type is even, indicating both endpoints must support it.
        pub fn is_even(&self) -> bool {
                (self.type_id() & 1) == 0
@@ -118,7 +230,7 @@ impl<T> Message<T> where T: core::fmt::Debug + Type {
 ///
 /// # Errors
 ///
-/// Returns an error if the message payload code not be decoded as the specified type.
+/// Returns an error if the message payload could not be decoded as the specified type.
 pub(crate) fn read<R: io::Read, T, H: core::ops::Deref>(buffer: &mut R, custom_reader: H)
 -> Result<Message<T>, (msgs::DecodeError, Option<u16>)> where
        T: core::fmt::Debug + Type + Writeable,
@@ -152,17 +264,65 @@ fn do_read<R: io::Read, T, H: core::ops::Deref>(buffer: &mut R, message_type: u1
                msgs::OpenChannel::TYPE => {
                        Ok(Message::OpenChannel(Readable::read(buffer)?))
                },
+               msgs::OpenChannelV2::TYPE => {
+                       Ok(Message::OpenChannelV2(Readable::read(buffer)?))
+               },
                msgs::AcceptChannel::TYPE => {
                        Ok(Message::AcceptChannel(Readable::read(buffer)?))
                },
+               msgs::AcceptChannelV2::TYPE => {
+                       Ok(Message::AcceptChannelV2(Readable::read(buffer)?))
+               },
                msgs::FundingCreated::TYPE => {
                        Ok(Message::FundingCreated(Readable::read(buffer)?))
                },
                msgs::FundingSigned::TYPE => {
                        Ok(Message::FundingSigned(Readable::read(buffer)?))
                },
-               msgs::FundingLocked::TYPE => {
-                       Ok(Message::FundingLocked(Readable::read(buffer)?))
+               #[cfg(splicing)]
+               msgs::Splice::TYPE => {
+                       Ok(Message::Splice(Readable::read(buffer)?))
+               },
+               msgs::Stfu::TYPE => {
+                       Ok(Message::Stfu(Readable::read(buffer)?))
+               },
+               #[cfg(splicing)]
+               msgs::SpliceAck::TYPE => {
+                       Ok(Message::SpliceAck(Readable::read(buffer)?))
+               },
+               #[cfg(splicing)]
+               msgs::SpliceLocked::TYPE => {
+                       Ok(Message::SpliceLocked(Readable::read(buffer)?))
+               },
+               msgs::TxAddInput::TYPE => {
+                       Ok(Message::TxAddInput(Readable::read(buffer)?))
+               },
+               msgs::TxAddOutput::TYPE => {
+                       Ok(Message::TxAddOutput(Readable::read(buffer)?))
+               },
+               msgs::TxRemoveInput::TYPE => {
+                       Ok(Message::TxRemoveInput(Readable::read(buffer)?))
+               },
+               msgs::TxRemoveOutput::TYPE => {
+                       Ok(Message::TxRemoveOutput(Readable::read(buffer)?))
+               },
+               msgs::TxComplete::TYPE => {
+                       Ok(Message::TxComplete(Readable::read(buffer)?))
+               },
+               msgs::TxSignatures::TYPE => {
+                       Ok(Message::TxSignatures(Readable::read(buffer)?))
+               },
+               msgs::TxInitRbf::TYPE => {
+                       Ok(Message::TxInitRbf(Readable::read(buffer)?))
+               },
+               msgs::TxAckRbf::TYPE => {
+                       Ok(Message::TxAckRbf(Readable::read(buffer)?))
+               },
+               msgs::TxAbort::TYPE => {
+                       Ok(Message::TxAbort(Readable::read(buffer)?))
+               },
+               msgs::ChannelReady::TYPE => {
+                       Ok(Message::ChannelReady(Readable::read(buffer)?))
                },
                msgs::Shutdown::TYPE => {
                        Ok(Message::Shutdown(Readable::read(buffer)?))
@@ -170,6 +330,9 @@ fn do_read<R: io::Read, T, H: core::ops::Deref>(buffer: &mut R, message_type: u1
                msgs::ClosingSigned::TYPE => {
                        Ok(Message::ClosingSigned(Readable::read(buffer)?))
                },
+               msgs::OnionMessage::TYPE => {
+                       Ok(Message::OnionMessage(Readable::read(buffer)?))
+               },
                msgs::UpdateAddHTLC::TYPE => {
                        Ok(Message::UpdateAddHTLC(Readable::read(buffer)?))
                },
@@ -252,6 +415,7 @@ mod encode {
 
 pub(crate) use self::encode::Encode;
 
+#[cfg(not(test))]
 /// Defines a type identifier for sending messages over the wire.
 ///
 /// Messages implementing this trait specify a type and must be [`Writeable`].
@@ -260,10 +424,28 @@ pub trait Type: core::fmt::Debug + Writeable {
        fn type_id(&self) -> u16;
 }
 
+#[cfg(test)]
+pub trait Type: core::fmt::Debug + Writeable + PartialEq {
+       fn type_id(&self) -> u16;
+}
+
+#[cfg(any(feature = "_test_utils", fuzzing, test))]
+impl Type for () {
+       fn type_id(&self) -> u16 { unreachable!(); }
+}
+
+#[cfg(test)]
+impl<T: core::fmt::Debug + Writeable + PartialEq> Type for T where T: Encode {
+       fn type_id(&self) -> u16 { T::TYPE }
+}
+
+#[cfg(not(test))]
 impl<T: core::fmt::Debug + Writeable> Type for T where T: Encode {
-       fn type_id(&self) -> u16 {
-               T::TYPE
-       }
+       fn type_id(&self) -> u16 { T::TYPE }
+}
+
+impl Encode for msgs::Stfu {
+       const TYPE: u16 = 2;
 }
 
 impl Encode for msgs::Init {
@@ -302,7 +484,7 @@ impl Encode for msgs::FundingSigned {
        const TYPE: u16 = 35;
 }
 
-impl Encode for msgs::FundingLocked {
+impl Encode for msgs::ChannelReady {
        const TYPE: u16 = 36;
 }
 
@@ -314,6 +496,67 @@ impl Encode for msgs::ClosingSigned {
        const TYPE: u16 = 39;
 }
 
+impl Encode for msgs::OpenChannelV2 {
+       const TYPE: u16 = 64;
+}
+
+impl Encode for msgs::AcceptChannelV2 {
+       const TYPE: u16 = 65;
+}
+
+impl Encode for msgs::Splice {
+       // TODO(splicing) Double check with finalized spec; draft spec contains 74, which is probably wrong as it is used by tx_Abort; CLN uses 75
+       const TYPE: u16 = 75;
+}
+
+impl Encode for msgs::SpliceAck {
+       const TYPE: u16 = 76;
+}
+
+impl Encode for msgs::SpliceLocked {
+       const TYPE: u16 = 77;
+}
+
+impl Encode for msgs::TxAddInput {
+       const TYPE: u16 = 66;
+}
+
+impl Encode for msgs::TxAddOutput {
+       const TYPE: u16 = 67;
+}
+
+impl Encode for msgs::TxRemoveInput {
+       const TYPE: u16 = 68;
+}
+
+impl Encode for msgs::TxRemoveOutput {
+       const TYPE: u16 = 69;
+}
+
+impl Encode for msgs::TxComplete {
+       const TYPE: u16 = 70;
+}
+
+impl Encode for msgs::TxSignatures {
+       const TYPE: u16 = 71;
+}
+
+impl Encode for msgs::TxInitRbf {
+       const TYPE: u16 = 72;
+}
+
+impl Encode for msgs::TxAckRbf {
+       const TYPE: u16 = 73;
+}
+
+impl Encode for msgs::TxAbort {
+       const TYPE: u16 = 74;
+}
+
+impl Encode for msgs::OnionMessage {
+       const TYPE: u16 = 513;
+}
+
 impl Encode for msgs::UpdateAddHTLC {
        const TYPE: u16 = 128;
 }
@@ -385,9 +628,8 @@ impl Encode for msgs::GossipTimestampFilter {
 #[cfg(test)]
 mod tests {
        use super::*;
-       use prelude::*;
-       use core::convert::TryInto;
-       use ::ln::peer_handler::IgnoringMessageHandler;
+       use crate::prelude::*;
+       use crate::ln::peer_handler::IgnoringMessageHandler;
 
        // Big-endian wire encoding of Pong message (type = 19, byteslen = 2).
        const ENCODED_PONG: [u8; 6] = [0u8, 19u8, 0u8, 2u8, 0u8, 0u8];
@@ -471,10 +713,6 @@ mod tests {
                }
        }
 
-       impl Type for () {
-               fn type_id(&self) -> u16 { unreachable!(); }
-       }
-
        #[test]
        fn is_even_message_type() {
                let message = Message::<()>::Unknown(42);