From: Alec Chen Date: Thu, 16 Feb 2023 22:34:06 +0000 (-0600) Subject: Replace `BTreeSet` in `IndexedMap` with sorted `Vec` X-Git-Tag: v0.0.114-beta~12^2 X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=commitdiff_plain;h=62a88f97de725c366665a34c8183747024970fa6;p=rust-lightning Replace `BTreeSet` in `IndexedMap` with sorted `Vec` The `Vec` is sorted not on `IndexedMap::insert`, but on `IndexedMap::range` to avoid unnecessary work while reading a network graph. --- diff --git a/fuzz/src/indexedmap.rs b/fuzz/src/indexedmap.rs index 795d6175b..7cbb8957a 100644 --- a/fuzz/src/indexedmap.rs +++ b/fuzz/src/indexedmap.rs @@ -13,14 +13,27 @@ use hashbrown::HashSet; use crate::utils::test_logger; -fn check_eq(btree: &BTreeMap, indexed: &IndexedMap) { +use std::ops::{RangeBounds, Bound}; + +struct ExclLowerInclUpper(u8, u8); +impl RangeBounds for ExclLowerInclUpper { + fn start_bound(&self) -> Bound<&u8> { Bound::Excluded(&self.0) } + fn end_bound(&self) -> Bound<&u8> { Bound::Included(&self.1) } +} +struct ExclLowerExclUpper(u8, u8); +impl RangeBounds for ExclLowerExclUpper { + fn start_bound(&self) -> Bound<&u8> { Bound::Excluded(&self.0) } + fn end_bound(&self) -> Bound<&u8> { Bound::Excluded(&self.1) } +} + +fn check_eq(btree: &BTreeMap, mut indexed: IndexedMap) { assert_eq!(btree.len(), indexed.len()); assert_eq!(btree.is_empty(), indexed.is_empty()); let mut btree_clone = btree.clone(); assert!(btree_clone == *btree); let mut indexed_clone = indexed.clone(); - assert!(indexed_clone == *indexed); + assert!(indexed_clone == indexed); for k in 0..=255 { assert_eq!(btree.contains_key(&k), indexed.contains_key(&k)); @@ -43,16 +56,27 @@ fn check_eq(btree: &BTreeMap, indexed: &IndexedMap) { } const STRIDE: u8 = 16; - for k in 0..=255/STRIDE { - let lower_bound = k * STRIDE; - let upper_bound = lower_bound + (STRIDE - 1); - let mut btree_iter = btree.range(lower_bound..=upper_bound); - let mut indexed_iter = indexed.range(lower_bound..=upper_bound); - loop { - let b_v = btree_iter.next(); - let i_v = indexed_iter.next(); - assert_eq!(b_v, i_v); - if b_v.is_none() { break; } + for range_type in 0..4 { + for k in 0..=255/STRIDE { + let lower_bound = k * STRIDE; + let upper_bound = lower_bound + (STRIDE - 1); + macro_rules! range { ($map: expr) => { + match range_type { + 0 => $map.range(lower_bound..upper_bound), + 1 => $map.range(lower_bound..=upper_bound), + 2 => $map.range(ExclLowerInclUpper(lower_bound, upper_bound)), + 3 => $map.range(ExclLowerExclUpper(lower_bound, upper_bound)), + _ => unreachable!(), + } + } } + let mut btree_iter = range!(btree); + let mut indexed_iter = range!(indexed); + loop { + let b_v = btree_iter.next(); + let i_v = indexed_iter.next(); + assert_eq!(b_v, i_v); + if b_v.is_none() { break; } + } } } @@ -91,7 +115,7 @@ pub fn do_test(data: &[u8]) { let prev_value_i = indexed.insert(tuple[0], tuple[1]); assert_eq!(prev_value_b, prev_value_i); } - check_eq(&btree, &indexed); + check_eq(&btree, indexed.clone()); // Now, modify the maps in all the ways we have to do so, checking that the maps remain // equivalent as we go. @@ -99,7 +123,7 @@ pub fn do_test(data: &[u8]) { *v = *k; *btree.get_mut(k).unwrap() = *k; } - check_eq(&btree, &indexed); + check_eq(&btree, indexed.clone()); for k in 0..=255 { match btree.entry(k) { @@ -124,7 +148,7 @@ pub fn do_test(data: &[u8]) { }, } } - check_eq(&btree, &indexed); + check_eq(&btree, indexed); } pub fn indexedmap_test(data: &[u8], _out: Out) { diff --git a/lightning/src/routing/gossip.rs b/lightning/src/routing/gossip.rs index 6c2d70bd6..82d2cd4cb 100644 --- a/lightning/src/routing/gossip.rs +++ b/lightning/src/routing/gossip.rs @@ -390,7 +390,7 @@ where U::Target: UtxoLookup, L::Target: Logger } fn get_next_channel_announcement(&self, starting_point: u64) -> Option<(ChannelAnnouncement, Option, Option)> { - let channels = self.network_graph.channels.read().unwrap(); + let mut channels = self.network_graph.channels.write().unwrap(); for (_, ref chan) in channels.range(starting_point..) { if chan.announcement_message.is_some() { let chan_announcement = chan.announcement_message.clone().unwrap(); @@ -412,7 +412,7 @@ where U::Target: UtxoLookup, L::Target: Logger } fn get_next_node_announcement(&self, starting_point: Option<&NodeId>) -> Option { - let nodes = self.network_graph.nodes.read().unwrap(); + let mut nodes = self.network_graph.nodes.write().unwrap(); let iter = if let Some(node_id) = starting_point { nodes.range((Bound::Excluded(node_id), Bound::Unbounded)) } else { @@ -572,7 +572,7 @@ where U::Target: UtxoLookup, L::Target: Logger // (has at least one update). A peer may still want to know the channel // exists even if its not yet routable. let mut batches: Vec> = vec![Vec::with_capacity(MAX_SCIDS_PER_REPLY)]; - let channels = self.network_graph.channels.read().unwrap(); + let mut channels = self.network_graph.channels.write().unwrap(); for (_, ref chan) in channels.range(inclusive_start_scid.unwrap()..exclusive_end_scid.unwrap()) { if let Some(chan_announcement) = &chan.announcement_message { // Construct a new batch if last one is full diff --git a/lightning/src/util/indexed_map.rs b/lightning/src/util/indexed_map.rs index cccbfe7bc..3d4517251 100644 --- a/lightning/src/util/indexed_map.rs +++ b/lightning/src/util/indexed_map.rs @@ -1,10 +1,11 @@ //! This module has a map which can be iterated in a deterministic order. See the [`IndexedMap`]. use crate::prelude::{HashMap, hash_map}; -use alloc::collections::{BTreeSet, btree_set}; +use alloc::vec::Vec; +use alloc::slice::Iter; use core::hash::Hash; use core::cmp::Ord; -use core::ops::RangeBounds; +use core::ops::{Bound, RangeBounds}; /// A map which can be iterated in a deterministic order. /// @@ -21,11 +22,10 @@ use core::ops::RangeBounds; /// keys in the order defined by [`Ord`]. /// /// [`BTreeMap`]: alloc::collections::BTreeMap -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, Eq)] pub struct IndexedMap { map: HashMap, - // TODO: Explore swapping this for a sorted vec (that is only sorted on first range() call) - keys: BTreeSet, + keys: Vec, } impl IndexedMap { @@ -33,7 +33,7 @@ impl IndexedMap { pub fn new() -> Self { Self { map: HashMap::new(), - keys: BTreeSet::new(), + keys: Vec::new(), } } @@ -58,7 +58,8 @@ impl IndexedMap { pub fn remove(&mut self, key: &K) -> Option { let ret = self.map.remove(key); if let Some(_) = ret { - assert!(self.keys.remove(key), "map and keys must be consistent"); + let idx = self.keys.iter().position(|k| k == key).expect("map and keys must be consistent"); + self.keys.remove(idx); } ret } @@ -68,7 +69,7 @@ impl IndexedMap { pub fn insert(&mut self, key: K, value: V) -> Option { let ret = self.map.insert(key.clone(), value); if ret.is_none() { - assert!(self.keys.insert(key), "map and keys must be consistent"); + self.keys.push(key); } ret } @@ -109,9 +110,21 @@ impl IndexedMap { } /// Returns an iterator which iterates over the `key`/`value` pairs in a given range. - pub fn range>(&self, range: R) -> Range { + pub fn range>(&mut self, range: R) -> Range { + self.keys.sort_unstable(); + let start = match range.start_bound() { + Bound::Unbounded => 0, + Bound::Included(key) => self.keys.binary_search(key).unwrap_or_else(|index| index), + Bound::Excluded(key) => self.keys.binary_search(key).and_then(|index| Ok(index + 1)).unwrap_or_else(|index| index), + }; + let end = match range.end_bound() { + Bound::Unbounded => self.keys.len(), + Bound::Included(key) => self.keys.binary_search(key).and_then(|index| Ok(index + 1)).unwrap_or_else(|index| index), + Bound::Excluded(key) => self.keys.binary_search(key).unwrap_or_else(|index| index), + }; + Range { - inner_range: self.keys.range(range), + inner_range: self.keys[start..end].iter(), map: &self.map, } } @@ -127,9 +140,15 @@ impl IndexedMap { } } +impl PartialEq for IndexedMap { + fn eq(&self, other: &Self) -> bool { + self.map == other.map + } +} + /// An iterator over a range of values in an [`IndexedMap`] pub struct Range<'a, K: Hash + Ord, V> { - inner_range: btree_set::Range<'a, K>, + inner_range: Iter<'a, K>, map: &'a HashMap, } impl<'a, K: Hash + Ord, V: 'a> Iterator for Range<'a, K, V> { @@ -148,7 +167,7 @@ pub struct VacantEntry<'a, K: Hash + Ord, V> { #[cfg(not(feature = "hashbrown"))] underlying_entry: hash_map::VacantEntry<'a, K, V>, key: K, - keys: &'a mut BTreeSet, + keys: &'a mut Vec, } /// An [`Entry`] for an existing key-value pair @@ -157,7 +176,7 @@ pub struct OccupiedEntry<'a, K: Hash + Ord, V> { underlying_entry: hash_map::OccupiedEntry<'a, K, V, hash_map::DefaultHashBuilder>, #[cfg(not(feature = "hashbrown"))] underlying_entry: hash_map::OccupiedEntry<'a, K, V>, - keys: &'a mut BTreeSet, + keys: &'a mut Vec, } /// A mutable reference to a position in the map. This can be used to reference, add, or update the @@ -172,7 +191,7 @@ pub enum Entry<'a, K: Hash + Ord, V> { impl<'a, K: Hash + Ord, V> VacantEntry<'a, K, V> { /// Insert a value into the position described by this entry. pub fn insert(self, value: V) -> &'a mut V { - assert!(self.keys.insert(self.key), "map and keys must be consistent"); + self.keys.push(self.key); self.underlying_entry.insert(value) } } @@ -181,7 +200,8 @@ impl<'a, K: Hash + Ord, V> OccupiedEntry<'a, K, V> { /// Remove the value at the position described by this entry. pub fn remove_entry(self) -> (K, V) { let res = self.underlying_entry.remove_entry(); - assert!(self.keys.remove(&res.0), "map and keys must be consistent"); + let idx = self.keys.iter().position(|k| k == &res.0).expect("map and keys must be consistent"); + self.keys.remove(idx); res }