Stop BackgroundProcessor's thread on drop
[rust-lightning] / lightning-background-processor / src / lib.rs
index 0b886f7bf4ea2e61bf94ff4ef54014c13b04385b..cc4c9e635025117c6b87be05f6615c4788f7092b 100644 (file)
@@ -43,7 +43,7 @@ pub struct BackgroundProcessor {
        stop_thread: Arc<AtomicBool>,
        /// May be used to retrieve and handle the error if `BackgroundProcessor`'s thread
        /// exits due to an error while persisting.
-       pub thread_handle: JoinHandle<Result<(), std::io::Error>>,
+       pub thread_handle: Option<JoinHandle<Result<(), std::io::Error>>>,
 }
 
 #[cfg(not(test))]
@@ -158,13 +158,27 @@ impl BackgroundProcessor {
                                }
                        }
                });
-               Self { stop_thread: stop_thread_clone, thread_handle: handle }
+               Self { stop_thread: stop_thread_clone, thread_handle: Some(handle) }
        }
 
        /// Stop `BackgroundProcessor`'s thread.
-       pub fn stop(self) -> Result<(), std::io::Error> {
+       pub fn stop(mut self) -> Result<(), std::io::Error> {
+               assert!(self.thread_handle.is_some());
+               self.stop_and_join_thread()
+       }
+
+       fn stop_and_join_thread(&mut self) -> Result<(), std::io::Error> {
                self.stop_thread.store(true, Ordering::Release);
-               self.thread_handle.join().unwrap()
+               match self.thread_handle.take() {
+                       Some(handle) => handle.join().unwrap(),
+                       None => Ok(()),
+               }
+       }
+}
+
+impl Drop for BackgroundProcessor {
+       fn drop(&mut self) {
+               self.stop_and_join_thread().unwrap();
        }
 }
 
@@ -416,7 +430,13 @@ mod tests {
                let persister = |_: &_| Err(std::io::Error::new(std::io::ErrorKind::Other, "test"));
                let event_handler = |_| {};
                let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
-               let _ = bg_processor.thread_handle.join().unwrap().expect_err("Errored persisting manager: test");
+               match bg_processor.stop() {
+                       Ok(_) => panic!("Expected error persisting manager"),
+                       Err(e) => {
+                               assert_eq!(e.kind(), std::io::ErrorKind::Other);
+                               assert_eq!(e.get_ref().unwrap().to_string(), "test");
+                       },
+               }
        }
 
        #[test]