Unset the needs-notify bit in a Notifier when a Future is fetched
authorMatt Corallo <git@bluematt.me>
Mon, 14 Nov 2022 23:49:27 +0000 (23:49 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 16 Nov 2022 00:21:43 +0000 (00:21 +0000)
If a `Notifier` gets `notify()`ed and the a `Future` is fetched,
even though the `Future` is marked completed from the start and
the user may pass callbacks which are called, we'll never wipe the
needs-notify bit in the `Notifier`.

The solution is to keep track of the `FutureState` in the returned
`Future` even though its `complete` from the start, adding a new
flag in the `FutureState` which indicates callbacks have been made
and checking that flag when waiting or returning a second `Future`.

lightning/src/util/wakers.rs

index 60684dcadaf9f129a389f569de6f97f72e7125f7..6d8b03cbd7c4c66e9c686bad5afbec601df96ee5 100644 (file)
@@ -15,7 +15,7 @@
 
 use alloc::sync::Arc;
 use core::mem;
-use crate::sync::{Condvar, Mutex};
+use crate::sync::{Condvar, Mutex, MutexGuard};
 
 use crate::prelude::*;
 
@@ -41,9 +41,22 @@ impl Notifier {
                }
        }
 
+       fn propagate_future_state_to_notify_flag(&self) -> MutexGuard<(bool, Option<Arc<Mutex<FutureState>>>)> {
+               let mut lock = self.notify_pending.lock().unwrap();
+               if let Some(existing_state) = &lock.1 {
+                       if existing_state.lock().unwrap().callbacks_made {
+                               // If the existing `FutureState` has completed and actually made callbacks,
+                               // consider the notification flag to have been cleared and reset the future state.
+                               lock.1.take();
+                               lock.0 = false;
+                       }
+               }
+               lock
+       }
+
        pub(crate) fn wait(&self) {
                loop {
-                       let mut guard = self.notify_pending.lock().unwrap();
+                       let mut guard = self.propagate_future_state_to_notify_flag();
                        if guard.0 {
                                guard.0 = false;
                                return;
@@ -61,7 +74,7 @@ impl Notifier {
        pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool {
                let current_time = Instant::now();
                loop {
-                       let mut guard = self.notify_pending.lock().unwrap();
+                       let mut guard = self.propagate_future_state_to_notify_flag();
                        if guard.0 {
                                guard.0 = false;
                                return true;
@@ -88,17 +101,8 @@ impl Notifier {
        /// Wake waiters, tracking that wake needs to occur even if there are currently no waiters.
        pub(crate) fn notify(&self) {
                let mut lock = self.notify_pending.lock().unwrap();
-               let mut future_probably_generated_calls = false;
-               if let Some(future_state) = lock.1.take() {
-                       future_probably_generated_calls |= future_state.lock().unwrap().complete();
-                       future_probably_generated_calls |= Arc::strong_count(&future_state) > 1;
-               }
-               if future_probably_generated_calls {
-                       // If a future made some callbacks or has not yet been drop'd (i.e. the state has more
-                       // than the one reference we hold), assume the user was notified and skip setting the
-                       // notification-required flag. This will not cause the `wait` functions above to return
-                       // and avoid any future `Future`s starting in a completed state.
-                       return;
+               if let Some(future_state) = &lock.1 {
+                       future_state.lock().unwrap().complete();
                }
                lock.0 = true;
                mem::drop(lock);
@@ -107,20 +111,14 @@ impl Notifier {
 
        /// Gets a [`Future`] that will get woken up with any waiters
        pub(crate) fn get_future(&self) -> Future {
-               let mut lock = self.notify_pending.lock().unwrap();
-               if lock.0 {
-                       Future {
-                               state: Arc::new(Mutex::new(FutureState {
-                                       callbacks: Vec::new(),
-                                       complete: true,
-                               }))
-                       }
-               } else if let Some(existing_state) = &lock.1 {
+               let mut lock = self.propagate_future_state_to_notify_flag();
+               if let Some(existing_state) = &lock.1 {
                        Future { state: Arc::clone(&existing_state) }
                } else {
                        let state = Arc::new(Mutex::new(FutureState {
                                callbacks: Vec::new(),
-                               complete: false,
+                               complete: lock.0,
+                               callbacks_made: false,
                        }));
                        lock.1 = Some(Arc::clone(&state));
                        Future { state }
@@ -153,17 +151,16 @@ impl<F: Fn() + Send> FutureCallback for F {
 pub(crate) struct FutureState {
        callbacks: Vec<Box<dyn FutureCallback>>,
        complete: bool,
+       callbacks_made: bool,
 }
 
 impl FutureState {
-       fn complete(&mut self) -> bool {
-               let mut made_calls = false;
+       fn complete(&mut self) {
                for callback in self.callbacks.drain(..) {
                        callback.call();
-                       made_calls = true;
+                       self.callbacks_made = true;
                }
                self.complete = true;
-               made_calls
        }
 }
 
@@ -180,6 +177,7 @@ impl Future {
        pub fn register_callback(&self, callback: Box<dyn FutureCallback>) {
                let mut state = self.state.lock().unwrap();
                if state.complete {
+                       state.callbacks_made = true;
                        mem::drop(state);
                        callback.call();
                } else {
@@ -283,6 +281,28 @@ mod tests {
                assert!(!callback.load(Ordering::SeqCst));
        }
 
+       #[test]
+       fn new_future_wipes_notify_bit() {
+               // Previously, if we were only using the `Future` interface to learn when a `Notifier` has
+               // been notified, we'd never mark the notifier as not-awaiting-notify if a `Future` is
+               // fetched after the notify bit has been set.
+               let notifier = Notifier::new();
+               notifier.notify();
+
+               let callback = Arc::new(AtomicBool::new(false));
+               let callback_ref = Arc::clone(&callback);
+               notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst))));
+               assert!(callback.load(Ordering::SeqCst));
+
+               let callback = Arc::new(AtomicBool::new(false));
+               let callback_ref = Arc::clone(&callback);
+               notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst))));
+               assert!(!callback.load(Ordering::SeqCst));
+
+               notifier.notify();
+               assert!(callback.load(Ordering::SeqCst));
+       }
+
        #[cfg(feature = "std")]
        #[test]
        fn test_wait_timeout() {
@@ -334,6 +354,7 @@ mod tests {
                        state: Arc::new(Mutex::new(FutureState {
                                callbacks: Vec::new(),
                                complete: false,
+                               callbacks_made: false,
                        }))
                };
                let callback = Arc::new(AtomicBool::new(false));
@@ -352,6 +373,7 @@ mod tests {
                        state: Arc::new(Mutex::new(FutureState {
                                callbacks: Vec::new(),
                                complete: false,
+                               callbacks_made: false,
                        }))
                };
                future.state.lock().unwrap().complete();
@@ -389,6 +411,7 @@ mod tests {
                        state: Arc::new(Mutex::new(FutureState {
                                callbacks: Vec::new(),
                                complete: false,
+                               callbacks_made: false,
                        }))
                };
                let mut second_future = Future { state: Arc::clone(&future.state) };