Track a `counter` for each node in our network graph
[rust-lightning] / lightning / src / routing / gossip.rs
index be1c31ff8633047c9a3f93919a4407b0025be086..f033b320ed6a989846587f42a68d1dc82da39265 100644 (file)
@@ -40,7 +40,6 @@ use crate::prelude::*;
 use core::{cmp, fmt};
 use core::convert::TryFrom;
 use crate::sync::{RwLock, RwLockReadGuard, LockTestExt};
-#[cfg(feature = "std")]
 use core::sync::atomic::{AtomicUsize, Ordering};
 use crate::sync::Mutex;
 use core::ops::{Bound, Deref};
@@ -184,6 +183,8 @@ pub struct NetworkGraph<L: Deref> where L::Target: Logger {
        // Lock order: channels -> nodes
        channels: RwLock<IndexedMap<u64, ChannelInfo>>,
        nodes: RwLock<IndexedMap<NodeId, NodeInfo>>,
+       removed_node_counters: Mutex<Vec<u32>>,
+       next_node_counter: AtomicUsize,
        // Lock order: removed_channels -> removed_nodes
        //
        // NOTE: In the following `removed_*` maps, we use seconds since UNIX epoch to track time instead
@@ -1230,7 +1231,7 @@ impl Readable for NodeAlias {
        }
 }
 
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, Eq)]
 /// Details about a node in the network, known from the network announcement.
 pub struct NodeInfo {
        /// All valid channels a node has announced
@@ -1238,7 +1239,19 @@ pub struct NodeInfo {
        /// More information about a node from node_announcement.
        /// Optional because we store a Node entry after learning about it from
        /// a channel announcement, but before receiving a node announcement.
-       pub announcement_info: Option<NodeAnnouncementInfo>
+       pub announcement_info: Option<NodeAnnouncementInfo>,
+       /// In memory, each node is assigned a unique ID. They are eagerly reused, ensuring they remain
+       /// relatively dense.
+       ///
+       /// These IDs allow the router to avoid a `HashMap` lookup by simply using this value as an
+       /// index in a `Vec`, skipping a big step in some of the hottest code when routing.
+       pub(crate) node_counter: u32,
+}
+
+impl PartialEq for NodeInfo {
+       fn eq(&self, o: &NodeInfo) -> bool {
+               self.channels == o.channels && self.announcement_info == o.announcement_info
+       }
 }
 
 impl NodeInfo {
@@ -1308,6 +1321,7 @@ impl Readable for NodeInfo {
                Ok(NodeInfo {
                        announcement_info: announcement_info_wrap.map(|w| w.0),
                        channels,
+                       node_counter: u32::max_value(),
                })
        }
 }
@@ -1317,6 +1331,8 @@ const MIN_SERIALIZATION_VERSION: u8 = 1;
 
 impl<L: Deref> Writeable for NetworkGraph<L> where L::Target: Logger {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               self.test_node_counter_consistency();
+
                write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
 
                self.chain_hash.write(writer)?;
@@ -1355,11 +1371,13 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
                        channels.insert(chan_id, chan_info);
                }
                let nodes_count: u64 = Readable::read(reader)?;
+               if nodes_count > u32::max_value() as u64 / 2 { return Err(DecodeError::InvalidValue); }
                // In Nov, 2023 there were about 69K channels; we cap allocations to 1.5x that.
                let mut nodes = IndexedMap::with_capacity(cmp::min(nodes_count as usize, 103500));
-               for _ in 0..nodes_count {
+               for i in 0..nodes_count {
                        let node_id = Readable::read(reader)?;
-                       let node_info = Readable::read(reader)?;
+                       let mut node_info: NodeInfo = Readable::read(reader)?;
+                       node_info.node_counter = i as u32;
                        nodes.insert(node_id, node_info);
                }
 
@@ -1374,6 +1392,8 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
                        logger,
                        channels: RwLock::new(channels),
                        nodes: RwLock::new(nodes),
+                       removed_node_counters: Mutex::new(Vec::new()),
+                       next_node_counter: AtomicUsize::new(nodes_count as usize),
                        last_rapid_gossip_sync_timestamp: Mutex::new(last_rapid_gossip_sync_timestamp),
                        removed_nodes: Mutex::new(new_hash_map()),
                        removed_channels: Mutex::new(new_hash_map()),
@@ -1419,6 +1439,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        logger,
                        channels: RwLock::new(IndexedMap::new()),
                        nodes: RwLock::new(IndexedMap::new()),
+                       next_node_counter: AtomicUsize::new(0),
+                       removed_node_counters: Mutex::new(Vec::new()),
                        last_rapid_gossip_sync_timestamp: Mutex::new(None),
                        removed_channels: Mutex::new(new_hash_map()),
                        removed_nodes: Mutex::new(new_hash_map()),
@@ -1426,8 +1448,33 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                }
        }
 
+       fn test_node_counter_consistency(&self) {
+               #[cfg(debug_assertions)] {
+                       let nodes = self.nodes.read().unwrap();
+                       let removed_node_counters = self.removed_node_counters.lock().unwrap();
+                       let next_counter = self.next_node_counter.load(Ordering::Acquire);
+                       assert!(next_counter < (u32::max_value() as usize) / 2);
+                       let mut used_node_counters = vec![0u8; next_counter / 8 + 1];
+
+                       for counter in removed_node_counters.iter() {
+                               let pos = (*counter as usize) / 8;
+                               let bit = 1 << (counter % 8);
+                               assert_eq!(used_node_counters[pos] & bit, 0);
+                               used_node_counters[pos] |= bit;
+                       }
+                       for (_, node) in nodes.unordered_iter() {
+                               assert!((node.node_counter as usize) < next_counter);
+                               let pos = (node.node_counter as usize) / 8;
+                               let bit = 1 << (node.node_counter % 8);
+                               assert_eq!(used_node_counters[pos] & bit, 0);
+                               used_node_counters[pos] |= bit;
+                       }
+               }
+       }
+
        /// Returns a read-only view of the network graph.
        pub fn read_only(&'_ self) -> ReadOnlyNetworkGraph<'_> {
+               self.test_node_counter_consistency();
                let channels = self.channels.read().unwrap();
                let nodes = self.nodes.read().unwrap();
                ReadOnlyNetworkGraph {
@@ -1609,7 +1656,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                        // b) we don't track UTXOs of channels we know about and remove them if they
                                        //    get reorg'd out.
                                        // c) it's unclear how to do so without exposing ourselves to massive DoS risk.
-                                       Self::remove_channel_in_nodes(&mut nodes, &entry.get(), short_channel_id);
+                                       self.remove_channel_in_nodes(&mut nodes, &entry.get(), short_channel_id);
                                        *entry.get_mut() = channel_info;
                                } else {
                                        return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip});
@@ -1626,9 +1673,13 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                        node_entry.into_mut().channels.push(short_channel_id);
                                },
                                IndexedMapEntry::Vacant(node_entry) => {
+                                       let mut removed_node_counters = self.removed_node_counters.lock().unwrap();
+                                       let node_counter = removed_node_counters.pop()
+                                               .unwrap_or(self.next_node_counter.fetch_add(1, Ordering::Relaxed) as u32);
                                        node_entry.insert(NodeInfo {
                                                channels: vec!(short_channel_id),
                                                announcement_info: None,
+                                               node_counter,
                                        });
                                }
                        };
@@ -1747,7 +1798,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                if let Some(chan) = channels.remove(&short_channel_id) {
                        let mut nodes = self.nodes.write().unwrap();
                        self.removed_channels.lock().unwrap().insert(short_channel_id, current_time_unix);
-                       Self::remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
+                       self.remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
                }
        }
 
@@ -1766,6 +1817,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                let mut removed_nodes = self.removed_nodes.lock().unwrap();
 
                if let Some(node) = nodes.remove(&node_id) {
+                       let mut removed_node_counters = self.removed_node_counters.lock().unwrap();
                        for scid in node.channels.iter() {
                                if let Some(chan_info) = channels.remove(scid) {
                                        let other_node_id = if node_id == chan_info.node_one { chan_info.node_two } else { chan_info.node_one };
@@ -1774,12 +1826,14 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                                        *scid != *chan_id
                                                });
                                                if other_node_entry.get().channels.is_empty() {
+                                                       removed_node_counters.push(other_node_entry.get().node_counter);
                                                        other_node_entry.remove_entry();
                                                }
                                        }
                                        removed_channels.insert(*scid, current_time_unix);
                                }
                        }
+                       removed_node_counters.push(node.node_counter);
                        removed_nodes.insert(node_id, current_time_unix);
                }
        }
@@ -1855,7 +1909,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        let mut nodes = self.nodes.write().unwrap();
                        for scid in scids_to_remove {
                                let info = channels.remove(&scid).expect("We just accessed this scid, it should be present");
-                               Self::remove_channel_in_nodes(&mut nodes, &info, scid);
+                               self.remove_channel_in_nodes(&mut nodes, &info, scid);
                                self.removed_channels.lock().unwrap().insert(scid, Some(current_time_unix));
                        }
                }
@@ -2037,7 +2091,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                Ok(())
        }
 
-       fn remove_channel_in_nodes(nodes: &mut IndexedMap<NodeId, NodeInfo>, chan: &ChannelInfo, short_channel_id: u64) {
+       fn remove_channel_in_nodes(&self, nodes: &mut IndexedMap<NodeId, NodeInfo>, chan: &ChannelInfo, short_channel_id: u64) {
                macro_rules! remove_from_node {
                        ($node_id: expr) => {
                                if let IndexedMapEntry::Occupied(mut entry) = nodes.entry($node_id) {
@@ -2045,6 +2099,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                                short_channel_id != *chan_id
                                        });
                                        if entry.get().channels.is_empty() {
+                                               self.removed_node_counters.lock().unwrap().push(entry.get().node_counter);
                                                entry.remove_entry();
                                        }
                                } else {
@@ -3471,6 +3526,7 @@ pub(crate) mod tests {
                let valid_node_info = NodeInfo {
                        channels: Vec::new(),
                        announcement_info: Some(valid_node_ann_info),
+                       node_counter: 0,
                };
 
                let mut encoded_valid_node_info = Vec::new();