Use a trait to handle ChannelManager persistence instead of an Fn
[rust-lightning] / background-processor / src / lib.rs
index 120305446475eef8a8321394f177695d5f9d0029..de9aa286a623959a41c5d8d2f0a244283c24e353 100644 (file)
@@ -48,6 +48,38 @@ const FRESHNESS_TIMER: u64 = 60;
 #[cfg(test)]
 const FRESHNESS_TIMER: u64 = 1;
 
+/// Trait which handles persisting a [`ChannelManager`] to disk.
+///
+/// [`ChannelManager`]: lightning::ln::channelmanager::ChannelManager
+pub trait ChannelManagerPersister<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
+where
+       M::Target: 'static + chain::Watch<Signer>,
+       T::Target: 'static + BroadcasterInterface,
+       K::Target: 'static + KeysInterface<Signer = Signer>,
+       F::Target: 'static + FeeEstimator,
+       L::Target: 'static + Logger,
+{
+       /// Persist the given [`ChannelManager`] to disk, returning an error if persistence failed
+       /// (which will cause the [`BackgroundProcessor`] which called this method to exit.
+       ///
+       /// [`ChannelManager`]: lightning::ln::channelmanager::ChannelManager
+       fn persist_manager(&self, channel_manager: &ChannelManager<Signer, M, T, K, F, L>) -> Result<(), std::io::Error>;
+}
+
+impl<Fun, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
+ChannelManagerPersister<Signer, M, T, K, F, L> for Fun where
+       M::Target: 'static + chain::Watch<Signer>,
+       T::Target: 'static + BroadcasterInterface,
+       K::Target: 'static + KeysInterface<Signer = Signer>,
+       F::Target: 'static + FeeEstimator,
+       L::Target: 'static + Logger,
+       Fun: Fn(&ChannelManager<Signer, M, T, K, F, L>) -> Result<(), std::io::Error>,
+{
+       fn persist_manager(&self, channel_manager: &ChannelManager<Signer, M, T, K, F, L>) -> Result<(), std::io::Error> {
+               self(channel_manager)
+       }
+}
+
 impl BackgroundProcessor {
        /// Start a background thread that takes care of responsibilities enumerated in the top-level
        /// documentation.
@@ -68,28 +100,28 @@ impl BackgroundProcessor {
        /// [`ChannelManager::write`]: lightning::ln::channelmanager::ChannelManager#impl-Writeable
        /// [`FilesystemPersister::persist_manager`]: lightning_persister::FilesystemPersister::persist_manager
        pub fn start<
-               PM, Signer,
+               Signer: 'static + Sign,
                M: 'static + Deref + Send + Sync,
                T: 'static + Deref + Send + Sync,
                K: 'static + Deref + Send + Sync,
                F: 'static + Deref + Send + Sync,
                L: 'static + Deref + Send + Sync,
                Descriptor: 'static + SocketDescriptor + Send + Sync,
-               CM: 'static + Deref + Send + Sync,
-               RM: 'static + Deref + Send + Sync
-       >(
-               persist_channel_manager: PM, channel_manager: Arc<ChannelManager<Signer, M, T, K, F, L>>,
-               peer_manager: Arc<PeerManager<Descriptor, CM, RM, L>>, logger: L,
-       ) -> Self where
-               Signer: 'static + Sign,
+               CMH: 'static + Deref + Send + Sync,
+               RMH: 'static + Deref + Send + Sync,
+               CMP: 'static + Send + ChannelManagerPersister<Signer, M, T, K, F, L>,
+               CM: 'static + Deref<Target = ChannelManager<Signer, M, T, K, F, L>> + Send + Sync,
+               PM: 'static + Deref<Target = PeerManager<Descriptor, CMH, RMH, L>> + Send + Sync,
+       >
+       (handler: CMP, channel_manager: CM, peer_manager: PM, logger: L) -> Self
+       where
                M::Target: 'static + chain::Watch<Signer>,
                T::Target: 'static + BroadcasterInterface,
                K::Target: 'static + KeysInterface<Signer = Signer>,
                F::Target: 'static + FeeEstimator,
                L::Target: 'static + Logger,
-               CM::Target: 'static + ChannelMessageHandler,
-               RM::Target: 'static + RoutingMessageHandler,
-               PM: 'static + Send + Fn(&ChannelManager<Signer, M, T, K, F, L>) -> Result<(), std::io::Error>,
+               CMH::Target: 'static + ChannelMessageHandler,
+               RMH::Target: 'static + RoutingMessageHandler,
        {
                let stop_thread = Arc::new(AtomicBool::new(false));
                let stop_thread_clone = stop_thread.clone();
@@ -100,7 +132,7 @@ impl BackgroundProcessor {
                                let updates_available =
                                        channel_manager.await_persistable_update_timeout(Duration::from_millis(100));
                                if updates_available {
-                                       persist_channel_manager(&*channel_manager)?;
+                                       handler.persist_manager(&*channel_manager)?;
                                }
                                // Exit the loop if the background processor was requested to stop.
                                if stop_thread.load(Ordering::Acquire) == true {