Store nodes that timed out in a rolling bloom filter
authorMatt Corallo <git@bluematt.me>
Tue, 13 Jul 2021 13:41:44 +0000 (13:41 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 14 Jul 2021 14:17:44 +0000 (14:17 +0000)
src/bloom.rs [new file with mode: 0644]
src/datastore.rs
src/main.rs
src/printer.rs

diff --git a/src/bloom.rs b/src/bloom.rs
new file mode 100644 (file)
index 0000000..7255876
--- /dev/null
@@ -0,0 +1,143 @@
+use std::collections::hash_map::RandomState;
+use std::hash::{BuildHasher, Hash, Hasher};
+use std::time::{Duration, Instant};
+use std::marker::PhantomData;
+
+// Constants for roughly 1 in 1 million fp with 18m entries
+/// Number of entries in the filter (each 4 bits). 256MiB in total.
+const FILTER_SIZE: usize = 64 * 1024 * 1024 * 8;
+const HASHES: usize = 27;
+const ROLL_COUNT: usize = 1_240_000;
+#[cfg(test)]
+const GENERATION_BITS: usize = 2;
+#[cfg(not(test))]
+const GENERATION_BITS: usize = 4;
+pub const GENERATION_COUNT: usize = (1 << GENERATION_BITS) - 1;
+const ELEMENTS_PER_BYTE: usize = 8 / GENERATION_BITS;
+
+pub struct RollingBloomFilter<T: Hash> {
+       last_roll: Instant,
+       inserted_in_last_generations: [usize; GENERATION_COUNT - 1],
+       inserted_since_last_roll: usize,
+       current_generation: u8,
+       bits: Vec<u8>,
+       hash_keys: [RandomState; HASHES],
+       _entry_type: PhantomData<T>,
+}
+
+impl<T: Hash> RollingBloomFilter<T> {
+       pub fn new() -> Self {
+               let mut bits = Vec::new();
+               bits.resize(FILTER_SIZE * GENERATION_BITS / 8, 0);
+               Self {
+                       last_roll: Instant::now(),
+                       inserted_since_last_roll: 0,
+                       inserted_in_last_generations: [0; GENERATION_COUNT - 1],
+                       current_generation: 1,
+                       bits,
+                       hash_keys: [RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
+                                   RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
+                                   RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
+                                   RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
+                                   RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
+                                   RandomState::new(), RandomState::new()],
+                       _entry_type: PhantomData,
+               }
+       }
+
+       pub fn contains(&self, item: &T) -> bool {
+               for state in self.hash_keys.iter() {
+                       let mut hasher = state.build_hasher();
+                       item.hash(&mut hasher);
+                       let idx = hasher.finish() as usize;
+
+                       let byte = self.bits[(idx / ELEMENTS_PER_BYTE) % (FILTER_SIZE / 8)];
+                       let bits_shift = (idx % ELEMENTS_PER_BYTE) * GENERATION_BITS;
+                       let bits = (byte & ((GENERATION_COUNT as u8) << bits_shift)) >> bits_shift;
+                       if bits == 0 { return false; }
+               }
+               true
+       }
+
+       pub fn get_element_count(&self) -> [usize; GENERATION_COUNT] {
+               let mut res = [0; GENERATION_COUNT];
+               res[0..(GENERATION_COUNT-1)].copy_from_slice(&self.inserted_in_last_generations);
+               *res.last_mut().unwrap() = self.inserted_since_last_roll;
+               res
+       }
+
+       pub fn insert(&mut self, item: &T, roll_duration: Duration) {
+               if Instant::now() - self.last_roll > roll_duration / GENERATION_COUNT as u32 ||
+                  self.inserted_since_last_roll > ROLL_COUNT {
+                       self.current_generation += 1;
+                       if self.current_generation == GENERATION_COUNT as u8 + 1 { self.current_generation = 1; }
+                       let remove_generation = self.current_generation;
+
+                       for idx in 0..FILTER_SIZE {
+                               let byte = &mut self.bits[(idx / ELEMENTS_PER_BYTE) % (FILTER_SIZE / 8)];
+                               let bits_shift = (idx % ELEMENTS_PER_BYTE) * GENERATION_BITS;
+                               let bits = (*byte & ((GENERATION_COUNT as u8) << bits_shift)) >> bits_shift;
+
+                               if bits == remove_generation {
+                                       *byte &= !((GENERATION_COUNT as u8) << bits_shift);
+                               }
+                       }
+                       self.last_roll = Instant::now();
+                       let mut new_generations = [0; GENERATION_COUNT - 1];
+                       new_generations[0..GENERATION_COUNT - 2].copy_from_slice(&self.inserted_in_last_generations[1..]);
+                       new_generations[GENERATION_COUNT - 2] = self.inserted_since_last_roll;
+                       self.inserted_in_last_generations = new_generations;
+                       self.inserted_since_last_roll = 0;
+               }
+
+               for state in self.hash_keys.iter() {
+                       let mut hasher = state.build_hasher();
+                       item.hash(&mut hasher);
+                       let idx = hasher.finish() as usize;
+
+                       let byte = &mut self.bits[(idx / ELEMENTS_PER_BYTE) % (FILTER_SIZE / 8)];
+                       let bits_shift = (idx % ELEMENTS_PER_BYTE) * GENERATION_BITS;
+                       *byte &= !((GENERATION_COUNT as u8) << bits_shift);
+                       *byte |= self.current_generation << bits_shift;
+               }
+               self.inserted_since_last_roll += 1;
+       }
+}
+
+#[test]
+fn test_bloom() {
+       let mut filter = RollingBloomFilter::new();
+       for i in 0..1000 {
+               filter.insert(&i, Duration::from_secs(60 * 60 * 24));
+       }
+       for i in 0..1000 {
+               assert!(filter.contains(&i));
+       }
+       for i in 1000..2000 {
+               assert!(!filter.contains(&i));
+       }
+       assert_eq!(filter.get_element_count(), [0, 0, 1000]);
+       filter.inserted_since_last_roll = ROLL_COUNT + 1;
+       filter.insert(&1000, Duration::from_secs(60 * 60 * 24));
+       assert_eq!(filter.get_element_count(), [0, ROLL_COUNT + 1, 1]);
+       for i in 0..1001 {
+               assert!(filter.contains(&i));
+       }
+       filter.inserted_since_last_roll = ROLL_COUNT + 1;
+       for i in 1001..2000 {
+               filter.insert(&i, Duration::from_secs(60 * 60 * 24));
+       }
+       assert_eq!(filter.get_element_count(), [ROLL_COUNT + 1, ROLL_COUNT + 1, 999]);
+       for i in 0..2000 {
+               assert!(filter.contains(&i));
+       }
+       filter.inserted_since_last_roll = ROLL_COUNT + 1;
+       filter.insert(&2000, Duration::from_secs(60 * 60 * 24));
+       assert_eq!(filter.get_element_count(), [ROLL_COUNT + 1, ROLL_COUNT + 1, 1]);
+       for i in 0..1000 {
+               assert!(!filter.contains(&i));
+       }
+       for i in 1000..2001 {
+               assert!(filter.contains(&i));
+       }
+}
index 04e377be126d61e8be93f3dc558dd6cf7affe0c7..bf0470a10f6c34393915da67bd9edbcb1684b111 100644 (file)
@@ -3,7 +3,7 @@ use std::convert::TryInto;
 use std::collections::{HashSet, HashMap, hash_map};
 use std::sync::{Arc, RwLock};
 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
-use std::time::Instant;
+use std::time::{Duration, Instant};
 use std::io::{BufRead, BufReader};
 
 use bitcoin::network::address::{Address, AddrV2Message};
@@ -17,6 +17,7 @@ use tokio::io::write_all;
 
 use regex::Regex;
 
+use crate::bloom::RollingBloomFilter;
 use crate::bgp_client::BGPClient;
 
 pub const SECS_PER_SCAN_RESULTS: u64 = 15;
@@ -203,11 +204,13 @@ impl SockAddr {
 struct Nodes {
        good_node_services: [HashSet<SockAddr>; 64],
        nodes_to_state: HashMap<SockAddr, Node>,
+       timeout_nodes: RollingBloomFilter<SockAddr>,
        state_next_scan: [Vec<SockAddr>; AddressState::get_count() as usize],
 }
 struct NodesMutRef<'a> {
        good_node_services: &'a mut [HashSet<SockAddr>; 64],
        nodes_to_state: &'a mut HashMap<SockAddr, Node>,
+       timeout_nodes: &'a mut RollingBloomFilter<SockAddr>,
        state_next_scan: &'a mut [Vec<SockAddr>; AddressState::get_count() as usize],
 }
 
@@ -216,6 +219,7 @@ impl Nodes {
                NodesMutRef {
                        good_node_services: &mut self.good_node_services,
                        nodes_to_state: &mut self.nodes_to_state,
+                       timeout_nodes: &mut self.timeout_nodes,
                        state_next_scan: &mut self.state_next_scan,
                }
        }
@@ -297,6 +301,7 @@ impl Store {
                                Nodes {
                                        good_node_services,
                                        nodes_to_state: HashMap::new(),
+                                       timeout_nodes: RollingBloomFilter::new(),
                                        state_next_scan: state_vecs,
                                }
                        } }
@@ -370,6 +375,9 @@ impl Store {
        pub fn get_node_count(&self, state: AddressState) -> usize {
                self.nodes.read().unwrap().state_next_scan[state.to_num() as usize].len()
        }
+       pub fn get_bloom_node_count(&self) -> [usize; crate::bloom::GENERATION_COUNT] {
+               self.nodes.read().unwrap().timeout_nodes.get_element_count()
+       }
 
        pub fn get_regex(&self, _setting: RegexSetting) -> Arc<Regex> {
                Arc::clone(&*self.subver_regex.read().unwrap())
@@ -426,7 +434,27 @@ impl Store {
                let mut nodes_lock = self.nodes.write().unwrap();
                let nodes = nodes_lock.borrow_mut();
 
-               let state_ref = nodes.nodes_to_state.entry(addr.clone()).or_insert(Node {
+               let node_entry = nodes.nodes_to_state.entry(addr.clone());
+               match node_entry {
+                       hash_map::Entry::Occupied(entry)
+                                       if entry.get().state == AddressState::Untested &&
+                                          entry.get().last_services() == 0 &&
+                                          state == AddressState::Timeout => {
+                               entry.remove_entry();
+                               nodes.timeout_nodes.insert(&addr, Duration::from_secs(self.get_u64(U64Setting::RescanInterval(AddressState::Timeout))));
+                               return AddressState::Untested;
+                       },
+                       hash_map::Entry::Vacant(_) if state == AddressState::Timeout => {
+                               nodes.timeout_nodes.insert(&addr, Duration::from_secs(self.get_u64(U64Setting::RescanInterval(AddressState::Timeout))));
+                               return AddressState::Untested;
+                       },
+                       hash_map::Entry::Vacant(_) if nodes.timeout_nodes.contains(&addr) => {
+                               return AddressState::Timeout;
+                       },
+                       _ => {},
+               }
+
+               let state_ref = node_entry.or_insert(Node {
                        state: AddressState::Untested,
                        last_services: (0, 0),
                        last_good: now,
index 288cdd24d229168e09ee270e9fdb988b430cf23a..9c6e68155f584252805a15240e044ca8b13b06c4 100644 (file)
@@ -1,3 +1,4 @@
+mod bloom;
 mod printer;
 mod reader;
 mod peer;
index e16bfbb42bb644e3dff324bb10333453261f6b0f..c20b925e7f5f83d14316ca9d4a31926dc4c5e9dd 100644 (file)
@@ -65,9 +65,14 @@ impl Printer {
                                                                store.get_node_count(AddressState::from_num(i).unwrap())
                                                                ).as_bytes()).expect("stdout broken?");
                                        }
+                                       let generations = store.get_bloom_node_count();
+                                       out.write_all(b"Bloom filter generations contain:").expect("stdout broken?");
+                                       for generation in &generations {
+                                               out.write_all(format!(" {}", generation).as_bytes()).expect("stdout broken?");
+                                       }
 
                                        out.write_all(format!(
-                                                       "\nCurrent connections open/in progress: {}\n", stats.connection_count).as_bytes()).expect("stdout broken?");
+                                                       "\n\nCurrent connections open/in progress: {}\n", stats.connection_count).as_bytes()).expect("stdout broken?");
                                        out.write_all(format!(
                                                        "Current block count: {}\n", stats.header_count).as_bytes()).expect("stdout broken?");