Individually lock NetworkGraph fields
[rust-lightning] / lightning / src / routing / network_graph.rs
index 8accdab60882321dea801497643478b702a21b47..3b593351fac56832abc1fde3e7271fe7a1252938 100644 (file)
@@ -51,11 +51,11 @@ const MAX_EXCESS_BYTES_FOR_RELAY: usize = 1024;
 const MAX_SCIDS_PER_REPLY: usize = 8000;
 
 /// Represents the network as nodes and channels between them
-#[derive(Clone, PartialEq)]
 pub struct NetworkGraph {
        genesis_hash: BlockHash,
-       channels: BTreeMap<u64, ChannelInfo>,
-       nodes: BTreeMap<PublicKey, NodeInfo>,
+       // Lock order: channels -> nodes
+       channels: RwLock<BTreeMap<u64, ChannelInfo>>,
+       nodes: RwLock<BTreeMap<PublicKey, NodeInfo>>,
 }
 
 /// A simple newtype for RwLockReadGuard<'a, NetworkGraph>.
@@ -193,7 +193,8 @@ impl<C: Deref , L: Deref > RoutingMessageHandler for NetGraphMsgHandler<C, L> wh
        fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> {
                let network_graph = self.network_graph.read().unwrap();
                let mut result = Vec::with_capacity(batch_amount as usize);
-               let mut iter = network_graph.get_channels().range(starting_point..);
+               let channels = network_graph.get_channels();
+               let mut iter = channels.range(starting_point..);
                while result.len() < batch_amount as usize {
                        if let Some((_, ref chan)) = iter.next() {
                                if chan.announcement_message.is_some() {
@@ -221,12 +222,13 @@ impl<C: Deref , L: Deref > RoutingMessageHandler for NetGraphMsgHandler<C, L> wh
        fn get_next_node_announcements(&self, starting_point: Option<&PublicKey>, batch_amount: u8) -> Vec<NodeAnnouncement> {
                let network_graph = self.network_graph.read().unwrap();
                let mut result = Vec::with_capacity(batch_amount as usize);
+               let nodes = network_graph.get_nodes();
                let mut iter = if let Some(pubkey) = starting_point {
-                               let mut iter = network_graph.get_nodes().range((*pubkey)..);
+                               let mut iter = nodes.range((*pubkey)..);
                                iter.next();
                                iter
                        } else {
-                               network_graph.get_nodes().range(..)
+                               nodes.range(..)
                        };
                while result.len() < batch_amount as usize {
                        if let Some((_, ref node)) = iter.next() {
@@ -616,13 +618,15 @@ impl Writeable for NetworkGraph {
                write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
 
                self.genesis_hash.write(writer)?;
-               (self.channels.len() as u64).write(writer)?;
-               for (ref chan_id, ref chan_info) in self.channels.iter() {
+               let channels = self.channels.read().unwrap();
+               (channels.len() as u64).write(writer)?;
+               for (ref chan_id, ref chan_info) in channels.iter() {
                        (*chan_id).write(writer)?;
                        chan_info.write(writer)?;
                }
-               (self.nodes.len() as u64).write(writer)?;
-               for (ref node_id, ref node_info) in self.nodes.iter() {
+               let nodes = self.nodes.read().unwrap();
+               (nodes.len() as u64).write(writer)?;
+               for (ref node_id, ref node_info) in nodes.iter() {
                        node_id.write(writer)?;
                        node_info.write(writer)?;
                }
@@ -655,8 +659,8 @@ impl Readable for NetworkGraph {
 
                Ok(NetworkGraph {
                        genesis_hash,
-                       channels,
-                       nodes,
+                       channels: RwLock::new(channels),
+                       nodes: RwLock::new(nodes),
                })
        }
 }
@@ -664,36 +668,49 @@ impl Readable for NetworkGraph {
 impl fmt::Display for NetworkGraph {
        fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
                writeln!(f, "Network map\n[Channels]")?;
-               for (key, val) in self.channels.iter() {
+               for (key, val) in self.channels.read().unwrap().iter() {
                        writeln!(f, " {}: {}", key, val)?;
                }
                writeln!(f, "[Nodes]")?;
-               for (key, val) in self.nodes.iter() {
+               for (key, val) in self.nodes.read().unwrap().iter() {
                        writeln!(f, " {}: {}", log_pubkey!(key), val)?;
                }
                Ok(())
        }
 }
 
+impl PartialEq for NetworkGraph {
+       fn eq(&self, other: &Self) -> bool {
+               self.genesis_hash == other.genesis_hash &&
+                       *self.channels.read().unwrap() == *other.channels.read().unwrap() &&
+                       *self.nodes.read().unwrap() == *other.nodes.read().unwrap()
+       }
+}
+
 impl NetworkGraph {
        /// Returns all known valid channels' short ids along with announced channel info.
        ///
        /// (C-not exported) because we have no mapping for `BTreeMap`s
-       pub fn get_channels<'a>(&'a self) -> &'a BTreeMap<u64, ChannelInfo> { &self.channels }
+       pub fn get_channels(&self) -> RwLockReadGuard<'_, BTreeMap<u64, ChannelInfo>> {
+               self.channels.read().unwrap()
+       }
+
        /// Returns all known nodes' public keys along with announced node info.
        ///
        /// (C-not exported) because we have no mapping for `BTreeMap`s
-       pub fn get_nodes<'a>(&'a self) -> &'a BTreeMap<PublicKey, NodeInfo> { &self.nodes }
+       pub fn get_nodes(&self) -> RwLockReadGuard<'_, BTreeMap<PublicKey, NodeInfo>> {
+               self.nodes.read().unwrap()
+       }
 
        /// Get network addresses by node id.
        /// Returns None if the requested node is completely unknown,
        /// or if node announcement for the node was never received.
        ///
        /// (C-not exported) as there is no practical way to track lifetimes of returned values.
-       pub fn get_addresses<'a>(&'a self, pubkey: &PublicKey) -> Option<&'a Vec<NetAddress>> {
-               if let Some(node) = self.nodes.get(pubkey) {
+       pub fn get_addresses(&self, pubkey: &PublicKey) -> Option<Vec<NetAddress>> {
+               if let Some(node) = self.nodes.read().unwrap().get(pubkey) {
                        if let Some(node_info) = node.announcement_info.as_ref() {
-                               return Some(&node_info.addresses)
+                               return Some(node_info.addresses.clone())
                        }
                }
                None
@@ -703,8 +720,8 @@ impl NetworkGraph {
        pub fn new(genesis_hash: BlockHash) -> NetworkGraph {
                Self {
                        genesis_hash,
-                       channels: BTreeMap::new(),
-                       nodes: BTreeMap::new(),
+                       channels: RwLock::new(BTreeMap::new()),
+                       nodes: RwLock::new(BTreeMap::new()),
                }
        }
 
@@ -729,7 +746,7 @@ impl NetworkGraph {
        }
 
        fn update_node_from_announcement_intern(&mut self, msg: &msgs::UnsignedNodeAnnouncement, full_msg: Option<&msgs::NodeAnnouncement>) -> Result<(), LightningError> {
-               match self.nodes.get_mut(&msg.node_id) {
+               match self.nodes.write().unwrap().get_mut(&msg.node_id) {
                        None => Err(LightningError{err: "No existing channels for node_announcement".to_owned(), action: ErrorAction::IgnoreError}),
                        Some(node) => {
                                if let Some(node_info) = node.announcement_info.as_ref() {
@@ -838,7 +855,9 @@ impl NetworkGraph {
                                        { full_msg.cloned() } else { None },
                        };
 
-               match self.channels.entry(msg.short_channel_id) {
+               let mut channels = self.channels.write().unwrap();
+               let mut nodes = self.nodes.write().unwrap();
+               match channels.entry(msg.short_channel_id) {
                        BtreeEntry::Occupied(mut entry) => {
                                //TODO: because asking the blockchain if short_channel_id is valid is only optional
                                //in the blockchain API, we need to handle it smartly here, though it's unclear
@@ -852,7 +871,7 @@ impl NetworkGraph {
                                        // b) we don't track UTXOs of channels we know about and remove them if they
                                        //    get reorg'd out.
                                        // c) it's unclear how to do so without exposing ourselves to massive DoS risk.
-                                       Self::remove_channel_in_nodes(&mut self.nodes, &entry.get(), msg.short_channel_id);
+                                       Self::remove_channel_in_nodes(&mut nodes, &entry.get(), msg.short_channel_id);
                                        *entry.get_mut() = chan_info;
                                } else {
                                        return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreAndLog(Level::Trace)})
@@ -865,7 +884,7 @@ impl NetworkGraph {
 
                macro_rules! add_channel_to_node {
                        ( $node_id: expr ) => {
-                               match self.nodes.entry($node_id) {
+                               match nodes.entry($node_id) {
                                        BtreeEntry::Occupied(node_entry) => {
                                                node_entry.into_mut().channels.push(msg.short_channel_id);
                                        },
@@ -891,12 +910,14 @@ impl NetworkGraph {
        /// May cause the removal of nodes too, if this was their last channel.
        /// If not permanent, makes channels unavailable for routing.
        pub fn close_channel_from_update(&mut self, short_channel_id: u64, is_permanent: bool) {
+               let mut channels = self.channels.write().unwrap();
                if is_permanent {
-                       if let Some(chan) = self.channels.remove(&short_channel_id) {
-                               Self::remove_channel_in_nodes(&mut self.nodes, &chan, short_channel_id);
+                       if let Some(chan) = channels.remove(&short_channel_id) {
+                               let mut nodes = self.nodes.write().unwrap();
+                               Self::remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
                        }
                } else {
-                       if let Some(chan) = self.channels.get_mut(&short_channel_id) {
+                       if let Some(chan) = channels.get_mut(&short_channel_id) {
                                if let Some(one_to_two) = chan.one_to_two.as_mut() {
                                        one_to_two.enabled = false;
                                }
@@ -937,7 +958,8 @@ impl NetworkGraph {
                let chan_enabled = msg.flags & (1 << 1) != (1 << 1);
                let chan_was_enabled;
 
-               match self.channels.get_mut(&msg.short_channel_id) {
+               let mut channels = self.channels.write().unwrap();
+               match channels.get_mut(&msg.short_channel_id) {
                        None => return Err(LightningError{err: "Couldn't find channel for update".to_owned(), action: ErrorAction::IgnoreError}),
                        Some(channel) => {
                                if let OptionalField::Present(htlc_maximum_msat) = msg.htlc_maximum_msat {
@@ -1000,8 +1022,9 @@ impl NetworkGraph {
                        }
                }
 
+               let mut nodes = self.nodes.write().unwrap();
                if chan_enabled {
-                       let node = self.nodes.get_mut(&dest_node_id).unwrap();
+                       let node = nodes.get_mut(&dest_node_id).unwrap();
                        let mut base_msat = msg.fee_base_msat;
                        let mut proportional_millionths = msg.fee_proportional_millionths;
                        if let Some(fees) = node.lowest_inbound_channel_fees {
@@ -1013,11 +1036,11 @@ impl NetworkGraph {
                                proportional_millionths
                        });
                } else if chan_was_enabled {
-                       let node = self.nodes.get_mut(&dest_node_id).unwrap();
+                       let node = nodes.get_mut(&dest_node_id).unwrap();
                        let mut lowest_inbound_channel_fees = None;
 
                        for chan_id in node.channels.iter() {
-                               let chan = self.channels.get(chan_id).unwrap();
+                               let chan = channels.get(chan_id).unwrap();
                                let chan_info_opt;
                                if chan.node_one == dest_node_id {
                                        chan_info_opt = chan.two_to_one.as_ref();
@@ -1268,7 +1291,7 @@ mod tests {
                        match network.get_channels().get(&unsigned_announcement.short_channel_id) {
                                None => panic!(),
                                Some(_) => ()
-                       }
+                       };
                }
 
                // If we receive announcement for the same channel (with UTXO lookups disabled),
@@ -1320,7 +1343,7 @@ mod tests {
                        match network.get_channels().get(&unsigned_announcement.short_channel_id) {
                                None => panic!(),
                                Some(_) => ()
-                       }
+                       };
                }
 
                // If we receive announcement for the same channel (but TX is not confirmed),
@@ -1353,7 +1376,7 @@ mod tests {
                                        assert_eq!(channel_entry.features, ChannelFeatures::empty());
                                },
                                _ => panic!()
-                       }
+                       };
                }
 
                // Don't relay valid channels with excess data
@@ -1484,7 +1507,7 @@ mod tests {
                                        assert_eq!(channel_info.one_to_two.as_ref().unwrap().cltv_expiry_delta, 144);
                                        assert!(channel_info.two_to_one.is_none());
                                }
-                       }
+                       };
                }
 
                unsigned_channel_update.timestamp += 100;
@@ -1645,7 +1668,7 @@ mod tests {
                                Some(channel_info) => {
                                        assert!(channel_info.one_to_two.is_some());
                                }
-                       }
+                       };
                }
 
                let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed {
@@ -1663,7 +1686,7 @@ mod tests {
                                Some(channel_info) => {
                                        assert!(!channel_info.one_to_two.as_ref().unwrap().enabled);
                                }
-                       }
+                       };
                }
 
                let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed {