X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Fwakers.rs;h=37c036da9594747ea81b142e151ff960c2892dad;hb=c558ccd6a92fa9034929769f55e65bf9c1336abd;hp=834721fd4aa3abf207e70405774f29958f743bf4;hpb=b455fb5e775d5062035242d38da9c57527ca5a85;p=rust-lightning diff --git a/lightning/src/util/wakers.rs b/lightning/src/util/wakers.rs index 834721fd..37c036da 100644 --- a/lightning/src/util/wakers.rs +++ b/lightning/src/util/wakers.rs @@ -15,11 +15,13 @@ use alloc::sync::Arc; use core::mem; -use crate::sync::{Condvar, Mutex}; +use crate::sync::Mutex; use crate::prelude::*; -#[cfg(any(test, feature = "std"))] +#[cfg(feature = "std")] +use crate::sync::Condvar; +#[cfg(feature = "std")] use std::time::Duration; use core::future::Future as StdFuture; @@ -39,20 +41,11 @@ impl Notifier { } } - pub(crate) fn wait(&self) { - Sleeper::from_single_future(self.get_future()).wait(); - } - - #[cfg(any(test, feature = "std"))] - pub(crate) fn wait_timeout(&self, max_wait: Duration) -> bool { - Sleeper::from_single_future(self.get_future()).wait_timeout(max_wait) - } - /// Wake waiters, tracking that wake needs to occur even if there are currently no waiters. pub(crate) fn notify(&self) { let mut lock = self.notify_pending.lock().unwrap(); if let Some(future_state) = &lock.1 { - if future_state.lock().unwrap().complete() { + if complete_future(future_state) { lock.1 = None; return; } @@ -76,6 +69,7 @@ impl Notifier { } else { let state = Arc::new(Mutex::new(FutureState { callbacks: Vec::new(), + callbacks_with_state: Vec::new(), complete: lock.0, callbacks_made: false, })); @@ -119,19 +113,24 @@ pub(crate) struct FutureState { // first bool - set to false if we're just calling a Waker, and true if we're calling an actual // user-provided function. callbacks: Vec<(bool, Box)>, + callbacks_with_state: Vec<(bool, Box>) -> () + Send>)>, complete: bool, callbacks_made: bool, } -impl FutureState { - fn complete(&mut self) -> bool { - for (counts_as_call, callback) in self.callbacks.drain(..) { - callback.call(); - self.callbacks_made |= counts_as_call; - } - self.complete = true; - self.callbacks_made +fn complete_future(this: &Arc>) -> bool { + let mut state_lock = this.lock().unwrap(); + let state = &mut *state_lock; + for (counts_as_call, callback) in state.callbacks.drain(..) { + callback.call(); + state.callbacks_made |= counts_as_call; + } + for (counts_as_call, callback) in state.callbacks_with_state.drain(..) { + (callback)(this); + state.callbacks_made |= counts_as_call; } + state.complete = true; + state.callbacks_made } /// A simple future which can complete once, and calls some callback(s) when it does so. @@ -167,6 +166,29 @@ impl Future { pub fn register_callback_fn(&self, callback: F) { self.register_callback(Box::new(callback)); } + + /// Waits until this [`Future`] completes. + #[cfg(feature = "std")] + pub fn wait(self) { + Sleeper::from_single_future(self).wait(); + } + + /// Waits until this [`Future`] completes or the given amount of time has elapsed. + /// + /// Returns true if the [`Future`] completed, false if the time elapsed. + #[cfg(feature = "std")] + pub fn wait_timeout(self, max_wait: Duration) -> bool { + Sleeper::from_single_future(self).wait_timeout(max_wait) + } + + #[cfg(test)] + pub fn poll_is_complete(&self) -> bool { + let mut state = self.state.lock().unwrap(); + if state.complete { + state.callbacks_made = true; + true + } else { false } + } } use core::task::Waker; @@ -194,10 +216,12 @@ impl<'a> StdFuture for Future { /// A struct which can be used to select across many [`Future`]s at once without relying on a full /// async context. +#[cfg(feature = "std")] pub struct Sleeper { notifiers: Vec>>, } +#[cfg(feature = "std")] impl Sleeper { /// Constructs a new sleeper from one future, allowing blocking on it. pub fn from_single_future(future: Future) -> Self { @@ -222,14 +246,13 @@ impl Sleeper { for notifier_mtx in self.notifiers.iter() { let cv_ref = Arc::clone(&cv); let notified_fut_ref = Arc::clone(¬ified_fut_mtx); - let notifier_ref = Arc::clone(¬ifier_mtx); let mut notifier = notifier_mtx.lock().unwrap(); if notifier.complete { - *notified_fut_mtx.lock().unwrap() = Some(notifier_ref); + *notified_fut_mtx.lock().unwrap() = Some(Arc::clone(¬ifier_mtx)); break; } - notifier.callbacks.push((false, Box::new(move || { - *notified_fut_ref.lock().unwrap() = Some(Arc::clone(¬ifier_ref)); + notifier.callbacks_with_state.push((false, Box::new(move |notifier_ref| { + *notified_fut_ref.lock().unwrap() = Some(Arc::clone(notifier_ref)); cv_ref.notify_all(); }))); } @@ -248,7 +271,6 @@ impl Sleeper { /// Wait until one of the [`Future`]s registered with this [`Sleeper`] has completed or the /// given amount of time has elapsed. Returns true if a [`Future`] completed, false if the time /// elapsed. - #[cfg(any(test, feature = "std"))] pub fn wait_timeout(&self, max_wait: Duration) -> bool { let (cv, notified_fut_mtx) = self.setup_wait(); let notified_fut = @@ -369,12 +391,12 @@ mod tests { }); // Check that we can block indefinitely until updates are available. - let _ = persistence_notifier.wait(); + let _ = persistence_notifier.get_future().wait(); // Check that the Notifier will return after the given duration if updates are // available. loop { - if persistence_notifier.wait_timeout(Duration::from_millis(100)) { + if persistence_notifier.get_future().wait_timeout(Duration::from_millis(100)) { break } } @@ -384,17 +406,56 @@ mod tests { // Check that the Notifier will return after the given duration even if no updates // are available. loop { - if !persistence_notifier.wait_timeout(Duration::from_millis(100)) { + if !persistence_notifier.get_future().wait_timeout(Duration::from_millis(100)) { break } } } + #[cfg(feature = "std")] + #[test] + fn test_state_drops() { + // Previously, there was a leak if a `Notifier` was `drop`ed without ever being notified + // but after having been slept-on. This tests for that leak. + use crate::sync::Arc; + use std::thread; + + let notifier_a = Arc::new(Notifier::new()); + let notifier_b = Arc::new(Notifier::new()); + + let thread_notifier_a = Arc::clone(¬ifier_a); + + let future_a = notifier_a.get_future(); + let future_state_a = Arc::downgrade(&future_a.state); + + let future_b = notifier_b.get_future(); + let future_state_b = Arc::downgrade(&future_b.state); + + let join_handle = thread::spawn(move || { + // Let the other thread get to the wait point, then notify it. + std::thread::sleep(Duration::from_millis(50)); + thread_notifier_a.notify(); + }); + + // Wait on the other thread to finish its sleep, note that the leak only happened if we + // actually have to sleep here, not if we immediately return. + Sleeper::from_two_futures(future_a, future_b).wait(); + + join_handle.join().unwrap(); + + // then drop the notifiers and make sure the future states are gone. + mem::drop(notifier_a); + mem::drop(notifier_b); + + assert!(future_state_a.upgrade().is_none() && future_state_b.upgrade().is_none()); + } + #[test] fn test_future_callbacks() { let future = Future { state: Arc::new(Mutex::new(FutureState { callbacks: Vec::new(), + callbacks_with_state: Vec::new(), complete: false, callbacks_made: false, })) @@ -404,9 +465,9 @@ mod tests { future.register_callback(Box::new(move || assert!(!callback_ref.fetch_or(true, Ordering::SeqCst)))); assert!(!callback.load(Ordering::SeqCst)); - future.state.lock().unwrap().complete(); + complete_future(&future.state); assert!(callback.load(Ordering::SeqCst)); - future.state.lock().unwrap().complete(); + complete_future(&future.state); } #[test] @@ -414,11 +475,12 @@ mod tests { let future = Future { state: Arc::new(Mutex::new(FutureState { callbacks: Vec::new(), + callbacks_with_state: Vec::new(), complete: false, callbacks_made: false, })) }; - future.state.lock().unwrap().complete(); + complete_future(&future.state); let callback = Arc::new(AtomicBool::new(false)); let callback_ref = Arc::clone(&callback); @@ -452,6 +514,7 @@ mod tests { let mut future = Future { state: Arc::new(Mutex::new(FutureState { callbacks: Vec::new(), + callbacks_with_state: Vec::new(), complete: false, callbacks_made: false, })) @@ -466,7 +529,7 @@ mod tests { assert_eq!(Pin::new(&mut second_future).poll(&mut Context::from_waker(&second_waker)), Poll::Pending); assert!(!second_woken.load(Ordering::SeqCst)); - future.state.lock().unwrap().complete(); + complete_future(&future.state); assert!(woken.load(Ordering::SeqCst)); assert!(second_woken.load(Ordering::SeqCst)); assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(())); @@ -474,6 +537,7 @@ mod tests { } #[test] + #[cfg(feature = "std")] fn test_dropped_future_doesnt_count() { // Tests that if a Future gets drop'd before it is poll()ed `Ready` it doesn't count as // having been woken, leaving the notify-required flag set. @@ -482,8 +546,8 @@ mod tests { // If we get a future and don't touch it we're definitely still notify-required. notifier.get_future(); - assert!(notifier.wait_timeout(Duration::from_millis(1))); - assert!(!notifier.wait_timeout(Duration::from_millis(1))); + assert!(notifier.get_future().wait_timeout(Duration::from_millis(1))); + assert!(!notifier.get_future().wait_timeout(Duration::from_millis(1))); // Even if we poll'd once but didn't observe a `Ready`, we should be notify-required. let mut future = notifier.get_future(); @@ -492,7 +556,7 @@ mod tests { notifier.notify(); assert!(woken.load(Ordering::SeqCst)); - assert!(notifier.wait_timeout(Duration::from_millis(1))); + assert!(notifier.get_future().wait_timeout(Duration::from_millis(1))); // However, once we do poll `Ready` it should wipe the notify-required flag. let mut future = notifier.get_future(); @@ -502,7 +566,7 @@ mod tests { notifier.notify(); assert!(woken.load(Ordering::SeqCst)); assert_eq!(Pin::new(&mut future).poll(&mut Context::from_waker(&waker)), Poll::Ready(())); - assert!(!notifier.wait_timeout(Duration::from_millis(1))); + assert!(!notifier.get_future().wait_timeout(Duration::from_millis(1))); } #[test] @@ -565,6 +629,7 @@ mod tests { } #[test] + #[cfg(feature = "std")] fn test_multi_future_sleep() { // Tests the `Sleeper` with multiple futures. let notifier_a = Notifier::new();