9aeb5370b736796be865a348acbb2fb0e6c59b1f
[rust-lightning] / lightning / src / util / wakers.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 //! Utilities which allow users to block on some future notification from LDK. These are
11 //! specifically used by [`ChannelManager`] to allow waiting until the [`ChannelManager`] needs to
12 //! be re-persisted.
13 //!
14 //! [`ChannelManager`]: crate::ln::channelmanager::ChannelManager
15
16 use core::mem;
17 use core::time::Duration;
18 use sync::{Condvar, Mutex};
19
20 #[cfg(any(test, feature = "std"))]
21 use std::time::Instant;
22
23 /// Used to signal to one of many waiters that the condition they're waiting on has happened.
24 pub(crate) struct Notifier {
25         /// Users won't access the lock directly, but rather wait on its bool using
26         /// `wait_timeout` and `wait`.
27         lock: (Mutex<bool>, Condvar),
28 }
29
30 impl Notifier {
31         pub(crate) fn new() -> Self {
32                 Self {
33                         lock: (Mutex::new(false), Condvar::new()),
34                 }
35         }
36
37         pub(crate) fn wait(&self) {
38                 loop {
39                         let &(ref mtx, ref cvar) = &self.lock;
40                         let mut guard = mtx.lock().unwrap();
41                         if *guard {
42                                 *guard = false;
43                                 return;
44                         }
45                         guard = cvar.wait(guard).unwrap();
46                         let result = *guard;
47                         if result {
48                                 *guard = false;
49                                 return
50                         }
51                 }
52         }
53
54         #[cfg(any(test, feature = "std"))]
55         pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool {
56                 let current_time = Instant::now();
57                 loop {
58                         let &(ref mtx, ref cvar) = &self.lock;
59                         let mut guard = mtx.lock().unwrap();
60                         if *guard {
61                                 *guard = false;
62                                 return true;
63                         }
64                         guard = cvar.wait_timeout(guard, max_wait).unwrap().0;
65                         // Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
66                         // desired wait time has actually passed, and if not then restart the loop with a reduced wait
67                         // time. Note that this logic can be highly simplified through the use of
68                         // `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
69                         // 1.42.0.
70                         let elapsed = current_time.elapsed();
71                         let result = *guard;
72                         if result || elapsed >= max_wait {
73                                 *guard = false;
74                                 return result;
75                         }
76                         match max_wait.checked_sub(elapsed) {
77                                 None => return result,
78                                 Some(_) => continue
79                         }
80                 }
81         }
82
83         /// Wake waiters, tracking that wake needs to occur even if there are currently no waiters.
84         pub(crate) fn notify(&self) {
85                 let &(ref persist_mtx, ref cnd) = &self.lock;
86                 let mut lock = persist_mtx.lock().unwrap();
87                 *lock = true;
88                 mem::drop(lock);
89                 cnd.notify_all();
90         }
91
92         #[cfg(any(test, feature = "_test_utils"))]
93         pub fn notify_pending(&self) -> bool {
94                 let &(ref mtx, _) = &self.lock;
95                 let guard = mtx.lock().unwrap();
96                 *guard
97         }
98 }
99
100 #[cfg(test)]
101 mod tests {
102         #[cfg(feature = "std")]
103         #[test]
104         fn test_wait_timeout() {
105                 use super::*;
106                 use sync::Arc;
107                 use core::sync::atomic::{AtomicBool, Ordering};
108                 use std::thread;
109
110                 let persistence_notifier = Arc::new(Notifier::new());
111                 let thread_notifier = Arc::clone(&persistence_notifier);
112
113                 let exit_thread = Arc::new(AtomicBool::new(false));
114                 let exit_thread_clone = exit_thread.clone();
115                 thread::spawn(move || {
116                         loop {
117                                 let &(ref persist_mtx, ref cnd) = &thread_notifier.lock;
118                                 let mut lock = persist_mtx.lock().unwrap();
119                                 *lock = true;
120                                 cnd.notify_all();
121
122                                 if exit_thread_clone.load(Ordering::SeqCst) {
123                                         break
124                                 }
125                         }
126                 });
127
128                 // Check that we can block indefinitely until updates are available.
129                 let _ = persistence_notifier.wait();
130
131                 // Check that the Notifier will return after the given duration if updates are
132                 // available.
133                 loop {
134                         if persistence_notifier.wait_timeout(Duration::from_millis(100)) {
135                                 break
136                         }
137                 }
138
139                 exit_thread.store(true, Ordering::SeqCst);
140
141                 // Check that the Notifier will return after the given duration even if no updates
142                 // are available.
143                 loop {
144                         if !persistence_notifier.wait_timeout(Duration::from_millis(100)) {
145                                 break
146                         }
147                 }
148         }
149 }