Track a `counter` for each node in our network graph
[rust-lightning] / lightning / src / routing / gossip.rs
index f556fabad0cbcdc9f3d94924bd236080c692c865..1dfee3249881e5ebcdab332a58f2e331a1551b33 100644 (file)
@@ -40,7 +40,6 @@ use crate::prelude::*;
 use core::{cmp, fmt};
 use core::convert::TryFrom;
 use crate::sync::{RwLock, RwLockReadGuard, LockTestExt};
-#[cfg(feature = "std")]
 use core::sync::atomic::{AtomicUsize, Ordering};
 use crate::sync::Mutex;
 use core::ops::{Bound, Deref};
@@ -174,6 +173,8 @@ pub struct NetworkGraph<L: Deref> where L::Target: Logger {
        // Lock order: channels -> nodes
        channels: RwLock<IndexedMap<u64, ChannelInfo>>,
        nodes: RwLock<IndexedMap<NodeId, NodeInfo>>,
+       removed_node_counters: Mutex<Vec<u32>>,
+       next_node_counter: AtomicUsize,
        // Lock order: removed_channels -> removed_nodes
        //
        // NOTE: In the following `removed_*` maps, we use seconds since UNIX epoch to track time instead
@@ -870,31 +871,31 @@ impl ChannelInfo {
        /// Returns a [`DirectedChannelInfo`] for the channel directed to the given `target` from a
        /// returned `source`, or `None` if `target` is not one of the channel's counterparties.
        pub fn as_directed_to(&self, target: &NodeId) -> Option<(DirectedChannelInfo, &NodeId)> {
-               let (direction, source) = {
+               let (direction, source, outbound) = {
                        if target == &self.node_one {
-                               (self.two_to_one.as_ref(), &self.node_two)
+                               (self.two_to_one.as_ref(), &self.node_two, false)
                        } else if target == &self.node_two {
-                               (self.one_to_two.as_ref(), &self.node_one)
+                               (self.one_to_two.as_ref(), &self.node_one, true)
                        } else {
                                return None;
                        }
                };
-               direction.map(|dir| (DirectedChannelInfo::new(self, dir), source))
+               direction.map(|dir| (DirectedChannelInfo::new(self, dir, outbound), source))
        }
 
        /// Returns a [`DirectedChannelInfo`] for the channel directed from the given `source` to a
        /// returned `target`, or `None` if `source` is not one of the channel's counterparties.
        pub fn as_directed_from(&self, source: &NodeId) -> Option<(DirectedChannelInfo, &NodeId)> {
-               let (direction, target) = {
+               let (direction, target, outbound) = {
                        if source == &self.node_one {
-                               (self.one_to_two.as_ref(), &self.node_two)
+                               (self.one_to_two.as_ref(), &self.node_two, true)
                        } else if source == &self.node_two {
-                               (self.two_to_one.as_ref(), &self.node_one)
+                               (self.two_to_one.as_ref(), &self.node_one, false)
                        } else {
                                return None;
                        }
                };
-               direction.map(|dir| (DirectedChannelInfo::new(self, dir), target))
+               direction.map(|dir| (DirectedChannelInfo::new(self, dir, outbound), target))
        }
 
        /// Returns a [`ChannelUpdateInfo`] based on the direction implied by the channel_flag.
@@ -990,51 +991,55 @@ impl Readable for ChannelInfo {
 pub struct DirectedChannelInfo<'a> {
        channel: &'a ChannelInfo,
        direction: &'a ChannelUpdateInfo,
-       htlc_maximum_msat: u64,
-       effective_capacity: EffectiveCapacity,
+       /// The direction this channel is in - if set, it indicates that we're traversing the channel
+       /// from [`ChannelInfo::node_one`] to [`ChannelInfo::node_two`].
+       from_node_one: bool,
 }
 
 impl<'a> DirectedChannelInfo<'a> {
        #[inline]
-       fn new(channel: &'a ChannelInfo, direction: &'a ChannelUpdateInfo) -> Self {
-               let mut htlc_maximum_msat = direction.htlc_maximum_msat;
-               let capacity_msat = channel.capacity_sats.map(|capacity_sats| capacity_sats * 1000);
-
-               let effective_capacity = match capacity_msat {
-                       Some(capacity_msat) => {
-                               htlc_maximum_msat = cmp::min(htlc_maximum_msat, capacity_msat);
-                               EffectiveCapacity::Total { capacity_msat, htlc_maximum_msat: htlc_maximum_msat }
-                       },
-                       None => EffectiveCapacity::AdvertisedMaxHTLC { amount_msat: htlc_maximum_msat },
-               };
-
-               Self {
-                       channel, direction, htlc_maximum_msat, effective_capacity
-               }
+       fn new(channel: &'a ChannelInfo, direction: &'a ChannelUpdateInfo, from_node_one: bool) -> Self {
+               Self { channel, direction, from_node_one }
        }
 
        /// Returns information for the channel.
        #[inline]
        pub fn channel(&self) -> &'a ChannelInfo { self.channel }
 
-       /// Returns the maximum HTLC amount allowed over the channel in the direction.
-       #[inline]
-       pub fn htlc_maximum_msat(&self) -> u64 {
-               self.htlc_maximum_msat
-       }
-
        /// Returns the [`EffectiveCapacity`] of the channel in the direction.
        ///
        /// This is either the total capacity from the funding transaction, if known, or the
        /// `htlc_maximum_msat` for the direction as advertised by the gossip network, if known,
        /// otherwise.
+       #[inline]
        pub fn effective_capacity(&self) -> EffectiveCapacity {
-               self.effective_capacity
+               let mut htlc_maximum_msat = self.direction().htlc_maximum_msat;
+               let capacity_msat = self.channel.capacity_sats.map(|capacity_sats| capacity_sats * 1000);
+
+               match capacity_msat {
+                       Some(capacity_msat) => {
+                               htlc_maximum_msat = cmp::min(htlc_maximum_msat, capacity_msat);
+                               EffectiveCapacity::Total { capacity_msat, htlc_maximum_msat }
+                       },
+                       None => EffectiveCapacity::AdvertisedMaxHTLC { amount_msat: htlc_maximum_msat },
+               }
        }
 
        /// Returns information for the direction.
        #[inline]
        pub(super) fn direction(&self) -> &'a ChannelUpdateInfo { self.direction }
+
+       /// Returns the `node_id` of the source hop.
+       ///
+       /// Refers to the `node_id` forwarding the payment to the next hop.
+       #[inline]
+       pub(super) fn source(&self) -> &'a NodeId { if self.from_node_one { &self.channel.node_one } else { &self.channel.node_two } }
+
+       /// Returns the `node_id` of the target hop.
+       ///
+       /// Refers to the `node_id` receiving the payment from the previous hop.
+       #[inline]
+       pub(super) fn target(&self) -> &'a NodeId { if self.from_node_one { &self.channel.node_two } else { &self.channel.node_one } }
 }
 
 impl<'a> fmt::Debug for DirectedChannelInfo<'a> {
@@ -1216,7 +1221,7 @@ impl Readable for NodeAlias {
        }
 }
 
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, Eq)]
 /// Details about a node in the network, known from the network announcement.
 pub struct NodeInfo {
        /// All valid channels a node has announced
@@ -1224,7 +1229,15 @@ pub struct NodeInfo {
        /// More information about a node from node_announcement.
        /// Optional because we store a Node entry after learning about it from
        /// a channel announcement, but before receiving a node announcement.
-       pub announcement_info: Option<NodeAnnouncementInfo>
+       pub announcement_info: Option<NodeAnnouncementInfo>,
+       /// XXX: Docs
+       pub(crate) node_counter: u32,
+}
+
+impl PartialEq for NodeInfo {
+       fn eq(&self, o: &NodeInfo) -> bool {
+               self.channels == o.channels && self.announcement_info == o.announcement_info
+       }
 }
 
 impl fmt::Display for NodeInfo {
@@ -1282,6 +1295,7 @@ impl Readable for NodeInfo {
                Ok(NodeInfo {
                        announcement_info: announcement_info_wrap.map(|w| w.0),
                        channels,
+                       node_counter: u32::max_value(),
                })
        }
 }
@@ -1291,6 +1305,8 @@ const MIN_SERIALIZATION_VERSION: u8 = 1;
 
 impl<L: Deref> Writeable for NetworkGraph<L> where L::Target: Logger {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               self.test_node_counter_consistency();
+
                write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
 
                self.chain_hash.write(writer)?;
@@ -1329,11 +1345,13 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
                        channels.insert(chan_id, chan_info);
                }
                let nodes_count: u64 = Readable::read(reader)?;
+               if nodes_count > u32::max_value() as u64 / 2 { return Err(DecodeError::InvalidValue); }
                // In Nov, 2023 there were about 69K channels; we cap allocations to 1.5x that.
                let mut nodes = IndexedMap::with_capacity(cmp::min(nodes_count as usize, 103500));
-               for _ in 0..nodes_count {
+               for i in 0..nodes_count {
                        let node_id = Readable::read(reader)?;
-                       let node_info = Readable::read(reader)?;
+                       let mut node_info: NodeInfo = Readable::read(reader)?;
+                       node_info.node_counter = i as u32;
                        nodes.insert(node_id, node_info);
                }
 
@@ -1348,6 +1366,8 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
                        logger,
                        channels: RwLock::new(channels),
                        nodes: RwLock::new(nodes),
+                       removed_node_counters: Mutex::new(Vec::new()),
+                       next_node_counter: AtomicUsize::new(nodes_count as usize),
                        last_rapid_gossip_sync_timestamp: Mutex::new(last_rapid_gossip_sync_timestamp),
                        removed_nodes: Mutex::new(HashMap::new()),
                        removed_channels: Mutex::new(HashMap::new()),
@@ -1393,6 +1413,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        logger,
                        channels: RwLock::new(IndexedMap::new()),
                        nodes: RwLock::new(IndexedMap::new()),
+                       next_node_counter: AtomicUsize::new(0),
+                       removed_node_counters: Mutex::new(Vec::new()),
                        last_rapid_gossip_sync_timestamp: Mutex::new(None),
                        removed_channels: Mutex::new(HashMap::new()),
                        removed_nodes: Mutex::new(HashMap::new()),
@@ -1400,8 +1422,33 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                }
        }
 
+       fn test_node_counter_consistency(&self) {
+               #[cfg(debug_assertions)] {
+                       let nodes = self.nodes.read().unwrap();
+                       let removed_node_counters = self.removed_node_counters.lock().unwrap();
+                       let next_counter = self.next_node_counter.load(Ordering::Acquire);
+                       assert!(next_counter < (u32::max_value() as usize) / 2);
+                       let mut used_node_counters = vec![0u8; next_counter / 8 + 1];
+
+                       for counter in removed_node_counters.iter() {
+                               let pos = (*counter as usize) / 8;
+                               let bit = 1 << (counter % 8);
+                               assert_eq!(used_node_counters[pos] & bit, 0);
+                               used_node_counters[pos] |= bit;
+                       }
+                       for (_, node) in nodes.unordered_iter() {
+                               assert!((node.node_counter as usize) < next_counter);
+                               let pos = (node.node_counter as usize) / 8;
+                               let bit = 1 << (node.node_counter % 8);
+                               assert_eq!(used_node_counters[pos] & bit, 0);
+                               used_node_counters[pos] |= bit;
+                       }
+               }
+       }
+
        /// Returns a read-only view of the network graph.
        pub fn read_only(&'_ self) -> ReadOnlyNetworkGraph<'_> {
+               self.test_node_counter_consistency();
                let channels = self.channels.read().unwrap();
                let nodes = self.nodes.read().unwrap();
                ReadOnlyNetworkGraph {
@@ -1583,7 +1630,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                        // b) we don't track UTXOs of channels we know about and remove them if they
                                        //    get reorg'd out.
                                        // c) it's unclear how to do so without exposing ourselves to massive DoS risk.
-                                       Self::remove_channel_in_nodes(&mut nodes, &entry.get(), short_channel_id);
+                                       self.remove_channel_in_nodes(&mut nodes, &entry.get(), short_channel_id);
                                        *entry.get_mut() = channel_info;
                                } else {
                                        return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip});
@@ -1600,9 +1647,13 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                        node_entry.into_mut().channels.push(short_channel_id);
                                },
                                IndexedMapEntry::Vacant(node_entry) => {
+                                       let mut removed_node_counters = self.removed_node_counters.lock().unwrap();
+                                       let node_counter = removed_node_counters.pop()
+                                               .unwrap_or(self.next_node_counter.fetch_add(1, Ordering::Relaxed) as u32);
                                        node_entry.insert(NodeInfo {
                                                channels: vec!(short_channel_id),
                                                announcement_info: None,
+                                               node_counter,
                                        });
                                }
                        };
@@ -1721,7 +1772,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                if let Some(chan) = channels.remove(&short_channel_id) {
                        let mut nodes = self.nodes.write().unwrap();
                        self.removed_channels.lock().unwrap().insert(short_channel_id, current_time_unix);
-                       Self::remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
+                       self.remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
                }
        }
 
@@ -1740,6 +1791,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                let mut removed_nodes = self.removed_nodes.lock().unwrap();
 
                if let Some(node) = nodes.remove(&node_id) {
+                       let mut removed_node_counters = self.removed_node_counters.lock().unwrap();
                        for scid in node.channels.iter() {
                                if let Some(chan_info) = channels.remove(scid) {
                                        let other_node_id = if node_id == chan_info.node_one { chan_info.node_two } else { chan_info.node_one };
@@ -1748,12 +1800,14 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                                        *scid != *chan_id
                                                });
                                                if other_node_entry.get().channels.is_empty() {
+                                                       removed_node_counters.push(other_node_entry.get().node_counter);
                                                        other_node_entry.remove_entry();
                                                }
                                        }
                                        removed_channels.insert(*scid, current_time_unix);
                                }
                        }
+                       removed_node_counters.push(node.node_counter);
                        removed_nodes.insert(node_id, current_time_unix);
                }
        }
@@ -1829,7 +1883,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        let mut nodes = self.nodes.write().unwrap();
                        for scid in scids_to_remove {
                                let info = channels.remove(&scid).expect("We just accessed this scid, it should be present");
-                               Self::remove_channel_in_nodes(&mut nodes, &info, scid);
+                               self.remove_channel_in_nodes(&mut nodes, &info, scid);
                                self.removed_channels.lock().unwrap().insert(scid, Some(current_time_unix));
                        }
                }
@@ -2008,7 +2062,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                Ok(())
        }
 
-       fn remove_channel_in_nodes(nodes: &mut IndexedMap<NodeId, NodeInfo>, chan: &ChannelInfo, short_channel_id: u64) {
+       fn remove_channel_in_nodes(&self, nodes: &mut IndexedMap<NodeId, NodeInfo>, chan: &ChannelInfo, short_channel_id: u64) {
                macro_rules! remove_from_node {
                        ($node_id: expr) => {
                                if let IndexedMapEntry::Occupied(mut entry) = nodes.entry($node_id) {
@@ -2016,6 +2070,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                                short_channel_id != *chan_id
                                        });
                                        if entry.get().channels.is_empty() {
+                                               self.removed_node_counters.lock().unwrap().push(entry.get().node_counter);
                                                entry.remove_entry();
                                        }
                                } else {
@@ -3441,6 +3496,7 @@ pub(crate) mod tests {
                let valid_node_info = NodeInfo {
                        channels: Vec::new(),
                        announcement_info: Some(valid_node_ann_info),
+                       node_counter: 0,
                };
 
                let mut encoded_valid_node_info = Vec::new();