From: John Cantrell Date: Mon, 11 Apr 2022 17:50:31 +0000 (-0400) Subject: implement Persist and Persister with generic KVStorePersister trait X-Git-Tag: v0.0.107~58^2 X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=commitdiff_plain;h=49649442798af0344ac7ed5dbbb2c99d90f7a3fe;p=rust-lightning implement Persist and Persister with generic KVStorePersister trait --- diff --git a/lightning-background-processor/Cargo.toml b/lightning-background-processor/Cargo.toml index bd6d54d87..16ec763fb 100644 --- a/lightning-background-processor/Cargo.toml +++ b/lightning-background-processor/Cargo.toml @@ -16,8 +16,8 @@ rustdoc-args = ["--cfg", "docsrs"] [dependencies] bitcoin = "0.27" lightning = { version = "0.0.106", path = "../lightning", features = ["std"] } -lightning-persister = { version = "0.0.106", path = "../lightning-persister" } [dev-dependencies] lightning = { version = "0.0.106", path = "../lightning", features = ["_test_utils"] } lightning-invoice = { version = "0.14.0", path = "../lightning-invoice" } +lightning-persister = { version = "0.0.106", path = "../lightning-persister" } diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index 73f420c98..6beee915b 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -20,6 +20,7 @@ use lightning::ln::peer_handler::{CustomMessageHandler, PeerManager, SocketDescr use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler}; use lightning::util::events::{Event, EventHandler, EventsProvider}; use lightning::util::logger::Logger; +use lightning::util::persist::Persister; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::thread; @@ -80,22 +81,6 @@ const FIRST_NETWORK_PRUNE_TIMER: u64 = 60; #[cfg(test)] const FIRST_NETWORK_PRUNE_TIMER: u64 = 1; -/// Trait that handles persisting a [`ChannelManager`] and [`NetworkGraph`] to disk. -pub trait Persister -where - M::Target: 'static + chain::Watch, - T::Target: 'static + BroadcasterInterface, - K::Target: 'static + KeysInterface, - F::Target: 'static + FeeEstimator, - L::Target: 'static + Logger, -{ - /// Persist the given [`ChannelManager`] to disk, returning an error if persistence failed - /// (which will cause the [`BackgroundProcessor`] which called this method to exit). - fn persist_manager(&self, channel_manager: &ChannelManager) -> Result<(), std::io::Error>; - - /// Persist the given [`NetworkGraph`] to disk, returning an error if persistence failed. - fn persist_graph(&self, network_graph: &NetworkGraph) -> Result<(), std::io::Error>; -} /// Decorates an [`EventHandler`] with common functionality provided by standard [`EventHandler`]s. struct DecoratingEventHandler< @@ -138,12 +123,12 @@ impl BackgroundProcessor { /// /// [`Persister::persist_manager`] is responsible for writing out the [`ChannelManager`] to disk, and/or /// uploading to one or more backup services. See [`ChannelManager::write`] for writing out a - /// [`ChannelManager`]. See [`FilesystemPersister::persist_manager`] for Rust-Lightning's + /// [`ChannelManager`]. See the `lightning-persister` crate for LDK's /// provided implementation. /// /// [`Persister::persist_graph`] is responsible for writing out the [`NetworkGraph`] to disk. See - /// [`NetworkGraph::write`] for writing out a [`NetworkGraph`]. See [`FilesystemPersister::persist_network_graph`] - /// for Rust-Lightning's provided implementation. + /// [`NetworkGraph::write`] for writing out a [`NetworkGraph`]. See the `lightning-persister` crate + /// for LDK's provided implementation. /// /// Typically, users should either implement [`Persister::persist_manager`] to never return an /// error or call [`join`] and handle any error that may arise. For the latter case, @@ -161,8 +146,8 @@ impl BackgroundProcessor { /// [`stop`]: Self::stop /// [`ChannelManager`]: lightning::ln::channelmanager::ChannelManager /// [`ChannelManager::write`]: lightning::ln::channelmanager::ChannelManager#impl-Writeable - /// [`FilesystemPersister::persist_manager`]: lightning_persister::FilesystemPersister::persist_manager - /// [`FilesystemPersister::persist_network_graph`]: lightning_persister::FilesystemPersister::persist_network_graph + /// [`Persister::persist_manager`]: lightning::util::persist::Persister::persist_manager + /// [`Persister::persist_graph`]: lightning::util::persist::Persister::persist_graph /// [`NetworkGraph`]: lightning::routing::network_graph::NetworkGraph /// [`NetworkGraph::write`]: lightning::routing::network_graph::NetworkGraph#impl-Writeable pub fn start< @@ -180,7 +165,7 @@ impl BackgroundProcessor { CMH: 'static + Deref + Send + Sync, RMH: 'static + Deref + Send + Sync, EH: 'static + EventHandler + Send, - PS: 'static + Send + Persister, + PS: 'static + Deref + Send, M: 'static + Deref> + Send + Sync, CM: 'static + Deref> + Send + Sync, NG: 'static + Deref> + Send + Sync, @@ -202,6 +187,7 @@ impl BackgroundProcessor { CMH::Target: 'static + ChannelMessageHandler, RMH::Target: 'static + RoutingMessageHandler, UMH::Target: 'static + CustomMessageHandler, + PS::Target: 'static + Persister { let stop_thread = Arc::new(AtomicBool::new(false)); let stop_thread_clone = stop_thread.clone(); @@ -365,10 +351,11 @@ mod tests { use lightning::util::logger::Logger; use lightning::util::ser::Writeable; use lightning::util::test_utils; + use lightning::util::persist::KVStorePersister; use lightning_invoice::payment::{InvoicePayer, RetryAttempts}; use lightning_invoice::utils::DefaultRouter; use lightning_persister::FilesystemPersister; - use std::fs; + use std::fs::{self, File}; use std::ops::Deref; use std::path::PathBuf; use std::sync::{Arc, Mutex}; @@ -414,12 +401,14 @@ mod tests { struct Persister { data_dir: String, graph_error: Option<(std::io::ErrorKind, &'static str)>, - manager_error: Option<(std::io::ErrorKind, &'static str)> + manager_error: Option<(std::io::ErrorKind, &'static str)>, + filesystem_persister: FilesystemPersister, } impl Persister { fn new(data_dir: String) -> Self { - Self { data_dir, graph_error: None, manager_error: None } + let filesystem_persister = FilesystemPersister::new(data_dir.clone()); + Self { data_dir, graph_error: None, manager_error: None, filesystem_persister } } fn with_graph_error(self, error: std::io::ErrorKind, message: &'static str) -> Self { @@ -431,25 +420,21 @@ mod tests { } } - impl super::Persister for Persister where - M::Target: 'static + chain::Watch, - T::Target: 'static + BroadcasterInterface, - K::Target: 'static + KeysInterface, - F::Target: 'static + FeeEstimator, - L::Target: 'static + Logger, - { - fn persist_manager(&self, channel_manager: &ChannelManager) -> Result<(), std::io::Error> { - match self.manager_error { - None => FilesystemPersister::persist_manager(self.data_dir.clone(), channel_manager), - Some((error, message)) => Err(std::io::Error::new(error, message)), + impl KVStorePersister for Persister { + fn persist(&self, key: &str, object: &W) -> std::io::Result<()> { + if key == "manager" { + if let Some((error, message)) = self.manager_error { + return Err(std::io::Error::new(error, message)) + } } - } - fn persist_graph(&self, network_graph: &NetworkGraph) -> Result<(), std::io::Error> { - match self.graph_error { - None => FilesystemPersister::persist_network_graph(self.data_dir.clone(), network_graph), - Some((error, message)) => Err(std::io::Error::new(error, message)), + if key == "network_graph" { + if let Some((error, message)) = self.graph_error { + return Err(std::io::Error::new(error, message)) + } } + + self.filesystem_persister.persist(key, object) } } @@ -576,7 +561,7 @@ mod tests { // Initiate the background processors to watch each node. let data_dir = nodes[0].persister.get_data_dir(); - let persister = Persister::new(data_dir); + let persister = Arc::new(Persister::new(data_dir)); let event_handler = |_: &_| {}; let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone()); @@ -637,7 +622,7 @@ mod tests { // `FRESHNESS_TIMER`. let nodes = create_nodes(1, "test_timer_tick_called".to_string()); let data_dir = nodes[0].persister.get_data_dir(); - let persister = Persister::new(data_dir); + let persister = Arc::new(Persister::new(data_dir)); let event_handler = |_: &_| {}; let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone()); loop { @@ -660,7 +645,7 @@ mod tests { open_channel!(nodes[0], nodes[1], 100000); let data_dir = nodes[0].persister.get_data_dir(); - let persister = Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test"); + let persister = Arc::new(Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test")); let event_handler = |_: &_| {}; let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone()); match bg_processor.join() { @@ -677,7 +662,7 @@ mod tests { // Test that if we encounter an error during network graph persistence, an error gets returned. let nodes = create_nodes(2, "test_persist_network_graph_error".to_string()); let data_dir = nodes[0].persister.get_data_dir(); - let persister = Persister::new(data_dir).with_graph_error(std::io::ErrorKind::Other, "test"); + let persister = Arc::new(Persister::new(data_dir).with_graph_error(std::io::ErrorKind::Other, "test")); let event_handler = |_: &_| {}; let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone()); @@ -695,7 +680,7 @@ mod tests { let mut nodes = create_nodes(2, "test_background_event_handling".to_string()); let channel_value = 100000; let data_dir = nodes[0].persister.get_data_dir(); - let persister = Persister::new(data_dir.clone()); + let persister = Arc::new(Persister::new(data_dir.clone())); // Set up a background event handler for FundingGenerationReady events. let (sender, receiver) = std::sync::mpsc::sync_channel(1); @@ -726,7 +711,8 @@ mod tests { // Set up a background event handler for SpendableOutputs events. let (sender, receiver) = std::sync::mpsc::sync_channel(1); let event_handler = move |event: &Event| sender.send(event.clone()).unwrap(); - let bg_processor = BackgroundProcessor::start(Persister::new(data_dir), event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone()); + let persister = Arc::new(Persister::new(data_dir)); + let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone()); // Force close the channel and check that the SpendableOutputs event was handled. nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id).unwrap(); @@ -752,7 +738,7 @@ mod tests { // Initiate the background processors to watch each node. let data_dir = nodes[0].persister.get_data_dir(); - let persister = Persister::new(data_dir); + let persister = Arc::new(Persister::new(data_dir)); let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0))); let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger), random_seed_bytes); let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2))); diff --git a/lightning-persister/src/lib.rs b/lightning-persister/src/lib.rs index 450062127..c23baf8ad 100644 --- a/lightning-persister/src/lib.rs +++ b/lightning-persister/src/lib.rs @@ -15,20 +15,13 @@ extern crate bitcoin; extern crate libc; use bitcoin::hash_types::{BlockHash, Txid}; -use bitcoin::hashes::hex::{FromHex, ToHex}; -use lightning::routing::network_graph::NetworkGraph; -use crate::util::DiskWriteable; -use lightning::chain; -use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator}; -use lightning::chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate}; -use lightning::chain::chainmonitor; +use bitcoin::hashes::hex::FromHex; +use lightning::chain::channelmonitor::ChannelMonitor; use lightning::chain::keysinterface::{Sign, KeysInterface}; -use lightning::chain::transaction::OutPoint; -use lightning::ln::channelmanager::ChannelManager; -use lightning::util::logger::Logger; use lightning::util::ser::{ReadableArgs, Writeable}; +use lightning::util::persist::KVStorePersister; use std::fs; -use std::io::{Cursor, Error, Write}; +use std::io::Cursor; use std::ops::Deref; use std::path::{Path, PathBuf}; @@ -48,31 +41,6 @@ pub struct FilesystemPersister { path_to_channel_data: String, } -impl DiskWriteable for ChannelMonitor { - fn write_to_file(&self, writer: &mut W) -> Result<(), Error> { - self.write(writer) - } -} - -impl DiskWriteable for ChannelManager -where - M::Target: chain::Watch, - T::Target: BroadcasterInterface, - K::Target: KeysInterface, - F::Target: FeeEstimator, - L::Target: Logger, -{ - fn write_to_file(&self, writer: &mut W) -> Result<(), std::io::Error> { - self.write(writer) - } -} - -impl DiskWriteable for NetworkGraph { - fn write_to_file(&self, writer: &mut W) -> Result<(), std::io::Error> { - self.write(writer) - } -} - impl FilesystemPersister { /// Initialize a new FilesystemPersister and set the path to the individual channels' /// files. @@ -87,43 +55,14 @@ impl FilesystemPersister { self.path_to_channel_data.clone() } - pub(crate) fn path_to_monitor_data(&self) -> PathBuf { - let mut path = PathBuf::from(self.path_to_channel_data.clone()); - path.push("monitors"); - path - } - - /// Writes the provided `ChannelManager` to the path provided at `FilesystemPersister` - /// initialization, within a file called "manager". - pub fn persist_manager( - data_dir: String, - manager: &ChannelManager - ) -> Result<(), std::io::Error> - where - M::Target: chain::Watch, - T::Target: BroadcasterInterface, - K::Target: KeysInterface, - F::Target: FeeEstimator, - L::Target: Logger, - { - let path = PathBuf::from(data_dir); - util::write_to_file(path, "manager".to_string(), manager) - } - - /// Write the provided `NetworkGraph` to the path provided at `FilesystemPersister` - /// initialization, within a file called "network_graph" - pub fn persist_network_graph(data_dir: String, network_graph: &NetworkGraph) -> Result<(), std::io::Error> { - let path = PathBuf::from(data_dir); - util::write_to_file(path, "network_graph".to_string(), network_graph) - } - /// Read `ChannelMonitor`s from disk. pub fn read_channelmonitors ( &self, keys_manager: K ) -> Result)>, std::io::Error> where K::Target: KeysInterface + Sized, { - let path = self.path_to_monitor_data(); + let mut path = PathBuf::from(&self.path_to_channel_data); + path.push("monitors"); if !Path::new(&path).exists() { return Ok(Vec::new()); } @@ -180,22 +119,11 @@ impl FilesystemPersister { } } -impl chainmonitor::Persist for FilesystemPersister { - // TODO: We really need a way for the persister to inform the user that its time to crash/shut - // down once these start returning failure. - // A PermanentFailure implies we need to shut down since we're force-closing channels without - // even broadcasting! - - fn persist_new_channel(&self, funding_txo: OutPoint, monitor: &ChannelMonitor, _update_id: chainmonitor::MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> { - let filename = format!("{}_{}", funding_txo.txid.to_hex(), funding_txo.index); - util::write_to_file(self.path_to_monitor_data(), filename, monitor) - .map_err(|_| chain::ChannelMonitorUpdateErr::PermanentFailure) - } - - fn update_persisted_channel(&self, funding_txo: OutPoint, _update: &Option, monitor: &ChannelMonitor, _update_id: chainmonitor::MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> { - let filename = format!("{}_{}", funding_txo.txid.to_hex(), funding_txo.index); - util::write_to_file(self.path_to_monitor_data(), filename, monitor) - .map_err(|_| chain::ChannelMonitorUpdateErr::PermanentFailure) +impl KVStorePersister for FilesystemPersister { + fn persist(&self, key: &str, object: &W) -> std::io::Result<()> { + let mut dest_file = PathBuf::from(self.path_to_channel_data.clone()); + dest_file.push(key); + util::write_to_file(dest_file, object) } } diff --git a/lightning-persister/src/util.rs b/lightning-persister/src/util.rs index f26296794..25bd00f5e 100644 --- a/lightning-persister/src/util.rs +++ b/lightning-persister/src/util.rs @@ -2,27 +2,20 @@ extern crate winapi; use std::fs; -use std::path::{Path, PathBuf}; -use std::io::{BufWriter, Write}; +use std::path::PathBuf; +use std::io::BufWriter; #[cfg(not(target_os = "windows"))] use std::os::unix::io::AsRawFd; +use lightning::util::ser::Writeable; + #[cfg(target_os = "windows")] use { std::ffi::OsStr, std::os::windows::ffi::OsStrExt }; -pub(crate) trait DiskWriteable { - fn write_to_file(&self, writer: &mut W) -> Result<(), std::io::Error>; -} - -pub(crate) fn get_full_filepath(mut filepath: PathBuf, filename: String) -> String { - filepath.push(filename); - filepath.to_str().unwrap().to_string() -} - #[cfg(target_os = "windows")] macro_rules! call { ($e: expr) => ( @@ -40,45 +33,43 @@ fn path_to_windows_str>(path: T) -> Vec(path: PathBuf, filename: String, data: &D) -> std::io::Result<()> { - fs::create_dir_all(path.clone())?; +pub(crate) fn write_to_file(dest_file: PathBuf, data: &W) -> std::io::Result<()> { + let mut tmp_file = dest_file.clone(); + tmp_file.set_extension("tmp"); + + let parent_directory = dest_file.parent().unwrap(); + fs::create_dir_all(parent_directory)?; // Do a crazy dance with lots of fsync()s to be overly cautious here... // We never want to end up in a state where we've lost the old data, or end up using the // old data on power loss after we've returned. // The way to atomically write a file on Unix platforms is: // open(tmpname), write(tmpfile), fsync(tmpfile), close(tmpfile), rename(), fsync(dir) - let filename_with_path = get_full_filepath(path, filename); - let tmp_filename = format!("{}.tmp", filename_with_path.clone()); - { // Note that going by rust-lang/rust@d602a6b, on MacOS it is only safe to use // rust stdlib 1.36 or higher. - let mut buf = BufWriter::new(fs::File::create(&tmp_filename)?); - data.write_to_file(&mut buf)?; + let mut buf = BufWriter::new(fs::File::create(&tmp_file)?); + data.write(&mut buf)?; buf.into_inner()?.sync_all()?; } // Fsync the parent directory on Unix. #[cfg(not(target_os = "windows"))] { - fs::rename(&tmp_filename, &filename_with_path)?; - let path = Path::new(&filename_with_path).parent().unwrap(); - let dir_file = fs::OpenOptions::new().read(true).open(path)?; + fs::rename(&tmp_file, &dest_file)?; + let dir_file = fs::OpenOptions::new().read(true).open(parent_directory)?; unsafe { libc::fsync(dir_file.as_raw_fd()); } } #[cfg(target_os = "windows")] { - let src = PathBuf::from(tmp_filename.clone()); - let dst = PathBuf::from(filename_with_path.clone()); - if Path::new(&filename_with_path.clone()).exists() { + if dest_file.exists() { unsafe {winapi::um::winbase::ReplaceFileW( - path_to_windows_str(dst).as_ptr(), path_to_windows_str(src).as_ptr(), std::ptr::null(), + path_to_windows_str(dest_file).as_ptr(), path_to_windows_str(tmp_file).as_ptr(), std::ptr::null(), winapi::um::winbase::REPLACEFILE_IGNORE_MERGE_ERRORS, std::ptr::null_mut() as *mut winapi::ctypes::c_void, std::ptr::null_mut() as *mut winapi::ctypes::c_void )}; } else { call!(unsafe {winapi::um::winbase::MoveFileExW( - path_to_windows_str(src).as_ptr(), path_to_windows_str(dst).as_ptr(), + path_to_windows_str(tmp_file).as_ptr(), path_to_windows_str(dest_file).as_ptr(), winapi::um::winbase::MOVEFILE_WRITE_THROUGH | winapi::um::winbase::MOVEFILE_REPLACE_EXISTING )}); } @@ -88,15 +79,17 @@ pub(crate) fn write_to_file(path: PathBuf, filename: String, d #[cfg(test)] mod tests { - use super::{DiskWriteable, get_full_filepath, write_to_file}; + use lightning::util::ser::{Writer, Writeable}; + + use super::{write_to_file}; use std::fs; use std::io; use std::io::Write; use std::path::PathBuf; struct TestWriteable{} - impl DiskWriteable for TestWriteable { - fn write_to_file(&self, writer: &mut W) -> Result<(), io::Error> { + impl Writeable for TestWriteable { + fn write(&self, writer: &mut W) -> Result<(), std::io::Error> { writer.write_all(&[42; 1]) } } @@ -114,7 +107,9 @@ mod tests { let mut perms = fs::metadata(path.to_string()).unwrap().permissions(); perms.set_readonly(true); fs::set_permissions(path.to_string(), perms).unwrap(); - match write_to_file(PathBuf::from(path.to_string()), filename, &test_writeable) { + let mut dest_file = PathBuf::from(path); + dest_file.push(filename); + match write_to_file(dest_file, &test_writeable) { Err(e) => assert_eq!(e.kind(), io::ErrorKind::PermissionDenied), _ => panic!("Unexpected error message") } @@ -132,10 +127,12 @@ mod tests { fn test_rename_failure() { let test_writeable = TestWriteable{}; let filename = "test_rename_failure_filename"; - let path = PathBuf::from("test_rename_failure_dir"); + let path = "test_rename_failure_dir"; + let mut dest_file = PathBuf::from(path); + dest_file.push(filename); // Create the channel data file and make it a directory. - fs::create_dir_all(get_full_filepath(path.clone(), filename.to_string())).unwrap(); - match write_to_file(path.clone(), filename.to_string(), &test_writeable) { + fs::create_dir_all(dest_file.clone()).unwrap(); + match write_to_file(dest_file, &test_writeable) { Err(e) => assert_eq!(e.raw_os_error(), Some(libc::EISDIR)), _ => panic!("Unexpected Ok(())") } @@ -145,16 +142,18 @@ mod tests { #[test] fn test_diskwriteable_failure() { struct FailingWriteable {} - impl DiskWriteable for FailingWriteable { - fn write_to_file(&self, _writer: &mut W) -> Result<(), std::io::Error> { + impl Writeable for FailingWriteable { + fn write(&self, _writer: &mut W) -> Result<(), std::io::Error> { Err(std::io::Error::new(std::io::ErrorKind::Other, "expected failure")) } } let filename = "test_diskwriteable_failure"; - let path = PathBuf::from("test_diskwriteable_failure_dir"); + let path = "test_diskwriteable_failure_dir"; let test_writeable = FailingWriteable{}; - match write_to_file(path.clone(), filename.to_string(), &test_writeable) { + let mut dest_file = PathBuf::from(path); + dest_file.push(filename); + match write_to_file(dest_file, &test_writeable) { Err(e) => { assert_eq!(e.kind(), std::io::ErrorKind::Other); assert_eq!(e.get_ref().unwrap().to_string(), "expected failure"); @@ -171,12 +170,13 @@ mod tests { fn test_tmp_file_creation_failure() { let test_writeable = TestWriteable{}; let filename = "test_tmp_file_creation_failure_filename".to_string(); - let path = PathBuf::from("test_tmp_file_creation_failure_dir"); - - // Create the tmp file and make it a directory. - let tmp_path = get_full_filepath(path.clone(), format!("{}.tmp", filename.clone())); - fs::create_dir_all(tmp_path).unwrap(); - match write_to_file(path, filename, &test_writeable) { + let path = "test_tmp_file_creation_failure_dir"; + let mut dest_file = PathBuf::from(path); + dest_file.push(filename); + let mut tmp_file = dest_file.clone(); + tmp_file.set_extension("tmp"); + fs::create_dir_all(tmp_file).unwrap(); + match write_to_file(dest_file, &test_writeable) { Err(e) => { #[cfg(not(target_os = "windows"))] assert_eq!(e.raw_os_error(), Some(libc::EISDIR)); diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index a1e92a0f8..95826b7e0 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -20,6 +20,7 @@ pub mod errors; pub mod ser; pub mod message_signing; pub mod invoice; +pub mod persist; pub(crate) mod atomic_counter; pub(crate) mod byte_utils; diff --git a/lightning/src/util/persist.rs b/lightning/src/util/persist.rs new file mode 100644 index 000000000..9476331c1 --- /dev/null +++ b/lightning/src/util/persist.rs @@ -0,0 +1,77 @@ +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +//! This module contains a simple key-value store trait KVStorePersister that +//! allows one to implement the persistence for [`ChannelManager`], [`NetworkGraph`], +//! and [`ChannelMonitor`] all in one place. + +use core::ops::Deref; +use bitcoin::hashes::hex::ToHex; +use io::{self}; + +use crate::{chain::{keysinterface::{Sign, KeysInterface}, self, transaction::{OutPoint}, chaininterface::{BroadcasterInterface, FeeEstimator}, chainmonitor::{Persist, MonitorUpdateId}, channelmonitor::{ChannelMonitor, ChannelMonitorUpdate}}, ln::channelmanager::ChannelManager, routing::network_graph::NetworkGraph}; +use super::{logger::Logger, ser::Writeable}; + +/// Trait for a key-value store for persisting some writeable object at some key +/// Implementing `KVStorePersister` provides auto-implementations for [`Persister`] +/// and [`Persist`] traits. It uses "manager", "network_graph", +/// and "monitors/{funding_txo_id}_{funding_txo_index}" for keys. +pub trait KVStorePersister { + /// Persist the given writeable using the provided key + fn persist(&self, key: &str, object: &W) -> io::Result<()>; +} + +/// Trait that handles persisting a [`ChannelManager`] and [`NetworkGraph`] to disk. +pub trait Persister + where M::Target: 'static + chain::Watch, + T::Target: 'static + BroadcasterInterface, + K::Target: 'static + KeysInterface, + F::Target: 'static + FeeEstimator, + L::Target: 'static + Logger, +{ + /// Persist the given ['ChannelManager'] to disk, returning an error if persistence failed. + fn persist_manager(&self, channel_manager: &ChannelManager) -> Result<(), io::Error>; + + /// Persist the given [`NetworkGraph`] to disk, returning an error if persistence failed. + fn persist_graph(&self, network_graph: &NetworkGraph) -> Result<(), io::Error>; +} + +impl Persister for A + where M::Target: 'static + chain::Watch, + T::Target: 'static + BroadcasterInterface, + K::Target: 'static + KeysInterface, + F::Target: 'static + FeeEstimator, + L::Target: 'static + Logger, +{ + /// Persist the given ['ChannelManager'] to disk, returning an error if persistence failed. + fn persist_manager(&self, channel_manager: &ChannelManager) -> Result<(), io::Error> { + self.persist("manager", channel_manager) + } + + /// Persist the given [`NetworkGraph`] to disk, returning an error if persistence failed. + fn persist_graph(&self, network_graph: &NetworkGraph) -> Result<(), io::Error> { + self.persist("network_graph", network_graph) + } +} + +impl Persist for K { + // TODO: We really need a way for the persister to inform the user that its time to crash/shut + // down once these start returning failure. + // A PermanentFailure implies we need to shut down since we're force-closing channels without + // even broadcasting! + + fn persist_new_channel(&self, funding_txo: OutPoint, monitor: &ChannelMonitor, _update_id: MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> { + let key = format!("monitors/{}_{}", funding_txo.txid.to_hex(), funding_txo.index); + self.persist(&key, monitor) + .map_err(|_| chain::ChannelMonitorUpdateErr::PermanentFailure) + } + + fn update_persisted_channel(&self, funding_txo: OutPoint, _update: &Option, monitor: &ChannelMonitor, _update_id: MonitorUpdateId) -> Result<(), chain::ChannelMonitorUpdateErr> { + let key = format!("monitors/{}_{}", funding_txo.txid.to_hex(), funding_txo.index); + self.persist(&key, monitor) + .map_err(|_| chain::ChannelMonitorUpdateErr::PermanentFailure) + } +}