ChannelManager+Router++ Logger Arc --> Deref
[rust-lightning] / lightning / src / routing / network_graph.rs
index dd6d7d7ed734aadbd7f89f93308aefd0223a9f79..3168a1fd8503be206063bf23fa9abbd177180699 100644 (file)
@@ -17,33 +17,34 @@ use util::ser::{Writeable, Readable, Writer};
 use util::logger::Logger;
 
 use std::cmp;
-use std::sync::{RwLock,Arc};
+use std::sync::RwLock;
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::collections::BTreeMap;
 use std::collections::btree_map::Entry as BtreeEntry;
 use std;
+use std::ops::Deref;
 
 /// Receives and validates network updates from peers,
 /// stores authentic and relevant data as a network graph.
 /// This network graph is then used for routing payments.
 /// Provides interface to help with initial routing sync by
 /// serving historical announcements.
-pub struct NetGraphMsgHandler {
+pub struct NetGraphMsgHandler<C: Deref, L: Deref> where C::Target: ChainWatchInterface, L::Target: Logger {
        secp_ctx: Secp256k1<secp256k1::VerifyOnly>,
        /// Representation of the payment channel network
        pub network_graph: RwLock<NetworkGraph>,
-       chain_monitor: Arc<ChainWatchInterface>,
+       chain_monitor: C,
        full_syncs_requested: AtomicUsize,
-       logger: Arc<Logger>,
+       logger: L,
 }
 
-impl NetGraphMsgHandler {
+impl<C: Deref, L: Deref> NetGraphMsgHandler<C, L> where C::Target: ChainWatchInterface, L::Target: Logger {
        /// Creates a new tracker of the actual state of the network of channels and nodes,
        /// assuming a fresh network graph.
        /// Chain monitor is used to make sure announced channels exist on-chain,
        /// channel data is correct, and that the announcement is signed with
        /// channel owners' keys.
-       pub fn new(chain_monitor: Arc<ChainWatchInterface>, logger: Arc<Logger>) -> Self {
+       pub fn new(chain_monitor: C, logger: L) -> Self {
                NetGraphMsgHandler {
                        secp_ctx: Secp256k1::verification_only(),
                        network_graph: RwLock::new(NetworkGraph {
@@ -52,19 +53,19 @@ impl NetGraphMsgHandler {
                        }),
                        full_syncs_requested: AtomicUsize::new(0),
                        chain_monitor,
-                       logger: logger.clone(),
+                       logger,
                }
        }
 
        /// Creates a new tracker of the actual state of the network of channels and nodes,
        /// assuming an existing Network Graph.
-       pub fn from_net_graph(chain_monitor: Arc<ChainWatchInterface>, logger: Arc<Logger>, network_graph: RwLock<NetworkGraph>) -> Self {
+       pub fn from_net_graph(chain_monitor: C, logger: L, network_graph: RwLock<NetworkGraph>) -> Self {
                NetGraphMsgHandler {
                        secp_ctx: Secp256k1::verification_only(),
-                       network_graph: network_graph,
+                       network_graph,
                        full_syncs_requested: AtomicUsize::new(0),
                        chain_monitor,
-                       logger: logger.clone(),
+                       logger,
                }
        }
 }
@@ -79,7 +80,7 @@ macro_rules! secp_verify_sig {
        };
 }
 
-impl RoutingMessageHandler for NetGraphMsgHandler {
+impl<C: Deref + Sync + Send, L: Deref + Sync + Send> RoutingMessageHandler for NetGraphMsgHandler<C, L> where C::Target: ChainWatchInterface, L::Target: Logger {
        fn handle_node_announcement(&self, msg: &msgs::NodeAnnouncement) -> Result<bool, LightningError> {
                self.network_graph.write().unwrap().update_node_from_announcement(msg, Some(&self.secp_ctx))
        }
@@ -115,7 +116,7 @@ impl RoutingMessageHandler for NetGraphMsgHandler {
                        },
                };
                let result = self.network_graph.write().unwrap().update_channel_from_announcement(msg, checked_utxo, Some(&self.secp_ctx));
-               log_trace!(self, "Added channel_announcement for {}{}", msg.contents.short_channel_id, if !msg.contents.excess_data.is_empty() { " with excess uninterpreted data!" } else { "" });
+               log_trace!(self.logger, "Added channel_announcement for {}{}", msg.contents.short_channel_id, if !msg.contents.excess_data.is_empty() { " with excess uninterpreted data!" } else { "" });
                result
        }
 
@@ -217,6 +218,9 @@ pub struct DirectionalChannelInfo {
        /// Fees charged when the channel is used for routing
        pub fees: RoutingFees,
        /// Most recent update for the channel received from the network
+       /// Mostly redundant with the data we store in fields explicitly.
+       /// Everything else is useful only for sending out for initial routing sync.
+       /// Not stored if contains excess data to prevent DoS.
        pub last_update_message: Option<msgs::ChannelUpdate>,
 }
 
@@ -308,10 +312,10 @@ impl Writeable for RoutingFees {
 /// Information received in the latest node_announcement from this node.
 pub struct NodeAnnouncementInfo {
        /// Protocol features the node announced support for
-      pub features: NodeFeatures,
+       pub features: NodeFeatures,
        /// When the last known update to the node state was issued.
        /// Value is opaque, as set in the announcement.
-      pub last_update: u32,
+       pub last_update: u32,
        /// Color assigned to the node
        pub rgb: [u8; 3],
        /// Moniker assigned to the node.
@@ -375,8 +379,8 @@ impl Readable for NodeAnnouncementInfo {
 pub struct NodeInfo {
        /// All valid channels a node has announced
        pub channels: Vec<u64>,
-       /// Lowest fees enabling routing via any of the known channels to a node.
-       /// The two fields (flat and proportional fee) are independent,
+       /// Lowest fees enabling routing via any of the enabled, known channels to a node.
+       /// The two fields (flat and proportional fee) are independent,
        /// meaning they don't have to refer to the same channel.
        pub lowest_inbound_channel_fees: Option<RoutingFees>,
        /// More information about a node from node_announcement.
@@ -709,34 +713,28 @@ impl NetworkGraph {
                                proportional_millionths
                        });
                } else if chan_was_enabled {
-                       let mut lowest_inbound_channel_fee_base_msat = u32::max_value();
-                       let mut lowest_inbound_channel_fee_proportional_millionths = u32::max_value();
-
-                       {
-                               let node = self.nodes.get(&dest_node_id).unwrap();
-
-                               for chan_id in node.channels.iter() {
-                                       let chan = self.channels.get(chan_id).unwrap();
-                                       // Since direction was enabled, the channel indeed had directional info
-                                       let chan_info;
-                                       if chan.node_one == dest_node_id {
-                                               chan_info = chan.two_to_one.as_ref().unwrap();
-                                       } else {
-                                               chan_info = chan.one_to_two.as_ref().unwrap();
+                       let node = self.nodes.get_mut(&dest_node_id).unwrap();
+                       let mut lowest_inbound_channel_fees = None;
+
+                       for chan_id in node.channels.iter() {
+                               let chan = self.channels.get(chan_id).unwrap();
+                               let chan_info_opt;
+                               if chan.node_one == dest_node_id {
+                                       chan_info_opt = chan.two_to_one.as_ref();
+                               } else {
+                                       chan_info_opt = chan.one_to_two.as_ref();
+                               }
+                               if let Some(chan_info) = chan_info_opt {
+                                       if chan_info.enabled {
+                                               let fees = lowest_inbound_channel_fees.get_or_insert(RoutingFees {
+                                                       base_msat: u32::max_value(), proportional_millionths: u32::max_value() });
+                                               fees.base_msat = cmp::min(fees.base_msat, chan_info.fees.base_msat);
+                                               fees.proportional_millionths = cmp::min(fees.proportional_millionths, chan_info.fees.proportional_millionths);
                                        }
-                                       lowest_inbound_channel_fee_base_msat = cmp::min(lowest_inbound_channel_fee_base_msat, chan_info.fees.base_msat);
-                                       lowest_inbound_channel_fee_proportional_millionths = cmp::min(lowest_inbound_channel_fee_proportional_millionths, chan_info.fees.proportional_millionths);
                                }
                        }
 
-                       //TODO: satisfy the borrow-checker without a double-map-lookup :(
-                       let mut_node = self.nodes.get_mut(&dest_node_id).unwrap();
-                       if mut_node.channels.len() > 0 {
-                               mut_node.lowest_inbound_channel_fees = Some(RoutingFees {
-                                       base_msat: lowest_inbound_channel_fee_base_msat,
-                                       proportional_millionths: lowest_inbound_channel_fee_proportional_millionths
-                               });
-                       }
+                       node.lowest_inbound_channel_fees = lowest_inbound_channel_fees;
                }
 
                Ok(msg.contents.excess_data.is_empty())
@@ -769,7 +767,7 @@ mod tests {
        use ln::features::{ChannelFeatures, NodeFeatures};
        use routing::network_graph::{NetGraphMsgHandler, NetworkGraph};
        use ln::msgs::{RoutingMessageHandler, UnsignedNodeAnnouncement, NodeAnnouncement,
-          UnsignedChannelAnnouncement, ChannelAnnouncement, UnsignedChannelUpdate, ChannelUpdate, HTLCFailChannelUpdate};
+               UnsignedChannelAnnouncement, ChannelAnnouncement, UnsignedChannelUpdate, ChannelUpdate, HTLCFailChannelUpdate};
        use util::test_utils;
        use util::logger::Logger;
        use util::ser::{Readable, Writeable};
@@ -789,10 +787,10 @@ mod tests {
 
        use std::sync::Arc;
 
-       fn create_net_graph_msg_handler() -> (Secp256k1<All>, NetGraphMsgHandler) {
+       fn create_net_graph_msg_handler() -> (Secp256k1<All>, NetGraphMsgHandler<Arc<chaininterface::ChainWatchInterfaceUtil>, Arc<test_utils::TestLogger>>) {
                let secp_ctx = Secp256k1::new();
-               let logger: Arc<Logger> = Arc::new(test_utils::TestLogger::new());
-               let chain_monitor = Arc::new(chaininterface::ChainWatchInterfaceUtil::new(Network::Testnet, Arc::clone(&logger)));
+               let logger = Arc::new(test_utils::TestLogger::new());
+               let chain_monitor = Arc::new(chaininterface::ChainWatchInterfaceUtil::new(Network::Testnet));
                let net_graph_msg_handler = NetGraphMsgHandler::new(chain_monitor, Arc::clone(&logger));
                (secp_ctx, net_graph_msg_handler)
        }
@@ -848,7 +846,7 @@ mod tests {
                        // Announce a channel to add a corresponding node.
                        let unsigned_announcement = UnsignedChannelAnnouncement {
                                features: ChannelFeatures::known(),
-                               chain_hash: genesis_block(Network::Testnet).header.bitcoin_hash(),
+                               chain_hash: genesis_block(Network::Testnet).header.bitcoin_hash(),
                                short_channel_id: 0,
                                node_id_1,
                                node_id_2,