From 45c1411b161562e9a4db7feda07646d72df7bafc Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Tue, 12 Apr 2022 19:05:15 +0000 Subject: [PATCH] Require `PartialEq` for `wire::Message` in `cfg(test)` ...and implement wire::Type for `()` for `feature = "_test_utils"`. --- lightning/src/ln/wire.rs | 44 ++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/lightning/src/ln/wire.rs b/lightning/src/ln/wire.rs index e7db446a..8fd5c16f 100644 --- a/lightning/src/ln/wire.rs +++ b/lightning/src/ln/wire.rs @@ -28,11 +28,26 @@ pub trait CustomMessageReader { fn read(&self, message_type: u16, buffer: &mut R) -> Result, msgs::DecodeError>; } +// TestEq is a dummy trait which requires PartialEq when built in testing, and otherwise is +// blanket-implemented for all types. + +#[cfg(test)] +pub trait TestEq : PartialEq {} +#[cfg(test)] +impl TestEq for T {} + +#[cfg(not(test))] +pub(crate) trait TestEq {} +#[cfg(not(test))] +impl TestEq for T {} + + /// A Lightning message returned by [`read()`] when decoding bytes received over the wire. Each /// variant contains a message from [`msgs`] or otherwise the message type if unknown. #[allow(missing_docs)] #[derive(Debug)] -pub(crate) enum Message where T: core::fmt::Debug + Type { +#[cfg_attr(test, derive(PartialEq))] +pub(crate) enum Message where T: core::fmt::Debug + Type + TestEq { Init(msgs::Init), Error(msgs::ErrorMessage), Warning(msgs::WarningMessage), @@ -69,7 +84,7 @@ pub(crate) enum Message where T: core::fmt::Debug + Type { Custom(T), } -impl Message where T: core::fmt::Debug + Type { +impl Message where T: core::fmt::Debug + Type + TestEq { /// Returns the type that was used to decode the message payload. pub fn type_id(&self) -> u16 { match self { @@ -252,6 +267,7 @@ mod encode { pub(crate) use self::encode::Encode; +#[cfg(not(test))] /// Defines a type identifier for sending messages over the wire. /// /// Messages implementing this trait specify a type and must be [`Writeable`]. @@ -260,10 +276,24 @@ pub trait Type: core::fmt::Debug + Writeable { fn type_id(&self) -> u16; } +#[cfg(test)] +pub trait Type: core::fmt::Debug + Writeable + PartialEq { + fn type_id(&self) -> u16; +} + +#[cfg(any(feature = "_test_utils", fuzzing, test))] +impl Type for () { + fn type_id(&self) -> u16 { unreachable!(); } +} + +#[cfg(test)] +impl Type for T where T: Encode { + fn type_id(&self) -> u16 { T::TYPE } +} + +#[cfg(not(test))] impl Type for T where T: Encode { - fn type_id(&self) -> u16 { - T::TYPE - } + fn type_id(&self) -> u16 { T::TYPE } } impl Encode for msgs::Init { @@ -471,10 +501,6 @@ mod tests { } } - impl Type for () { - fn type_id(&self) -> u16 { unreachable!(); } - } - #[test] fn is_even_message_type() { let message = Message::<()>::Unknown(42); -- 2.30.2