From: Matt Corallo Date: Mon, 14 Nov 2022 23:49:27 +0000 (+0000) Subject: Unset the needs-notify bit in a Notifier when a Future is fetched X-Git-Tag: v0.0.113~40^2~2 X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=commitdiff_plain;h=7527e4b7df674fc2f5442514bb7b1d2294e59ce8;p=rust-lightning Unset the needs-notify bit in a Notifier when a Future is fetched If a `Notifier` gets `notify()`ed and the a `Future` is fetched, even though the `Future` is marked completed from the start and the user may pass callbacks which are called, we'll never wipe the needs-notify bit in the `Notifier`. The solution is to keep track of the `FutureState` in the returned `Future` even though its `complete` from the start, adding a new flag in the `FutureState` which indicates callbacks have been made and checking that flag when waiting or returning a second `Future`. --- diff --git a/lightning/src/util/wakers.rs b/lightning/src/util/wakers.rs index 60684dcad..6d8b03cbd 100644 --- a/lightning/src/util/wakers.rs +++ b/lightning/src/util/wakers.rs @@ -15,7 +15,7 @@ use alloc::sync::Arc; use core::mem; -use crate::sync::{Condvar, Mutex}; +use crate::sync::{Condvar, Mutex, MutexGuard}; use crate::prelude::*; @@ -41,9 +41,22 @@ impl Notifier { } } + fn propagate_future_state_to_notify_flag(&self) -> MutexGuard<(bool, Option>>)> { + let mut lock = self.notify_pending.lock().unwrap(); + if let Some(existing_state) = &lock.1 { + if existing_state.lock().unwrap().callbacks_made { + // If the existing `FutureState` has completed and actually made callbacks, + // consider the notification flag to have been cleared and reset the future state. + lock.1.take(); + lock.0 = false; + } + } + lock + } + pub(crate) fn wait(&self) { loop { - let mut guard = self.notify_pending.lock().unwrap(); + let mut guard = self.propagate_future_state_to_notify_flag(); if guard.0 { guard.0 = false; return; @@ -61,7 +74,7 @@ impl Notifier { pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool { let current_time = Instant::now(); loop { - let mut guard = self.notify_pending.lock().unwrap(); + let mut guard = self.propagate_future_state_to_notify_flag(); if guard.0 { guard.0 = false; return true; @@ -88,17 +101,8 @@ impl Notifier { /// Wake waiters, tracking that wake needs to occur even if there are currently no waiters. pub(crate) fn notify(&self) { let mut lock = self.notify_pending.lock().unwrap(); - let mut future_probably_generated_calls = false; - if let Some(future_state) = lock.1.take() { - future_probably_generated_calls |= future_state.lock().unwrap().complete(); - future_probably_generated_calls |= Arc::strong_count(&future_state) > 1; - } - if future_probably_generated_calls { - // If a future made some callbacks or has not yet been drop'd (i.e. the state has more - // than the one reference we hold), assume the user was notified and skip setting the - // notification-required flag. This will not cause the `wait` functions above to return - // and avoid any future `Future`s starting in a completed state. - return; + if let Some(future_state) = &lock.1 { + future_state.lock().unwrap().complete(); } lock.0 = true; mem::drop(lock); @@ -107,20 +111,14 @@ impl Notifier { /// Gets a [`Future`] that will get woken up with any waiters pub(crate) fn get_future(&self) -> Future { - let mut lock = self.notify_pending.lock().unwrap(); - if lock.0 { - Future { - state: Arc::new(Mutex::new(FutureState { - callbacks: Vec::new(), - complete: true, - })) - } - } else if let Some(existing_state) = &lock.1 { + let mut lock = self.propagate_future_state_to_notify_flag(); + if let Some(existing_state) = &lock.1 { Future { state: Arc::clone(&existing_state) } } else { let state = Arc::new(Mutex::new(FutureState { callbacks: Vec::new(), - complete: false, + complete: lock.0, + callbacks_made: false, })); lock.1 = Some(Arc::clone(&state)); Future { state } @@ -153,17 +151,16 @@ impl FutureCallback for F { pub(crate) struct FutureState { callbacks: Vec>, complete: bool, + callbacks_made: bool, } impl FutureState { - fn complete(&mut self) -> bool { - let mut made_calls = false; + fn complete(&mut self) { for callback in self.callbacks.drain(..) { callback.call(); - made_calls = true; + self.callbacks_made = true; } self.complete = true; - made_calls } } @@ -180,6 +177,7 @@ impl Future { pub fn register_callback(&self, callback: Box) { let mut state = self.state.lock().unwrap(); if state.complete { + state.callbacks_made = true; mem::drop(state); callback.call(); } else { @@ -283,6 +281,28 @@ mod tests { assert!(!callback.load(Ordering::SeqCst)); } + #[test] + fn new_future_wipes_notify_bit() { + // Previously, if we were only using the `Future` interface to learn when a `Notifier` has + // been notified, we'd never mark the notifier as not-awaiting-notify if a `Future` is + // fetched after the notify bit has been set. + let notifier = Notifier::new(); + notifier.notify(); + + let callback = Arc::new(AtomicBool::new(false)); + let callback_ref = Arc::clone(&callback); + notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + assert!(callback.load(Ordering::SeqCst)); + + let callback = Arc::new(AtomicBool::new(false)); + let callback_ref = Arc::clone(&callback); + notifier.get_future().register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); + assert!(!callback.load(Ordering::SeqCst)); + + notifier.notify(); + assert!(callback.load(Ordering::SeqCst)); + } + #[cfg(feature = "std")] #[test] fn test_wait_timeout() { @@ -334,6 +354,7 @@ mod tests { state: Arc::new(Mutex::new(FutureState { callbacks: Vec::new(), complete: false, + callbacks_made: false, })) }; let callback = Arc::new(AtomicBool::new(false)); @@ -352,6 +373,7 @@ mod tests { state: Arc::new(Mutex::new(FutureState { callbacks: Vec::new(), complete: false, + callbacks_made: false, })) }; future.state.lock().unwrap().complete(); @@ -389,6 +411,7 @@ mod tests { state: Arc::new(Mutex::new(FutureState { callbacks: Vec::new(), complete: false, + callbacks_made: false, })) }; let mut second_future = Future { state: Arc::clone(&future.state) };