From a17a159c937578abbd57e3e271be4fcc47b1620d Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 7 Dec 2023 04:32:06 +0000 Subject: [PATCH] Store the source and destination `node_counter`s in `ChannelInfo` In the next commit, we'll use the new `node_counter`s to remove a `HashMap` from the router, using a `Vec` to store all our per-node information. In order to make finding entries in that `Vec` cheap, here we store the source and destintaion `node_counter`s in `ChannelInfo`, givind us the counters for both ends of a channel without doing a second `HashMap` lookup. --- lightning/src/routing/gossip.rs | 84 ++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 17 deletions(-) diff --git a/lightning/src/routing/gossip.rs b/lightning/src/routing/gossip.rs index 02f8a8c32..152604c5d 100644 --- a/lightning/src/routing/gossip.rs +++ b/lightning/src/routing/gossip.rs @@ -64,7 +64,7 @@ const MAX_EXCESS_BYTES_FOR_RELAY: usize = 1024; const MAX_SCIDS_PER_REPLY: usize = 8000; /// Represents the compressed public key of a node -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq)] pub struct NodeId([u8; PUBLIC_KEY_SIZE]); impl NodeId { @@ -116,14 +116,6 @@ impl core::hash::Hash for NodeId { } } -impl Eq for NodeId {} - -impl PartialEq for NodeId { - fn eq(&self, other: &Self) -> bool { - self.0[..] == other.0[..] - } -} - impl cmp::PartialOrd for NodeId { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) @@ -885,7 +877,7 @@ impl Readable for ChannelUpdateInfo { } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, Eq)] /// Details about a channel (both directions). /// Received within a channel announcement. pub struct ChannelInfo { @@ -910,6 +902,24 @@ pub struct ChannelInfo { /// (which we can probably assume we are - no-std environments probably won't have a full /// network graph in memory!). announcement_received_time: u64, + + /// The [`NodeInfo::node_counter`] of the node pointed to by [`Self::node_one`]. + pub(crate) node_one_counter: u32, + /// The [`NodeInfo::node_counter`] of the node pointed to by [`Self::node_two`]. + pub(crate) node_two_counter: u32, +} + +impl PartialEq for ChannelInfo { + fn eq(&self, o: &ChannelInfo) -> bool { + self.features == o.features && + self.node_one == o.node_one && + self.one_to_two == o.one_to_two && + self.node_two == o.node_two && + self.two_to_one == o.two_to_one && + self.capacity_sats == o.capacity_sats && + self.announcement_message == o.announcement_message && + self.announcement_received_time == o.announcement_received_time + } } impl ChannelInfo { @@ -1030,6 +1040,8 @@ impl Readable for ChannelInfo { capacity_sats: _init_tlv_based_struct_field!(capacity_sats, required), announcement_message: _init_tlv_based_struct_field!(announcement_message, required), announcement_received_time: _init_tlv_based_struct_field!(announcement_received_time, (default_value, 0)), + node_one_counter: u32::max_value(), + node_two_counter: u32::max_value(), }) } } @@ -1505,7 +1517,7 @@ impl ReadableArgs for NetworkGraph where L::Target: Logger { let mut channels = IndexedMap::with_capacity(cmp::min(channels_count as usize, 22500)); for _ in 0..channels_count { let chan_id: u64 = Readable::read(reader)?; - let chan_info = Readable::read(reader)?; + let chan_info: ChannelInfo = Readable::read(reader)?; channels.insert(chan_id, chan_info); } let nodes_count: u64 = Readable::read(reader)?; @@ -1521,6 +1533,13 @@ impl ReadableArgs for NetworkGraph where L::Target: Logger { nodes.insert(node_id, node_info); } + for (_, chan) in channels.unordered_iter_mut() { + chan.node_one_counter = + nodes.get(&chan.node_one).ok_or(DecodeError::InvalidValue)?.node_counter; + chan.node_two_counter = + nodes.get(&chan.node_two).ok_or(DecodeError::InvalidValue)?.node_counter; + } + let mut last_rapid_gossip_sync_timestamp: Option = None; read_tlv_fields!(reader, { (1, last_rapid_gossip_sync_timestamp, option), @@ -1590,6 +1609,7 @@ impl NetworkGraph where L::Target: Logger { fn test_node_counter_consistency(&self) { #[cfg(debug_assertions)] { + let channels = self.channels.read().unwrap(); let nodes = self.nodes.read().unwrap(); let removed_node_counters = self.removed_node_counters.lock().unwrap(); let next_counter = self.next_node_counter.load(Ordering::Acquire); @@ -1609,6 +1629,19 @@ impl NetworkGraph where L::Target: Logger { assert_eq!(used_node_counters[pos] & bit, 0); used_node_counters[pos] |= bit; } + + for (idx, used_bitset) in used_node_counters.iter().enumerate() { + if idx != next_counter / 8 { + assert_eq!(*used_bitset, 0xff); + } else { + assert_eq!(*used_bitset, (1u8 << (next_counter % 8)) - 1); + } + } + + for (_, chan) in channels.unordered_iter() { + assert_eq!(chan.node_one_counter, nodes.get(&chan.node_one).unwrap().node_counter); + assert_eq!(chan.node_two_counter, nodes.get(&chan.node_two).unwrap().node_counter); + } } } @@ -1773,6 +1806,8 @@ impl NetworkGraph where L::Target: Logger { capacity_sats: None, announcement_message: None, announcement_received_time: timestamp, + node_one_counter: u32::max_value(), + node_two_counter: u32::max_value(), }; self.add_channel_between_nodes(short_channel_id, channel_info, None) @@ -1787,7 +1822,7 @@ impl NetworkGraph where L::Target: Logger { log_gossip!(self.logger, "Adding channel {} between nodes {} and {}", short_channel_id, node_id_a, node_id_b); - match channels.entry(short_channel_id) { + let channel_info = match channels.entry(short_channel_id) { IndexedMapEntry::Occupied(mut entry) => { //TODO: because asking the blockchain if short_channel_id is valid is only optional //in the blockchain API, we need to handle it smartly here, though it's unclear @@ -1803,28 +1838,35 @@ impl NetworkGraph where L::Target: Logger { // c) it's unclear how to do so without exposing ourselves to massive DoS risk. self.remove_channel_in_nodes(&mut nodes, &entry.get(), short_channel_id); *entry.get_mut() = channel_info; + entry.into_mut() } else { return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip}); } }, IndexedMapEntry::Vacant(entry) => { - entry.insert(channel_info); + entry.insert(channel_info) } }; - for current_node_id in [node_id_a, node_id_b].iter() { + let mut node_counter_id = [ + (&mut channel_info.node_one_counter, node_id_a), + (&mut channel_info.node_two_counter, node_id_b) + ]; + for (node_counter, current_node_id) in node_counter_id.iter_mut() { match nodes.entry(current_node_id.clone()) { IndexedMapEntry::Occupied(node_entry) => { - node_entry.into_mut().channels.push(short_channel_id); + let node = node_entry.into_mut(); + node.channels.push(short_channel_id); + **node_counter = node.node_counter; }, IndexedMapEntry::Vacant(node_entry) => { let mut removed_node_counters = self.removed_node_counters.lock().unwrap(); - let node_counter = removed_node_counters.pop() + **node_counter = removed_node_counters.pop() .unwrap_or(self.next_node_counter.fetch_add(1, Ordering::Relaxed) as u32); node_entry.insert(NodeInfo { channels: vec!(short_channel_id), announcement_info: None, - node_counter, + node_counter: **node_counter, }); } }; @@ -1915,6 +1957,8 @@ impl NetworkGraph where L::Target: Logger { announcement_message: if msg.excess_data.len() <= MAX_EXCESS_BYTES_FOR_RELAY { full_msg.cloned() } else { None }, announcement_received_time, + node_one_counter: u32::max_value(), + node_two_counter: u32::max_value(), }; self.add_channel_between_nodes(msg.short_channel_id, chan_info, utxo_value)?; @@ -1976,6 +2020,8 @@ impl NetworkGraph where L::Target: Logger { } } removed_channels.insert(*scid, current_time_unix); + } else { + debug_assert!(false, "Channels in nodes must always have channel info"); } } removed_node_counters.push(node.node_counter); @@ -3595,6 +3641,8 @@ pub(crate) mod tests { capacity_sats: None, announcement_message: None, announcement_received_time: 87654, + node_one_counter: 0, + node_two_counter: 1, }; let mut encoded_chan_info: Vec = Vec::new(); @@ -3613,6 +3661,8 @@ pub(crate) mod tests { capacity_sats: None, announcement_message: None, announcement_received_time: 87654, + node_one_counter: 0, + node_two_counter: 1, }; let mut encoded_chan_info: Vec = Vec::new(); -- 2.39.5