Merge pull request #1763 from gcomte/feature/derive-eq
[rust-lightning] / lightning / src / ln / wire.rs
index 05bdb5694d8b4c4202348e4623917af41fe0e3e8..1191a8d3d531477977cad80776bb29f6c49ed2c8 100644 (file)
@@ -9,8 +9,8 @@
 
 //! Wire encoding/decoding for Lightning messages according to [BOLT #1], and for
 //! custom message through the [`CustomMessageReader`] trait.
-//! 
-//! [BOLT #1]: https://github.com/lightningnetwork/lightning-rfc/blob/master/01-messaging.md
+//!
+//! [BOLT #1]: https://github.com/lightning/bolts/blob/master/01-messaging.md
 
 use io;
 use ln::msgs;
@@ -20,7 +20,7 @@ use util::ser::{Readable, Writeable, Writer};
 /// decoders.
 pub trait CustomMessageReader {
        /// The type of the message decoded by the implementation.
-       type CustomMessage: core::fmt::Debug + Type + Writeable;
+       type CustomMessage: Type;
        /// Decodes a custom message to `CustomMessageType`. If the given message type is known to the
        /// implementation and the message could be decoded, must return `Ok(Some(message))`. If the
        /// message type is unknown to the implementation, must return `Ok(None)`. If a decoding error
@@ -28,22 +28,39 @@ pub trait CustomMessageReader {
        fn read<R: io::Read>(&self, message_type: u16, buffer: &mut R) -> Result<Option<Self::CustomMessage>, msgs::DecodeError>;
 }
 
+// TestEq is a dummy trait which requires PartialEq when built in testing, and otherwise is
+// blanket-implemented for all types.
+
+#[cfg(test)]
+pub trait TestEq : PartialEq {}
+#[cfg(test)]
+impl<T: PartialEq> TestEq for T {}
+
+#[cfg(not(test))]
+pub(crate) trait TestEq {}
+#[cfg(not(test))]
+impl<T> TestEq for T {}
+
+
 /// A Lightning message returned by [`read()`] when decoding bytes received over the wire. Each
 /// variant contains a message from [`msgs`] or otherwise the message type if unknown.
 #[allow(missing_docs)]
 #[derive(Debug)]
-pub(crate) enum Message<T> where T: core::fmt::Debug + Type {
+#[cfg_attr(test, derive(PartialEq))]
+pub(crate) enum Message<T> where T: core::fmt::Debug + Type + TestEq {
        Init(msgs::Init),
        Error(msgs::ErrorMessage),
+       Warning(msgs::WarningMessage),
        Ping(msgs::Ping),
        Pong(msgs::Pong),
        OpenChannel(msgs::OpenChannel),
        AcceptChannel(msgs::AcceptChannel),
        FundingCreated(msgs::FundingCreated),
        FundingSigned(msgs::FundingSigned),
-       FundingLocked(msgs::FundingLocked),
+       ChannelReady(msgs::ChannelReady),
        Shutdown(msgs::Shutdown),
        ClosingSigned(msgs::ClosingSigned),
+       OnionMessage(msgs::OnionMessage),
        UpdateAddHTLC(msgs::UpdateAddHTLC),
        UpdateFulfillHTLC(msgs::UpdateFulfillHTLC),
        UpdateFailHTLC(msgs::UpdateFailHTLC),
@@ -62,32 +79,29 @@ pub(crate) enum Message<T> where T: core::fmt::Debug + Type {
        ReplyChannelRange(msgs::ReplyChannelRange),
        GossipTimestampFilter(msgs::GossipTimestampFilter),
        /// A message that could not be decoded because its type is unknown.
-       Unknown(MessageType),
+       Unknown(u16),
        /// A message that was produced by a [`CustomMessageReader`] and is to be handled by a
        /// [`::ln::peer_handler::CustomMessageHandler`].
        Custom(T),
 }
 
-/// A number identifying a message to determine how it is encoded on the wire.
-#[derive(Clone, Copy, Debug)]
-pub struct MessageType(u16);
-
-impl<T> Message<T> where T: core::fmt::Debug + Type {
-       #[allow(dead_code)] // This method is only used in tests
+impl<T> Message<T> where T: core::fmt::Debug + Type + TestEq {
        /// Returns the type that was used to decode the message payload.
-       pub fn type_id(&self) -> MessageType {
+       pub fn type_id(&self) -> u16 {
                match self {
                        &Message::Init(ref msg) => msg.type_id(),
                        &Message::Error(ref msg) => msg.type_id(),
+                       &Message::Warning(ref msg) => msg.type_id(),
                        &Message::Ping(ref msg) => msg.type_id(),
                        &Message::Pong(ref msg) => msg.type_id(),
                        &Message::OpenChannel(ref msg) => msg.type_id(),
                        &Message::AcceptChannel(ref msg) => msg.type_id(),
                        &Message::FundingCreated(ref msg) => msg.type_id(),
                        &Message::FundingSigned(ref msg) => msg.type_id(),
-                       &Message::FundingLocked(ref msg) => msg.type_id(),
+                       &Message::ChannelReady(ref msg) => msg.type_id(),
                        &Message::Shutdown(ref msg) => msg.type_id(),
                        &Message::ClosingSigned(ref msg) => msg.type_id(),
+                       &Message::OnionMessage(ref msg) => msg.type_id(),
                        &Message::UpdateAddHTLC(ref msg) => msg.type_id(),
                        &Message::UpdateFulfillHTLC(ref msg) => msg.type_id(),
                        &Message::UpdateFailHTLC(ref msg) => msg.type_id(),
@@ -109,18 +123,10 @@ impl<T> Message<T> where T: core::fmt::Debug + Type {
                        &Message::Custom(ref msg) => msg.type_id(),
                }
        }
-}
 
-impl MessageType {
-       /// Returns whether the message type is even, indicating both endpoints must support it.
+       /// Returns whether the message's type is even, indicating both endpoints must support it.
        pub fn is_even(&self) -> bool {
-               (self.0 & 1) == 0
-       }
-}
-
-impl ::core::fmt::Display for MessageType {
-       fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
-               write!(f, "{}", self.0)
+               (self.type_id() & 1) == 0
        }
 }
 
@@ -130,15 +136,20 @@ impl ::core::fmt::Display for MessageType {
 /// # Errors
 ///
 /// Returns an error if the message payload code not be decoded as the specified type.
-pub(crate) fn read<R: io::Read, T, H: core::ops::Deref>(
-       buffer: &mut R,
-       custom_reader: H,
-) -> Result<Message<T>, msgs::DecodeError>
-where
+pub(crate) fn read<R: io::Read, T, H: core::ops::Deref>(buffer: &mut R, custom_reader: H)
+-> Result<Message<T>, (msgs::DecodeError, Option<u16>)> where
+       T: core::fmt::Debug + Type + Writeable,
+       H::Target: CustomMessageReader<CustomMessage = T>,
+{
+       let message_type = <u16 as Readable>::read(buffer).map_err(|e| (e, None))?;
+       do_read(buffer, message_type, custom_reader).map_err(|e| (e, Some(message_type)))
+}
+
+fn do_read<R: io::Read, T, H: core::ops::Deref>(buffer: &mut R, message_type: u16, custom_reader: H)
+-> Result<Message<T>, msgs::DecodeError> where
        T: core::fmt::Debug + Type + Writeable,
        H::Target: CustomMessageReader<CustomMessage = T>,
 {
-       let message_type = <u16 as Readable>::read(buffer)?;
        match message_type {
                msgs::Init::TYPE => {
                        Ok(Message::Init(Readable::read(buffer)?))
@@ -146,6 +157,9 @@ where
                msgs::ErrorMessage::TYPE => {
                        Ok(Message::Error(Readable::read(buffer)?))
                },
+               msgs::WarningMessage::TYPE => {
+                       Ok(Message::Warning(Readable::read(buffer)?))
+               },
                msgs::Ping::TYPE => {
                        Ok(Message::Ping(Readable::read(buffer)?))
                },
@@ -164,8 +178,8 @@ where
                msgs::FundingSigned::TYPE => {
                        Ok(Message::FundingSigned(Readable::read(buffer)?))
                },
-               msgs::FundingLocked::TYPE => {
-                       Ok(Message::FundingLocked(Readable::read(buffer)?))
+               msgs::ChannelReady::TYPE => {
+                       Ok(Message::ChannelReady(Readable::read(buffer)?))
                },
                msgs::Shutdown::TYPE => {
                        Ok(Message::Shutdown(Readable::read(buffer)?))
@@ -173,6 +187,9 @@ where
                msgs::ClosingSigned::TYPE => {
                        Ok(Message::ClosingSigned(Readable::read(buffer)?))
                },
+               msgs::OnionMessage::TYPE => {
+                       Ok(Message::OnionMessage(Readable::read(buffer)?))
+               },
                msgs::UpdateAddHTLC::TYPE => {
                        Ok(Message::UpdateAddHTLC(Readable::read(buffer)?))
                },
@@ -228,7 +245,7 @@ where
                        if let Some(custom) = custom_reader.read(message_type, buffer)? {
                                Ok(Message::Custom(custom))
                        } else {
-                               Ok(Message::Unknown(MessageType(message_type)))
+                               Ok(Message::Unknown(message_type))
                        }
                },
        }
@@ -241,7 +258,7 @@ where
 ///
 /// Returns an I/O error if the write could not be completed.
 pub(crate) fn write<M: Type + Writeable, W: Writer>(message: &M, buffer: &mut W) -> Result<(), io::Error> {
-       message.type_id().0.write(buffer)?;
+       message.type_id().write(buffer)?;
        message.write(buffer)
 }
 
@@ -255,18 +272,33 @@ mod encode {
 
 pub(crate) use self::encode::Encode;
 
+#[cfg(not(test))]
 /// Defines a type identifier for sending messages over the wire.
 ///
 /// Messages implementing this trait specify a type and must be [`Writeable`].
-pub trait Type {
+pub trait Type: core::fmt::Debug + Writeable {
        /// Returns the type identifying the message payload.
-       fn type_id(&self) -> MessageType;
+       fn type_id(&self) -> u16;
 }
 
-impl<T> Type for T where T: Encode {
-       fn type_id(&self) -> MessageType {
-               MessageType(T::TYPE)
-       }
+#[cfg(test)]
+pub trait Type: core::fmt::Debug + Writeable + PartialEq {
+       fn type_id(&self) -> u16;
+}
+
+#[cfg(any(feature = "_test_utils", fuzzing, test))]
+impl Type for () {
+       fn type_id(&self) -> u16 { unreachable!(); }
+}
+
+#[cfg(test)]
+impl<T: core::fmt::Debug + Writeable + PartialEq> Type for T where T: Encode {
+       fn type_id(&self) -> u16 { T::TYPE }
+}
+
+#[cfg(not(test))]
+impl<T: core::fmt::Debug + Writeable> Type for T where T: Encode {
+       fn type_id(&self) -> u16 { T::TYPE }
 }
 
 impl Encode for msgs::Init {
@@ -277,6 +309,10 @@ impl Encode for msgs::ErrorMessage {
        const TYPE: u16 = 17;
 }
 
+impl Encode for msgs::WarningMessage {
+       const TYPE: u16 = 1;
+}
+
 impl Encode for msgs::Ping {
        const TYPE: u16 = 18;
 }
@@ -301,7 +337,7 @@ impl Encode for msgs::FundingSigned {
        const TYPE: u16 = 35;
 }
 
-impl Encode for msgs::FundingLocked {
+impl Encode for msgs::ChannelReady {
        const TYPE: u16 = 36;
 }
 
@@ -313,6 +349,10 @@ impl Encode for msgs::ClosingSigned {
        const TYPE: u16 = 39;
 }
 
+impl Encode for msgs::OnionMessage {
+       const TYPE: u16 = 513;
+}
+
 impl Encode for msgs::UpdateAddHTLC {
        const TYPE: u16 = 128;
 }
@@ -436,7 +476,7 @@ mod tests {
                let mut reader = io::Cursor::new(buffer);
                let message = read(&mut reader, &IgnoringMessageHandler{}).unwrap();
                match message {
-                       Message::Unknown(MessageType(::core::u16::MAX)) => (),
+                       Message::Unknown(::core::u16::MAX) => (),
                        _ => panic!("Expected message type {}; found: {}", ::core::u16::MAX, message.type_id()),
                }
        }
@@ -472,14 +512,14 @@ mod tests {
 
        #[test]
        fn is_even_message_type() {
-               let message = Message::<()>::Unknown(MessageType(42));
-               assert!(message.type_id().is_even());
+               let message = Message::<()>::Unknown(42);
+               assert!(message.is_even());
        }
 
        #[test]
        fn is_odd_message_type() {
-               let message = Message::<()>::Unknown(MessageType(43));
-               assert!(!message.type_id().is_even());
+               let message = Message::<()>::Unknown(43);
+               assert!(!message.is_even());
        }
 
        #[test]
@@ -500,7 +540,7 @@ mod tests {
                let mut reader = io::Cursor::new(buffer);
                let decoded_msg = read(&mut reader, &IgnoringMessageHandler{}).unwrap();
                match decoded_msg {
-                       Message::Init(msgs::Init { features }) => {
+                       Message::Init(msgs::Init { features, .. }) => {
                                assert!(features.supports_variable_length_onion());
                                assert!(features.supports_upfront_shutdown_script());
                                assert!(features.supports_gossip_queries());
@@ -549,8 +589,8 @@ mod tests {
        const CUSTOM_MESSAGE_TYPE : u16 = 9000;
 
        impl Type for TestCustomMessage {
-               fn type_id(&self) -> MessageType {
-                       MessageType(CUSTOM_MESSAGE_TYPE)
+               fn type_id(&self) -> u16 {
+                       CUSTOM_MESSAGE_TYPE
                }
        }
 
@@ -584,7 +624,7 @@ mod tests {
                let decoded_msg = read(&mut reader, &TestCustomMessageReader{}).unwrap();
                match decoded_msg {
                        Message::Custom(custom) => {
-                               assert_eq!(custom.type_id().0, CUSTOM_MESSAGE_TYPE);
+                               assert_eq!(custom.type_id(), CUSTOM_MESSAGE_TYPE);
                                assert_eq!(custom, TestCustomMessage {});
                        },
                        _ => panic!("Expected custom message, found message type: {}", decoded_msg.type_id()),