Drop memory limit. It was useful to debug OOMs but is now unnecessary
[dnsseed-rust] / src / bloom.rs
1 use std::collections::hash_map::RandomState;
2 use std::hash::{BuildHasher, Hash, Hasher};
3 use std::time::{Duration, Instant};
4 use std::marker::PhantomData;
5 use std::sync::RwLock;
6
7 // Constants for roughly 1 in 250k fp with 20m entries
8 /// Number of entries in the filter (each 4 bits). 256MiB in total.
9 const FILTER_SIZE: usize = 64 * 1024 * 1024 * 8;
10 const HASHES: usize = 18;
11 const ROLL_COUNT: usize = 1_370_000;
12 #[cfg(test)]
13 const GENERATION_BITS: usize = 2;
14 #[cfg(not(test))]
15 const GENERATION_BITS: usize = 4;
16 pub const GENERATION_COUNT: usize = (1 << GENERATION_BITS) - 1;
17 const ELEMENTS_PER_VAR: usize = 64 / GENERATION_BITS;
18
19 struct FilterState {
20         last_roll: Instant,
21         inserted_in_last_generations: [usize; GENERATION_COUNT - 1],
22         inserted_since_last_roll: usize,
23         current_generation: u8,
24         bits: Vec<u64>,
25 }
26
27 pub struct RollingBloomFilter<T: Hash> {
28         state: RwLock<FilterState>,
29         hash_keys: [RandomState; HASHES],
30         _entry_type: PhantomData<T>,
31 }
32
33 impl<T: Hash> RollingBloomFilter<T> {
34         pub fn new() -> Self {
35                 let mut bits = Vec::new();
36                 bits.resize(FILTER_SIZE * GENERATION_BITS / 64, 0);
37                 Self {
38                         state: RwLock::new(FilterState {
39                                 last_roll: Instant::now(),
40                                 inserted_since_last_roll: 0,
41                                 inserted_in_last_generations: [0; GENERATION_COUNT - 1],
42                                 current_generation: 1,
43                                 bits,
44                         }),
45                         hash_keys: [RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
46                                     RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
47                                     RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(), RandomState::new(),
48                                     RandomState::new(), RandomState::new(), RandomState::new()],
49                         _entry_type: PhantomData,
50                 }
51         }
52
53         pub fn contains(&self, item: &T) -> bool {
54                 let mut hashes = [None; HASHES];
55                 for (idx, state) in self.hash_keys.iter().enumerate() {
56                         let mut hasher = state.build_hasher();
57                         item.hash(&mut hasher);
58                         hashes[idx] = Some(hasher.finish() as usize);
59                 }
60
61                 let state = self.state.read().unwrap();
62                 for idx_opt in hashes.iter() {
63                         let idx = idx_opt.unwrap();
64
65                         let byte = state.bits[(idx / ELEMENTS_PER_VAR) % (FILTER_SIZE / 64)];
66                         let bits_shift = (idx % ELEMENTS_PER_VAR) * GENERATION_BITS;
67                         let bits = (byte & ((GENERATION_COUNT as u64) << bits_shift)) >> bits_shift;
68                         if bits == 0 { return false; }
69                 }
70                 true
71         }
72
73         pub fn get_element_count(&self) -> [usize; GENERATION_COUNT] {
74                 let mut res = [0; GENERATION_COUNT];
75                 let state = self.state.read().unwrap();
76                 res[0..(GENERATION_COUNT-1)].copy_from_slice(&state.inserted_in_last_generations);
77                 *res.last_mut().unwrap() = state.inserted_since_last_roll;
78                 res
79         }
80
81         pub fn insert(&self, item: &T, roll_duration: Duration) {
82                 let mut hashes = [None; HASHES];
83                 for (idx, state) in self.hash_keys.iter().enumerate() {
84                         let mut hasher = state.build_hasher();
85                         item.hash(&mut hasher);
86                         hashes[idx] = Some(hasher.finish() as usize);
87                 }
88
89                 let mut state = self.state.write().unwrap();
90                 if Instant::now() - state.last_roll > roll_duration / GENERATION_COUNT as u32 ||
91                    state.inserted_since_last_roll > ROLL_COUNT {
92                         state.current_generation += 1;
93                         if state.current_generation == GENERATION_COUNT as u8 + 1 { state.current_generation = 1; }
94                         let remove_generation = state.current_generation;
95
96                         for idx in 0..(FILTER_SIZE / ELEMENTS_PER_VAR) {
97                                 let mut var = state.bits[idx];
98                                 for i in 0..ELEMENTS_PER_VAR {
99                                         let bits_shift = i * GENERATION_BITS;
100                                         let bits = (var & ((GENERATION_COUNT as u64) << bits_shift)) >> bits_shift;
101
102                                         if bits == remove_generation as u64 {
103                                                 var &= !((GENERATION_COUNT as u64) << bits_shift);
104                                         }
105                                 }
106                                 state.bits[idx] = var;
107                         }
108                         state.last_roll = Instant::now();
109                         let mut new_generations = [0; GENERATION_COUNT - 1];
110                         new_generations[0..GENERATION_COUNT - 2].copy_from_slice(&state.inserted_in_last_generations[1..]);
111                         new_generations[GENERATION_COUNT - 2] = state.inserted_since_last_roll;
112                         state.inserted_in_last_generations = new_generations;
113                         state.inserted_since_last_roll = 0;
114                 }
115
116                 let generation = state.current_generation;
117                 for idx_opt in hashes.iter() {
118                         let idx = idx_opt.unwrap();
119
120                         let byte = &mut state.bits[(idx / ELEMENTS_PER_VAR) % (FILTER_SIZE / 64)];
121                         let bits_shift = (idx % ELEMENTS_PER_VAR) * GENERATION_BITS;
122                         *byte &= !((GENERATION_COUNT as u64) << bits_shift);
123                         *byte |= (generation as u64) << bits_shift;
124                 }
125                 state.inserted_since_last_roll += 1;
126         }
127 }
128
129 #[test]
130 fn test_bloom() {
131         let filter = RollingBloomFilter::new();
132         for i in 0..1000 {
133                 filter.insert(&i, Duration::from_secs(60 * 60 * 24));
134         }
135         for i in 0..1000 {
136                 assert!(filter.contains(&i));
137         }
138         for i in 1000..2000 {
139                 assert!(!filter.contains(&i));
140         }
141         assert_eq!(filter.get_element_count(), [0, 0, 1000]);
142         filter.state.write().unwrap().inserted_since_last_roll = ROLL_COUNT + 1;
143         filter.insert(&1000, Duration::from_secs(60 * 60 * 24));
144         assert_eq!(filter.get_element_count(), [0, ROLL_COUNT + 1, 1]);
145         for i in 0..1001 {
146                 assert!(filter.contains(&i));
147         }
148         filter.state.write().unwrap().inserted_since_last_roll = ROLL_COUNT + 1;
149         for i in 1001..2000 {
150                 filter.insert(&i, Duration::from_secs(60 * 60 * 24));
151         }
152         assert_eq!(filter.get_element_count(), [ROLL_COUNT + 1, ROLL_COUNT + 1, 999]);
153         for i in 0..2000 {
154                 assert!(filter.contains(&i));
155         }
156         filter.state.write().unwrap().inserted_since_last_roll = ROLL_COUNT + 1;
157         filter.insert(&2000, Duration::from_secs(60 * 60 * 24));
158         assert_eq!(filter.get_element_count(), [ROLL_COUNT + 1, ROLL_COUNT + 1, 1]);
159         for i in 0..1000 {
160                 assert!(!filter.contains(&i));
161         }
162         for i in 1000..2001 {
163                 assert!(filter.contains(&i));
164         }
165 }