Move message type parity logic to the wire module
authorJeffrey Czyz <jkczyz@gmail.com>
Fri, 24 Jan 2020 14:43:58 +0000 (06:43 -0800)
committerJeffrey Czyz <jkczyz@gmail.com>
Wed, 5 Feb 2020 20:13:13 +0000 (12:13 -0800)
Create a MessageType abstraction and use it throughout the wire module's
external interfaces. Include an is_even method for clients to determine
how to handle unknown messages.

lightning/src/ln/peer_handler.rs
lightning/src/ln/wire.rs

index 30e208155c5699596ee92988b37790f10ff904a2..e25e50c28b5f1190edea81c429bae659b6a0c092 100644 (file)
@@ -773,12 +773,11 @@ impl<Descriptor: SocketDescriptor, CM: Deref> PeerManager<Descriptor, CM> where
                                                                                        },
 
                                                                                        // Unknown messages:
-                                                                                       wire::Message::Unknown(msg_type) => {
+                                                                                       wire::Message::Unknown(msg_type) if msg_type.is_even() => {
                                                                                                // Fail the channel if message is an even, unknown type as per BOLT #1.
-                                                                                               if (msg_type & 1) == 0 {
-                                                                                                       return Err(PeerHandleError{ no_connection_possible: true });
-                                                                                               }
+                                                                                               return Err(PeerHandleError{ no_connection_possible: true });
                                                                                        },
+                                                                                       wire::Message::Unknown(_) => {},
                                                                                }
                                                                        }
                                                                }
index 963eaeb54a7f22ce780f23c0e7d705ab69cc8fdd..9dedf580f5651ad6732e3958fbf75c22603d1ea5 100644 (file)
@@ -47,12 +47,18 @@ pub enum Message {
        NodeAnnouncement(msgs::NodeAnnouncement),
        ChannelUpdate(msgs::ChannelUpdate),
        /// A message that could not be decoded because its type is unknown.
-       Unknown(u16),
+       Unknown(MessageType),
+}
+
+/// A number identifying a message to determine how it is encoded on the wire.
+#[derive(Clone, Copy)]
+pub struct MessageType {
+       number: u16,
 }
 
 impl Message {
        /// Returns the type that was used to decode the message payload.
-       pub fn type_id(&self) -> u16 {
+       pub fn type_id(&self) -> MessageType {
                match self {
                        &Message::Init(ref msg) => msg.type_id(),
                        &Message::Error(ref msg) => msg.type_id(),
@@ -82,6 +88,19 @@ impl Message {
        }
 }
 
+impl MessageType {
+       /// Returns whether the message type is even, indicating both endpoints must support it.
+       pub fn is_even(&self) -> bool {
+               (self.number & 1) == 0
+       }
+}
+
+impl ::std::fmt::Display for MessageType {
+       fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+               write!(f, "{}", self.number)
+       }
+}
+
 /// Reads a message from the data buffer consisting of a 2-byte big-endian type and a
 /// variable-length payload conforming to the type.
 ///
@@ -161,7 +180,7 @@ pub fn read<R: ::std::io::Read>(buffer: &mut R) -> Result<Message, msgs::DecodeE
                        Ok(Message::ChannelUpdate(Readable::read(buffer)?))
                },
                _ => {
-                       Ok(Message::Unknown(message_type))
+                       Ok(Message::Unknown(MessageType { number: message_type }))
                },
        }
 }
@@ -189,8 +208,8 @@ pub trait Encode {
 
        /// Returns the type identifying the message payload. Convenience method for accessing
        /// [`TYPE`](TYPE).
-       fn type_id(&self) -> u16 {
-               Self::TYPE
+       fn type_id(&self) -> MessageType {
+               MessageType { number: Self::TYPE }
        }
 }
 
@@ -339,7 +358,7 @@ mod tests {
                let mut reader = ::std::io::Cursor::new(buffer);
                let message = read(&mut reader).unwrap();
                match message {
-                       Message::Unknown(::std::u16::MAX) => (),
+                       Message::Unknown(MessageType { number: ::std::u16::MAX }) => (),
                        _ => panic!("Expected message type {}; found: {}", ::std::u16::MAX, message.type_id()),
                }
        }
@@ -372,4 +391,16 @@ mod tests {
                        _ => panic!("Expected pong message; found message type: {}", decoded_message.type_id()),
                }
        }
+
+       #[test]
+       fn is_even_message_type() {
+               let message = Message::Unknown(MessageType { number: 42 });
+               assert!(message.type_id().is_even());
+       }
+
+       #[test]
+       fn is_odd_message_type() {
+               let message = Message::Unknown(MessageType { number: 43 });
+               assert!(!message.type_id().is_even());
+       }
 }