Drop the `dist` `HashMap` in routing, replacing it with a `Vec`.
[rust-lightning] / lightning / src / routing / gossip.rs
index 9b4e41ae174d7b8a73cb8f64b7918f62a0dbf01b..93a521bab05a77304565818ad8520567aaefc560 100644 (file)
@@ -40,7 +40,6 @@ use crate::prelude::*;
 use core::{cmp, fmt};
 use core::convert::TryFrom;
 use crate::sync::{RwLock, RwLockReadGuard, LockTestExt};
-#[cfg(feature = "std")]
 use core::sync::atomic::{AtomicUsize, Ordering};
 use crate::sync::Mutex;
 use core::ops::{Bound, Deref};
@@ -65,7 +64,7 @@ const MAX_EXCESS_BYTES_FOR_RELAY: usize = 1024;
 const MAX_SCIDS_PER_REPLY: usize = 8000;
 
 /// Represents the compressed public key of a node
-#[derive(Clone, Copy)]
+#[derive(Clone, Copy, PartialEq, Eq)]
 pub struct NodeId([u8; PUBLIC_KEY_SIZE]);
 
 impl NodeId {
@@ -107,14 +106,6 @@ impl core::hash::Hash for NodeId {
        }
 }
 
-impl Eq for NodeId {}
-
-impl PartialEq for NodeId {
-       fn eq(&self, other: &Self) -> bool {
-               self.0[..] == other.0[..]
-       }
-}
-
 impl cmp::PartialOrd for NodeId {
        fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
                Some(self.cmp(other))
@@ -174,6 +165,8 @@ pub struct NetworkGraph<L: Deref> where L::Target: Logger {
        // Lock order: channels -> nodes
        channels: RwLock<IndexedMap<u64, ChannelInfo>>,
        nodes: RwLock<IndexedMap<NodeId, NodeInfo>>,
+       removed_node_counters: Mutex<Vec<u32>>,
+       next_node_counter: AtomicUsize,
        // Lock order: removed_channels -> removed_nodes
        //
        // NOTE: In the following `removed_*` maps, we use seconds since UNIX epoch to track time instead
@@ -201,6 +194,7 @@ pub struct NetworkGraph<L: Deref> where L::Target: Logger {
 pub struct ReadOnlyNetworkGraph<'a> {
        channels: RwLockReadGuard<'a, IndexedMap<u64, ChannelInfo>>,
        nodes: RwLockReadGuard<'a, IndexedMap<NodeId, NodeInfo>>,
+       max_node_counter: u32,
 }
 
 /// Update to the [`NetworkGraph`] based on payment failure information conveyed via the Onion
@@ -839,7 +833,7 @@ impl Readable for ChannelUpdateInfo {
        }
 }
 
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, Eq)]
 /// Details about a channel (both directions).
 /// Received within a channel announcement.
 pub struct ChannelInfo {
@@ -864,6 +858,24 @@ pub struct ChannelInfo {
        /// (which we can probably assume we are - no-std environments probably won't have a full
        /// network graph in memory!).
        announcement_received_time: u64,
+
+       /// XXX docs
+       pub(crate) node_one_counter: u32,
+       /// XXX docs
+       pub(crate) node_two_counter: u32,
+}
+
+impl PartialEq for ChannelInfo {
+       fn eq(&self, o: &ChannelInfo) -> bool {
+               self.features == o.features &&
+                       self.node_one == o.node_one &&
+                       self.one_to_two == o.one_to_two &&
+                       self.node_two == o.node_two &&
+                       self.two_to_one == o.two_to_one &&
+                       self.capacity_sats == o.capacity_sats &&
+                       self.announcement_message == o.announcement_message &&
+                       self.announcement_received_time == o.announcement_received_time
+       }
 }
 
 impl ChannelInfo {
@@ -980,6 +992,8 @@ impl Readable for ChannelInfo {
                        capacity_sats: _init_tlv_based_struct_field!(capacity_sats, required),
                        announcement_message: _init_tlv_based_struct_field!(announcement_message, required),
                        announcement_received_time: _init_tlv_based_struct_field!(announcement_received_time, (default_value, 0)),
+                       node_one_counter: u32::max_value(),
+                       node_two_counter: u32::max_value(),
                })
        }
 }
@@ -1039,6 +1053,14 @@ impl<'a> DirectedChannelInfo<'a> {
        /// Refers to the `node_id` receiving the payment from the previous hop.
        #[inline]
        pub(super) fn target(&self) -> &'a NodeId { if self.from_node_one { &self.channel.node_two } else { &self.channel.node_one } }
+
+       /// Returns the source node's counter
+       #[inline]
+       pub(super) fn source_counter(&self) -> u32 { if self.from_node_one { self.channel.node_one_counter } else { self.channel.node_two_counter } }
+
+       /// Returns the target node's counter
+       #[inline]
+       pub(super) fn target_counter(&self) -> u32 { if self.from_node_one { self.channel.node_two_counter } else { self.channel.node_one_counter } }
 }
 
 impl<'a> fmt::Debug for DirectedChannelInfo<'a> {
@@ -1220,7 +1242,7 @@ impl Readable for NodeAlias {
        }
 }
 
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, Eq)]
 /// Details about a node in the network, known from the network announcement.
 pub struct NodeInfo {
        /// All valid channels a node has announced
@@ -1228,7 +1250,15 @@ pub struct NodeInfo {
        /// More information about a node from node_announcement.
        /// Optional because we store a Node entry after learning about it from
        /// a channel announcement, but before receiving a node announcement.
-       pub announcement_info: Option<NodeAnnouncementInfo>
+       pub announcement_info: Option<NodeAnnouncementInfo>,
+       /// XXX: Docs
+       pub(crate) node_counter: u32,
+}
+
+impl PartialEq for NodeInfo {
+       fn eq(&self, o: &NodeInfo) -> bool {
+               self.channels == o.channels && self.announcement_info == o.announcement_info
+       }
 }
 
 impl fmt::Display for NodeInfo {
@@ -1286,6 +1316,7 @@ impl Readable for NodeInfo {
                Ok(NodeInfo {
                        announcement_info: announcement_info_wrap.map(|w| w.0),
                        channels,
+                       node_counter: u32::max_value(),
                })
        }
 }
@@ -1295,6 +1326,8 @@ const MIN_SERIALIZATION_VERSION: u8 = 1;
 
 impl<L: Deref> Writeable for NetworkGraph<L> where L::Target: Logger {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               self.test_node_counter_consistency();
+
                write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
 
                self.chain_hash.write(writer)?;
@@ -1329,18 +1362,27 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
                let mut channels = IndexedMap::with_capacity(cmp::min(channels_count as usize, 22500));
                for _ in 0..channels_count {
                        let chan_id: u64 = Readable::read(reader)?;
-                       let chan_info = Readable::read(reader)?;
+                       let chan_info: ChannelInfo = Readable::read(reader)?;
                        channels.insert(chan_id, chan_info);
                }
                let nodes_count: u64 = Readable::read(reader)?;
+               if nodes_count > u32::max_value() as u64 / 2 { return Err(DecodeError::InvalidValue); }
                // In Nov, 2023 there were about 69K channels; we cap allocations to 1.5x that.
                let mut nodes = IndexedMap::with_capacity(cmp::min(nodes_count as usize, 103500));
-               for _ in 0..nodes_count {
+               for i in 0..nodes_count {
                        let node_id = Readable::read(reader)?;
-                       let node_info = Readable::read(reader)?;
+                       let mut node_info: NodeInfo = Readable::read(reader)?;
+                       node_info.node_counter = i as u32;
                        nodes.insert(node_id, node_info);
                }
 
+               for (_, chan) in channels.unordered_iter_mut() {
+                       chan.node_one_counter =
+                               nodes.get(&chan.node_one).ok_or(DecodeError::InvalidValue)?.node_counter;
+                       chan.node_two_counter =
+                               nodes.get(&chan.node_two).ok_or(DecodeError::InvalidValue)?.node_counter;
+               }
+
                let mut last_rapid_gossip_sync_timestamp: Option<u32> = None;
                read_tlv_fields!(reader, {
                        (1, last_rapid_gossip_sync_timestamp, option),
@@ -1352,6 +1394,8 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
                        logger,
                        channels: RwLock::new(channels),
                        nodes: RwLock::new(nodes),
+                       removed_node_counters: Mutex::new(Vec::new()),
+                       next_node_counter: AtomicUsize::new(nodes_count as usize),
                        last_rapid_gossip_sync_timestamp: Mutex::new(last_rapid_gossip_sync_timestamp),
                        removed_nodes: Mutex::new(HashMap::new()),
                        removed_channels: Mutex::new(HashMap::new()),
@@ -1397,6 +1441,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        logger,
                        channels: RwLock::new(IndexedMap::new()),
                        nodes: RwLock::new(IndexedMap::new()),
+                       next_node_counter: AtomicUsize::new(0),
+                       removed_node_counters: Mutex::new(Vec::new()),
                        last_rapid_gossip_sync_timestamp: Mutex::new(None),
                        removed_channels: Mutex::new(HashMap::new()),
                        removed_nodes: Mutex::new(HashMap::new()),
@@ -1404,13 +1450,45 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                }
        }
 
+       fn test_node_counter_consistency(&self) {
+               #[cfg(debug_assertions)] {
+                       let channels = self.channels.read().unwrap();
+                       let nodes = self.nodes.read().unwrap();
+                       let removed_node_counters = self.removed_node_counters.lock().unwrap();
+                       let next_counter = self.next_node_counter.load(Ordering::Acquire);
+                       assert!(next_counter < (u32::max_value() as usize) / 2);
+                       let mut used_node_counters = vec![0u8; next_counter / 8 + 1];
+
+                       for counter in removed_node_counters.iter() {
+                               let pos = (*counter as usize) / 8;
+                               let bit = 1 << (counter % 8);
+                               assert_eq!(used_node_counters[pos] & bit, 0);
+                               used_node_counters[pos] |= bit;
+                       }
+                       for (_, node) in nodes.unordered_iter() {
+                               assert!((node.node_counter as usize) < next_counter);
+                               let pos = (node.node_counter as usize) / 8;
+                               let bit = 1 << (node.node_counter % 8);
+                               assert_eq!(used_node_counters[pos] & bit, 0);
+                               used_node_counters[pos] |= bit;
+                       }
+
+                       for (_, chan) in channels.unordered_iter() {
+                               assert_eq!(chan.node_one_counter, nodes.get(&chan.node_one).unwrap().node_counter);
+                               assert_eq!(chan.node_two_counter, nodes.get(&chan.node_two).unwrap().node_counter);
+                       }
+               }
+       }
+
        /// Returns a read-only view of the network graph.
        pub fn read_only(&'_ self) -> ReadOnlyNetworkGraph<'_> {
+               self.test_node_counter_consistency();
                let channels = self.channels.read().unwrap();
                let nodes = self.nodes.read().unwrap();
                ReadOnlyNetworkGraph {
                        channels,
                        nodes,
+                       max_node_counter: (self.next_node_counter.load(Ordering::Acquire) as u32).saturating_sub(1),
                }
        }
 
@@ -1559,6 +1637,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        capacity_sats: None,
                        announcement_message: None,
                        announcement_received_time: timestamp,
+                       node_one_counter: u32::max_value(),
+                       node_two_counter: u32::max_value(),
                };
 
                self.add_channel_between_nodes(short_channel_id, channel_info, None)
@@ -1573,7 +1653,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
 
                log_gossip!(self.logger, "Adding channel {} between nodes {} and {}", short_channel_id, node_id_a, node_id_b);
 
-               match channels.entry(short_channel_id) {
+               let channel_entry = channels.entry(short_channel_id);
+               let channel_info = match channel_entry {
                        IndexedMapEntry::Occupied(mut entry) => {
                                //TODO: because asking the blockchain if short_channel_id is valid is only optional
                                //in the blockchain API, we need to handle it smartly here, though it's unclear
@@ -1587,26 +1668,37 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                        // b) we don't track UTXOs of channels we know about and remove them if they
                                        //    get reorg'd out.
                                        // c) it's unclear how to do so without exposing ourselves to massive DoS risk.
-                                       Self::remove_channel_in_nodes(&mut nodes, &entry.get(), short_channel_id);
+                                       self.remove_channel_in_nodes(&mut nodes, &entry.get(), short_channel_id);
                                        *entry.get_mut() = channel_info;
+                                       entry.into_mut()
                                } else {
                                        return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip});
                                }
                        },
                        IndexedMapEntry::Vacant(entry) => {
-                               entry.insert(channel_info);
+                               entry.insert(channel_info)
                        }
                };
 
-               for current_node_id in [node_id_a, node_id_b].iter() {
+               let mut node_counter_id = [
+                       (&mut channel_info.node_one_counter, node_id_a),
+                       (&mut channel_info.node_two_counter, node_id_b)
+               ];
+               for (node_counter, current_node_id) in node_counter_id.iter_mut() {
                        match nodes.entry(current_node_id.clone()) {
                                IndexedMapEntry::Occupied(node_entry) => {
-                                       node_entry.into_mut().channels.push(short_channel_id);
+                                       let node = node_entry.into_mut();
+                                       node.channels.push(short_channel_id);
+                                       **node_counter = node.node_counter;
                                },
                                IndexedMapEntry::Vacant(node_entry) => {
+                                       let mut removed_node_counters = self.removed_node_counters.lock().unwrap();
+                                       **node_counter = removed_node_counters.pop()
+                                               .unwrap_or(self.next_node_counter.fetch_add(1, Ordering::Relaxed) as u32);
                                        node_entry.insert(NodeInfo {
                                                channels: vec!(short_channel_id),
                                                announcement_info: None,
+                                               node_counter: **node_counter,
                                        });
                                }
                        };
@@ -1697,6 +1789,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        announcement_message: if msg.excess_data.len() <= MAX_EXCESS_BYTES_FOR_RELAY
                                { full_msg.cloned() } else { None },
                        announcement_received_time,
+                       node_one_counter: u32::max_value(),
+                       node_two_counter: u32::max_value(),
                };
 
                self.add_channel_between_nodes(msg.short_channel_id, chan_info, utxo_value)?;
@@ -1725,7 +1819,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                if let Some(chan) = channels.remove(&short_channel_id) {
                        let mut nodes = self.nodes.write().unwrap();
                        self.removed_channels.lock().unwrap().insert(short_channel_id, current_time_unix);
-                       Self::remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
+                       self.remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
                }
        }
 
@@ -1744,6 +1838,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                let mut removed_nodes = self.removed_nodes.lock().unwrap();
 
                if let Some(node) = nodes.remove(&node_id) {
+                       let mut removed_node_counters = self.removed_node_counters.lock().unwrap();
                        for scid in node.channels.iter() {
                                if let Some(chan_info) = channels.remove(scid) {
                                        let other_node_id = if node_id == chan_info.node_one { chan_info.node_two } else { chan_info.node_one };
@@ -1752,12 +1847,14 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                                        *scid != *chan_id
                                                });
                                                if other_node_entry.get().channels.is_empty() {
+                                                       removed_node_counters.push(other_node_entry.get().node_counter);
                                                        other_node_entry.remove_entry();
                                                }
                                        }
                                        removed_channels.insert(*scid, current_time_unix);
                                }
                        }
+                       removed_node_counters.push(node.node_counter);
                        removed_nodes.insert(node_id, current_time_unix);
                }
        }
@@ -1833,7 +1930,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        let mut nodes = self.nodes.write().unwrap();
                        for scid in scids_to_remove {
                                let info = channels.remove(&scid).expect("We just accessed this scid, it should be present");
-                               Self::remove_channel_in_nodes(&mut nodes, &info, scid);
+                               self.remove_channel_in_nodes(&mut nodes, &info, scid);
                                self.removed_channels.lock().unwrap().insert(scid, Some(current_time_unix));
                        }
                }
@@ -2012,7 +2109,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                Ok(())
        }
 
-       fn remove_channel_in_nodes(nodes: &mut IndexedMap<NodeId, NodeInfo>, chan: &ChannelInfo, short_channel_id: u64) {
+       fn remove_channel_in_nodes(&self, nodes: &mut IndexedMap<NodeId, NodeInfo>, chan: &ChannelInfo, short_channel_id: u64) {
                macro_rules! remove_from_node {
                        ($node_id: expr) => {
                                if let IndexedMapEntry::Occupied(mut entry) = nodes.entry($node_id) {
@@ -2020,6 +2117,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                                short_channel_id != *chan_id
                                        });
                                        if entry.get().channels.is_empty() {
+                                               self.removed_node_counters.lock().unwrap().push(entry.get().node_counter);
                                                entry.remove_entry();
                                        }
                                } else {
@@ -2077,6 +2175,11 @@ impl ReadOnlyNetworkGraph<'_> {
                self.nodes.get(&NodeId::from_pubkey(&pubkey))
                        .and_then(|node| node.announcement_info.as_ref().map(|ann| ann.addresses().to_vec()))
        }
+
+       /// Gets the maximum possible node_counter for a node in this graph
+       pub(crate) fn max_node_counter(&self) -> u32 {
+               self.max_node_counter
+       }
 }
 
 #[cfg(test)]
@@ -3373,6 +3476,8 @@ pub(crate) mod tests {
                        capacity_sats: None,
                        announcement_message: None,
                        announcement_received_time: 87654,
+                       node_one_counter: 0,
+                       node_two_counter: 1,
                };
 
                let mut encoded_chan_info: Vec<u8> = Vec::new();
@@ -3391,6 +3496,8 @@ pub(crate) mod tests {
                        capacity_sats: None,
                        announcement_message: None,
                        announcement_received_time: 87654,
+                       node_one_counter: 0,
+                       node_two_counter: 1,
                };
 
                let mut encoded_chan_info: Vec<u8> = Vec::new();
@@ -3445,6 +3552,7 @@ pub(crate) mod tests {
                let valid_node_info = NodeInfo {
                        channels: Vec::new(),
                        announcement_info: Some(valid_node_ann_info),
+                       node_counter: 0,
                };
 
                let mut encoded_valid_node_info = Vec::new();