Store the source and destination `node_counter`s in `ChannelInfo`
authorMatt Corallo <git@bluematt.me>
Thu, 7 Dec 2023 04:32:06 +0000 (04:32 +0000)
committerMatt Corallo <git@bluematt.me>
Sun, 10 Dec 2023 03:29:38 +0000 (03:29 +0000)
In the next commit, we'll use the new `node_counter`s to remove a
`HashMap` from the router, using a `Vec` to store all our per-node
information. In order to make finding entries in that `Vec` cheap,
here we store the source and destintaion `node_counter`s in
`ChannelInfo`, givind us the counters for both ends of a channel
without doing a second `HashMap` lookup.

lightning/src/routing/gossip.rs

index 1dfee3249881e5ebcdab332a58f2e331a1551b33..161d635ef6ee4709376bc9ae3b6013cb74786293 100644 (file)
@@ -64,7 +64,7 @@ const MAX_EXCESS_BYTES_FOR_RELAY: usize = 1024;
 const MAX_SCIDS_PER_REPLY: usize = 8000;
 
 /// Represents the compressed public key of a node
-#[derive(Clone, Copy)]
+#[derive(Clone, Copy, PartialEq, Eq)]
 pub struct NodeId([u8; PUBLIC_KEY_SIZE]);
 
 impl NodeId {
@@ -106,14 +106,6 @@ impl core::hash::Hash for NodeId {
        }
 }
 
-impl Eq for NodeId {}
-
-impl PartialEq for NodeId {
-       fn eq(&self, other: &Self) -> bool {
-               self.0[..] == other.0[..]
-       }
-}
-
 impl cmp::PartialOrd for NodeId {
        fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
                Some(self.cmp(other))
@@ -840,7 +832,7 @@ impl Readable for ChannelUpdateInfo {
        }
 }
 
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, Eq)]
 /// Details about a channel (both directions).
 /// Received within a channel announcement.
 pub struct ChannelInfo {
@@ -865,6 +857,24 @@ pub struct ChannelInfo {
        /// (which we can probably assume we are - no-std environments probably won't have a full
        /// network graph in memory!).
        announcement_received_time: u64,
+
+       /// XXX docs
+       pub(crate) node_one_counter: u32,
+       /// XXX docs
+       pub(crate) node_two_counter: u32,
+}
+
+impl PartialEq for ChannelInfo {
+       fn eq(&self, o: &ChannelInfo) -> bool {
+               self.features == o.features &&
+                       self.node_one == o.node_one &&
+                       self.one_to_two == o.one_to_two &&
+                       self.node_two == o.node_two &&
+                       self.two_to_one == o.two_to_one &&
+                       self.capacity_sats == o.capacity_sats &&
+                       self.announcement_message == o.announcement_message &&
+                       self.announcement_received_time == o.announcement_received_time
+       }
 }
 
 impl ChannelInfo {
@@ -981,6 +991,8 @@ impl Readable for ChannelInfo {
                        capacity_sats: _init_tlv_based_struct_field!(capacity_sats, required),
                        announcement_message: _init_tlv_based_struct_field!(announcement_message, required),
                        announcement_received_time: _init_tlv_based_struct_field!(announcement_received_time, (default_value, 0)),
+                       node_one_counter: u32::max_value(),
+                       node_two_counter: u32::max_value(),
                })
        }
 }
@@ -1341,7 +1353,7 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
                let mut channels = IndexedMap::with_capacity(cmp::min(channels_count as usize, 22500));
                for _ in 0..channels_count {
                        let chan_id: u64 = Readable::read(reader)?;
-                       let chan_info = Readable::read(reader)?;
+                       let chan_info: ChannelInfo = Readable::read(reader)?;
                        channels.insert(chan_id, chan_info);
                }
                let nodes_count: u64 = Readable::read(reader)?;
@@ -1355,6 +1367,13 @@ impl<L: Deref> ReadableArgs<L> for NetworkGraph<L> where L::Target: Logger {
                        nodes.insert(node_id, node_info);
                }
 
+               for (_, chan) in channels.unordered_iter_mut() {
+                       chan.node_one_counter =
+                               nodes.get(&chan.node_one).ok_or(DecodeError::InvalidValue)?.node_counter;
+                       chan.node_two_counter =
+                               nodes.get(&chan.node_two).ok_or(DecodeError::InvalidValue)?.node_counter;
+               }
+
                let mut last_rapid_gossip_sync_timestamp: Option<u32> = None;
                read_tlv_fields!(reader, {
                        (1, last_rapid_gossip_sync_timestamp, option),
@@ -1424,6 +1443,7 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
 
        fn test_node_counter_consistency(&self) {
                #[cfg(debug_assertions)] {
+                       let channels = self.channels.read().unwrap();
                        let nodes = self.nodes.read().unwrap();
                        let removed_node_counters = self.removed_node_counters.lock().unwrap();
                        let next_counter = self.next_node_counter.load(Ordering::Acquire);
@@ -1443,6 +1463,11 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                assert_eq!(used_node_counters[pos] & bit, 0);
                                used_node_counters[pos] |= bit;
                        }
+
+                       for (_, chan) in channels.unordered_iter() {
+                               assert_eq!(chan.node_one_counter, nodes.get(&chan.node_one).unwrap().node_counter);
+                               assert_eq!(chan.node_two_counter, nodes.get(&chan.node_two).unwrap().node_counter);
+                       }
                }
        }
 
@@ -1602,6 +1627,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        capacity_sats: None,
                        announcement_message: None,
                        announcement_received_time: timestamp,
+                       node_one_counter: u32::max_value(),
+                       node_two_counter: u32::max_value(),
                };
 
                self.add_channel_between_nodes(short_channel_id, channel_info, None)
@@ -1616,7 +1643,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
 
                log_gossip!(self.logger, "Adding channel {} between nodes {} and {}", short_channel_id, node_id_a, node_id_b);
 
-               match channels.entry(short_channel_id) {
+               let channel_entry = channels.entry(short_channel_id);
+               let channel_info = match channel_entry {
                        IndexedMapEntry::Occupied(mut entry) => {
                                //TODO: because asking the blockchain if short_channel_id is valid is only optional
                                //in the blockchain API, we need to handle it smartly here, though it's unclear
@@ -1632,28 +1660,35 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                                        // c) it's unclear how to do so without exposing ourselves to massive DoS risk.
                                        self.remove_channel_in_nodes(&mut nodes, &entry.get(), short_channel_id);
                                        *entry.get_mut() = channel_info;
+                                       entry.into_mut()
                                } else {
                                        return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip});
                                }
                        },
                        IndexedMapEntry::Vacant(entry) => {
-                               entry.insert(channel_info);
+                               entry.insert(channel_info)
                        }
                };
 
-               for current_node_id in [node_id_a, node_id_b].iter() {
+               let mut node_counter_id = [
+                       (&mut channel_info.node_one_counter, node_id_a),
+                       (&mut channel_info.node_two_counter, node_id_b)
+               ];
+               for (node_counter, current_node_id) in node_counter_id.iter_mut() {
                        match nodes.entry(current_node_id.clone()) {
                                IndexedMapEntry::Occupied(node_entry) => {
-                                       node_entry.into_mut().channels.push(short_channel_id);
+                                       let node = node_entry.into_mut();
+                                       node.channels.push(short_channel_id);
+                                       **node_counter = node.node_counter;
                                },
                                IndexedMapEntry::Vacant(node_entry) => {
                                        let mut removed_node_counters = self.removed_node_counters.lock().unwrap();
-                                       let node_counter = removed_node_counters.pop()
+                                       **node_counter = removed_node_counters.pop()
                                                .unwrap_or(self.next_node_counter.fetch_add(1, Ordering::Relaxed) as u32);
                                        node_entry.insert(NodeInfo {
                                                channels: vec!(short_channel_id),
                                                announcement_info: None,
-                                               node_counter,
+                                               node_counter: **node_counter,
                                        });
                                }
                        };
@@ -1744,6 +1779,8 @@ impl<L: Deref> NetworkGraph<L> where L::Target: Logger {
                        announcement_message: if msg.excess_data.len() <= MAX_EXCESS_BYTES_FOR_RELAY
                                { full_msg.cloned() } else { None },
                        announcement_received_time,
+                       node_one_counter: u32::max_value(),
+                       node_two_counter: u32::max_value(),
                };
 
                self.add_channel_between_nodes(msg.short_channel_id, chan_info, utxo_value)?;
@@ -3424,6 +3461,8 @@ pub(crate) mod tests {
                        capacity_sats: None,
                        announcement_message: None,
                        announcement_received_time: 87654,
+                       node_one_counter: 0,
+                       node_two_counter: 1,
                };
 
                let mut encoded_chan_info: Vec<u8> = Vec::new();
@@ -3442,6 +3481,8 @@ pub(crate) mod tests {
                        capacity_sats: None,
                        announcement_message: None,
                        announcement_received_time: 87654,
+                       node_one_counter: 0,
+                       node_two_counter: 1,
                };
 
                let mut encoded_chan_info: Vec<u8> = Vec::new();