Parameterize NetGraphMsgHandler with NetworkGraph
[rust-lightning] / lightning / src / routing / network_graph.rs
index 364c9fd8865154e020e9403f897e0c6546354eb9..93701541b3d96ee7169375aed9436b5524c570c9 100644 (file)
@@ -186,7 +186,7 @@ impl_writeable_tlv_based_enum_upgradable!(NetworkUpdate,
        },
 );
 
-impl<C: Deref, L: Deref> EventHandler for NetGraphMsgHandler<C, L>
+impl<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref> EventHandler for NetGraphMsgHandler<G, C, L>
 where C::Target: chain::Access, L::Target: Logger {
        fn handle_event(&self, event: &Event) {
                if let Event::PaymentPathFailed { payment_hash: _, rejected_by_dest: _, network_update, .. } = event {
@@ -205,19 +205,19 @@ where C::Target: chain::Access, L::Target: Logger {
 ///
 /// Serves as an [`EventHandler`] for applying updates from [`Event::PaymentPathFailed`] to the
 /// [`NetworkGraph`].
-pub struct NetGraphMsgHandler<C: Deref, L: Deref>
+pub struct NetGraphMsgHandler<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref>
 where C::Target: chain::Access, L::Target: Logger
 {
        secp_ctx: Secp256k1<secp256k1::VerifyOnly>,
        /// Representation of the payment channel network
-       pub network_graph: NetworkGraph,
+       pub network_graph: G,
        chain_access: Option<C>,
        full_syncs_requested: AtomicUsize,
        pending_events: Mutex<Vec<MessageSendEvent>>,
        logger: L,
 }
 
-impl<C: Deref, L: Deref> NetGraphMsgHandler<C, L>
+impl<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref> NetGraphMsgHandler<G, C, L>
 where C::Target: chain::Access, L::Target: Logger
 {
        /// Creates a new tracker of the actual state of the network of channels and nodes,
@@ -225,7 +225,7 @@ where C::Target: chain::Access, L::Target: Logger
        /// Chain monitor is used to make sure announced channels exist on-chain,
        /// channel data is correct, and that the announcement is signed with
        /// channel owners' keys.
-       pub fn new(network_graph: NetworkGraph, chain_access: Option<C>, logger: L) -> Self {
+       pub fn new(network_graph: G, chain_access: Option<C>, logger: L) -> Self {
                NetGraphMsgHandler {
                        secp_ctx: Secp256k1::verification_only(),
                        network_graph,
@@ -288,7 +288,7 @@ macro_rules! secp_verify_sig {
        };
 }
 
-impl<C: Deref, L: Deref> RoutingMessageHandler for NetGraphMsgHandler<C, L>
+impl<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref> RoutingMessageHandler for NetGraphMsgHandler<G, C, L>
 where C::Target: chain::Access, L::Target: Logger
 {
        fn handle_node_announcement(&self, msg: &msgs::NodeAnnouncement) -> Result<bool, LightningError> {
@@ -554,7 +554,7 @@ where C::Target: chain::Access, L::Target: Logger
        }
 }
 
-impl<C: Deref, L: Deref> MessageSendEventsProvider for NetGraphMsgHandler<C, L>
+impl<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref> MessageSendEventsProvider for NetGraphMsgHandler<G, C, L>
 where
        C::Target: chain::Access,
        L::Target: Logger,
@@ -1255,18 +1255,24 @@ mod tests {
        use prelude::*;
        use sync::Arc;
 
-       fn create_net_graph_msg_handler() -> (Secp256k1<All>, NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>) {
+       fn create_network_graph() -> NetworkGraph {
+               let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
+               NetworkGraph::new(genesis_hash)
+       }
+
+       fn create_net_graph_msg_handler(network_graph: &NetworkGraph) -> (
+               Secp256k1<All>, NetGraphMsgHandler<&NetworkGraph, Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>
+       ) {
                let secp_ctx = Secp256k1::new();
                let logger = Arc::new(test_utils::TestLogger::new());
-               let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
-               let network_graph = NetworkGraph::new(genesis_hash);
                let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger));
                (secp_ctx, net_graph_msg_handler)
        }
 
        #[test]
        fn request_full_sync_finite_times() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_id = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0202020202020202020202020202020202020202020202020202020202020202").unwrap()[..]).unwrap());
 
                assert!(net_graph_msg_handler.should_request_full_sync(&node_id));
@@ -1279,7 +1285,8 @@ mod tests {
 
        #[test]
        fn handling_node_announcements() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
 
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
@@ -1422,15 +1429,14 @@ mod tests {
 
                // Test if the UTXO lookups were not supported
                let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               let mut net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger));
+               let mut net_graph_msg_handler = NetGraphMsgHandler::new(&network_graph, None, Arc::clone(&logger));
                match net_graph_msg_handler.handle_channel_announcement(&valid_announcement) {
                        Ok(res) => assert!(res),
                        _ => panic!()
                };
 
                {
-                       let network = &net_graph_msg_handler.network_graph;
-                       match network.read_only().channels().get(&unsigned_announcement.short_channel_id) {
+                       match network_graph.read_only().channels().get(&unsigned_announcement.short_channel_id) {
                                None => panic!(),
                                Some(_) => ()
                        };
@@ -1447,7 +1453,7 @@ mod tests {
                let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
                *chain_source.utxo_ret.lock().unwrap() = Err(chain::AccessError::UnknownTx);
                let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, Some(chain_source.clone()), Arc::clone(&logger));
+               net_graph_msg_handler = NetGraphMsgHandler::new(&network_graph, Some(chain_source.clone()), Arc::clone(&logger));
                unsigned_announcement.short_channel_id += 1;
 
                msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_announcement.encode()[..])[..]);
@@ -1482,8 +1488,7 @@ mod tests {
                };
 
                {
-                       let network = &net_graph_msg_handler.network_graph;
-                       match network.read_only().channels().get(&unsigned_announcement.short_channel_id) {
+                       match network_graph.read_only().channels().get(&unsigned_announcement.short_channel_id) {
                                None => panic!(),
                                Some(_) => ()
                        };
@@ -1513,8 +1518,7 @@ mod tests {
                        _ => panic!()
                };
                {
-                       let network = &net_graph_msg_handler.network_graph;
-                       match network.read_only().channels().get(&unsigned_announcement.short_channel_id) {
+                       match network_graph.read_only().channels().get(&unsigned_announcement.short_channel_id) {
                                Some(channel_entry) => {
                                        assert_eq!(channel_entry.features, ChannelFeatures::empty());
                                },
@@ -1572,7 +1576,7 @@ mod tests {
                let logger: Arc<Logger> = Arc::new(test_utils::TestLogger::new());
                let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
                let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, Some(chain_source.clone()), Arc::clone(&logger));
+               let net_graph_msg_handler = NetGraphMsgHandler::new(&network_graph, Some(chain_source.clone()), Arc::clone(&logger));
 
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
@@ -1644,8 +1648,7 @@ mod tests {
                };
 
                {
-                       let network = &net_graph_msg_handler.network_graph;
-                       match network.read_only().channels().get(&short_channel_id) {
+                       match network_graph.read_only().channels().get(&short_channel_id) {
                                None => panic!(),
                                Some(channel_info) => {
                                        assert_eq!(channel_info.one_to_two.as_ref().unwrap().cltv_expiry_delta, 144);
@@ -1741,7 +1744,7 @@ mod tests {
                let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
                let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
                let network_graph = NetworkGraph::new(genesis_hash);
-               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, Some(chain_source.clone()), &logger);
+               let net_graph_msg_handler = NetGraphMsgHandler::new(&network_graph, Some(chain_source.clone()), &logger);
                let secp_ctx = Secp256k1::new();
 
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
@@ -1753,7 +1756,6 @@ mod tests {
 
                let short_channel_id = 0;
                let chain_hash = genesis_block(Network::Testnet).header.block_hash();
-               let network_graph = &net_graph_msg_handler.network_graph;
 
                {
                        // There is no nodes in the table at the beginning.
@@ -1883,7 +1885,8 @@ mod tests {
 
        #[test]
        fn getting_next_channel_announcements() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_1_privkey);
@@ -2017,7 +2020,8 @@ mod tests {
 
        #[test]
        fn getting_next_node_announcements() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_1_privkey);
@@ -2134,7 +2138,8 @@ mod tests {
 
        #[test]
        fn network_graph_serialization() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
 
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
@@ -2191,17 +2196,17 @@ mod tests {
                        Err(_) => panic!()
                };
 
-               let network = &net_graph_msg_handler.network_graph;
                let mut w = test_utils::TestVecWriter(Vec::new());
-               assert!(!network.read_only().nodes().is_empty());
-               assert!(!network.read_only().channels().is_empty());
-               network.write(&mut w).unwrap();
-               assert!(<NetworkGraph>::read(&mut io::Cursor::new(&w.0)).unwrap() == *network);
+               assert!(!network_graph.read_only().nodes().is_empty());
+               assert!(!network_graph.read_only().channels().is_empty());
+               network_graph.write(&mut w).unwrap();
+               assert!(<NetworkGraph>::read(&mut io::Cursor::new(&w.0)).unwrap() == network_graph);
        }
 
        #[test]
        fn calling_sync_routing_table() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_privkey_1 = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_privkey_1);
 
@@ -2238,7 +2243,8 @@ mod tests {
                // The initial implementation allows syncing with the first 5 peers after
                // which should_request_full_sync will return false
                {
-                       let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+                       let network_graph = create_network_graph();
+                       let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                        let init_msg = Init { features: InitFeatures::known() };
                        for n in 1..7 {
                                let node_privkey = &SecretKey::from_slice(&[n; 32]).unwrap();
@@ -2257,7 +2263,8 @@ mod tests {
 
        #[test]
        fn handling_reply_channel_range() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_privkey_1 = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_privkey_1);
 
@@ -2305,7 +2312,8 @@ mod tests {
 
        #[test]
        fn handling_reply_short_channel_ids() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
                let node_id = PublicKey::from_secret_key(&secp_ctx, node_privkey);
 
@@ -2334,7 +2342,8 @@ mod tests {
 
        #[test]
        fn handling_query_channel_range() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
 
                let chain_hash = genesis_block(Network::Testnet).header.block_hash();
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
@@ -2596,7 +2605,7 @@ mod tests {
        }
 
        fn do_handling_query_channel_range(
-               net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
+               net_graph_msg_handler: &NetGraphMsgHandler<&NetworkGraph, Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
                test_node_id: &PublicKey,
                msg: QueryChannelRange,
                expected_ok: bool,
@@ -2645,7 +2654,8 @@ mod tests {
 
        #[test]
        fn handling_query_short_channel_ids() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
                let node_id = PublicKey::from_secret_key(&secp_ctx, node_privkey);