Implement serialize/deserialize for Router.
[rust-lightning] / src / ln / msgs.rs
index 494abf1953616f8dd46f720d40dc735f4d31dfe2..4c6524afbe62aa04cb9bda89d937578c4fbc19f4 100644 (file)
@@ -26,7 +26,7 @@ use std::{cmp, fmt};
 use std::io::Read;
 use std::result::Result;
 
-use util::{byte_utils, events};
+use util::events;
 use util::ser::{Readable, Writeable, Writer};
 
 use ln::channelmanager::{PaymentPreimage, PaymentHash};
@@ -47,7 +47,6 @@ pub enum DecodeError {
        /// node_announcement included more than one address of a given type!
        ExtraAddressesPerType,
        /// A length descriptor in the packet didn't describe the later data correctly
-       /// (currently only generated in node_announcement)
        BadLengthDescriptor,
        /// Error from std::io
        Io(::std::io::Error),
@@ -336,7 +335,7 @@ pub struct AnnouncementSignatures {
 }
 
 /// An address which can be used to connect to a remote peer
-#[derive(Clone)]
+#[derive(PartialEq, Clone)]
 pub enum NetAddress {
        /// An IPv4 address/port on which the peer is listenting.
        IPv4 {
@@ -382,9 +381,84 @@ impl NetAddress {
                        &NetAddress::OnionV3 {..} => { 4 },
                }
        }
+
+       /// Strict byte-length of address descriptor, 1-byte type not recorded
+       fn len(&self) -> u16 {
+               match self {
+                       &NetAddress::IPv4 { .. } => { 6 },
+                       &NetAddress::IPv6 { .. } => { 18 },
+                       &NetAddress::OnionV2 { .. } => { 12 },
+                       &NetAddress::OnionV3 { .. } => { 37 },
+               }
+       }
 }
 
-#[derive(Clone)]
+impl Writeable for NetAddress {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+               match self {
+                       &NetAddress::IPv4 { ref addr, ref port } => {
+                               1u8.write(writer)?;
+                               addr.write(writer)?;
+                               port.write(writer)?;
+                       },
+                       &NetAddress::IPv6 { ref addr, ref port } => {
+                               2u8.write(writer)?;
+                               addr.write(writer)?;
+                               port.write(writer)?;
+                       },
+                       &NetAddress::OnionV2 { ref addr, ref port } => {
+                               3u8.write(writer)?;
+                               addr.write(writer)?;
+                               port.write(writer)?;
+                       },
+                       &NetAddress::OnionV3 { ref ed25519_pubkey, ref checksum, ref version, ref port } => {
+                               4u8.write(writer)?;
+                               ed25519_pubkey.write(writer)?;
+                               checksum.write(writer)?;
+                               version.write(writer)?;
+                               port.write(writer)?;
+                       }
+               }
+               Ok(())
+       }
+}
+
+impl<R: ::std::io::Read>  Readable<R> for Result<NetAddress, u8> {
+       fn read(reader: &mut R) -> Result<Result<NetAddress, u8>, DecodeError> {
+               let byte = <u8 as Readable<R>>::read(reader)?;
+               match byte {
+                       1 => {
+                               Ok(Ok(NetAddress::IPv4 {
+                                       addr: Readable::read(reader)?,
+                                       port: Readable::read(reader)?,
+                               }))
+                       },
+                       2 => {
+                               Ok(Ok(NetAddress::IPv6 {
+                                       addr: Readable::read(reader)?,
+                                       port: Readable::read(reader)?,
+                               }))
+                       },
+                       3 => {
+                               Ok(Ok(NetAddress::OnionV2 {
+                                       addr: Readable::read(reader)?,
+                                       port: Readable::read(reader)?,
+                               }))
+                       },
+                       4 => {
+                               Ok(Ok(NetAddress::OnionV3 {
+                                       ed25519_pubkey: Readable::read(reader)?,
+                                       checksum: Readable::read(reader)?,
+                                       version: Readable::read(reader)?,
+                                       port: Readable::read(reader)?,
+                               }))
+                       },
+                       _ => return Ok(Err(byte)),
+               }
+       }
+}
+
+#[derive(PartialEq, Clone)]
 // Only exposed as broadcast of node_announcement should be filtered by node_id
 /// The unsigned part of a node_announcement
 pub struct UnsignedNodeAnnouncement {
@@ -401,7 +475,7 @@ pub struct UnsignedNodeAnnouncement {
        pub(crate) excess_address_data: Vec<u8>,
        pub(crate) excess_data: Vec<u8>,
 }
-#[derive(Clone)]
+#[derive(PartialEq, Clone)]
 /// A node_announcement message to be sent or received from a peer
 pub struct NodeAnnouncement {
        pub(crate) signature: Signature,
@@ -1192,38 +1266,17 @@ impl Writeable for UnsignedNodeAnnouncement {
                w.write_all(&self.rgb)?;
                self.alias.write(w)?;
 
-               let mut addr_slice = Vec::with_capacity(self.addresses.len() * 18);
                let mut addrs_to_encode = self.addresses.clone();
                addrs_to_encode.sort_unstable_by(|a, b| { a.get_id().cmp(&b.get_id()) });
                addrs_to_encode.dedup_by(|a, b| { a.get_id() == b.get_id() });
-               for addr in addrs_to_encode.iter() {
-                       match addr {
-                               &NetAddress::IPv4{addr, port} => {
-                                       addr_slice.push(1);
-                                       addr_slice.extend_from_slice(&addr);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(port));
-                               },
-                               &NetAddress::IPv6{addr, port} => {
-                                       addr_slice.push(2);
-                                       addr_slice.extend_from_slice(&addr);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(port));
-                               },
-                               &NetAddress::OnionV2{addr, port} => {
-                                       addr_slice.push(3);
-                                       addr_slice.extend_from_slice(&addr);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(port));
-                               },
-                               &NetAddress::OnionV3{ed25519_pubkey, checksum, version, port} => {
-                                       addr_slice.push(4);
-                                       addr_slice.extend_from_slice(&ed25519_pubkey);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(checksum));
-                                       addr_slice.push(version);
-                                       addr_slice.extend_from_slice(&byte_utils::be16_to_array(port));
-                               },
-                       }
+               let mut addr_len = 0;
+               for addr in &addrs_to_encode {
+                       addr_len += 1 + addr.len();
+               }
+               (addr_len + self.excess_address_data.len() as u16).write(w)?;
+               for addr in addrs_to_encode {
+                       addr.write(w)?;
                }
-               ((addr_slice.len() + self.excess_address_data.len()) as u16).write(w)?;
-               w.write_all(&addr_slice[..])?;
                w.write_all(&self.excess_address_data[..])?;
                w.write_all(&self.excess_data[..])?;
                Ok(())
@@ -1242,112 +1295,77 @@ impl<R: Read> Readable<R> for UnsignedNodeAnnouncement {
                r.read_exact(&mut rgb)?;
                let alias: [u8; 32] = Readable::read(r)?;
 
-               let addrlen: u16 = Readable::read(r)?;
+               let addr_len: u16 = Readable::read(r)?;
+               let mut addresses: Vec<NetAddress> = Vec::with_capacity(4);
                let mut addr_readpos = 0;
-               let mut addresses = Vec::with_capacity(4);
-               let mut f: u8 = 0;
-               let mut excess = 0;
+               let mut excess = false;
+               let mut excess_byte = 0;
                loop {
-                       if addrlen <= addr_readpos { break; }
-                       f = Readable::read(r)?;
-                       match f {
-                               1 => {
-                                       if addresses.len() > 0 {
-                                               return Err(DecodeError::ExtraAddressesPerType);
-                                       }
-                                       if addrlen < addr_readpos + 1 + 6 {
-                                               return Err(DecodeError::BadLengthDescriptor);
-                                       }
-                                       addresses.push(NetAddress::IPv4 {
-                                               addr: {
-                                                       let mut addr = [0; 4];
-                                                       r.read_exact(&mut addr)?;
-                                                       addr
+                       if addr_len <= addr_readpos { break; }
+                       match Readable::read(r) {
+                               Ok(Ok(addr)) => {
+                                       match addr {
+                                               NetAddress::IPv4 { .. } => {
+                                                       if addresses.len() > 0 {
+                                                               return Err(DecodeError::ExtraAddressesPerType);
+                                                       }
                                                },
-                                               port: Readable::read(r)?,
-                                       });
-                                       addr_readpos += 1 + 6
-                               },
-                               2 => {
-                                       if addresses.len() > 1 || (addresses.len() == 1 && addresses[0].get_id() != 1) {
-                                               return Err(DecodeError::ExtraAddressesPerType);
-                                       }
-                                       if addrlen < addr_readpos + 1 + 18 {
-                                               return Err(DecodeError::BadLengthDescriptor);
-                                       }
-                                       addresses.push(NetAddress::IPv6 {
-                                               addr: {
-                                                       let mut addr = [0; 16];
-                                                       r.read_exact(&mut addr)?;
-                                                       addr
+                                               NetAddress::IPv6 { .. } => {
+                                                       if addresses.len() > 1 || (addresses.len() == 1 && addresses[0].get_id() != 1) {
+                                                               return Err(DecodeError::ExtraAddressesPerType);
+                                                       }
                                                },
-                                               port: Readable::read(r)?,
-                                       });
-                                       addr_readpos += 1 + 18
-                               },
-                               3 => {
-                                       if addresses.len() > 2 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 2) {
-                                               return Err(DecodeError::ExtraAddressesPerType);
-                                       }
-                                       if addrlen < addr_readpos + 1 + 12 {
-                                               return Err(DecodeError::BadLengthDescriptor);
-                                       }
-                                       addresses.push(NetAddress::OnionV2 {
-                                               addr: {
-                                                       let mut addr = [0; 10];
-                                                       r.read_exact(&mut addr)?;
-                                                       addr
+                                               NetAddress::OnionV2 { .. } => {
+                                                       if addresses.len() > 2 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 2) {
+                                                               return Err(DecodeError::ExtraAddressesPerType);
+                                                       }
+                                               },
+                                               NetAddress::OnionV3 { .. } => {
+                                                       if addresses.len() > 3 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 3) {
+                                                               return Err(DecodeError::ExtraAddressesPerType);
+                                                       }
                                                },
-                                               port: Readable::read(r)?,
-                                       });
-                                       addr_readpos += 1 + 12
-                               },
-                               4 => {
-                                       if addresses.len() > 3 || (addresses.len() > 0 && addresses.last().unwrap().get_id() > 3) {
-                                               return Err(DecodeError::ExtraAddressesPerType);
                                        }
-                                       if addrlen < addr_readpos + 1 + 37 {
+                                       if addr_len < addr_readpos + 1 + addr.len() {
                                                return Err(DecodeError::BadLengthDescriptor);
                                        }
-                                       addresses.push(NetAddress::OnionV3 {
-                                               ed25519_pubkey: Readable::read(r)?,
-                                               checksum: Readable::read(r)?,
-                                               version: Readable::read(r)?,
-                                               port: Readable::read(r)?,
-                                       });
-                                       addr_readpos += 1 + 37
+                                       addr_readpos += (1 + addr.len()) as u16;
+                                       addresses.push(addr);
                                },
-                               _ => { excess = 1; break; }
+                               Ok(Err(unknown_descriptor)) => {
+                                       excess = true;
+                                       excess_byte = unknown_descriptor;
+                                       break;
+                               },
+                               Err(DecodeError::ShortRead) => return Err(DecodeError::BadLengthDescriptor),
+                               Err(e) => return Err(e),
                        }
                }
 
                let mut excess_data = vec![];
-               let excess_address_data = if addr_readpos < addrlen {
-                       let mut excess_address_data = vec![0; (addrlen - addr_readpos) as usize];
-                       r.read_exact(&mut excess_address_data[excess..])?;
-                       if excess == 1 {
-                               excess_address_data[0] = f;
+               let excess_address_data = if addr_readpos < addr_len {
+                       let mut excess_address_data = vec![0; (addr_len - addr_readpos) as usize];
+                       r.read_exact(&mut excess_address_data[if excess { 1 } else { 0 }..])?;
+                       if excess {
+                               excess_address_data[0] = excess_byte;
                        }
                        excess_address_data
                } else {
-                       if excess == 1 {
-                               excess_data.push(f);
+                       if excess {
+                               excess_data.push(excess_byte);
                        }
                        Vec::new()
                };
-
+               r.read_to_end(&mut excess_data)?;
                Ok(UnsignedNodeAnnouncement {
-                       features: features,
-                       timestamp: timestamp,
-                       node_id: node_id,
-                       rgb: rgb,
-                       alias: alias,
-                       addresses: addresses,
-                       excess_address_data: excess_address_data,
-                       excess_data: {
-                               r.read_to_end(&mut excess_data)?;
-                               excess_data
-                       },
+                       features,
+                       timestamp,
+                       node_id,
+                       rgb,
+                       alias,
+                       addresses,
+                       excess_address_data,
+                       excess_data,
                })
        }
 }