Use common Persister for persistence tests
authorJurvis Tan <5944973+jurvis@users.noreply.github.com>
Tue, 29 Mar 2022 02:36:43 +0000 (19:36 -0700)
committerJurvis Tan <5944973+jurvis@users.noreply.github.com>
Wed, 30 Mar 2022 02:38:41 +0000 (19:38 -0700)
lightning-background-processor/src/lib.rs

index 6a5ec6118b11e93bafde336152e7e180ad3b255b..2ebb64b4098ee8cffd6d189bf55df9b53095ed6d 100644 (file)
@@ -272,7 +272,7 @@ impl BackgroundProcessor {
                                if last_prune_call.elapsed().as_secs() > if have_pruned { NETWORK_PRUNE_TIMER } else { FIRST_NETWORK_PRUNE_TIMER } {
                                        if let Some(ref handler) = net_graph_msg_handler {
                                                log_trace!(logger, "Pruning network graph of stale entries");
-                                               handler.network_graph().remove_stale_channels(); 
+                                               handler.network_graph().remove_stale_channels();
                                                if let Err(e) = persister.persist_graph(handler.network_graph()) {
                                                        log_error!(logger, "Error: Failed to persist network graph, check your disk and permissions {}", e)
                                                }
@@ -413,6 +413,22 @@ mod tests {
 
        struct Persister {
                data_dir: String,
+               graph_error: Option<(std::io::ErrorKind, &'static str)>,
+               manager_error: Option<(std::io::ErrorKind, &'static str)>
+       }
+
+       impl Persister {
+               fn new(data_dir: String) -> Self {
+                       Self { data_dir, graph_error: None, manager_error: None }
+               }
+
+               fn with_graph_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
+                       Self { graph_error: Some((error, message)), ..self }
+               }
+
+               fn with_manager_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
+                       Self { manager_error: Some((error, message)), ..self }
+               }
        }
 
        impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L:Deref> super::Persister<Signer, M, T, K, F, L> for Persister where
@@ -423,11 +439,17 @@ mod tests {
                L::Target: 'static + Logger,
        {
                fn persist_manager(&self, channel_manager: &ChannelManager<Signer, M, T, K, F, L>) -> Result<(), std::io::Error> {
-                       FilesystemPersister::persist_manager(self.data_dir.clone(), channel_manager)
+                       match self.manager_error {
+                               None => FilesystemPersister::persist_manager(self.data_dir.clone(), channel_manager),
+                               Some((error, message)) => Err(std::io::Error::new(error, message)),
+                       }
                }
 
                fn persist_graph(&self, network_graph: &NetworkGraph) -> Result<(), std::io::Error> {
-                       FilesystemPersister::persist_network_graph(self.data_dir.clone(), network_graph)
+                       match self.graph_error {
+                               None => FilesystemPersister::persist_network_graph(self.data_dir.clone(), network_graph),
+                               Some((error, message)) => Err(std::io::Error::new(error, message)),
+                       }
                }
        }
 
@@ -554,7 +576,7 @@ mod tests {
 
                // Initiate the background processors to watch each node.
                let data_dir = nodes[0].persister.get_data_dir();
-               let persister = Persister { data_dir };
+               let persister = Persister::new(data_dir);
                let event_handler = |_: &_| {};
                let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
 
@@ -615,7 +637,7 @@ mod tests {
                // `FRESHNESS_TIMER`.
                let nodes = create_nodes(1, "test_timer_tick_called".to_string());
                let data_dir = nodes[0].persister.get_data_dir();
-               let persister = Persister { data_dir };
+               let persister = Persister::new(data_dir);
                let event_handler = |_: &_| {};
                let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
                loop {
@@ -637,28 +659,8 @@ mod tests {
                let nodes = create_nodes(2, "test_persist_error".to_string());
                open_channel!(nodes[0], nodes[1], 100000);
 
-               struct ChannelManagerErrorPersister {
-                       data_dir: String,
-               }
-
-               impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L:Deref> super::Persister<Signer, M, T, K, F, L> for ChannelManagerErrorPersister where
-                       M::Target: 'static + chain::Watch<Signer>,
-                       T::Target: 'static + BroadcasterInterface,
-                       K::Target: 'static + KeysInterface<Signer = Signer>,
-                       F::Target: 'static + FeeEstimator,
-                       L::Target: 'static + Logger,
-               {
-                       fn persist_manager(&self, _channel_manager: &ChannelManager<Signer, M, T, K, F, L>) -> Result<(), std::io::Error> {
-                               Err(std::io::Error::new(std::io::ErrorKind::Other, "test"))
-                       }
-
-                       fn persist_graph(&self, network_graph: &NetworkGraph) -> Result<(), std::io::Error> {
-                               FilesystemPersister::persist_network_graph(self.data_dir.clone(), network_graph)
-                       }
-               }
-
                let data_dir = nodes[0].persister.get_data_dir();
-               let persister = ChannelManagerErrorPersister{ data_dir };
+               let persister = Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test");
                let event_handler = |_: &_| {};
                let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
                match bg_processor.join() {
@@ -674,28 +676,8 @@ mod tests {
        fn test_network_graph_persist_error() {
                // Test that if we encounter an error during network graph persistence, an error gets returned.
                let nodes = create_nodes(2, "test_persist_network_graph_error".to_string());
-               struct NetworkGraphErrorPersister {
-                       data_dir: String,
-               }
-
-               impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L:Deref> super::Persister<Signer, M, T, K, F, L> for NetworkGraphErrorPersister where
-                       M::Target: 'static + chain::Watch<Signer>,
-                       T::Target: 'static + BroadcasterInterface,
-                       K::Target: 'static + KeysInterface<Signer = Signer>,
-                       F::Target: 'static + FeeEstimator,
-                       L::Target: 'static + Logger,
-               {
-                       fn persist_manager(&self, channel_manager: &ChannelManager<Signer, M, T, K, F, L>) -> Result<(), std::io::Error> {
-                               FilesystemPersister::persist_manager(self.data_dir.clone(), channel_manager)
-                       }
-
-                       fn persist_graph(&self, _network_graph: &NetworkGraph) -> Result<(), std::io::Error> {
-                               Err(std::io::Error::new(std::io::ErrorKind::Other, "test"))
-                       }
-               }
-
                let data_dir = nodes[0].persister.get_data_dir();
-               let persister = NetworkGraphErrorPersister { data_dir };
+               let persister = Persister::new(data_dir).with_graph_error(std::io::ErrorKind::Other, "test");
                let event_handler = |_: &_| {};
                let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
 
@@ -713,7 +695,7 @@ mod tests {
                let mut nodes = create_nodes(2, "test_background_event_handling".to_string());
                let channel_value = 100000;
                let data_dir = nodes[0].persister.get_data_dir();
-               let persister = Persister { data_dir: data_dir.clone() };
+               let persister = Persister::new(data_dir.clone());
 
                // Set up a background event handler for FundingGenerationReady events.
                let (sender, receiver) = std::sync::mpsc::sync_channel(1);
@@ -744,7 +726,7 @@ mod tests {
                // Set up a background event handler for SpendableOutputs events.
                let (sender, receiver) = std::sync::mpsc::sync_channel(1);
                let event_handler = move |event: &Event| sender.send(event.clone()).unwrap();
-               let bg_processor = BackgroundProcessor::start(Persister{ data_dir: data_dir.clone() }, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
+               let bg_processor = BackgroundProcessor::start(Persister::new(data_dir), event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
 
                // Force close the channel and check that the SpendableOutputs event was handled.
                nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id).unwrap();
@@ -770,7 +752,7 @@ mod tests {
 
                // Initiate the background processors to watch each node.
                let data_dir = nodes[0].persister.get_data_dir();
-               let persister = Persister { data_dir };
+               let persister = Persister::new(data_dir);
                let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
                let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger), random_seed_bytes);
                let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2)));