Move PersistenceNotifier to a new util module
[rust-lightning] / lightning / src / util / wakers.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 //! Utilities which allow users to block on some future notification from LDK. These are
11 //! specifically used by [`ChannelManager`] to allow waiting until the [`ChannelManager`] needs to
12 //! be re-persisted.
13 //!
14 //! [`ChannelManager`]: crate::ln::channelmanager::ChannelManager
15
16 use core::mem;
17 use core::time::Duration;
18 use sync::{Condvar, Mutex};
19
20 #[cfg(any(test, feature = "std"))]
21 use std::time::Instant;
22
23 /// Used to signal to the ChannelManager persister that the manager needs to be re-persisted to
24 /// disk/backups, through `await_persistable_update_timeout` and `await_persistable_update`.
25 pub(crate) struct PersistenceNotifier {
26         /// Users won't access the persistence_lock directly, but rather wait on its bool using
27         /// `wait_timeout` and `wait`.
28         persistence_lock: (Mutex<bool>, Condvar),
29 }
30
31 impl PersistenceNotifier {
32         pub(crate) fn new() -> Self {
33                 Self {
34                         persistence_lock: (Mutex::new(false), Condvar::new()),
35                 }
36         }
37
38         pub(crate) fn wait(&self) {
39                 loop {
40                         let &(ref mtx, ref cvar) = &self.persistence_lock;
41                         let mut guard = mtx.lock().unwrap();
42                         if *guard {
43                                 *guard = false;
44                                 return;
45                         }
46                         guard = cvar.wait(guard).unwrap();
47                         let result = *guard;
48                         if result {
49                                 *guard = false;
50                                 return
51                         }
52                 }
53         }
54
55         #[cfg(any(test, feature = "std"))]
56         pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool {
57                 let current_time = Instant::now();
58                 loop {
59                         let &(ref mtx, ref cvar) = &self.persistence_lock;
60                         let mut guard = mtx.lock().unwrap();
61                         if *guard {
62                                 *guard = false;
63                                 return true;
64                         }
65                         guard = cvar.wait_timeout(guard, max_wait).unwrap().0;
66                         // Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
67                         // desired wait time has actually passed, and if not then restart the loop with a reduced wait
68                         // time. Note that this logic can be highly simplified through the use of
69                         // `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
70                         // 1.42.0.
71                         let elapsed = current_time.elapsed();
72                         let result = *guard;
73                         if result || elapsed >= max_wait {
74                                 *guard = false;
75                                 return result;
76                         }
77                         match max_wait.checked_sub(elapsed) {
78                                 None => return result,
79                                 Some(_) => continue
80                         }
81                 }
82         }
83
84         /// Wake waiters, tracking that persistence needs to occur.
85         pub(crate) fn notify(&self) {
86                 let &(ref persist_mtx, ref cnd) = &self.persistence_lock;
87                 let mut persistence_lock = persist_mtx.lock().unwrap();
88                 *persistence_lock = true;
89                 mem::drop(persistence_lock);
90                 cnd.notify_all();
91         }
92
93         #[cfg(any(test, feature = "_test_utils"))]
94         pub fn needs_persist(&self) -> bool {
95                 let &(ref mtx, _) = &self.persistence_lock;
96                 let guard = mtx.lock().unwrap();
97                 *guard
98         }
99 }
100
101 #[cfg(test)]
102 mod tests {
103         #[cfg(feature = "std")]
104         #[test]
105         fn test_wait_timeout() {
106                 use super::*;
107                 use sync::Arc;
108                 use core::sync::atomic::{AtomicBool, Ordering};
109                 use std::thread;
110
111                 let persistence_notifier = Arc::new(PersistenceNotifier::new());
112                 let thread_notifier = Arc::clone(&persistence_notifier);
113
114                 let exit_thread = Arc::new(AtomicBool::new(false));
115                 let exit_thread_clone = exit_thread.clone();
116                 thread::spawn(move || {
117                         loop {
118                                 let &(ref persist_mtx, ref cnd) = &thread_notifier.persistence_lock;
119                                 let mut persistence_lock = persist_mtx.lock().unwrap();
120                                 *persistence_lock = true;
121                                 cnd.notify_all();
122
123                                 if exit_thread_clone.load(Ordering::SeqCst) {
124                                         break
125                                 }
126                         }
127                 });
128
129                 // Check that we can block indefinitely until updates are available.
130                 let _ = persistence_notifier.wait();
131
132                 // Check that the PersistenceNotifier will return after the given duration if updates are
133                 // available.
134                 loop {
135                         if persistence_notifier.wait_timeout(Duration::from_millis(100)) {
136                                 break
137                         }
138                 }
139
140                 exit_thread.store(true, Ordering::SeqCst);
141
142                 // Check that the PersistenceNotifier will return after the given duration even if no updates
143                 // are available.
144                 loop {
145                         if !persistence_notifier.wait_timeout(Duration::from_millis(100)) {
146                                 break
147                         }
148                 }
149         }
150 }