Merge pull request #298 from TheBlueMatt/2019-01-271-cleanup
[rust-lightning] / src / ln / router.rs
index 3de73ccc059ae8159b09f856ca4508fbedf7a278..f4a702fb709387f859b5fe54c82457dc5229ded1 100644 (file)
@@ -4,7 +4,7 @@
 //! interrogate it to get routes for your own payments.
 
 use secp256k1::key::PublicKey;
-use secp256k1::{Secp256k1,Message};
+use secp256k1::Secp256k1;
 use secp256k1;
 
 use bitcoin::util::hash::Sha256dHash;
@@ -15,7 +15,7 @@ use chain::chaininterface::{ChainError, ChainWatchInterface};
 use ln::channelmanager;
 use ln::msgs::{DecodeError,ErrorAction,HandleError,RoutingMessageHandler,NetAddress,GlobalFeatures};
 use ln::msgs;
-use util::ser::{Writeable, Readable};
+use util::ser::{Writeable, Readable, Writer, ReadableArgs};
 use util::logger::Logger;
 
 use std::cmp;
@@ -25,7 +25,7 @@ use std::collections::btree_map::Entry as BtreeEntry;
 use std;
 
 /// A hop in a route
-#[derive(Clone)]
+#[derive(Clone, PartialEq)]
 pub struct RouteHop {
        /// The node_id of the node at this hop.
        pub pubkey: PublicKey,
@@ -39,7 +39,7 @@ pub struct RouteHop {
 }
 
 /// A route from us through the network to a destination
-#[derive(Clone)]
+#[derive(Clone, PartialEq)]
 pub struct Route {
        /// The list of hops, NOT INCLUDING our own, where the last hop is the destination. Thus, this
        /// must always be at least length one. By protocol rules, this may not currently exceed 20 in
@@ -78,6 +78,7 @@ impl<R: ::std::io::Read> Readable<R> for Route {
        }
 }
 
+#[derive(PartialEq)]
 struct DirectionalChannelInfo {
        src_node_id: PublicKey,
        last_update: u32,
@@ -96,6 +97,18 @@ impl std::fmt::Display for DirectionalChannelInfo {
        }
 }
 
+impl_writeable!(DirectionalChannelInfo, 0, {
+       src_node_id,
+       last_update,
+       enabled,
+       cltv_expiry_delta,
+       htlc_minimum_msat,
+       fee_base_msat,
+       fee_proportional_millionths,
+       last_update_message
+});
+
+#[derive(PartialEq)]
 struct ChannelInfo {
        features: GlobalFeatures,
        one_to_two: DirectionalChannelInfo,
@@ -112,6 +125,14 @@ impl std::fmt::Display for ChannelInfo {
        }
 }
 
+impl_writeable!(ChannelInfo, 0, {
+       features,
+       one_to_two,
+       two_to_one,
+       announcement_message
+});
+
+#[derive(PartialEq)]
 struct NodeInfo {
        #[cfg(feature = "non_bitcoin_chain_hash_routing")]
        channels: Vec<(u64, Sha256dHash)>,
@@ -138,6 +159,68 @@ impl std::fmt::Display for NodeInfo {
        }
 }
 
+impl Writeable for NodeInfo {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+               (self.channels.len() as u64).write(writer)?;
+               for ref chan in self.channels.iter() {
+                       chan.write(writer)?;
+               }
+               self.lowest_inbound_channel_fee_base_msat.write(writer)?;
+               self.lowest_inbound_channel_fee_proportional_millionths.write(writer)?;
+               self.features.write(writer)?;
+               self.last_update.write(writer)?;
+               self.rgb.write(writer)?;
+               self.alias.write(writer)?;
+               (self.addresses.len() as u64).write(writer)?;
+               for ref addr in &self.addresses {
+                       addr.write(writer)?;
+               }
+               self.announcement_message.write(writer)?;
+               Ok(())
+       }
+}
+
+const MAX_ALLOC_SIZE: u64 = 64*1024;
+
+impl<R: ::std::io::Read> Readable<R> for NodeInfo {
+       fn read(reader: &mut R) -> Result<NodeInfo, DecodeError> {
+               let channels_count: u64 = Readable::read(reader)?;
+               let mut channels = Vec::with_capacity(cmp::min(channels_count, MAX_ALLOC_SIZE / 8) as usize);
+               for _ in 0..channels_count {
+                       channels.push(Readable::read(reader)?);
+               }
+               let lowest_inbound_channel_fee_base_msat = Readable::read(reader)?;
+               let lowest_inbound_channel_fee_proportional_millionths = Readable::read(reader)?;
+               let features = Readable::read(reader)?;
+               let last_update = Readable::read(reader)?;
+               let rgb = Readable::read(reader)?;
+               let alias = Readable::read(reader)?;
+               let addresses_count: u64 = Readable::read(reader)?;
+               let mut addresses = Vec::with_capacity(cmp::min(addresses_count, MAX_ALLOC_SIZE / 40) as usize);
+               for _ in 0..addresses_count {
+                       match Readable::read(reader) {
+                               Ok(Ok(addr)) => { addresses.push(addr); },
+                               Ok(Err(_)) => return Err(DecodeError::InvalidValue),
+                               Err(DecodeError::ShortRead) => return Err(DecodeError::BadLengthDescriptor),
+                               _ => unreachable!(),
+                       }
+               }
+               let announcement_message = Readable::read(reader)?;
+               Ok(NodeInfo {
+                       channels,
+                       lowest_inbound_channel_fee_base_msat,
+                       lowest_inbound_channel_fee_proportional_millionths,
+                       features,
+                       last_update,
+                       rgb,
+                       alias,
+                       addresses,
+                       announcement_message
+               })
+       }
+}
+
+#[derive(PartialEq)]
 struct NetworkMap {
        #[cfg(feature = "non_bitcoin_chain_hash_routing")]
        channels: BTreeMap<(u64, Sha256dHash), ChannelInfo>,
@@ -147,6 +230,49 @@ struct NetworkMap {
        our_node_id: PublicKey,
        nodes: BTreeMap<PublicKey, NodeInfo>,
 }
+
+impl Writeable for NetworkMap {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+               (self.channels.len() as u64).write(writer)?;
+               for (ref chan_id, ref chan_info) in self.channels.iter() {
+                       (*chan_id).write(writer)?;
+                       chan_info.write(writer)?;
+               }
+               self.our_node_id.write(writer)?;
+               (self.nodes.len() as u64).write(writer)?;
+               for (ref node_id, ref node_info) in self.nodes.iter() {
+                       node_id.write(writer)?;
+                       node_info.write(writer)?;
+               }
+               Ok(())
+       }
+}
+
+impl<R: ::std::io::Read> Readable<R> for NetworkMap {
+       fn read(reader: &mut R) -> Result<NetworkMap, DecodeError> {
+               let channels_count: u64 = Readable::read(reader)?;
+               let mut channels = BTreeMap::new();
+               for _ in 0..channels_count {
+                       let chan_id: u64 = Readable::read(reader)?;
+                       let chan_info = Readable::read(reader)?;
+                       channels.insert(chan_id, chan_info);
+               }
+               let our_node_id = Readable::read(reader)?;
+               let nodes_count: u64 = Readable::read(reader)?;
+               let mut nodes = BTreeMap::new();
+               for _ in 0..nodes_count {
+                       let node_id = Readable::read(reader)?;
+                       let node_info = Readable::read(reader)?;
+                       nodes.insert(node_id, node_info);
+               }
+               Ok(NetworkMap {
+                       channels,
+                       our_node_id,
+                       nodes,
+               })
+       }
+}
+
 struct MutNetworkMap<'a> {
        #[cfg(feature = "non_bitcoin_chain_hash_routing")]
        channels: &'a mut BTreeMap<(u64, Sha256dHash), ChannelInfo>,
@@ -228,6 +354,51 @@ pub struct Router {
        logger: Arc<Logger>,
 }
 
+const SERIALIZATION_VERSION: u8 = 1;
+const MIN_SERIALIZATION_VERSION: u8 = 1;
+
+impl Writeable for Router {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+               writer.write_all(&[SERIALIZATION_VERSION; 1])?;
+               writer.write_all(&[MIN_SERIALIZATION_VERSION; 1])?;
+
+               let network = self.network_map.read().unwrap();
+               network.write(writer)?;
+               Ok(())
+       }
+}
+
+/// Arguments for the creation of a Router that are not deserialized.
+/// At a high-level, the process for deserializing a Router and resuming normal operation is:
+/// 1) Deserialize the Router by filling in this struct and calling <Router>::read(reaser, args).
+/// 2) Register the new Router with your ChainWatchInterface
+pub struct RouterReadArgs {
+       /// The ChainWatchInterface for use in the Router in the future.
+       ///
+       /// No calls to the ChainWatchInterface will be made during deserialization.
+       pub chain_monitor: Arc<ChainWatchInterface>,
+       /// The Logger for use in the ChannelManager and which may be used to log information during
+       /// deserialization.
+       pub logger: Arc<Logger>,
+}
+
+impl<R: ::std::io::Read> ReadableArgs<R, RouterReadArgs> for Router {
+       fn read(reader: &mut R, args: RouterReadArgs) -> Result<Router, DecodeError> {
+               let _ver: u8 = Readable::read(reader)?;
+               let min_ver: u8 = Readable::read(reader)?;
+               if min_ver > SERIALIZATION_VERSION {
+                       return Err(DecodeError::UnknownVersion);
+               }
+               let network_map = Readable::read(reader)?;
+               Ok(Router {
+                       secp_ctx: Secp256k1::verification_only(),
+                       network_map: RwLock::new(network_map),
+                       chain_monitor: args.chain_monitor,
+                       logger: args.logger,
+               })
+       }
+}
+
 macro_rules! secp_verify_sig {
        ( $secp_ctx: expr, $msg: expr, $sig: expr, $pubkey: expr ) => {
                match $secp_ctx.verify($msg, $sig, $pubkey) {
@@ -239,7 +410,7 @@ macro_rules! secp_verify_sig {
 
 impl RoutingMessageHandler for Router {
        fn handle_node_announcement(&self, msg: &msgs::NodeAnnouncement) -> Result<bool, HandleError> {
-               let msg_hash = Message::from_slice(&Sha256dHash::from_data(&msg.contents.encode()[..])[..]).unwrap();
+               let msg_hash = hash_to_message!(&Sha256dHash::from_data(&msg.contents.encode()[..])[..]);
                secp_verify_sig!(self.secp_ctx, &msg_hash, &msg.signature, &msg.contents.node_id);
 
                if msg.contents.features.requires_unknown_bits() {
@@ -272,7 +443,7 @@ impl RoutingMessageHandler for Router {
                        return Err(HandleError{err: "Channel announcement node had a channel with itself", action: Some(ErrorAction::IgnoreError)});
                }
 
-               let msg_hash = Message::from_slice(&Sha256dHash::from_data(&msg.contents.encode()[..])[..]).unwrap();
+               let msg_hash = hash_to_message!(&Sha256dHash::from_data(&msg.contents.encode()[..])[..]);
                secp_verify_sig!(self.secp_ctx, &msg_hash, &msg.node_signature_1, &msg.contents.node_id_1);
                secp_verify_sig!(self.secp_ctx, &msg_hash, &msg.node_signature_2, &msg.contents.node_id_2);
                secp_verify_sig!(self.secp_ctx, &msg_hash, &msg.bitcoin_signature_1, &msg.contents.bitcoin_key_1);
@@ -284,10 +455,11 @@ impl RoutingMessageHandler for Router {
 
                let checked_utxo = match self.chain_monitor.get_chain_utxo(msg.contents.chain_hash, msg.contents.short_channel_id) {
                        Ok((script_pubkey, _value)) => {
-                               let expected_script = Builder::new().push_opcode(opcodes::All::OP_PUSHNUM_2)
+                               let expected_script = Builder::new().push_opcode(opcodes::all::OP_PUSHNUM_2)
                                                                    .push_slice(&msg.contents.bitcoin_key_1.serialize())
                                                                    .push_slice(&msg.contents.bitcoin_key_2.serialize())
-                                                                   .push_opcode(opcodes::All::OP_PUSHNUM_2).push_opcode(opcodes::All::OP_CHECKMULTISIG).into_script().to_v0_p2wsh();
+                                                                   .push_opcode(opcodes::all::OP_PUSHNUM_2)
+                                                                   .push_opcode(opcodes::all::OP_CHECKMULTISIG).into_script().to_v0_p2wsh();
                                if script_pubkey != expected_script {
                                        return Err(HandleError{err: "Channel announcement keys didn't match on-chain script", action: Some(ErrorAction::IgnoreError)});
                                }
@@ -447,7 +619,7 @@ impl RoutingMessageHandler for Router {
                                                };
                                        }
                                }
-                               let msg_hash = Message::from_slice(&Sha256dHash::from_data(&msg.contents.encode()[..])[..]).unwrap();
+                               let msg_hash = hash_to_message!(&Sha256dHash::from_data(&msg.contents.encode()[..])[..]);
                                if msg.contents.flags & 1 == 1 {
                                        dest_node_id = channel.one_to_two.src_node_id.clone();
                                        secp_verify_sig!(self.secp_ctx, &msg_hash, &msg.signature, &channel.two_to_one.src_node_id);
@@ -708,7 +880,7 @@ impl Router {
                        // $directional_info.
                        ( $chan_id: expr, $dest_node_id: expr, $directional_info: expr, $starting_fee_msat: expr ) => {
                                //TODO: Explore simply adding fee to hit htlc_minimum_msat
-                               if $starting_fee_msat as u64 + final_value_msat > $directional_info.htlc_minimum_msat {
+                               if $starting_fee_msat as u64 + final_value_msat >= $directional_info.htlc_minimum_msat {
                                        let proportional_fee_millions = ($starting_fee_msat + final_value_msat).checked_mul($directional_info.fee_proportional_millionths as u64);
                                        if let Some(new_fee) = proportional_fee_millions.and_then(|part| {
                                                        ($directional_info.fee_base_msat as u64).checked_add(part / 1000000) })
@@ -844,7 +1016,9 @@ mod tests {
        use ln::router::{Router,NodeInfo,NetworkMap,ChannelInfo,DirectionalChannelInfo,RouteHint};
        use ln::msgs::GlobalFeatures;
        use util::test_utils;
+       use util::test_utils::TestVecWriter;
        use util::logger::Logger;
+       use util::ser::{Writeable, Readable};
 
        use bitcoin::util::hash::Sha256dHash;
        use bitcoin::network::constants::Network;
@@ -859,7 +1033,7 @@ mod tests {
        #[test]
        fn route_test() {
                let secp_ctx = Secp256k1::new();
-               let our_id = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&secp_ctx, &hex::decode("0101010101010101010101010101010101010101010101010101010101010101").unwrap()[..]).unwrap());
+               let our_id = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0101010101010101010101010101010101010101010101010101010101010101").unwrap()[..]).unwrap());
                let logger: Arc<Logger> = Arc::new(test_utils::TestLogger::new());
                let chain_monitor = Arc::new(chaininterface::ChainWatchInterfaceUtil::new(Network::Testnet, Arc::clone(&logger)));
                let router = Router::new(our_id, chain_monitor, Arc::clone(&logger));
@@ -921,14 +1095,14 @@ mod tests {
                // chan11 1-to-2: enabled, 0 fee
                // chan11 2-to-1: enabled, 0 fee
 
-               let node1 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&secp_ctx, &hex::decode("0202020202020202020202020202020202020202020202020202020202020202").unwrap()[..]).unwrap());
-               let node2 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&secp_ctx, &hex::decode("0303030303030303030303030303030303030303030303030303030303030303").unwrap()[..]).unwrap());
-               let node3 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&secp_ctx, &hex::decode("0404040404040404040404040404040404040404040404040404040404040404").unwrap()[..]).unwrap());
-               let node4 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&secp_ctx, &hex::decode("0505050505050505050505050505050505050505050505050505050505050505").unwrap()[..]).unwrap());
-               let node5 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&secp_ctx, &hex::decode("0606060606060606060606060606060606060606060606060606060606060606").unwrap()[..]).unwrap());
-               let node6 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&secp_ctx, &hex::decode("0707070707070707070707070707070707070707070707070707070707070707").unwrap()[..]).unwrap());
-               let node7 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&secp_ctx, &hex::decode("0808080808080808080808080808080808080808080808080808080808080808").unwrap()[..]).unwrap());
-               let node8 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&secp_ctx, &hex::decode("0909090909090909090909090909090909090909090909090909090909090909").unwrap()[..]).unwrap());
+               let node1 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0202020202020202020202020202020202020202020202020202020202020202").unwrap()[..]).unwrap());
+               let node2 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0303030303030303030303030303030303030303030303030303030303030303").unwrap()[..]).unwrap());
+               let node3 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0404040404040404040404040404040404040404040404040404040404040404").unwrap()[..]).unwrap());
+               let node4 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0505050505050505050505050505050505050505050505050505050505050505").unwrap()[..]).unwrap());
+               let node5 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0606060606060606060606060606060606060606060606060606060606060606").unwrap()[..]).unwrap());
+               let node6 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0707070707070707070707070707070707070707070707070707070707070707").unwrap()[..]).unwrap());
+               let node7 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0808080808080808080808080808080808080808080808080808080808080808").unwrap()[..]).unwrap());
+               let node8 = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0909090909090909090909090909090909090909090909090909090909090909").unwrap()[..]).unwrap());
 
                let zero_hash = Sha256dHash::from_data(&[0; 32]);
 
@@ -1438,5 +1612,14 @@ mod tests {
                        assert_eq!(route.hops[4].fee_msat, 2000);
                        assert_eq!(route.hops[4].cltv_expiry_delta, 42);
                }
+
+               { // Test Router serialization/deserialization
+                       let mut w = TestVecWriter(Vec::new());
+                       let network = router.network_map.read().unwrap();
+                       assert!(!network.channels.is_empty());
+                       assert!(!network.nodes.is_empty());
+                       network.write(&mut w).unwrap();
+                       assert!(<NetworkMap>::read(&mut ::std::io::Cursor::new(&w.0)).unwrap() == *network);
+               }
        }
 }