Move events.rs into its own top-level module
[rust-lightning] / lightning / src / util / wakers.rs
index 6d8b03cbd7c4c66e9c686bad5afbec601df96ee5..f86fc376cee0202323b9924057157aaf8379f86a 100644 (file)
@@ -33,6 +33,20 @@ pub(crate) struct Notifier {
        condvar: Condvar,
 }
 
+macro_rules! check_woken {
+       ($guard: expr, $retval: expr) => { {
+               if $guard.0 {
+                       $guard.0 = false;
+                       if $guard.1.as_ref().map(|l| l.lock().unwrap().complete).unwrap_or(false) {
+                               // If we're about to return as woken, and the future state is marked complete, wipe
+                               // the future state and let the next future wait until we get a new notify.
+                               $guard.1.take();
+                       }
+                       return $retval;
+               }
+       } }
+}
+
 impl Notifier {
        pub(crate) fn new() -> Self {
                Self {
@@ -57,16 +71,9 @@ impl Notifier {
        pub(crate) fn wait(&self) {
                loop {
                        let mut guard = self.propagate_future_state_to_notify_flag();
-                       if guard.0 {
-                               guard.0 = false;
-                               return;
-                       }
+                       check_woken!(guard, ());
                        guard = self.condvar.wait(guard).unwrap();
-                       let result = guard.0;
-                       if result {
-                               guard.0 = false;
-                               return
-                       }
+                       check_woken!(guard, ());
                }
        }
 
@@ -75,24 +82,20 @@ impl Notifier {
                let current_time = Instant::now();
                loop {
                        let mut guard = self.propagate_future_state_to_notify_flag();
-                       if guard.0 {
-                               guard.0 = false;
-                               return true;
-                       }
+                       check_woken!(guard, true);
                        guard = self.condvar.wait_timeout(guard, max_wait).unwrap().0;
+                       check_woken!(guard, true);
                        // Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
                        // desired wait time has actually passed, and if not then restart the loop with a reduced wait
                        // time. Note that this logic can be highly simplified through the use of
                        // `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
                        // 1.42.0.
                        let elapsed = current_time.elapsed();
-                       let result = guard.0;
-                       if result || elapsed >= max_wait {
-                               guard.0 = false;
-                               return result;
+                       if elapsed >= max_wait {
+                               return false;
                        }
                        match max_wait.checked_sub(elapsed) {
-                               None => return result,
+                               None => return false,
                                Some(_) => continue
                        }
                }
@@ -102,7 +105,10 @@ impl Notifier {
        pub(crate) fn notify(&self) {
                let mut lock = self.notify_pending.lock().unwrap();
                if let Some(future_state) = &lock.1 {
-                       future_state.lock().unwrap().complete();
+                       if future_state.lock().unwrap().complete() {
+                               lock.1 = None;
+                               return;
+                       }
                }
                lock.0 = true;
                mem::drop(lock);
@@ -149,18 +155,22 @@ impl<F: Fn() + Send> FutureCallback for F {
 }
 
 pub(crate) struct FutureState {
-       callbacks: Vec<Box<dyn FutureCallback>>,
+       // When we're tracking whether a callback counts as having woken the user's code, we check the
+       // first bool - set to false if we're just calling a Waker, and true if we're calling an actual
+       // user-provided function.
+       callbacks: Vec<(bool, Box<dyn FutureCallback>)>,
        complete: bool,
        callbacks_made: bool,
 }
 
 impl FutureState {
-       fn complete(&mut self) {
-               for callback in self.callbacks.drain(..) {
+       fn complete(&mut self) -> bool {
+               for (counts_as_call, callback) in self.callbacks.drain(..) {
                        callback.call();
-                       self.callbacks_made = true;
+                       self.callbacks_made |= counts_as_call;
                }
                self.complete = true;
+               self.callbacks_made
        }
 }
 
@@ -181,7 +191,7 @@ impl Future {
                        mem::drop(state);
                        callback.call();
                } else {
-                       state.callbacks.push(callback);
+                       state.callbacks.push((true, callback));
                }
        }
 
@@ -209,10 +219,11 @@ impl<'a> StdFuture for Future {
        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
                let mut state = self.state.lock().unwrap();
                if state.complete {
+                       state.callbacks_made = true;
                        Poll::Ready(())
                } else {
                        let waker = cx.waker().clone();
-                       state.callbacks.push(Box::new(StdWaker(waker)));
+                       state.callbacks.push((false, Box::new(StdWaker(waker))));
                        Poll::Pending
                }
        }
@@ -430,4 +441,95 @@ mod tests {
                assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
                assert_eq!(Pin::new(&mut second_future).poll(&mut Context::from_waker(&second_waker)), Poll::Ready(()));
        }
+
+       #[test]
+       fn test_dropped_future_doesnt_count() {
+               // Tests that if a Future gets drop'd before it is poll()ed `Ready` it doesn't count as
+               // having been woken, leaving the notify-required flag set.
+               let notifier = Notifier::new();
+               notifier.notify();
+
+               // If we get a future and don't touch it we're definitely still notify-required.
+               notifier.get_future();
+               assert!(notifier.wait_timeout(Duration::from_millis(1)));
+               assert!(!notifier.wait_timeout(Duration::from_millis(1)));
+
+               // Even if we poll'd once but didn't observe a `Ready`, we should be notify-required.
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert!(notifier.wait_timeout(Duration::from_millis(1)));
+
+               // However, once we do poll `Ready` it should wipe the notify-required flag.
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!notifier.wait_timeout(Duration::from_millis(1)));
+       }
+
+       #[test]
+       fn test_poll_post_notify_completes() {
+               // Tests that if we have a future state that has completed, and we haven't yet requested a
+               // new future, if we get a notify prior to requesting that second future it is generated
+               // pre-completed.
+               let notifier = Notifier::new();
+
+               notifier.notify();
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!woken.load(Ordering::SeqCst));
+
+               notifier.notify();
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!woken.load(Ordering::SeqCst));
+
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+               assert!(!woken.load(Ordering::SeqCst));
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+       }
+
+       #[test]
+       fn test_poll_post_notify_completes_initial_notified() {
+               // Identical to the previous test, but the first future completes via a wake rather than an
+               // immediate `Poll::Ready`.
+               let notifier = Notifier::new();
+
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+
+               notifier.notify();
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+               assert!(!woken.load(Ordering::SeqCst));
+
+               let mut future = notifier.get_future();
+               let (woken, waker) = create_waker();
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Pending);
+               assert!(!woken.load(Ordering::SeqCst));
+
+               notifier.notify();
+               assert!(woken.load(Ordering::SeqCst));
+               assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(()));
+       }
 }