Manually unroll bloom clearing loop since LLVM doesn't do it
[dnsseed-rust] / src / bloom.rs
1 use std::collections::hash_map::RandomState;
2 use std::hash::{BuildHasher, Hash, Hasher};
3 use std::time::{Duration, Instant};
4 use std::marker::PhantomData;
5
6 // Constants for roughly 1 in 1 million fp with 18m entries
7 /// Number of entries in the filter (each 4 bits). 256MiB in total.
8 const FILTER_SIZE: usize = 64 * 1024 * 1024 * 8;
9 const HASHES: usize = 27;
10 const ROLL_COUNT: usize = 1_240_000;
11 #[cfg(test)]
12 const GENERATION_BITS: usize = 2;
13 #[cfg(not(test))]
14 const GENERATION_BITS: usize = 4;
15 pub const GENERATION_COUNT: usize = (1 << GENERATION_BITS) - 1;
16 const ELEMENTS_PER_VAR: usize = 64 / GENERATION_BITS;
17
18 pub struct RollingBloomFilter<T: Hash> {
19         last_roll: Instant,
20         inserted_in_last_generations: [usize; GENERATION_COUNT - 1],
21         inserted_since_last_roll: usize,
22         current_generation: u8,
23         bits: Vec<u64>,
24         hash_keys: [RandomState; HASHES],
25         _entry_type: PhantomData<T>,
26 }
27
28 impl<T: Hash> RollingBloomFilter<T> {
29         pub fn new() -> Self {
30                 let mut bits = Vec::new();
31                 bits.resize(FILTER_SIZE * GENERATION_BITS / 64, 0);
32                 Self {
33                         last_roll: Instant::now(),
34                         inserted_since_last_roll: 0,
35                         inserted_in_last_generations: [0; GENERATION_COUNT - 1],
36                         current_generation: 1,
37                         bits,
38                         hash_keys: [RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
39                                     RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
40                                     RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
41                                     RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
42                                     RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
43                                     RandomState::new(), RandomState::new()],
44                         _entry_type: PhantomData,
45                 }
46         }
47
48         pub fn contains(&self, item: &T) -> bool {
49                 for state in self.hash_keys.iter() {
50                         let mut hasher = state.build_hasher();
51                         item.hash(&mut hasher);
52                         let idx = hasher.finish() as usize;
53
54                         let byte = self.bits[(idx / ELEMENTS_PER_VAR) % (FILTER_SIZE / 64)];
55                         let bits_shift = (idx % ELEMENTS_PER_VAR) * GENERATION_BITS;
56                         let bits = (byte & ((GENERATION_COUNT as u64) << bits_shift)) >> bits_shift;
57                         if bits == 0 { return false; }
58                 }
59                 true
60         }
61
62         pub fn get_element_count(&self) -> [usize; GENERATION_COUNT] {
63                 let mut res = [0; GENERATION_COUNT];
64                 res[0..(GENERATION_COUNT-1)].copy_from_slice(&self.inserted_in_last_generations);
65                 *res.last_mut().unwrap() = self.inserted_since_last_roll;
66                 res
67         }
68
69         pub fn insert(&mut self, item: &T, roll_duration: Duration) {
70                 if Instant::now() - self.last_roll > roll_duration / GENERATION_COUNT as u32 ||
71                    self.inserted_since_last_roll > ROLL_COUNT {
72                         self.current_generation += 1;
73                         if self.current_generation == GENERATION_COUNT as u8 + 1 { self.current_generation = 1; }
74                         let remove_generation = self.current_generation;
75
76                         for idx in 0..(FILTER_SIZE / ELEMENTS_PER_VAR) {
77                                 let mut var = self.bits[idx];
78                                 for i in 0..ELEMENTS_PER_VAR {
79                                         let bits_shift = i * GENERATION_BITS;
80                                         let bits = (var & ((GENERATION_COUNT as u64) << bits_shift)) >> bits_shift;
81
82                                         if bits == remove_generation as u64 {
83                                                 var &= !((GENERATION_COUNT as u64) << bits_shift);
84                                         }
85                                 }
86                                 self.bits[idx] = var;
87                         }
88                         self.last_roll = Instant::now();
89                         let mut new_generations = [0; GENERATION_COUNT - 1];
90                         new_generations[0..GENERATION_COUNT - 2].copy_from_slice(&self.inserted_in_last_generations[1..]);
91                         new_generations[GENERATION_COUNT - 2] = self.inserted_since_last_roll;
92                         self.inserted_in_last_generations = new_generations;
93                         self.inserted_since_last_roll = 0;
94                 }
95
96                 for state in self.hash_keys.iter() {
97                         let mut hasher = state.build_hasher();
98                         item.hash(&mut hasher);
99                         let idx = hasher.finish() as usize;
100
101                         let byte = &mut self.bits[(idx / ELEMENTS_PER_VAR) % (FILTER_SIZE / 64)];
102                         let bits_shift = (idx % ELEMENTS_PER_VAR) * GENERATION_BITS;
103                         *byte &= !((GENERATION_COUNT as u64) << bits_shift);
104                         *byte |= (self.current_generation as u64) << bits_shift;
105                 }
106                 self.inserted_since_last_roll += 1;
107         }
108 }
109
110 #[test]
111 fn test_bloom() {
112         let mut filter = RollingBloomFilter::new();
113         for i in 0..1000 {
114                 filter.insert(&i, Duration::from_secs(60 * 60 * 24));
115         }
116         for i in 0..1000 {
117                 assert!(filter.contains(&i));
118         }
119         for i in 1000..2000 {
120                 assert!(!filter.contains(&i));
121         }
122         assert_eq!(filter.get_element_count(), [0, 0, 1000]);
123         filter.inserted_since_last_roll = ROLL_COUNT + 1;
124         filter.insert(&1000, Duration::from_secs(60 * 60 * 24));
125         assert_eq!(filter.get_element_count(), [0, ROLL_COUNT + 1, 1]);
126         for i in 0..1001 {
127                 assert!(filter.contains(&i));
128         }
129         filter.inserted_since_last_roll = ROLL_COUNT + 1;
130         for i in 1001..2000 {
131                 filter.insert(&i, Duration::from_secs(60 * 60 * 24));
132         }
133         assert_eq!(filter.get_element_count(), [ROLL_COUNT + 1, ROLL_COUNT + 1, 999]);
134         for i in 0..2000 {
135                 assert!(filter.contains(&i));
136         }
137         filter.inserted_since_last_roll = ROLL_COUNT + 1;
138         filter.insert(&2000, Duration::from_secs(60 * 60 * 24));
139         assert_eq!(filter.get_element_count(), [ROLL_COUNT + 1, ROLL_COUNT + 1, 1]);
140         for i in 0..1000 {
141                 assert!(!filter.contains(&i));
142         }
143         for i in 1000..2001 {
144                 assert!(filter.contains(&i));
145         }
146 }