Parallelize rolling filter checks/updates, reduce hashes
authorMatt Corallo <git@bluematt.me>
Tue, 27 Jul 2021 20:41:24 +0000 (20:41 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 27 Jul 2021 20:41:24 +0000 (20:41 +0000)
src/bloom.rs
src/datastore.rs

index 66a40a651ffc119c8a867d6bc80406ae62decc7c..cfd9cda5adabbfac2888d284cd8e0d61c081ca73 100644 (file)
@@ -2,12 +2,13 @@ use std::collections::hash_map::RandomState;
 use std::hash::{BuildHasher, Hash, Hasher};
 use std::time::{Duration, Instant};
 use std::marker::PhantomData;
+use std::sync::RwLock;
 
-// Constants for roughly 1 in 1 million fp with 18m entries
+// Constants for roughly 1 in 250k fp with 20m entries
 /// Number of entries in the filter (each 4 bits). 256MiB in total.
 const FILTER_SIZE: usize = 64 * 1024 * 1024 * 8;
-const HASHES: usize = 27;
-const ROLL_COUNT: usize = 1_240_000;
+const HASHES: usize = 18;
+const ROLL_COUNT: usize = 1_370_000;
 #[cfg(test)]
 const GENERATION_BITS: usize = 2;
 #[cfg(not(test))]
@@ -15,12 +16,16 @@ const GENERATION_BITS: usize = 4;
 pub const GENERATION_COUNT: usize = (1 << GENERATION_BITS) - 1;
 const ELEMENTS_PER_VAR: usize = 64 / GENERATION_BITS;
 
-pub struct RollingBloomFilter<T: Hash> {
+struct FilterState {
        last_roll: Instant,
        inserted_in_last_generations: [usize; GENERATION_COUNT - 1],
        inserted_since_last_roll: usize,
        current_generation: u8,
        bits: Vec<u64>,
+}
+
+pub struct RollingBloomFilter<T: Hash> {
+       state: RwLock<FilterState>,
        hash_keys: [RandomState; HASHES],
        _entry_type: PhantomData<T>,
 }
@@ -30,28 +35,34 @@ impl<T: Hash> RollingBloomFilter<T> {
                let mut bits = Vec::new();
                bits.resize(FILTER_SIZE * GENERATION_BITS / 64, 0);
                Self {
-                       last_roll: Instant::now(),
-                       inserted_since_last_roll: 0,
-                       inserted_in_last_generations: [0; GENERATION_COUNT - 1],
-                       current_generation: 1,
-                       bits,
+                       state: RwLock::new(FilterState {
+                               last_roll: Instant::now(),
+                               inserted_since_last_roll: 0,
+                               inserted_in_last_generations: [0; GENERATION_COUNT - 1],
+                               current_generation: 1,
+                               bits,
+                       }),
                        hash_keys: [RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
                                    RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
                                    RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
-                                   RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
-                                   RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
-                                   RandomState::new(), RandomState::new()],
+                                   RandomState::new(), RandomState::new(), RandomState::new()],
                        _entry_type: PhantomData,
                }
        }
 
        pub fn contains(&self, item: &T) -> bool {
-               for state in self.hash_keys.iter() {
+               let mut hashes = [None; HASHES];
+               for (idx, state) in self.hash_keys.iter().enumerate() {
                        let mut hasher = state.build_hasher();
                        item.hash(&mut hasher);
-                       let idx = hasher.finish() as usize;
+                       hashes[idx] = Some(hasher.finish() as usize);
+               }
 
-                       let byte = self.bits[(idx / ELEMENTS_PER_VAR) % (FILTER_SIZE / 64)];
+               let state = self.state.read().unwrap();
+               for idx_opt in hashes.iter() {
+                       let idx = idx_opt.unwrap();
+
+                       let byte = state.bits[(idx / ELEMENTS_PER_VAR) % (FILTER_SIZE / 64)];
                        let bits_shift = (idx % ELEMENTS_PER_VAR) * GENERATION_BITS;
                        let bits = (byte & ((GENERATION_COUNT as u64) << bits_shift)) >> bits_shift;
                        if bits == 0 { return false; }
@@ -61,20 +72,29 @@ impl<T: Hash> RollingBloomFilter<T> {
 
        pub fn get_element_count(&self) -> [usize; GENERATION_COUNT] {
                let mut res = [0; GENERATION_COUNT];
-               res[0..(GENERATION_COUNT-1)].copy_from_slice(&self.inserted_in_last_generations);
-               *res.last_mut().unwrap() = self.inserted_since_last_roll;
+               let state = self.state.read().unwrap();
+               res[0..(GENERATION_COUNT-1)].copy_from_slice(&state.inserted_in_last_generations);
+               *res.last_mut().unwrap() = state.inserted_since_last_roll;
                res
        }
 
-       pub fn insert(&mut self, item: &T, roll_duration: Duration) {
-               if Instant::now() - self.last_roll > roll_duration / GENERATION_COUNT as u32 ||
-                  self.inserted_since_last_roll > ROLL_COUNT {
-                       self.current_generation += 1;
-                       if self.current_generation == GENERATION_COUNT as u8 + 1 { self.current_generation = 1; }
-                       let remove_generation = self.current_generation;
+       pub fn insert(&self, item: &T, roll_duration: Duration) {
+               let mut hashes = [None; HASHES];
+               for (idx, state) in self.hash_keys.iter().enumerate() {
+                       let mut hasher = state.build_hasher();
+                       item.hash(&mut hasher);
+                       hashes[idx] = Some(hasher.finish() as usize);
+               }
+
+               let mut state = self.state.write().unwrap();
+               if Instant::now() - state.last_roll > roll_duration / GENERATION_COUNT as u32 ||
+                  state.inserted_since_last_roll > ROLL_COUNT {
+                       state.current_generation += 1;
+                       if state.current_generation == GENERATION_COUNT as u8 + 1 { state.current_generation = 1; }
+                       let remove_generation = state.current_generation;
 
                        for idx in 0..(FILTER_SIZE / ELEMENTS_PER_VAR) {
-                               let mut var = self.bits[idx];
+                               let mut var = state.bits[idx];
                                for i in 0..ELEMENTS_PER_VAR {
                                        let bits_shift = i * GENERATION_BITS;
                                        let bits = (var & ((GENERATION_COUNT as u64) << bits_shift)) >> bits_shift;
@@ -83,33 +103,32 @@ impl<T: Hash> RollingBloomFilter<T> {
                                                var &= !((GENERATION_COUNT as u64) << bits_shift);
                                        }
                                }
-                               self.bits[idx] = var;
+                               state.bits[idx] = var;
                        }
-                       self.last_roll = Instant::now();
+                       state.last_roll = Instant::now();
                        let mut new_generations = [0; GENERATION_COUNT - 1];
-                       new_generations[0..GENERATION_COUNT - 2].copy_from_slice(&self.inserted_in_last_generations[1..]);
-                       new_generations[GENERATION_COUNT - 2] = self.inserted_since_last_roll;
-                       self.inserted_in_last_generations = new_generations;
-                       self.inserted_since_last_roll = 0;
+                       new_generations[0..GENERATION_COUNT - 2].copy_from_slice(&state.inserted_in_last_generations[1..]);
+                       new_generations[GENERATION_COUNT - 2] = state.inserted_since_last_roll;
+                       state.inserted_in_last_generations = new_generations;
+                       state.inserted_since_last_roll = 0;
                }
 
-               for state in self.hash_keys.iter() {
-                       let mut hasher = state.build_hasher();
-                       item.hash(&mut hasher);
-                       let idx = hasher.finish() as usize;
+               let generation = state.current_generation;
+               for idx_opt in hashes.iter() {
+                       let idx = idx_opt.unwrap();
 
-                       let byte = &mut self.bits[(idx / ELEMENTS_PER_VAR) % (FILTER_SIZE / 64)];
+                       let byte = &mut state.bits[(idx / ELEMENTS_PER_VAR) % (FILTER_SIZE / 64)];
                        let bits_shift = (idx % ELEMENTS_PER_VAR) * GENERATION_BITS;
                        *byte &= !((GENERATION_COUNT as u64) << bits_shift);
-                       *byte |= (self.current_generation as u64) << bits_shift;
+                       *byte |= (generation as u64) << bits_shift;
                }
-               self.inserted_since_last_roll += 1;
+               state.inserted_since_last_roll += 1;
        }
 }
 
 #[test]
 fn test_bloom() {
-       let mut filter = RollingBloomFilter::new();
+       let filter = RollingBloomFilter::new();
        for i in 0..1000 {
                filter.insert(&i, Duration::from_secs(60 * 60 * 24));
        }
@@ -120,13 +139,13 @@ fn test_bloom() {
                assert!(!filter.contains(&i));
        }
        assert_eq!(filter.get_element_count(), [0, 0, 1000]);
-       filter.inserted_since_last_roll = ROLL_COUNT + 1;
+       filter.state.write().unwrap().inserted_since_last_roll = ROLL_COUNT + 1;
        filter.insert(&1000, Duration::from_secs(60 * 60 * 24));
        assert_eq!(filter.get_element_count(), [0, ROLL_COUNT + 1, 1]);
        for i in 0..1001 {
                assert!(filter.contains(&i));
        }
-       filter.inserted_since_last_roll = ROLL_COUNT + 1;
+       filter.state.write().unwrap().inserted_since_last_roll = ROLL_COUNT + 1;
        for i in 1001..2000 {
                filter.insert(&i, Duration::from_secs(60 * 60 * 24));
        }
@@ -134,7 +153,7 @@ fn test_bloom() {
        for i in 0..2000 {
                assert!(filter.contains(&i));
        }
-       filter.inserted_since_last_roll = ROLL_COUNT + 1;
+       filter.state.write().unwrap().inserted_since_last_roll = ROLL_COUNT + 1;
        filter.insert(&2000, Duration::from_secs(60 * 60 * 24));
        assert_eq!(filter.get_element_count(), [ROLL_COUNT + 1, ROLL_COUNT + 1, 1]);
        for i in 0..1000 {
index b4c40fd8957e7b15c53f33f3a4712e06d99baca1..a078f49cae9c2d9f6959bc639005e651074a84b3 100644 (file)
@@ -204,13 +204,11 @@ impl SockAddr {
 struct Nodes {
        good_node_services: [HashSet<SockAddr>; 64],
        nodes_to_state: HashMap<SockAddr, Node>,
-       timeout_nodes: RollingBloomFilter<SockAddr>,
        state_next_scan: [Vec<SockAddr>; AddressState::get_count() as usize],
 }
 struct NodesMutRef<'a> {
        good_node_services: &'a mut [HashSet<SockAddr>; 64],
        nodes_to_state: &'a mut HashMap<SockAddr, Node>,
-       timeout_nodes: &'a mut RollingBloomFilter<SockAddr>,
        state_next_scan: &'a mut [Vec<SockAddr>; AddressState::get_count() as usize],
 }
 
@@ -219,7 +217,6 @@ impl Nodes {
                NodesMutRef {
                        good_node_services: &mut self.good_node_services,
                        nodes_to_state: &mut self.nodes_to_state,
-                       timeout_nodes: &mut self.timeout_nodes,
                        state_next_scan: &mut self.state_next_scan,
                }
        }
@@ -229,6 +226,7 @@ pub struct Store {
        u64_settings: RwLock<HashMap<U64Setting, u64>>,
        subver_regex: RwLock<Arc<Regex>>,
        nodes: RwLock<Nodes>,
+       timeout_nodes: RollingBloomFilter<SockAddr>,
        start_time: Instant,
        store: String,
 }
@@ -301,7 +299,6 @@ impl Store {
                                Nodes {
                                        good_node_services,
                                        nodes_to_state: HashMap::new(),
-                                       timeout_nodes: RollingBloomFilter::new(),
                                        state_next_scan: state_vecs,
                                }
                        } }
@@ -358,6 +355,7 @@ impl Store {
                                u64_settings: RwLock::new(u64_settings),
                                subver_regex: RwLock::new(Arc::new(regex)),
                                nodes: RwLock::new(nodes),
+                               timeout_nodes: RollingBloomFilter::new(),
                                store,
                                start_time: Instant::now(),
                        })
@@ -376,7 +374,7 @@ impl Store {
                self.nodes.read().unwrap().state_next_scan[state.to_num() as usize].len()
        }
        pub fn get_bloom_node_count(&self) -> [usize; crate::bloom::GENERATION_COUNT] {
-               self.nodes.read().unwrap().timeout_nodes.get_element_count()
+               self.timeout_nodes.get_element_count()
        }
 
        pub fn get_regex(&self, _setting: RegexSetting) -> Arc<Regex> {
@@ -429,6 +427,10 @@ impl Store {
        pub fn set_node_state(&self, sockaddr: SocketAddr, state: AddressState, services: u64) -> AddressState {
                let addr: SockAddr = sockaddr.into();
 
+               if state == AddressState::Untested && self.timeout_nodes.contains(&addr) {
+                       return AddressState::Timeout;
+               }
+
                let now = (Instant::now() - self.start_time).as_secs().try_into().unwrap();
 
                let mut nodes_lock = self.nodes.write().unwrap();
@@ -441,16 +443,13 @@ impl Store {
                                           entry.get().last_services() == 0 &&
                                           state == AddressState::Timeout => {
                                entry.remove_entry();
-                               nodes.timeout_nodes.insert(&addr, Duration::from_secs(self.get_u64(U64Setting::RescanInterval(AddressState::Timeout))));
+                               self.timeout_nodes.insert(&addr, Duration::from_secs(self.get_u64(U64Setting::RescanInterval(AddressState::Timeout))));
                                return AddressState::Untested;
                        },
                        hash_map::Entry::Vacant(_) if state == AddressState::Timeout => {
-                               nodes.timeout_nodes.insert(&addr, Duration::from_secs(self.get_u64(U64Setting::RescanInterval(AddressState::Timeout))));
+                               self.timeout_nodes.insert(&addr, Duration::from_secs(self.get_u64(U64Setting::RescanInterval(AddressState::Timeout))));
                                return AddressState::Untested;
                        },
-                       hash_map::Entry::Vacant(_) if nodes.timeout_nodes.contains(&addr) => {
-                               return AddressState::Timeout;
-                       },
                        _ => {},
                }