Require `PartialEq` for `wire::Message` in `cfg(test)`
authorMatt Corallo <git@bluematt.me>
Tue, 12 Apr 2022 19:05:15 +0000 (19:05 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 10 May 2022 23:40:20 +0000 (23:40 +0000)
...and implement wire::Type for `()` for `feature = "_test_utils"`.

lightning/src/ln/wire.rs

index e7db446a86f17f036ceeccba9aca52c4aa7ea652..8fd5c16f36261ce3782d72bb81846b3e5b503376 100644 (file)
@@ -28,11 +28,26 @@ pub trait CustomMessageReader {
        fn read<R: io::Read>(&self, message_type: u16, buffer: &mut R) -> Result<Option<Self::CustomMessage>, msgs::DecodeError>;
 }
 
+// TestEq is a dummy trait which requires PartialEq when built in testing, and otherwise is
+// blanket-implemented for all types.
+
+#[cfg(test)]
+pub trait TestEq : PartialEq {}
+#[cfg(test)]
+impl<T: PartialEq> TestEq for T {}
+
+#[cfg(not(test))]
+pub(crate) trait TestEq {}
+#[cfg(not(test))]
+impl<T> TestEq for T {}
+
+
 /// A Lightning message returned by [`read()`] when decoding bytes received over the wire. Each
 /// variant contains a message from [`msgs`] or otherwise the message type if unknown.
 #[allow(missing_docs)]
 #[derive(Debug)]
-pub(crate) enum Message<T> where T: core::fmt::Debug + Type {
+#[cfg_attr(test, derive(PartialEq))]
+pub(crate) enum Message<T> where T: core::fmt::Debug + Type + TestEq {
        Init(msgs::Init),
        Error(msgs::ErrorMessage),
        Warning(msgs::WarningMessage),
@@ -69,7 +84,7 @@ pub(crate) enum Message<T> where T: core::fmt::Debug + Type {
        Custom(T),
 }
 
-impl<T> Message<T> where T: core::fmt::Debug + Type {
+impl<T> Message<T> where T: core::fmt::Debug + Type + TestEq {
        /// Returns the type that was used to decode the message payload.
        pub fn type_id(&self) -> u16 {
                match self {
@@ -252,6 +267,7 @@ mod encode {
 
 pub(crate) use self::encode::Encode;
 
+#[cfg(not(test))]
 /// Defines a type identifier for sending messages over the wire.
 ///
 /// Messages implementing this trait specify a type and must be [`Writeable`].
@@ -260,10 +276,24 @@ pub trait Type: core::fmt::Debug + Writeable {
        fn type_id(&self) -> u16;
 }
 
+#[cfg(test)]
+pub trait Type: core::fmt::Debug + Writeable + PartialEq {
+       fn type_id(&self) -> u16;
+}
+
+#[cfg(any(feature = "_test_utils", fuzzing, test))]
+impl Type for () {
+       fn type_id(&self) -> u16 { unreachable!(); }
+}
+
+#[cfg(test)]
+impl<T: core::fmt::Debug + Writeable + PartialEq> Type for T where T: Encode {
+       fn type_id(&self) -> u16 { T::TYPE }
+}
+
+#[cfg(not(test))]
 impl<T: core::fmt::Debug + Writeable> Type for T where T: Encode {
-       fn type_id(&self) -> u16 {
-               T::TYPE
-       }
+       fn type_id(&self) -> u16 { T::TYPE }
 }
 
 impl Encode for msgs::Init {
@@ -471,10 +501,6 @@ mod tests {
                }
        }
 
-       impl Type for () {
-               fn type_id(&self) -> u16 { unreachable!(); }
-       }
-
        #[test]
        fn is_even_message_type() {
                let message = Message::<()>::Unknown(42);