Merge pull request #721 from TheBlueMatt/2020-09-649-bindings
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Thu, 12 Nov 2020 21:22:54 +0000 (13:22 -0800)
committerGitHub <noreply@github.com>
Thu, 12 Nov 2020 21:22:54 +0000 (13:22 -0800)
Bindings Updates for #649

22 files changed:
CONTRIBUTING.md
Cargo.toml
fuzz/src/chanmon_consistency.rs
fuzz/src/chanmon_deser.rs
fuzz/src/full_stack.rs
fuzz/src/utils/mod.rs
fuzz/src/utils/test_persister.rs [new file with mode: 0644]
lightning-net-tokio/src/lib.rs
lightning-persister/Cargo.toml [new file with mode: 0644]
lightning-persister/src/lib.rs [new file with mode: 0644]
lightning/Cargo.toml
lightning/src/chain/chainmonitor.rs
lightning/src/chain/channelmonitor.rs
lightning/src/lib.rs
lightning/src/ln/chanmon_update_fail_tests.rs
lightning/src/ln/channel.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs
lightning/src/ln/functional_tests.rs
lightning/src/ln/mod.rs
lightning/src/util/mod.rs
lightning/src/util/test_utils.rs

index 241c64208565459a0abf5a5d9818365add6de617..e8a57d85f7f7e8a9cebe857f6011e7c84dee6b6d 100644 (file)
@@ -24,6 +24,25 @@ requests.
 Major projects are tracked [here](https://github.com/rust-bitcoin/rust-lightning/projects).
 Major milestones are tracked [here](https://github.com/rust-bitcoin/rust-lightning/milestones?direction=asc&sort=title&state=open).
 
+Getting Started
+---------------
+
+First and foremost, start small.
+
+This doesn't mean don't be ambitious with the breadth and depth of your contributions but rather
+understand the project context and culture before investing an asymmetric number of hours on
+development compared to your merged work.
+
+Even if you have an extensive open source background or sound software engineering skills, consider
+that the reviewers' comprehension of the code is as much important as technical correctness.
+
+It's very welcome to ask for review, either on IRC or LDK Slack. And also for reviewers, it's nice
+to provide timelines when you hope to fulfill the request while bearing in mind for both sides that's
+a "soft" commitment.
+
+If you're eager to increase the velocity of the dev process, reviewing other contributors work is
+the best you can do while waiting review on yours.
+
 Contribution Workflow
 ---------------------
 
index 56f4ac32f3f629c3553adb07114e21370ee1c80b..c43e7927581432e9834aed21a16329aba92f80b3 100644 (file)
@@ -3,6 +3,7 @@
 members = [
     "lightning",
     "lightning-net-tokio",
+    "lightning-persister",
 ]
 
 # Our tests do actual crypo and lots of work, the tradeoff for -O1 is well worth it.
index 1650e2e25f7060fa8c63c568257983784bcbe721..d88cc71fbf50a4175690d76a62bedfc5f3df808c 100644 (file)
@@ -48,6 +48,7 @@ use lightning::routing::router::{Route, RouteHop};
 
 
 use utils::test_logger;
+use utils::test_persister::TestPersister;
 
 use bitcoin::secp256k1::key::{PublicKey,SecretKey};
 use bitcoin::secp256k1::Secp256k1;
@@ -84,7 +85,7 @@ impl Writer for VecWriter {
 
 struct TestChainMonitor {
        pub logger: Arc<dyn Logger>,
-       pub chain_monitor: Arc<chainmonitor::ChainMonitor<EnforcingChannelKeys, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>>>,
+       pub chain_monitor: Arc<chainmonitor::ChainMonitor<EnforcingChannelKeys, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>, Arc<TestPersister>>>,
        pub update_ret: Mutex<Result<(), channelmonitor::ChannelMonitorUpdateErr>>,
        // If we reload a node with an old copy of ChannelMonitors, the ChannelManager deserialization
        // logic will automatically force-close our channels for us (as we don't have an up-to-date
@@ -95,9 +96,9 @@ struct TestChainMonitor {
        pub should_update_manager: atomic::AtomicBool,
 }
 impl TestChainMonitor {
-       pub fn new(broadcaster: Arc<TestBroadcaster>, logger: Arc<dyn Logger>, feeest: Arc<FuzzEstimator>) -> Self {
+       pub fn new(broadcaster: Arc<TestBroadcaster>, logger: Arc<dyn Logger>, feeest: Arc<FuzzEstimator>, persister: Arc<TestPersister>) -> Self {
                Self {
-                       chain_monitor: Arc::new(chainmonitor::ChainMonitor::new(None, broadcaster, logger.clone(), feeest)),
+                       chain_monitor: Arc::new(chainmonitor::ChainMonitor::new(None, broadcaster, logger.clone(), feeest, persister)),
                        logger,
                        update_ret: Mutex::new(Ok(())),
                        latest_monitors: Mutex::new(HashMap::new()),
@@ -110,7 +111,7 @@ impl chain::Watch for TestChainMonitor {
 
        fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor<EnforcingChannelKeys>) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
                let mut ser = VecWriter(Vec::new());
-               monitor.write_for_disk(&mut ser).unwrap();
+               monitor.serialize_for_disk(&mut ser).unwrap();
                if let Some(_) = self.latest_monitors.lock().unwrap().insert(funding_txo, (monitor.get_latest_update_id(), ser.0)) {
                        panic!("Already had monitor pre-watch_channel");
                }
@@ -127,9 +128,9 @@ impl chain::Watch for TestChainMonitor {
                };
                let mut deserialized_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::
                        read(&mut Cursor::new(&map_entry.get().1)).unwrap().1;
-               deserialized_monitor.update_monitor(update.clone(), &&TestBroadcaster {}, &self.logger).unwrap();
+               deserialized_monitor.update_monitor(&update, &&TestBroadcaster {}, &self.logger).unwrap();
                let mut ser = VecWriter(Vec::new());
-               deserialized_monitor.write_for_disk(&mut ser).unwrap();
+               deserialized_monitor.serialize_for_disk(&mut ser).unwrap();
                map_entry.insert((update.update_id, ser.0));
                self.should_update_manager.store(true, atomic::Ordering::Relaxed);
                self.update_ret.lock().unwrap().clone()
@@ -192,7 +193,7 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
        macro_rules! make_node {
                ($node_id: expr) => { {
                        let logger: Arc<dyn Logger> = Arc::new(test_logger::TestLogger::new($node_id.to_string(), out.clone()));
-                       let monitor = Arc::new(TestChainMonitor::new(broadcast.clone(), logger.clone(), fee_est.clone()));
+                       let monitor = Arc::new(TestChainMonitor::new(broadcast.clone(), logger.clone(), fee_est.clone(), Arc::new(TestPersister{})));
 
                        let keys_manager = Arc::new(KeyProvider { node_id: $node_id, rand_bytes_id: atomic::AtomicU8::new(0) });
                        let mut config = UserConfig::default();
@@ -207,7 +208,7 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
        macro_rules! reload_node {
                ($ser: expr, $node_id: expr, $old_monitors: expr) => { {
                        let logger: Arc<dyn Logger> = Arc::new(test_logger::TestLogger::new($node_id.to_string(), out.clone()));
-                       let chain_monitor = Arc::new(TestChainMonitor::new(broadcast.clone(), logger.clone(), fee_est.clone()));
+                       let chain_monitor = Arc::new(TestChainMonitor::new(broadcast.clone(), logger.clone(), fee_est.clone(), Arc::new(TestPersister{})));
 
                        let keys_manager = Arc::new(KeyProvider { node_id: $node_id, rand_bytes_id: atomic::AtomicU8::new(0) });
                        let mut config = UserConfig::default();
index 5a76340ff309447c9cbae3baa0b112353b027210..fd326cc2ec85bdc72670da4cae4f7388366c05af 100644 (file)
@@ -26,7 +26,7 @@ impl Writer for VecWriter {
 pub fn do_test<Out: test_logger::Output>(data: &[u8], _out: Out) {
        if let Ok((latest_block_hash, monitor)) = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::read(&mut Cursor::new(data)) {
                let mut w = VecWriter(Vec::new());
-               monitor.write_for_disk(&mut w).unwrap();
+               monitor.serialize_for_disk(&mut w).unwrap();
                let deserialized_copy = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::read(&mut Cursor::new(&w.0)).unwrap();
                assert!(latest_block_hash == deserialized_copy.0);
                assert!(monitor == deserialized_copy.1);
index 1ed17b9ea3ff0831e4767854cb1f4c54a1ae168f..3aeb3d233b0b19d62196466c3de3f2a23dfe98dd 100644 (file)
@@ -40,6 +40,7 @@ use lightning::util::logger::Logger;
 use lightning::util::config::UserConfig;
 
 use utils::test_logger;
+use utils::test_persister::TestPersister;
 
 use bitcoin::secp256k1::key::{PublicKey,SecretKey};
 use bitcoin::secp256k1::Secp256k1;
@@ -145,13 +146,13 @@ impl<'a> std::hash::Hash for Peer<'a> {
 
 type ChannelMan = ChannelManager<
        EnforcingChannelKeys,
-       Arc<chainmonitor::ChainMonitor<EnforcingChannelKeys, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>>>,
+       Arc<chainmonitor::ChainMonitor<EnforcingChannelKeys, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>, Arc<TestPersister>>>,
        Arc<TestBroadcaster>, Arc<KeyProvider>, Arc<FuzzEstimator>, Arc<dyn Logger>>;
 type PeerMan<'a> = PeerManager<Peer<'a>, Arc<ChannelMan>, Arc<NetGraphMsgHandler<Arc<dyn chain::Access>, Arc<dyn Logger>>>, Arc<dyn Logger>>;
 
 struct MoneyLossDetector<'a> {
        manager: Arc<ChannelMan>,
-       monitor: Arc<chainmonitor::ChainMonitor<EnforcingChannelKeys, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>>>,
+       monitor: Arc<chainmonitor::ChainMonitor<EnforcingChannelKeys, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>, Arc<TestPersister>>>,
        handler: PeerMan<'a>,
 
        peers: &'a RefCell<[bool; 256]>,
@@ -165,7 +166,7 @@ struct MoneyLossDetector<'a> {
 impl<'a> MoneyLossDetector<'a> {
        pub fn new(peers: &'a RefCell<[bool; 256]>,
                   manager: Arc<ChannelMan>,
-                  monitor: Arc<chainmonitor::ChainMonitor<EnforcingChannelKeys, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>>>,
+                  monitor: Arc<chainmonitor::ChainMonitor<EnforcingChannelKeys, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>, Arc<TestPersister>>>,
                   handler: PeerMan<'a>) -> Self {
                MoneyLossDetector {
                        manager,
@@ -333,7 +334,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
        };
 
        let broadcast = Arc::new(TestBroadcaster{});
-       let monitor = Arc::new(chainmonitor::ChainMonitor::new(None, broadcast.clone(), Arc::clone(&logger), fee_est.clone()));
+       let monitor = Arc::new(chainmonitor::ChainMonitor::new(None, broadcast.clone(), Arc::clone(&logger), fee_est.clone(), Arc::new(TestPersister{})));
 
        let keys_manager = Arc::new(KeyProvider { node_secret: our_network_key.clone(), counter: AtomicU64::new(0) });
        let mut config = UserConfig::default();
index bb5b00a5b546dfeed23eee1be9558e090abf0476..937eee6b6b2d4879a3dda91b70070970a45d1037 100644 (file)
@@ -8,3 +8,4 @@
 // licenses.
 
 pub mod test_logger;
+pub mod test_persister;
diff --git a/fuzz/src/utils/test_persister.rs b/fuzz/src/utils/test_persister.rs
new file mode 100644 (file)
index 0000000..0bd6091
--- /dev/null
@@ -0,0 +1,14 @@
+use lightning::chain::channelmonitor;
+use lightning::chain::transaction::OutPoint;
+use lightning::util::enforcing_trait_impls::EnforcingChannelKeys;
+
+pub struct TestPersister {}
+impl channelmonitor::Persist<EnforcingChannelKeys> for TestPersister {
+       fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor<EnforcingChannelKeys>) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
+               Ok(())
+       }
+
+       fn update_persisted_channel(&self, _funding_txo: OutPoint, _update: &channelmonitor::ChannelMonitorUpdate, _data: &channelmonitor::ChannelMonitor<EnforcingChannelKeys>) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
+               Ok(())
+       }
+}
index e84ee76229fecb2a3090c72ed1865102f119eb74..36384380fb14a458b31e132446ae374ffaf0b51d 100644 (file)
@@ -36,7 +36,8 @@
 //! type Logger = dyn lightning::util::logger::Logger;
 //! type ChainAccess = dyn lightning::chain::Access;
 //! type ChainFilter = dyn lightning::chain::Filter;
-//! type ChainMonitor = lightning::chain::chainmonitor::ChainMonitor<lightning::chain::keysinterface::InMemoryChannelKeys, Arc<ChainFilter>, Arc<TxBroadcaster>, Arc<FeeEstimator>, Arc<Logger>>;
+//! type DataPersister = dyn lightning::chain::channelmonitor::Persist<lightning::chain::keysinterface::InMemoryChannelKeys>;
+//! type ChainMonitor = lightning::chain::chainmonitor::ChainMonitor<lightning::chain::keysinterface::InMemoryChannelKeys, Arc<ChainFilter>, Arc<TxBroadcaster>, Arc<FeeEstimator>, Arc<Logger>, Arc<DataPersister>>;
 //! type ChannelManager = lightning::ln::channelmanager::SimpleArcChannelManager<ChainMonitor, TxBroadcaster, FeeEstimator, Logger>;
 //! type PeerManager = lightning::ln::peer_handler::SimpleArcPeerManager<lightning_net_tokio::SocketDescriptor, ChainMonitor, TxBroadcaster, FeeEstimator, ChainAccess, Logger>;
 //!
diff --git a/lightning-persister/Cargo.toml b/lightning-persister/Cargo.toml
new file mode 100644 (file)
index 0000000..44bc952
--- /dev/null
@@ -0,0 +1,20 @@
+[package]
+name = "lightning-persister"
+version = "0.0.1"
+authors = ["Valentine Wallace", "Matt Corallo"]
+license = "Apache-2.0"
+description = """
+Utilities to manage channel data persistence and retrieval.
+"""
+
+[dependencies]
+bitcoin = "0.24"
+lightning = { version = "0.0.11", path = "../lightning" }
+libc = "0.2"
+
+[dev-dependencies.bitcoin]
+version = "0.24"
+features = ["bitcoinconsensus"]
+
+[dev-dependencies]
+lightning = { version = "0.0.11", path = "../lightning", features = ["_test_utils"] }
diff --git a/lightning-persister/src/lib.rs b/lightning-persister/src/lib.rs
new file mode 100644 (file)
index 0000000..48d4d0a
--- /dev/null
@@ -0,0 +1,413 @@
+extern crate lightning;
+extern crate bitcoin;
+extern crate libc;
+
+use bitcoin::hashes::hex::ToHex;
+use lightning::chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate, ChannelMonitorUpdateErr};
+use lightning::chain::channelmonitor;
+use lightning::chain::keysinterface::ChannelKeys;
+use lightning::chain::transaction::OutPoint;
+use lightning::util::ser::{Writeable, Readable};
+use std::fs;
+use std::io::Error;
+use std::path::{Path, PathBuf};
+
+#[cfg(test)]
+use {
+       bitcoin::{BlockHash, Txid},
+       bitcoin::hashes::hex::FromHex,
+       std::collections::HashMap,
+       std::io::Cursor
+};
+
+#[cfg(not(target_os = "windows"))]
+use std::os::unix::io::AsRawFd;
+
+/// FilesystemPersister persists channel data on disk, where each channel's
+/// data is stored in a file named after its funding outpoint.
+///
+/// Warning: this module does the best it can with calls to persist data, but it
+/// can only guarantee that the data is passed to the drive. It is up to the
+/// drive manufacturers to do the actual persistence properly, which they often
+/// don't (especially on consumer-grade hardware). Therefore, it is up to the
+/// user to validate their entire storage stack, to ensure the writes are
+/// persistent.
+/// Corollary: especially when dealing with larger amounts of money, it is best
+/// practice to have multiple channel data backups and not rely only on one
+/// FilesystemPersister.
+pub struct FilesystemPersister {
+       path_to_channel_data: String,
+}
+
+trait DiskWriteable {
+       fn write(&self, writer: &mut fs::File) -> Result<(), Error>;
+}
+
+impl<ChanSigner: ChannelKeys + Writeable> DiskWriteable for ChannelMonitor<ChanSigner> {
+       fn write(&self, writer: &mut fs::File) -> Result<(), Error> {
+               self.serialize_for_disk(writer)
+       }
+}
+
+impl FilesystemPersister {
+       /// Initialize a new FilesystemPersister and set the path to the individual channels'
+       /// files.
+       pub fn new(path_to_channel_data: String) -> Self {
+               return Self {
+                       path_to_channel_data,
+               }
+       }
+
+       fn get_full_filepath(&self, funding_txo: OutPoint) -> String {
+               let mut path = PathBuf::from(&self.path_to_channel_data);
+               path.push(format!("{}_{}", funding_txo.txid.to_hex(), funding_txo.index));
+               path.to_str().unwrap().to_string()
+       }
+
+       // Utility to write a file to disk.
+       fn write_channel_data(&self, funding_txo: OutPoint, monitor: &dyn DiskWriteable) -> std::io::Result<()> {
+               fs::create_dir_all(&self.path_to_channel_data)?;
+               // Do a crazy dance with lots of fsync()s to be overly cautious here...
+               // We never want to end up in a state where we've lost the old data, or end up using the
+               // old data on power loss after we've returned.
+               // The way to atomically write a file on Unix platforms is:
+               // open(tmpname), write(tmpfile), fsync(tmpfile), close(tmpfile), rename(), fsync(dir)
+               let filename = self.get_full_filepath(funding_txo);
+               let tmp_filename = format!("{}.tmp", filename.clone());
+
+               {
+                       // Note that going by rust-lang/rust@d602a6b, on MacOS it is only safe to use
+                       // rust stdlib 1.36 or higher.
+                       let mut f = fs::File::create(&tmp_filename)?;
+                       monitor.write(&mut f)?;
+                       f.sync_all()?;
+               }
+               fs::rename(&tmp_filename, &filename)?;
+               // Fsync the parent directory on Unix.
+               #[cfg(not(target_os = "windows"))]
+               {
+                       let path = Path::new(&filename).parent().unwrap();
+                       let dir_file = fs::OpenOptions::new().read(true).open(path)?;
+                       unsafe { libc::fsync(dir_file.as_raw_fd()); }
+               }
+               Ok(())
+       }
+
+       #[cfg(test)]
+       fn load_channel_data<ChanSigner: ChannelKeys + Readable + Writeable>(&self) ->
+               Result<HashMap<OutPoint, ChannelMonitor<ChanSigner>>, ChannelMonitorUpdateErr> {
+               if let Err(_) = fs::create_dir_all(&self.path_to_channel_data) {
+                       return Err(ChannelMonitorUpdateErr::PermanentFailure);
+               }
+               let mut res = HashMap::new();
+               for file_option in fs::read_dir(&self.path_to_channel_data).unwrap() {
+                       let file = file_option.unwrap();
+                       let owned_file_name = file.file_name();
+                       let filename = owned_file_name.to_str();
+                       if !filename.is_some() || !filename.unwrap().is_ascii() || filename.unwrap().len() < 65 {
+                               return Err(ChannelMonitorUpdateErr::PermanentFailure);
+                       }
+
+                       let txid = Txid::from_hex(filename.unwrap().split_at(64).0);
+                       if txid.is_err() { return Err(ChannelMonitorUpdateErr::PermanentFailure); }
+
+                       let index = filename.unwrap().split_at(65).1.split('.').next().unwrap().parse();
+                       if index.is_err() { return Err(ChannelMonitorUpdateErr::PermanentFailure); }
+
+                       let contents = fs::read(&file.path());
+                       if contents.is_err() { return Err(ChannelMonitorUpdateErr::PermanentFailure); }
+
+                       if let Ok((_, loaded_monitor)) =
+                               <(BlockHash, ChannelMonitor<ChanSigner>)>::read(&mut Cursor::new(&contents.unwrap())) {
+                               res.insert(OutPoint { txid: txid.unwrap(), index: index.unwrap() }, loaded_monitor);
+                       } else {
+                               return Err(ChannelMonitorUpdateErr::PermanentFailure);
+                       }
+               }
+               Ok(res)
+       }
+}
+
+impl<ChanSigner: ChannelKeys + Readable + Writeable + Send + Sync> channelmonitor::Persist<ChanSigner> for FilesystemPersister {
+       fn persist_new_channel(&self, funding_txo: OutPoint, monitor: &ChannelMonitor<ChanSigner>) -> Result<(), ChannelMonitorUpdateErr> {
+               self.write_channel_data(funding_txo, monitor)
+                 .map_err(|_| ChannelMonitorUpdateErr::PermanentFailure)
+       }
+
+       fn update_persisted_channel(&self, funding_txo: OutPoint, _update: &ChannelMonitorUpdate, monitor: &ChannelMonitor<ChanSigner>) -> Result<(), ChannelMonitorUpdateErr> {
+               self.write_channel_data(funding_txo, monitor)
+                 .map_err(|_| ChannelMonitorUpdateErr::PermanentFailure)
+       }
+}
+
+#[cfg(test)]
+impl Drop for FilesystemPersister {
+       fn drop(&mut self) {
+               // We test for invalid directory names, so it's OK if directory removal
+               // fails.
+               match fs::remove_dir_all(&self.path_to_channel_data) {
+                       Err(e) => println!("Failed to remove test persister directory: {}", e),
+                       _ => {}
+               }
+       }
+}
+
+#[cfg(test)]
+mod tests {
+       extern crate lightning;
+       extern crate bitcoin;
+       use crate::FilesystemPersister;
+       use bitcoin::blockdata::block::{Block, BlockHeader};
+       use bitcoin::hashes::hex::FromHex;
+       use bitcoin::Txid;
+       use DiskWriteable;
+       use Error;
+       use lightning::chain::channelmonitor::{Persist, ChannelMonitorUpdateErr};
+       use lightning::chain::transaction::OutPoint;
+       use lightning::{check_closed_broadcast, check_added_monitors};
+       use lightning::ln::features::InitFeatures;
+       use lightning::ln::functional_test_utils::*;
+       use lightning::ln::msgs::ErrorAction;
+       use lightning::util::enforcing_trait_impls::EnforcingChannelKeys;
+       use lightning::util::events::{MessageSendEventsProvider, MessageSendEvent};
+       use lightning::util::ser::Writer;
+       use lightning::util::test_utils;
+       use std::fs;
+       use std::io;
+       #[cfg(target_os = "windows")]
+       use {
+               lightning::get_event_msg,
+               lightning::ln::msgs::ChannelMessageHandler,
+       };
+
+       struct TestWriteable{}
+       impl DiskWriteable for TestWriteable {
+               fn write(&self, writer: &mut fs::File) -> Result<(), Error> {
+                       writer.write_all(&[42; 1])
+               }
+       }
+
+       // Integration-test the FilesystemPersister. Test relaying a few payments
+       // and check that the persisted data is updated the appropriate number of
+       // times.
+       #[test]
+       fn test_filesystem_persister() {
+               // Create the nodes, giving them FilesystemPersisters for data persisters.
+               let persister_0 = FilesystemPersister::new("test_filesystem_persister_0".to_string());
+               let persister_1 = FilesystemPersister::new("test_filesystem_persister_1".to_string());
+               let chanmon_cfgs = create_chanmon_cfgs(2);
+               let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+               let chain_mon_0 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &persister_0);
+               let chain_mon_1 = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[1].chain_source), &chanmon_cfgs[1].tx_broadcaster, &chanmon_cfgs[1].logger, &chanmon_cfgs[1].fee_estimator, &persister_1);
+               node_cfgs[0].chain_monitor = chain_mon_0;
+               node_cfgs[1].chain_monitor = chain_mon_1;
+               let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
+               let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+               // Check that the persisted channel data is empty before any channels are
+               // open.
+               let mut persisted_chan_data_0 = persister_0.load_channel_data::<EnforcingChannelKeys>().unwrap();
+               assert_eq!(persisted_chan_data_0.keys().len(), 0);
+               let mut persisted_chan_data_1 = persister_1.load_channel_data::<EnforcingChannelKeys>().unwrap();
+               assert_eq!(persisted_chan_data_1.keys().len(), 0);
+
+               // Helper to make sure the channel is on the expected update ID.
+               macro_rules! check_persisted_data {
+                       ($expected_update_id: expr) => {
+                               persisted_chan_data_0 = persister_0.load_channel_data::<EnforcingChannelKeys>().unwrap();
+                               assert_eq!(persisted_chan_data_0.keys().len(), 1);
+                               for mon in persisted_chan_data_0.values() {
+                                       assert_eq!(mon.get_latest_update_id(), $expected_update_id);
+                               }
+                               persisted_chan_data_1 = persister_1.load_channel_data::<EnforcingChannelKeys>().unwrap();
+                               assert_eq!(persisted_chan_data_1.keys().len(), 1);
+                               for mon in persisted_chan_data_1.values() {
+                                       assert_eq!(mon.get_latest_update_id(), $expected_update_id);
+                               }
+                       }
+               }
+
+               // Create some initial channel and check that a channel was persisted.
+               let _ = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
+               check_persisted_data!(0);
+
+               // Send a few payments and make sure the monitors are updated to the latest.
+               send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000, 8_000_000);
+               check_persisted_data!(5);
+               send_payment(&nodes[1], &vec!(&nodes[0])[..], 4000000, 4_000_000);
+               check_persisted_data!(10);
+
+               // Force close because cooperative close doesn't result in any persisted
+               // updates.
+               nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id);
+               check_closed_broadcast!(nodes[0], false);
+               check_added_monitors!(nodes[0], 1);
+
+               let node_txn = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap();
+               assert_eq!(node_txn.len(), 1);
+
+               let header = BlockHeader { version: 0x20000000, prev_blockhash: Default::default(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 };
+               connect_block(&nodes[1], &Block { header, txdata: vec![node_txn[0].clone(), node_txn[0].clone()]}, 1);
+               check_closed_broadcast!(nodes[1], false);
+               check_added_monitors!(nodes[1], 1);
+
+               // Make sure everything is persisted as expected after close.
+               check_persisted_data!(11);
+       }
+
+       // Test that if the persister's path to channel data is read-only, writing
+       // data to it fails. Windows ignores the read-only flag for folders, so this
+       // test is Unix-only.
+       #[cfg(not(target_os = "windows"))]
+       #[test]
+       fn test_readonly_dir() {
+               let persister = FilesystemPersister::new("test_readonly_dir_persister".to_string());
+               let test_writeable = TestWriteable{};
+               let test_txo = OutPoint {
+                       txid: Txid::from_hex("8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be").unwrap(),
+                       index: 0
+               };
+               // Create the persister's directory and set it to read-only.
+               let path = &persister.path_to_channel_data;
+               fs::create_dir_all(path).unwrap();
+               let mut perms = fs::metadata(path).unwrap().permissions();
+               perms.set_readonly(true);
+               fs::set_permissions(path, perms).unwrap();
+               match persister.write_channel_data(test_txo, &test_writeable) {
+                       Err(e) => assert_eq!(e.kind(), io::ErrorKind::PermissionDenied),
+                       _ => panic!("Unexpected error message")
+               }
+       }
+
+       // Test failure to rename in the process of atomically creating a channel
+       // monitor's file. We induce this failure by making the `tmp` file a
+       // directory.
+       // Explanation: given "from" = the file being renamed, "to" = the
+       // renamee that already exists: Windows should fail because it'll fail
+       // whenever "to" is a directory, and Unix should fail because if "from" is a
+       // file, then "to" is also required to be a file.
+       #[test]
+       fn test_rename_failure() {
+               let persister = FilesystemPersister::new("test_rename_failure".to_string());
+               let test_writeable = TestWriteable{};
+               let txid_hex = "8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be";
+               let outp_idx = 0;
+               let test_txo = OutPoint {
+                       txid: Txid::from_hex(txid_hex).unwrap(),
+                       index: outp_idx,
+               };
+               // Create the channel data file and make it a directory.
+               let path = &persister.path_to_channel_data;
+               fs::create_dir_all(format!("{}/{}_{}", path, txid_hex, outp_idx)).unwrap();
+               match persister.write_channel_data(test_txo, &test_writeable) {
+                       Err(e) => {
+                               #[cfg(not(target_os = "windows"))]
+                               assert_eq!(e.kind(), io::ErrorKind::Other);
+                               #[cfg(target_os = "windows")]
+                               assert_eq!(e.kind(), io::ErrorKind::PermissionDenied);
+                       }
+                       _ => panic!("Unexpected error message")
+               }
+       }
+
+       // Test failure to create the temporary file in the persistence process.
+       // We induce this failure by having the temp file already exist and be a
+       // directory.
+       #[test]
+       fn test_tmp_file_creation_failure() {
+               let persister = FilesystemPersister::new("test_tmp_file_creation_failure".to_string());
+               let test_writeable = TestWriteable{};
+               let txid_hex = "8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be";
+               let outp_idx = 0;
+               let test_txo = OutPoint {
+                       txid: Txid::from_hex(txid_hex).unwrap(),
+                       index: outp_idx,
+               };
+               // Create the tmp file and make it a directory.
+               let path = &persister.path_to_channel_data;
+               fs::create_dir_all(format!("{}/{}_{}.tmp", path, txid_hex, outp_idx)).unwrap();
+               match persister.write_channel_data(test_txo, &test_writeable) {
+                       Err(e) => {
+                               #[cfg(not(target_os = "windows"))]
+                               assert_eq!(e.kind(), io::ErrorKind::Other);
+                               #[cfg(target_os = "windows")]
+                               assert_eq!(e.kind(), io::ErrorKind::PermissionDenied);
+                       }
+                       _ => panic!("Unexpected error message")
+               }
+       }
+
+       // Test that if the persister's path to channel data is read-only, writing a
+       // monitor to it results in the persister returning a PermanentFailure.
+       // Windows ignores the read-only flag for folders, so this test is Unix-only.
+       #[cfg(not(target_os = "windows"))]
+       #[test]
+       fn test_readonly_dir_perm_failure() {
+               let persister = FilesystemPersister::new("test_readonly_dir_perm_failure".to_string());
+               fs::create_dir_all(&persister.path_to_channel_data).unwrap();
+
+               // Set up a dummy channel and force close. This will produce a monitor
+               // that we can then use to test persistence.
+               let chanmon_cfgs = create_chanmon_cfgs(2);
+               let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+               let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
+               let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+               let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
+               nodes[1].node.force_close_channel(&chan.2);
+               let mut added_monitors = nodes[1].chain_monitor.added_monitors.lock().unwrap();
+
+               // Set the persister's directory to read-only, which should result in
+               // returning a permanent failure when we then attempt to persist a
+               // channel update.
+               let path = &persister.path_to_channel_data;
+               let mut perms = fs::metadata(path).unwrap().permissions();
+               perms.set_readonly(true);
+               fs::set_permissions(path, perms).unwrap();
+
+               let test_txo = OutPoint {
+                       txid: Txid::from_hex("8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be").unwrap(),
+                       index: 0
+               };
+               match persister.persist_new_channel(test_txo, &added_monitors[0].1) {
+                       Err(ChannelMonitorUpdateErr::PermanentFailure) => {},
+                       _ => panic!("unexpected result from persisting new channel")
+               }
+
+               nodes[1].node.get_and_clear_pending_msg_events();
+               added_monitors.clear();
+       }
+
+       // Test that if a persister's directory name is invalid, monitor persistence
+       // will fail.
+       #[cfg(target_os = "windows")]
+       #[test]
+       fn test_fail_on_open() {
+               // Set up a dummy channel and force close. This will produce a monitor
+               // that we can then use to test persistence.
+               let chanmon_cfgs = create_chanmon_cfgs(2);
+               let mut node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+               let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
+               let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+               let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
+               nodes[1].node.force_close_channel(&chan.2);
+               let mut added_monitors = nodes[1].chain_monitor.added_monitors.lock().unwrap();
+
+               // Create the persister with an invalid directory name and test that the
+               // channel fails to open because the directories fail to be created. There
+               // don't seem to be invalid filename characters on Unix that Rust doesn't
+               // handle, hence why the test is Windows-only.
+               let persister = FilesystemPersister::new(":<>/".to_string());
+
+               let test_txo = OutPoint {
+                       txid: Txid::from_hex("8984484a580b825b9972d7adb15050b3ab624ccd731946b3eeddb92f4e7ef6be").unwrap(),
+                       index: 0
+               };
+               match persister.persist_new_channel(test_txo, &added_monitors[0].1) {
+                       Err(ChannelMonitorUpdateErr::PermanentFailure) => {},
+                       _ => panic!("unexpected result from persisting new channel")
+               }
+
+               nodes[1].node.get_and_clear_pending_msg_events();
+               added_monitors.clear();
+       }
+}
index 3b12468e9c4a964e353dc6df9978719312b10366..b5ec64d99b0aab3824ef407162aba5a5acee770d 100644 (file)
@@ -12,6 +12,7 @@ Still missing tons of error-handling. See GitHub issues for suggested projects i
 
 [features]
 fuzztarget = ["bitcoin/fuzztarget"]
+_test_utils = ["hex", "regex"]
 # Unlog messages superior at targeted level.
 max_level_off = []
 max_level_error = []
@@ -25,6 +26,9 @@ unsafe_revoked_tx_signing = []
 [dependencies]
 bitcoin = "0.24"
 
+hex = { version = "0.3", optional = true }
+regex = { version = "0.1.80", optional = true }
+
 [dev-dependencies.bitcoin]
 version = "0.24"
 features = ["bitcoinconsensus"]
index 179d0edb55fd7942eda7b819d16c57f53989fbd8..469837f07ed6cf8f790bd4abe2315ed3d7f43b49 100644 (file)
@@ -34,7 +34,8 @@ use bitcoin::blockdata::block::BlockHeader;
 use chain;
 use chain::Filter;
 use chain::chaininterface::{BroadcasterInterface, FeeEstimator};
-use chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate, ChannelMonitorUpdateErr, MonitorEvent, MonitorUpdateError};
+use chain::channelmonitor;
+use chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdate, ChannelMonitorUpdateErr, MonitorEvent, Persist};
 use chain::transaction::{OutPoint, TransactionData};
 use chain::keysinterface::ChannelKeys;
 use util::logger::Logger;
@@ -55,25 +56,28 @@ use std::ops::Deref;
 /// [`chain::Watch`]: ../trait.Watch.html
 /// [`ChannelManager`]: ../../ln/channelmanager/struct.ChannelManager.html
 /// [module-level documentation]: index.html
-pub struct ChainMonitor<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref>
+pub struct ChainMonitor<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref, P: Deref>
        where C::Target: chain::Filter,
         T::Target: BroadcasterInterface,
         F::Target: FeeEstimator,
         L::Target: Logger,
+        P::Target: channelmonitor::Persist<ChanSigner>,
 {
        /// The monitors
        pub monitors: Mutex<HashMap<OutPoint, ChannelMonitor<ChanSigner>>>,
        chain_source: Option<C>,
        broadcaster: T,
        logger: L,
-       fee_estimator: F
+       fee_estimator: F,
+       persister: P,
 }
 
-impl<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref> ChainMonitor<ChanSigner, C, T, F, L>
-       where C::Target: chain::Filter,
-             T::Target: BroadcasterInterface,
-             F::Target: FeeEstimator,
-             L::Target: Logger,
+impl<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref, P: Deref> ChainMonitor<ChanSigner, C, T, F, L, P>
+where C::Target: chain::Filter,
+           T::Target: BroadcasterInterface,
+           F::Target: FeeEstimator,
+           L::Target: Logger,
+           P::Target: channelmonitor::Persist<ChanSigner>,
 {
        /// Dispatches to per-channel monitors, which are responsible for updating their on-chain view
        /// of a channel and reacting accordingly based on transactions in the connected block. See
@@ -95,8 +99,8 @@ impl<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref> ChainMonit
 
                        if let Some(ref chain_source) = self.chain_source {
                                for (txid, outputs) in txn_outputs.drain(..) {
-                                       for (idx, output) in outputs.iter().enumerate() {
-                                               chain_source.register_output(&OutPoint { txid, index: idx as u16 }, &output.script_pubkey);
+                                       for (idx, output) in outputs.iter() {
+                                               chain_source.register_output(&OutPoint { txid, index: *idx as u16 }, &output.script_pubkey);
                                        }
                                }
                        }
@@ -124,27 +128,47 @@ impl<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref> ChainMonit
        /// transactions relevant to the watched channels.
        ///
        /// [`chain::Filter`]: ../trait.Filter.html
-       pub fn new(chain_source: Option<C>, broadcaster: T, logger: L, feeest: F) -> Self {
+       pub fn new(chain_source: Option<C>, broadcaster: T, logger: L, feeest: F, persister: P) -> Self {
                Self {
                        monitors: Mutex::new(HashMap::new()),
                        chain_source,
                        broadcaster,
                        logger,
                        fee_estimator: feeest,
+                       persister,
                }
        }
+}
+
+impl<ChanSigner: ChannelKeys, C: Deref + Sync + Send, T: Deref + Sync + Send, F: Deref + Sync + Send, L: Deref + Sync + Send, P: Deref + Sync + Send> chain::Watch for ChainMonitor<ChanSigner, C, T, F, L, P>
+where C::Target: chain::Filter,
+           T::Target: BroadcasterInterface,
+           F::Target: FeeEstimator,
+           L::Target: Logger,
+           P::Target: channelmonitor::Persist<ChanSigner>,
+{
+       type Keys = ChanSigner;
 
        /// Adds the monitor that watches the channel referred to by the given outpoint.
        ///
        /// Calls back to [`chain::Filter`] with the funding transaction and outputs to watch.
        ///
+       /// Note that we persist the given `ChannelMonitor` while holding the `ChainMonitor`
+       /// monitors lock.
+       ///
        /// [`chain::Filter`]: ../trait.Filter.html
-       fn add_monitor(&self, outpoint: OutPoint, monitor: ChannelMonitor<ChanSigner>) -> Result<(), MonitorUpdateError> {
+       fn watch_channel(&self, funding_outpoint: OutPoint, monitor: ChannelMonitor<ChanSigner>) -> Result<(), ChannelMonitorUpdateErr> {
                let mut monitors = self.monitors.lock().unwrap();
-               let entry = match monitors.entry(outpoint) {
-                       hash_map::Entry::Occupied(_) => return Err(MonitorUpdateError("Channel monitor for given outpoint is already present")),
+               let entry = match monitors.entry(funding_outpoint) {
+                       hash_map::Entry::Occupied(_) => {
+                               log_error!(self.logger, "Failed to add new channel data: channel monitor for given outpoint is already present");
+                               return Err(ChannelMonitorUpdateErr::PermanentFailure)},
                        hash_map::Entry::Vacant(e) => e,
                };
+               if let Err(e) = self.persister.persist_new_channel(funding_outpoint, &monitor) {
+                       log_error!(self.logger, "Failed to persist new channel data");
+                       return Err(e);
+               }
                {
                        let funding_txo = monitor.get_funding_txo();
                        log_trace!(self.logger, "Got new Channel Monitor for channel {}", log_bytes!(funding_txo.0.to_channel_id()[..]));
@@ -152,8 +176,8 @@ impl<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref> ChainMonit
                        if let Some(ref chain_source) = self.chain_source {
                                chain_source.register_tx(&funding_txo.0.txid, &funding_txo.1);
                                for (txid, outputs) in monitor.get_outputs_to_watch().iter() {
-                                       for (idx, script_pubkey) in outputs.iter().enumerate() {
-                                               chain_source.register_output(&OutPoint { txid: *txid, index: idx as u16 }, &script_pubkey);
+                                       for (idx, script_pubkey) in outputs.iter() {
+                                               chain_source.register_output(&OutPoint { txid: *txid, index: *idx as u16 }, script_pubkey);
                                        }
                                }
                        }
@@ -162,38 +186,34 @@ impl<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref> ChainMonit
                Ok(())
        }
 
-       /// Updates the monitor that watches the channel referred to by the given outpoint.
-       fn update_monitor(&self, outpoint: OutPoint, update: ChannelMonitorUpdate) -> Result<(), MonitorUpdateError> {
+       /// Note that we persist the given `ChannelMonitor` update while holding the
+       /// `ChainMonitor` monitors lock.
+       fn update_channel(&self, funding_txo: OutPoint, update: ChannelMonitorUpdate) -> Result<(), ChannelMonitorUpdateErr> {
+               // Update the monitor that watches the channel referred to by the given outpoint.
                let mut monitors = self.monitors.lock().unwrap();
-               match monitors.get_mut(&outpoint) {
+               match monitors.get_mut(&funding_txo) {
+                       None => {
+                               log_error!(self.logger, "Failed to update channel monitor: no such monitor registered");
+                               Err(ChannelMonitorUpdateErr::PermanentFailure)
+                       },
                        Some(orig_monitor) => {
                                log_trace!(self.logger, "Updating Channel Monitor for channel {}", log_funding_info!(orig_monitor));
-                               orig_monitor.update_monitor(update, &self.broadcaster, &self.logger)
-                       },
-                       None => Err(MonitorUpdateError("No such monitor registered"))
-               }
-       }
-}
-
-impl<ChanSigner: ChannelKeys, C: Deref + Sync + Send, T: Deref + Sync + Send, F: Deref + Sync + Send, L: Deref + Sync + Send> chain::Watch for ChainMonitor<ChanSigner, C, T, F, L>
-       where C::Target: chain::Filter,
-             T::Target: BroadcasterInterface,
-             F::Target: FeeEstimator,
-             L::Target: Logger,
-{
-       type Keys = ChanSigner;
-
-       fn watch_channel(&self, funding_txo: OutPoint, monitor: ChannelMonitor<ChanSigner>) -> Result<(), ChannelMonitorUpdateErr> {
-               match self.add_monitor(funding_txo, monitor) {
-                       Ok(_) => Ok(()),
-                       Err(_) => Err(ChannelMonitorUpdateErr::PermanentFailure),
-               }
-       }
-
-       fn update_channel(&self, funding_txo: OutPoint, update: ChannelMonitorUpdate) -> Result<(), ChannelMonitorUpdateErr> {
-               match self.update_monitor(funding_txo, update) {
-                       Ok(_) => Ok(()),
-                       Err(_) => Err(ChannelMonitorUpdateErr::PermanentFailure),
+                               let update_res = orig_monitor.update_monitor(&update, &self.broadcaster, &self.logger);
+                               if let Err(e) = &update_res {
+                                       log_error!(self.logger, "Failed to update channel monitor: {:?}", e);
+                               }
+                               // Even if updating the monitor returns an error, the monitor's state will
+                               // still be changed. So, persist the updated monitor despite the error.
+                               let persist_res = self.persister.update_persisted_channel(funding_txo, &update, orig_monitor);
+                               if let Err(ref e) = persist_res {
+                                       log_error!(self.logger, "Failed to persist channel monitor update: {:?}", e);
+                               }
+                               if update_res.is_err() {
+                                       Err(ChannelMonitorUpdateErr::PermanentFailure)
+                               } else {
+                                       persist_res
+                               }
+                       }
                }
        }
 
@@ -206,11 +226,12 @@ impl<ChanSigner: ChannelKeys, C: Deref + Sync + Send, T: Deref + Sync + Send, F:
        }
 }
 
-impl<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref> events::EventsProvider for ChainMonitor<ChanSigner, C, T, F, L>
+impl<ChanSigner: ChannelKeys, C: Deref, T: Deref, F: Deref, L: Deref, P: Deref> events::EventsProvider for ChainMonitor<ChanSigner, C, T, F, L, P>
        where C::Target: chain::Filter,
              T::Target: BroadcasterInterface,
              F::Target: FeeEstimator,
              L::Target: Logger,
+             P::Target: channelmonitor::Persist<ChanSigner>,
 {
        fn get_and_clear_pending_events(&self) -> Vec<Event> {
                let mut pending_events = Vec::new();
index 31e4565ecab766851cc47699e30f74a512018bf6..889dfa211ee78d167dbad2229ffcaf7dc82d1d4c 100644 (file)
@@ -57,7 +57,7 @@ use std::io::Error;
 
 /// An update generated by the underlying Channel itself which contains some new information the
 /// ChannelMonitor should be made aware of.
-#[cfg_attr(test, derive(PartialEq))]
+#[cfg_attr(any(test, feature = "_test_utils"), derive(PartialEq))]
 #[derive(Clone)]
 #[must_use]
 pub struct ChannelMonitorUpdate {
@@ -95,7 +95,7 @@ impl Readable for ChannelMonitorUpdate {
 }
 
 /// An error enum representing a failure to persist a channel monitor update.
-#[derive(Clone)]
+#[derive(Clone, Debug)]
 pub enum ChannelMonitorUpdateErr {
        /// Used to indicate a temporary failure (eg connection to a watchtower or remote backup of
        /// our state failed, but is expected to succeed at some point in the future).
@@ -159,7 +159,7 @@ pub enum ChannelMonitorUpdateErr {
 /// inconsistent with the ChannelMonitor being called. eg for ChannelMonitor::update_monitor this
 /// means you tried to update a monitor for a different channel or the ChannelMonitorUpdate was
 /// corrupted.
-/// Contains a human-readable error message.
+/// Contains a developer-readable error message.
 #[derive(Debug)]
 pub struct MonitorUpdateError(pub &'static str);
 
@@ -470,7 +470,7 @@ enum OnchainEvent {
 const SERIALIZATION_VERSION: u8 = 1;
 const MIN_SERIALIZATION_VERSION: u8 = 1;
 
-#[cfg_attr(test, derive(PartialEq))]
+#[cfg_attr(any(test, feature = "_test_utils"), derive(PartialEq))]
 #[derive(Clone)]
 pub(crate) enum ChannelMonitorUpdateStep {
        LatestHolderCommitmentTXInfo {
@@ -666,7 +666,7 @@ pub struct ChannelMonitor<ChanSigner: ChannelKeys> {
        // interface knows about the TXOs that we want to be notified of spends of. We could probably
        // be smart and derive them from the above storage fields, but its much simpler and more
        // Obviously Correct (tm) if we just keep track of them explicitly.
-       outputs_to_watch: HashMap<Txid, Vec<Script>>,
+       outputs_to_watch: HashMap<Txid, Vec<(u32, Script)>>,
 
        #[cfg(test)]
        pub onchain_tx_handler: OnchainTxHandler<ChanSigner>,
@@ -696,7 +696,7 @@ pub struct ChannelMonitor<ChanSigner: ChannelKeys> {
        secp_ctx: Secp256k1<secp256k1::All>, //TODO: dedup this a bit...
 }
 
-#[cfg(any(test, feature = "fuzztarget"))]
+#[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))]
 /// Used only in testing and fuzztarget to check serialization roundtrips don't change the
 /// underlying object
 impl<ChanSigner: ChannelKeys> PartialEq for ChannelMonitor<ChanSigner> {
@@ -746,7 +746,7 @@ impl<ChanSigner: ChannelKeys + Writeable> ChannelMonitor<ChanSigner> {
        /// the "reorg path" (ie disconnecting blocks until you find a common ancestor from both the
        /// returned block hash and the the current chain and then reconnecting blocks to get to the
        /// best chain) upon deserializing the object!
-       pub fn write_for_disk<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
+       pub fn serialize_for_disk<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
                //TODO: We still write out all the serialization here manually instead of using the fancy
                //serialization framework we have, we should migrate things over to it.
                writer.write_all(&[SERIALIZATION_VERSION; 1])?;
@@ -914,10 +914,11 @@ impl<ChanSigner: ChannelKeys + Writeable> ChannelMonitor<ChanSigner> {
                }
 
                (self.outputs_to_watch.len() as u64).write(writer)?;
-               for (txid, output_scripts) in self.outputs_to_watch.iter() {
+               for (txid, idx_scripts) in self.outputs_to_watch.iter() {
                        txid.write(writer)?;
-                       (output_scripts.len() as u64).write(writer)?;
-                       for script in output_scripts.iter() {
+                       (idx_scripts.len() as u64).write(writer)?;
+                       for (idx, script) in idx_scripts.iter() {
+                               idx.write(writer)?;
                                script.write(writer)?;
                        }
                }
@@ -963,7 +964,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                onchain_tx_handler.provide_latest_holder_tx(initial_holder_commitment_tx);
 
                let mut outputs_to_watch = HashMap::new();
-               outputs_to_watch.insert(funding_info.0.txid, vec![funding_info.1.clone()]);
+               outputs_to_watch.insert(funding_info.0.txid, vec![(funding_info.0.index as u32, funding_info.1.clone())]);
 
                ChannelMonitor {
                        latest_update_id: 0,
@@ -1161,28 +1162,28 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
        /// itself.
        ///
        /// panics if the given update is not the next update by update_id.
-       pub fn update_monitor<B: Deref, L: Deref>(&mut self, mut updates: ChannelMonitorUpdate, broadcaster: &B, logger: &L) -> Result<(), MonitorUpdateError>
+       pub fn update_monitor<B: Deref, L: Deref>(&mut self, updates: &ChannelMonitorUpdate, broadcaster: &B, logger: &L) -> Result<(), MonitorUpdateError>
                where B::Target: BroadcasterInterface,
                                        L::Target: Logger,
        {
                if self.latest_update_id + 1 != updates.update_id {
                        panic!("Attempted to apply ChannelMonitorUpdates out of order, check the update_id before passing an update to update_monitor!");
                }
-               for update in updates.updates.drain(..) {
+               for update in updates.updates.iter() {
                        match update {
                                ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { commitment_tx, htlc_outputs } => {
                                        if self.lockdown_from_offchain { panic!(); }
-                                       self.provide_latest_holder_commitment_tx_info(commitment_tx, htlc_outputs)?
+                                       self.provide_latest_holder_commitment_tx_info(commitment_tx.clone(), htlc_outputs.clone())?
                                },
                                ChannelMonitorUpdateStep::LatestCounterpartyCommitmentTXInfo { unsigned_commitment_tx, htlc_outputs, commitment_number, their_revocation_point } =>
-                                       self.provide_latest_counterparty_commitment_tx_info(&unsigned_commitment_tx, htlc_outputs, commitment_number, their_revocation_point, logger),
+                                       self.provide_latest_counterparty_commitment_tx_info(&unsigned_commitment_tx, htlc_outputs.clone(), *commitment_number, *their_revocation_point, logger),
                                ChannelMonitorUpdateStep::PaymentPreimage { payment_preimage } =>
                                        self.provide_payment_preimage(&PaymentHash(Sha256::hash(&payment_preimage.0[..]).into_inner()), &payment_preimage),
                                ChannelMonitorUpdateStep::CommitmentSecret { idx, secret } =>
-                                       self.provide_secret(idx, secret)?,
+                                       self.provide_secret(*idx, *secret)?,
                                ChannelMonitorUpdateStep::ChannelForceClosed { should_broadcast } => {
                                        self.lockdown_from_offchain = true;
-                                       if should_broadcast {
+                                       if *should_broadcast {
                                                self.broadcast_latest_holder_commitment_txn(broadcaster, logger);
                                        } else {
                                                log_error!(logger, "You have a toxic holder commitment transaction avaible in channel monitor, read comment in ChannelMonitor::get_latest_holder_commitment_txn to be informed of manual action to take");
@@ -1209,7 +1210,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
        /// transaction), which we must learn about spends of via block_connected().
        ///
        /// (C-not exported) because we have no HashMap bindings
-       pub fn get_outputs_to_watch(&self) -> &HashMap<Txid, Vec<Script>> {
+       pub fn get_outputs_to_watch(&self) -> &HashMap<Txid, Vec<(u32, Script)>> {
                // If we've detected a counterparty commitment tx on chain, we must include it in the set
                // of outputs to watch for spends of, otherwise we're likely to lose user funds. Because
                // its trivial to do, double-check that here.
@@ -1264,7 +1265,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
        /// HTLC-Success/HTLC-Timeout transactions.
        /// Return updates for HTLC pending in the channel and failed automatically by the broadcast of
        /// revoked counterparty commitment tx
-       fn check_spend_counterparty_transaction<L: Deref>(&mut self, tx: &Transaction, height: u32, logger: &L) -> (Vec<ClaimRequest>, (Txid, Vec<TxOut>)) where L::Target: Logger {
+       fn check_spend_counterparty_transaction<L: Deref>(&mut self, tx: &Transaction, height: u32, logger: &L) -> (Vec<ClaimRequest>, (Txid, Vec<(u32, TxOut)>)) where L::Target: Logger {
                // Most secp and related errors trying to create keys means we have no hope of constructing
                // a spend transaction...so we return no transactions to broadcast
                let mut claimable_outpoints = Vec::new();
@@ -1319,7 +1320,9 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                        if !claimable_outpoints.is_empty() || per_commitment_option.is_some() { // ie we're confident this is actually ours
                                // We're definitely a counterparty commitment transaction!
                                log_trace!(logger, "Got broadcast of revoked counterparty commitment transaction, going to generate general spend tx with {} inputs", claimable_outpoints.len());
-                               watch_outputs.append(&mut tx.output.clone());
+                               for (idx, outp) in tx.output.iter().enumerate() {
+                                       watch_outputs.push((idx as u32, outp.clone()));
+                               }
                                self.counterparty_commitment_txn_on_chain.insert(commitment_txid, commitment_number);
 
                                macro_rules! check_htlc_fails {
@@ -1366,7 +1369,9 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                        // already processed the block, resulting in the counterparty_commitment_txn_on_chain entry
                        // not being generated by the above conditional. Thus, to be safe, we go ahead and
                        // insert it here.
-                       watch_outputs.append(&mut tx.output.clone());
+                       for (idx, outp) in tx.output.iter().enumerate() {
+                               watch_outputs.push((idx as u32, outp.clone()));
+                       }
                        self.counterparty_commitment_txn_on_chain.insert(commitment_txid, commitment_number);
 
                        log_trace!(logger, "Got broadcast of non-revoked counterparty commitment transaction {}", commitment_txid);
@@ -1456,7 +1461,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
        }
 
        /// Attempts to claim a counterparty HTLC-Success/HTLC-Timeout's outputs using the revocation key
-       fn check_spend_counterparty_htlc<L: Deref>(&mut self, tx: &Transaction, commitment_number: u64, height: u32, logger: &L) -> (Vec<ClaimRequest>, Option<(Txid, Vec<TxOut>)>) where L::Target: Logger {
+       fn check_spend_counterparty_htlc<L: Deref>(&mut self, tx: &Transaction, commitment_number: u64, height: u32, logger: &L) -> (Vec<ClaimRequest>, Option<(Txid, Vec<(u32, TxOut)>)>) where L::Target: Logger {
                let htlc_txid = tx.txid();
                if tx.input.len() != 1 || tx.output.len() != 1 || tx.input[0].witness.len() != 5 {
                        return (Vec::new(), None)
@@ -1478,10 +1483,11 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                log_trace!(logger, "Counterparty HTLC broadcast {}:{}", htlc_txid, 0);
                let witness_data = InputMaterial::Revoked { per_commitment_point, counterparty_delayed_payment_base_key: self.counterparty_tx_cache.counterparty_delayed_payment_base_key, counterparty_htlc_base_key: self.counterparty_tx_cache.counterparty_htlc_base_key,  per_commitment_key, input_descriptor: InputDescriptors::RevokedOutput, amount: tx.output[0].value, htlc: None, on_counterparty_tx_csv: self.counterparty_tx_cache.on_counterparty_tx_csv };
                let claimable_outpoints = vec!(ClaimRequest { absolute_timelock: height + self.counterparty_tx_cache.on_counterparty_tx_csv as u32, aggregable: true, outpoint: BitcoinOutPoint { txid: htlc_txid, vout: 0}, witness_data });
-               (claimable_outpoints, Some((htlc_txid, tx.output.clone())))
+               let outputs = vec![(0, tx.output[0].clone())];
+               (claimable_outpoints, Some((htlc_txid, outputs)))
        }
 
-       fn broadcast_by_holder_state(&self, commitment_tx: &Transaction, holder_tx: &HolderSignedTx) -> (Vec<ClaimRequest>, Vec<TxOut>, Option<(Script, PublicKey, PublicKey)>) {
+       fn broadcast_by_holder_state(&self, commitment_tx: &Transaction, holder_tx: &HolderSignedTx) -> (Vec<ClaimRequest>, Vec<(u32, TxOut)>, Option<(Script, PublicKey, PublicKey)>) {
                let mut claim_requests = Vec::with_capacity(holder_tx.htlc_outputs.len());
                let mut watch_outputs = Vec::with_capacity(holder_tx.htlc_outputs.len());
 
@@ -1502,7 +1508,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                                                        } else { None },
                                                amount: htlc.amount_msat,
                                }});
-                               watch_outputs.push(commitment_tx.output[transaction_output_index as usize].clone());
+                               watch_outputs.push((transaction_output_index, commitment_tx.output[transaction_output_index as usize].clone()));
                        }
                }
 
@@ -1512,7 +1518,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
        /// Attempts to claim any claimable HTLCs in a commitment transaction which was not (yet)
        /// revoked using data in holder_claimable_outpoints.
        /// Should not be used if check_spend_revoked_transaction succeeds.
-       fn check_spend_holder_transaction<L: Deref>(&mut self, tx: &Transaction, height: u32, logger: &L) -> (Vec<ClaimRequest>, (Txid, Vec<TxOut>)) where L::Target: Logger {
+       fn check_spend_holder_transaction<L: Deref>(&mut self, tx: &Transaction, height: u32, logger: &L) -> (Vec<ClaimRequest>, (Txid, Vec<(u32, TxOut)>)) where L::Target: Logger {
                let commitment_txid = tx.txid();
                let mut claim_requests = Vec::new();
                let mut watch_outputs = Vec::new();
@@ -1662,7 +1668,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
        /// [`get_outputs_to_watch`].
        ///
        /// [`get_outputs_to_watch`]: #method.get_outputs_to_watch
-       pub fn block_connected<B: Deref, F: Deref, L: Deref>(&mut self, header: &BlockHeader, txdata: &TransactionData, height: u32, broadcaster: B, fee_estimator: F, logger: L)-> Vec<(Txid, Vec<TxOut>)>
+       pub fn block_connected<B: Deref, F: Deref, L: Deref>(&mut self, header: &BlockHeader, txdata: &TransactionData, height: u32, broadcaster: B, fee_estimator: F, logger: L)-> Vec<(Txid, Vec<(u32, TxOut)>)>
                where B::Target: BroadcasterInterface,
                      F::Target: FeeEstimator,
                                        L::Target: Logger,
@@ -1763,9 +1769,23 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                // Determine new outputs to watch by comparing against previously known outputs to watch,
                // updating the latter in the process.
                watch_outputs.retain(|&(ref txid, ref txouts)| {
-                       let output_scripts = txouts.iter().map(|o| o.script_pubkey.clone()).collect();
-                       self.outputs_to_watch.insert(txid.clone(), output_scripts).is_none()
+                       let idx_and_scripts = txouts.iter().map(|o| (o.0, o.1.script_pubkey.clone())).collect();
+                       self.outputs_to_watch.insert(txid.clone(), idx_and_scripts).is_none()
                });
+               #[cfg(test)]
+               {
+                       // If we see a transaction for which we registered outputs previously,
+                       // make sure the registered scriptpubkey at the expected index match
+                       // the actual transaction output one. We failed this case before #653.
+                       for tx in &txn_matched {
+                               if let Some(outputs) = self.get_outputs_to_watch().get(&tx.txid()) {
+                                       for idx_and_script in outputs.iter() {
+                                               assert!((idx_and_script.0 as usize) < tx.output.len());
+                                               assert_eq!(tx.output[idx_and_script.0 as usize].script_pubkey, idx_and_script.1);
+                                       }
+                               }
+                       }
+               }
                watch_outputs
        }
 
@@ -1813,8 +1833,19 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
        fn spends_watched_output(&self, tx: &Transaction) -> bool {
                for input in tx.input.iter() {
                        if let Some(outputs) = self.get_outputs_to_watch().get(&input.previous_output.txid) {
-                               for (idx, _script_pubkey) in outputs.iter().enumerate() {
-                                       if idx == input.previous_output.vout as usize {
+                               for (idx, _script_pubkey) in outputs.iter() {
+                                       if *idx == input.previous_output.vout {
+                                               #[cfg(test)]
+                                               {
+                                                       // If the expected script is a known type, check that the witness
+                                                       // appears to be spending the correct type (ie that the match would
+                                                       // actually succeed in BIP 158/159-style filters).
+                                                       if _script_pubkey.is_v0_p2wsh() {
+                                                               assert_eq!(&bitcoin::Address::p2wsh(&Script::from(input.witness.last().unwrap().clone()), bitcoin::Network::Bitcoin).script_pubkey(), _script_pubkey);
+                                                       } else if _script_pubkey.is_v0_p2wpkh() {
+                                                               assert_eq!(&bitcoin::Address::p2wpkh(&bitcoin::PublicKey::from_slice(&input.witness.last().unwrap()).unwrap(), bitcoin::Network::Bitcoin).unwrap().script_pubkey(), _script_pubkey);
+                                                       } else { panic!(); }
+                                               }
                                                return true;
                                        }
                                }
@@ -2092,6 +2123,61 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
        }
 }
 
+/// `Persist` defines behavior for persisting channel monitors: this could mean
+/// writing once to disk, and/or uploading to one or more backup services.
+///
+/// Note that for every new monitor, you **must** persist the new `ChannelMonitor`
+/// to disk/backups. And, on every update, you **must** persist either the
+/// `ChannelMonitorUpdate` or the updated monitor itself. Otherwise, there is risk
+/// of situations such as revoking a transaction, then crashing before this
+/// revocation can be persisted, then unintentionally broadcasting a revoked
+/// transaction and losing money. This is a risk because previous channel states
+/// are toxic, so it's important that whatever channel state is persisted is
+/// kept up-to-date.
+pub trait Persist<Keys: ChannelKeys>: Send + Sync {
+       /// Persist a new channel's data. The data can be stored any way you want, but
+       /// the identifier provided by Rust-Lightning is the channel's outpoint (and
+       /// it is up to you to maintain a correct mapping between the outpoint and the
+       /// stored channel data). Note that you **must** persist every new monitor to
+       /// disk. See the `Persist` trait documentation for more details.
+       ///
+       /// See [`ChannelMonitor::serialize_for_disk`] for writing out a `ChannelMonitor`,
+       /// and [`ChannelMonitorUpdateErr`] for requirements when returning errors.
+       ///
+       /// [`ChannelMonitor::serialize_for_disk`]: struct.ChannelMonitor.html#method.serialize_for_disk
+       /// [`ChannelMonitorUpdateErr`]: enum.ChannelMonitorUpdateErr.html
+       fn persist_new_channel(&self, id: OutPoint, data: &ChannelMonitor<Keys>) -> Result<(), ChannelMonitorUpdateErr>;
+
+       /// Update one channel's data. The provided `ChannelMonitor` has already
+       /// applied the given update.
+       ///
+       /// Note that on every update, you **must** persist either the
+       /// `ChannelMonitorUpdate` or the updated monitor itself to disk/backups. See
+       /// the `Persist` trait documentation for more details.
+       ///
+       /// If an implementer chooses to persist the updates only, they need to make
+       /// sure that all the updates are applied to the `ChannelMonitors` *before*
+       /// the set of channel monitors is given to the `ChannelManager`
+       /// deserialization routine. See [`ChannelMonitor::update_monitor`] for
+       /// applying a monitor update to a monitor. If full `ChannelMonitors` are
+       /// persisted, then there is no need to persist individual updates.
+       ///
+       /// Note that there could be a performance tradeoff between persisting complete
+       /// channel monitors on every update vs. persisting only updates and applying
+       /// them in batches. The size of each monitor grows `O(number of state updates)`
+       /// whereas updates are small and `O(1)`.
+       ///
+       /// See [`ChannelMonitor::serialize_for_disk`] for writing out a `ChannelMonitor`,
+       /// [`ChannelMonitorUpdate::write`] for writing out an update, and
+       /// [`ChannelMonitorUpdateErr`] for requirements when returning errors.
+       ///
+       /// [`ChannelMonitor::update_monitor`]: struct.ChannelMonitor.html#impl-1
+       /// [`ChannelMonitor::serialize_for_disk`]: struct.ChannelMonitor.html#method.serialize_for_disk
+       /// [`ChannelMonitorUpdate::write`]: struct.ChannelMonitorUpdate.html#method.write
+       /// [`ChannelMonitorUpdateErr`]: enum.ChannelMonitorUpdateErr.html
+       fn update_persisted_channel(&self, id: OutPoint, update: &ChannelMonitorUpdate, data: &ChannelMonitor<Keys>) -> Result<(), ChannelMonitorUpdateErr>;
+}
+
 const MAX_ALLOC_SIZE: usize = 64*1024;
 
 impl<ChanSigner: ChannelKeys + Readable> Readable for (BlockHash, ChannelMonitor<ChanSigner>) {
@@ -2316,13 +2402,13 @@ impl<ChanSigner: ChannelKeys + Readable> Readable for (BlockHash, ChannelMonitor
                }
 
                let outputs_to_watch_len: u64 = Readable::read(reader)?;
-               let mut outputs_to_watch = HashMap::with_capacity(cmp::min(outputs_to_watch_len as usize, MAX_ALLOC_SIZE / (mem::size_of::<Txid>() + mem::size_of::<Vec<Script>>())));
+               let mut outputs_to_watch = HashMap::with_capacity(cmp::min(outputs_to_watch_len as usize, MAX_ALLOC_SIZE / (mem::size_of::<Txid>() + mem::size_of::<u32>() + mem::size_of::<Vec<Script>>())));
                for _ in 0..outputs_to_watch_len {
                        let txid = Readable::read(reader)?;
                        let outputs_len: u64 = Readable::read(reader)?;
-                       let mut outputs = Vec::with_capacity(cmp::min(outputs_len as usize, MAX_ALLOC_SIZE / mem::size_of::<Script>()));
+                       let mut outputs = Vec::with_capacity(cmp::min(outputs_len as usize, MAX_ALLOC_SIZE / (mem::size_of::<u32>() + mem::size_of::<Script>())));
                        for _ in 0..outputs_len {
-                               outputs.push(Readable::read(reader)?);
+                               outputs.push((Readable::read(reader)?, Readable::read(reader)?));
                        }
                        if let Some(_) = outputs_to_watch.insert(txid, outputs) {
                                return Err(DecodeError::InvalidValue);
index dbd89a2ccdf182f88a6ee32b3ff311456b37b0aa..c5d369268cc6836f462ddb69d4c1b58c2a2047c5 100644 (file)
@@ -18,7 +18,7 @@
 //! generated/etc. This makes it a good candidate for tight integration into an existing wallet
 //! instead of having a rather-separate lightning appendage to a wallet.
 
-#![cfg_attr(not(feature = "fuzztarget"), deny(missing_docs))]
+#![cfg_attr(not(any(feature = "fuzztarget", feature = "_test_utils")), deny(missing_docs))]
 #![forbid(unsafe_code)]
 
 // In general, rust is absolutely horrid at supporting users doing things like,
@@ -28,8 +28,8 @@
 #![allow(ellipsis_inclusive_range_patterns)]
 
 extern crate bitcoin;
-#[cfg(test)] extern crate hex;
-#[cfg(test)] extern crate regex;
+#[cfg(any(test, feature = "_test_utils"))] extern crate hex;
+#[cfg(any(test, feature = "_test_utils"))] extern crate regex;
 
 #[macro_use]
 pub mod util;
index e6eb9e6a24aadbc5cb08cfc0423201d969333f52..689c3496de16ebfd83fdc782d7ddec931d36b0f9 100644 (file)
 //! There are a bunch of these as their handling is relatively error-prone so they are split out
 //! here. See also the chanmon_fail_consistency fuzz test.
 
-use chain::channelmonitor::ChannelMonitorUpdateErr;
+use bitcoin::blockdata::block::BlockHeader;
+use bitcoin::hash_types::BlockHash;
+use bitcoin::network::constants::Network;
+use chain::channelmonitor::{ChannelMonitor, ChannelMonitorUpdateErr};
 use chain::transaction::OutPoint;
+use chain::Watch;
 use ln::channelmanager::{RAACommitmentOrder, PaymentPreimage, PaymentHash, PaymentSecret, PaymentSendFailure};
 use ln::features::InitFeatures;
 use ln::msgs;
 use ln::msgs::{ChannelMessageHandler, ErrorAction, RoutingMessageHandler};
 use routing::router::get_route;
+use util::enforcing_trait_impls::EnforcingChannelKeys;
 use util::events::{Event, EventsProvider, MessageSendEvent, MessageSendEventsProvider};
 use util::errors::APIError;
+use util::ser::Readable;
 
 use bitcoin::hashes::sha256::Hash as Sha256;
 use bitcoin::hashes::Hash;
@@ -29,10 +35,11 @@ use ln::functional_test_utils::*;
 
 use util::test_utils;
 
-#[test]
-fn test_simple_monitor_permanent_update_fail() {
+// If persister_fail is true, we have the persister return a PermanentFailure
+// instead of the higher-level ChainMonitor.
+fn do_test_simple_monitor_permanent_update_fail(persister_fail: bool) {
        // Test that we handle a simple permanent monitor update failure
-       let chanmon_cfgs = create_chanmon_cfgs(2);
+       let mut chanmon_cfgs = create_chanmon_cfgs(2);
        let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
        let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
@@ -41,7 +48,10 @@ fn test_simple_monitor_permanent_update_fail() {
 
        let (_, payment_hash_1) = get_payment_preimage_hash!(&nodes[0]);
 
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::PermanentFailure);
+       match persister_fail {
+               true => chanmon_cfgs[0].persister.set_update_ret(Err(ChannelMonitorUpdateErr::PermanentFailure)),
+               false => *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::PermanentFailure))
+       }
        let net_graph_msg_handler = &nodes[0].net_graph_msg_handler;
        let route = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[1].node.get_our_node_id(), None, &Vec::new(), 1000000, TEST_FINAL_CLTV, &logger).unwrap();
        unwrap_send_err!(nodes[0].node.send_payment(&route, payment_hash_1, &None), true, APIError::ChannelUnavailable {..}, {});
@@ -64,10 +74,87 @@ fn test_simple_monitor_permanent_update_fail() {
        assert_eq!(nodes[0].node.list_channels().len(), 0);
 }
 
-fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) {
+#[test]
+fn test_monitor_and_persister_update_fail() {
+       // Test that if both updating the `ChannelMonitor` and persisting the updated
+       // `ChannelMonitor` fail, then the failure from updating the `ChannelMonitor`
+       // one that gets returned.
+       let chanmon_cfgs = create_chanmon_cfgs(2);
+       let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+       let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
+       let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+       // Create some initial channel
+       let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
+       let outpoint = OutPoint { txid: chan.3.txid(), index: 0 };
+
+       // Rebalance the network to generate htlc in the two directions
+       send_payment(&nodes[0], &vec!(&nodes[1])[..], 10_000_000, 10_000_000);
+
+       // Route an HTLC from node 0 to node 1 (but don't settle)
+       let preimage = route_payment(&nodes[0], &vec!(&nodes[1])[..], 9_000_000).0;
+
+       // Make a copy of the ChainMonitor so we can capture the error it returns on a
+       // bogus update. Note that if instead we updated the nodes[0]'s ChainMonitor
+       // directly, the node would fail to be `Drop`'d at the end because its
+       // ChannelManager and ChainMonitor would be out of sync.
+       let chain_source = test_utils::TestChainSource::new(Network::Testnet);
+       let logger = test_utils::TestLogger::with_id(format!("node {}", 0));
+       let persister = test_utils::TestPersister::new();
+       let chain_mon = {
+               let monitors = nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap();
+               let monitor = monitors.get(&outpoint).unwrap();
+               let mut w = test_utils::TestVecWriter(Vec::new());
+               monitor.serialize_for_disk(&mut w).unwrap();
+               let new_monitor = <(BlockHash, ChannelMonitor<EnforcingChannelKeys>)>::read(
+                       &mut ::std::io::Cursor::new(&w.0)).unwrap().1;
+               assert!(new_monitor == *monitor);
+               let chain_mon = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister);
+               assert!(chain_mon.watch_channel(outpoint, new_monitor).is_ok());
+               chain_mon
+       };
+       let header = BlockHeader { version: 0x20000000, prev_blockhash: Default::default(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 };
+       chain_mon.chain_monitor.block_connected(&header, &[], 200);
+
+       // Set the persister's return value to be a TemporaryFailure.
+       persister.set_update_ret(Err(ChannelMonitorUpdateErr::TemporaryFailure));
+
+       // Try to update ChannelMonitor
+       assert!(nodes[1].node.claim_funds(preimage, &None, 9_000_000));
+       check_added_monitors!(nodes[1], 1);
+       let updates = get_htlc_update_msgs!(nodes[1], nodes[0].node.get_our_node_id());
+       assert_eq!(updates.update_fulfill_htlcs.len(), 1);
+       nodes[0].node.handle_update_fulfill_htlc(&nodes[1].node.get_our_node_id(), &updates.update_fulfill_htlcs[0]);
+       if let Some(ref mut channel) = nodes[0].node.channel_state.lock().unwrap().by_id.get_mut(&chan.2) {
+               if let Ok((_, _, _, update)) = channel.commitment_signed(&updates.commitment_signed, &node_cfgs[0].fee_estimator, &node_cfgs[0].logger) {
+                       // Check that even though the persister is returning a TemporaryFailure,
+                       // because the update is bogus, ultimately the error that's returned
+                       // should be a PermanentFailure.
+                       if let Err(ChannelMonitorUpdateErr::PermanentFailure) = chain_mon.chain_monitor.update_channel(outpoint, update.clone()) {} else { panic!("Expected monitor error to be permanent"); }
+                       logger.assert_log_contains("lightning::chain::chainmonitor".to_string(), "Failed to persist channel monitor update: TemporaryFailure".to_string(), 1);
+                       if let Ok(_) = nodes[0].chain_monitor.update_channel(outpoint, update) {} else { assert!(false); }
+               } else { assert!(false); }
+       } else { assert!(false); };
+
+       check_added_monitors!(nodes[0], 1);
+       let events = nodes[0].node.get_and_clear_pending_events();
+       assert_eq!(events.len(), 1);
+}
+
+#[test]
+fn test_simple_monitor_permanent_update_fail() {
+       do_test_simple_monitor_permanent_update_fail(false);
+
+       // Test behavior when the persister returns a PermanentFailure.
+       do_test_simple_monitor_permanent_update_fail(true);
+}
+
+// If persister_fail is true, we have the persister return a TemporaryFailure instead of the
+// higher-level ChainMonitor.
+fn do_test_simple_monitor_temporary_update_fail(disconnect: bool, persister_fail: bool) {
        // Test that we can recover from a simple temporary monitor update failure optionally with
        // a disconnect in between
-       let chanmon_cfgs = create_chanmon_cfgs(2);
+       let mut chanmon_cfgs = create_chanmon_cfgs(2);
        let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
        let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs);
@@ -76,7 +163,10 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) {
 
        let (payment_preimage_1, payment_hash_1) = get_payment_preimage_hash!(&nodes[0]);
 
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       match persister_fail {
+               true => chanmon_cfgs[0].persister.set_update_ret(Err(ChannelMonitorUpdateErr::TemporaryFailure)),
+               false => *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure))
+       }
 
        {
                let net_graph_msg_handler = &nodes[0].net_graph_msg_handler;
@@ -95,7 +185,10 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) {
                reconnect_nodes(&nodes[0], &nodes[1], (true, true), (0, 0), (0, 0), (0, 0), (0, 0), (false, false));
        }
 
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       match persister_fail {
+               true => chanmon_cfgs[0].persister.set_update_ret(Ok(())),
+               false => *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()))
+       }
        let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[0].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[0], 0);
@@ -125,7 +218,10 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) {
        // Now set it to failed again...
        let (_, payment_hash_2) = get_payment_preimage_hash!(&nodes[0]);
        {
-               *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+               match persister_fail {
+                       true => chanmon_cfgs[0].persister.set_update_ret(Err(ChannelMonitorUpdateErr::TemporaryFailure)),
+                       false => *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure))
+               }
                let net_graph_msg_handler = &nodes[0].net_graph_msg_handler;
                let route = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[1].node.get_our_node_id(), None, &Vec::new(), 1000000, TEST_FINAL_CLTV, &logger).unwrap();
                unwrap_send_err!(nodes[0].node.send_payment(&route, payment_hash_2, &None), false, APIError::MonitorUpdateFailed, {});
@@ -155,8 +251,12 @@ fn do_test_simple_monitor_temporary_update_fail(disconnect: bool) {
 
 #[test]
 fn test_simple_monitor_temporary_update_fail() {
-       do_test_simple_monitor_temporary_update_fail(false);
-       do_test_simple_monitor_temporary_update_fail(true);
+       do_test_simple_monitor_temporary_update_fail(false, false);
+       do_test_simple_monitor_temporary_update_fail(true, false);
+
+       // Test behavior when the persister returns a TemporaryFailure.
+       do_test_simple_monitor_temporary_update_fail(false, true);
+       do_test_simple_monitor_temporary_update_fail(true, true);
 }
 
 fn do_test_monitor_temporary_update_fail(disconnect_count: usize) {
@@ -191,7 +291,7 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) {
        // Now try to send a second payment which will fail to send
        let (payment_preimage_2, payment_hash_2) = get_payment_preimage_hash!(nodes[0]);
        {
-               *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+               *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
                let net_graph_msg_handler = &nodes[0].net_graph_msg_handler;
                let route = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler.network_graph.read().unwrap(), &nodes[1].node.get_our_node_id(), None, &Vec::new(), 1000000, TEST_FINAL_CLTV, &logger).unwrap();
                unwrap_send_err!(nodes[0].node.send_payment(&route, payment_hash_2, &None), false, APIError::MonitorUpdateFailed, {});
@@ -245,7 +345,7 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) {
        }
 
        // Now fix monitor updating...
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[0].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[0], 0);
@@ -532,14 +632,14 @@ fn test_monitor_update_fail_cs() {
        let send_event = SendEvent::from_event(nodes[0].node.get_and_clear_pending_msg_events().remove(0));
        nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        nodes[1].node.handle_commitment_signed(&nodes[0].node.get_our_node_id(), &send_event.commitment_msg);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
        nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1);
        check_added_monitors!(nodes[1], 1);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[1], 0);
@@ -563,7 +663,7 @@ fn test_monitor_update_fail_cs() {
                        assert!(updates.update_fee.is_none());
                        assert_eq!(*node_id, nodes[0].node.get_our_node_id());
 
-                       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+                       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
                        nodes[0].node.handle_commitment_signed(&nodes[1].node.get_our_node_id(), &updates.commitment_signed);
                        assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty());
                        nodes[0].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1);
@@ -573,7 +673,7 @@ fn test_monitor_update_fail_cs() {
                _ => panic!("Unexpected event"),
        }
 
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[0].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[0], 0);
@@ -622,7 +722,7 @@ fn test_monitor_update_fail_no_rebroadcast() {
        nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]);
        let bs_raa = commitment_signed_dance!(nodes[1], nodes[0], send_event.commitment_msg, false, true, false, true);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        nodes[1].node.handle_revoke_and_ack(&nodes[0].node.get_our_node_id(), &bs_raa);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
        nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1);
@@ -630,7 +730,7 @@ fn test_monitor_update_fail_no_rebroadcast() {
        assert!(nodes[1].node.get_and_clear_pending_events().is_empty());
        check_added_monitors!(nodes[1], 1);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
@@ -684,7 +784,7 @@ fn test_monitor_update_raa_while_paused() {
        check_added_monitors!(nodes[1], 1);
        let bs_raa = get_event_msg!(nodes[1], MessageSendEvent::SendRevokeAndACK, nodes[0].node.get_our_node_id());
 
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        nodes[0].node.handle_update_add_htlc(&nodes[1].node.get_our_node_id(), &send_event_2.msgs[0]);
        nodes[0].node.handle_commitment_signed(&nodes[1].node.get_our_node_id(), &send_event_2.commitment_msg);
        assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty());
@@ -696,7 +796,7 @@ fn test_monitor_update_raa_while_paused() {
        nodes[0].logger.assert_log("lightning::ln::channelmanager".to_string(), "Previous monitor update failure prevented responses to RAA".to_string(), 1);
        check_added_monitors!(nodes[0], 1);
 
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[0].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[0], 0);
@@ -779,7 +879,7 @@ fn do_test_monitor_update_fail_raa(test_ignore_second_cs: bool) {
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
 
        // Now fail monitor updating.
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        nodes[1].node.handle_revoke_and_ack(&nodes[2].node.get_our_node_id(), &bs_revoke_and_ack);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
        nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1);
@@ -797,7 +897,7 @@ fn do_test_monitor_update_fail_raa(test_ignore_second_cs: bool) {
                check_added_monitors!(nodes[0], 1);
        }
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(()); // We succeed in updating the monitor for the first channel
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(())); // We succeed in updating the monitor for the first channel
        send_event = SendEvent::from_event(nodes[0].node.get_and_clear_pending_msg_events().remove(0));
        nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &send_event.msgs[0]);
        commitment_signed_dance!(nodes[1], nodes[0], send_event.commitment_msg, false, true);
@@ -858,7 +958,7 @@ fn do_test_monitor_update_fail_raa(test_ignore_second_cs: bool) {
 
        // Restore monitor updating, ensuring we immediately get a fail-back update and a
        // update_add update.
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&chan_2.2).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[1], 0);
@@ -1020,7 +1120,7 @@ fn test_monitor_update_fail_reestablish() {
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
        commitment_signed_dance!(nodes[1], nodes[2], updates.commitment_signed, false);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        nodes[0].node.peer_connected(&nodes[1].node.get_our_node_id(), &msgs::Init { features: InitFeatures::empty() });
        nodes[1].node.peer_connected(&nodes[0].node.get_our_node_id(), &msgs::Init { features: InitFeatures::empty() });
 
@@ -1049,7 +1149,7 @@ fn test_monitor_update_fail_reestablish() {
        check_added_monitors!(nodes[1], 0);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&chan_1.2).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[1], 0);
@@ -1123,7 +1223,7 @@ fn raa_no_response_awaiting_raa_state() {
        // Now we have a CS queued up which adds a new HTLC (which will need a RAA/CS response from
        // nodes[1]) followed by an RAA. Fail the monitor updating prior to the CS, deliver the RAA,
        // then restore channel monitor updates.
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &payment_event.msgs[0]);
        nodes[1].node.handle_commitment_signed(&nodes[0].node.get_our_node_id(), &payment_event.commitment_msg);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
@@ -1135,7 +1235,7 @@ fn raa_no_response_awaiting_raa_state() {
        nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Previous monitor update failure prevented responses to RAA".to_string(), 1);
        check_added_monitors!(nodes[1], 1);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        // nodes[1] should be AwaitingRAA here!
@@ -1228,7 +1328,7 @@ fn claim_while_disconnected_monitor_update_fail() {
 
        // Now deliver a's reestablish, freeing the claim from the holding cell, but fail the monitor
        // update.
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
 
        nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &as_reconnect);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
@@ -1257,7 +1357,7 @@ fn claim_while_disconnected_monitor_update_fail() {
 
        // Now un-fail the monitor, which will result in B sending its original commitment update,
        // receiving the commitment update from A, and the resulting commitment dances.
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[1], 0);
@@ -1342,7 +1442,7 @@ fn monitor_failed_no_reestablish_response() {
                check_added_monitors!(nodes[0], 1);
        }
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        let mut events = nodes[0].node.get_and_clear_pending_msg_events();
        assert_eq!(events.len(), 1);
        let payment_event = SendEvent::from_event(events.pop().unwrap());
@@ -1366,7 +1466,7 @@ fn monitor_failed_no_reestablish_response() {
        nodes[1].node.handle_channel_reestablish(&nodes[0].node.get_our_node_id(), &as_reconnect);
        nodes[0].node.handle_channel_reestablish(&nodes[1].node.get_our_node_id(), &bs_reconnect);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[1], 0);
@@ -1445,7 +1545,7 @@ fn first_message_on_recv_ordering() {
        let payment_event = SendEvent::from_event(events.pop().unwrap());
        assert_eq!(payment_event.node_id, nodes[1].node.get_our_node_id());
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
 
        // Deliver the final RAA for the first payment, which does not require a response. RAAs
        // generally require a commitment_signed, so the fact that we're expecting an opposite response
@@ -1464,7 +1564,7 @@ fn first_message_on_recv_ordering() {
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
        nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Previous monitor update failure prevented generation of RAA".to_string(), 1);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[1], 0);
@@ -1509,7 +1609,7 @@ fn test_monitor_update_fail_claim() {
 
        let (payment_preimage_1, _) = route_payment(&nodes[0], &[&nodes[1]], 1000000);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        assert!(nodes[1].node.claim_funds(payment_preimage_1, &None, 1_000_000));
        check_added_monitors!(nodes[1], 1);
 
@@ -1523,7 +1623,7 @@ fn test_monitor_update_fail_claim() {
 
        // Successfully update the monitor on the 1<->2 channel, but the 0<->1 channel should still be
        // paused, so forward shouldn't succeed until we call channel_monitor_updated().
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
 
        let mut events = nodes[2].node.get_and_clear_pending_msg_events();
        assert_eq!(events.len(), 1);
@@ -1612,13 +1712,13 @@ fn test_monitor_update_on_pending_forwards() {
        nodes[1].node.handle_update_add_htlc(&nodes[2].node.get_our_node_id(), &payment_event.msgs[0]);
        commitment_signed_dance!(nodes[1], nodes[2], payment_event.commitment_msg, false);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        expect_pending_htlcs_forwardable!(nodes[1]);
        check_added_monitors!(nodes[1], 1);
        assert!(nodes[1].node.get_and_clear_pending_msg_events().is_empty());
        nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&chan_1.2).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[1], 0);
@@ -1675,14 +1775,14 @@ fn monitor_update_claim_fail_no_response() {
        nodes[1].node.handle_update_add_htlc(&nodes[0].node.get_our_node_id(), &payment_event.msgs[0]);
        let as_raa = commitment_signed_dance!(nodes[1], nodes[0], payment_event.commitment_msg, false, true, false, true);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        assert!(nodes[1].node.claim_funds(payment_preimage_1, &None, 1_000_000));
        check_added_monitors!(nodes[1], 1);
        let events = nodes[1].node.get_and_clear_pending_msg_events();
        assert_eq!(events.len(), 0);
        nodes[1].logger.assert_log("lightning::ln::channelmanager".to_string(), "Temporary failure claiming HTLC, treating as success: Failed to update ChannelMonitor".to_string(), 1);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[1], 0);
@@ -1728,19 +1828,19 @@ fn do_during_funding_monitor_fail(confirm_a_first: bool, restore_b_before_conf:
        nodes[0].node.funding_transaction_generated(&temporary_channel_id, funding_output);
        check_added_monitors!(nodes[0], 0);
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        let funding_created_msg = get_event_msg!(nodes[0], MessageSendEvent::SendFundingCreated, nodes[1].node.get_our_node_id());
        let channel_id = OutPoint { txid: funding_created_msg.funding_txid, index: funding_created_msg.funding_output_index }.to_channel_id();
        nodes[1].node.handle_funding_created(&nodes[0].node.get_our_node_id(), &funding_created_msg);
        check_added_monitors!(nodes[1], 1);
 
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Err(ChannelMonitorUpdateErr::TemporaryFailure);
+       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
        nodes[0].node.handle_funding_signed(&nodes[1].node.get_our_node_id(), &get_event_msg!(nodes[1], MessageSendEvent::SendFundingSigned, nodes[0].node.get_our_node_id()));
        assert!(nodes[0].node.get_and_clear_pending_msg_events().is_empty());
        nodes[0].logger.assert_log("lightning::ln::channelmanager".to_string(), "Failed to update ChannelMonitor".to_string(), 1);
        check_added_monitors!(nodes[0], 1);
        assert!(nodes[0].node.get_and_clear_pending_events().is_empty());
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[0].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[0].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[0], 0);
@@ -1777,7 +1877,7 @@ fn do_during_funding_monitor_fail(confirm_a_first: bool, restore_b_before_conf:
                assert!(nodes[1].node.get_and_clear_pending_events().is_empty());
        }
 
-       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[1].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        let (outpoint, latest_update) = nodes[1].chain_monitor.latest_monitor_update_id.lock().unwrap().get(&channel_id).unwrap().clone();
        nodes[1].node.channel_monitor_updated(&outpoint, latest_update);
        check_added_monitors!(nodes[1], 0);
@@ -1843,7 +1943,7 @@ fn test_path_paused_mpp() {
 
        // Set it so that the first monitor update (for the path 0 -> 1 -> 3) succeeds, but the second
        // (for the path 0 -> 2 -> 3) fails.
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
        *nodes[0].chain_monitor.next_update_ret.lock().unwrap() = Some(Err(ChannelMonitorUpdateErr::TemporaryFailure));
 
        // Now check that we get the right return value, indicating that the first path succeeded but
@@ -1855,7 +1955,7 @@ fn test_path_paused_mpp() {
                if let Err(APIError::MonitorUpdateFailed) = results[1] {} else { panic!(); }
        } else { panic!(); }
        check_added_monitors!(nodes[0], 2);
-       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Ok(());
+       *nodes[0].chain_monitor.update_ret.lock().unwrap() = Some(Ok(()));
 
        // Pass the first HTLC of the payment along to nodes[3].
        let mut events = nodes[0].node.get_and_clear_pending_msg_events();
index 8782bb3ac405569cd9d029949c939cd38b0041ae..c34fc6a38c28e5ae89a5cd98e094ee7a5348781a 100644 (file)
@@ -4017,12 +4017,6 @@ impl<ChanSigner: ChannelKeys> Channel<ChanSigner> {
                        }
                }
 
-               for _htlc in self.pending_outbound_htlcs.drain(..) {
-                       //TODO: Do something with the remaining HTLCs
-                       //(we need to have the ChannelManager monitor them so we can claim the inbound HTLCs
-                       //which correspond)
-               }
-
                self.channel_state = ChannelState::ShutdownComplete as u32;
                self.update_time_counter += 1;
                self.latest_monitor_update_id += 1;
index 379fe19300938127f0602d2c323b532f8604e97c..b43c98c840392c7b983de276e0867ba5e3baf795 100644 (file)
@@ -405,9 +405,9 @@ pub struct ChannelManager<ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref,
        last_block_hash: Mutex<BlockHash>,
        secp_ctx: Secp256k1<secp256k1::All>,
 
-       #[cfg(test)]
+       #[cfg(any(test, feature = "_test_utils"))]
        pub(super) channel_state: Mutex<ChannelHolder<ChanSigner>>,
-       #[cfg(not(test))]
+       #[cfg(not(any(test, feature = "_test_utils")))]
        channel_state: Mutex<ChannelHolder<ChanSigner>>,
        our_network_key: SecretKey,
 
index 7d54b6a92166d9c3378e4f0bf3431d157c28d37d..bc8351e4da09e333b5bcc481d8a94e079187753e 100644 (file)
@@ -94,6 +94,7 @@ pub struct TestChanMonCfg {
        pub tx_broadcaster: test_utils::TestBroadcaster,
        pub fee_estimator: test_utils::TestFeeEstimator,
        pub chain_source: test_utils::TestChainSource,
+       pub persister: test_utils::TestPersister,
        pub logger: test_utils::TestLogger,
 }
 
@@ -169,7 +170,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                let old_monitors = self.chain_monitor.chain_monitor.monitors.lock().unwrap();
                                for (_, old_monitor) in old_monitors.iter() {
                                        let mut w = test_utils::TestVecWriter(Vec::new());
-                                       old_monitor.write_for_disk(&mut w).unwrap();
+                                       old_monitor.serialize_for_disk(&mut w).unwrap();
                                        let (_, deserialized_monitor) = <(BlockHash, ChannelMonitor<EnforcingChannelKeys>)>::read(
                                                &mut ::std::io::Cursor::new(&w.0)).unwrap();
                                        deserialized_monitors.push(deserialized_monitor);
@@ -191,14 +192,20 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                        keys_manager: self.keys_manager,
                                        fee_estimator: &test_utils::TestFeeEstimator { sat_per_kw: 253 },
                                        chain_monitor: self.chain_monitor,
-                                       tx_broadcaster: self.tx_broadcaster.clone(),
+                                       tx_broadcaster: &test_utils::TestBroadcaster {
+                                               txn_broadcasted: Mutex::new(self.tx_broadcaster.txn_broadcasted.lock().unwrap().clone())
+                                       },
                                        logger: &test_utils::TestLogger::new(),
                                        channel_monitors,
                                }).unwrap();
                        }
 
+                       let persister = test_utils::TestPersister::new();
+                       let broadcaster = test_utils::TestBroadcaster {
+                               txn_broadcasted: Mutex::new(self.tx_broadcaster.txn_broadcasted.lock().unwrap().clone())
+                       };
                        let chain_source = test_utils::TestChainSource::new(Network::Testnet);
-                       let chain_monitor = test_utils::TestChainMonitor::new(Some(&chain_source), self.tx_broadcaster.clone(), &self.logger, &feeest);
+                       let chain_monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &broadcaster, &self.logger, &feeest, &persister);
                        for deserialized_monitor in deserialized_monitors.drain(..) {
                                if let Err(_) = chain_monitor.watch_channel(deserialized_monitor.get_funding_txo().0, deserialized_monitor) {
                                        panic!();
@@ -247,6 +254,8 @@ macro_rules! get_revoke_commit_msgs {
        }
 }
 
+/// Get an specific event message from the pending events queue.
+#[macro_export]
 macro_rules! get_event_msg {
        ($node: expr, $event_type: path, $node_id: expr) => {
                {
@@ -263,6 +272,7 @@ macro_rules! get_event_msg {
        }
 }
 
+#[cfg(test)]
 macro_rules! get_htlc_update_msgs {
        ($node: expr, $node_id: expr) => {
                {
@@ -279,6 +289,7 @@ macro_rules! get_htlc_update_msgs {
        }
 }
 
+#[cfg(test)]
 macro_rules! get_feerate {
        ($node: expr, $channel_id: expr) => {
                {
@@ -289,6 +300,7 @@ macro_rules! get_feerate {
        }
 }
 
+#[cfg(test)]
 macro_rules! get_local_commitment_txn {
        ($node: expr, $channel_id: expr) => {
                {
@@ -305,6 +317,8 @@ macro_rules! get_local_commitment_txn {
        }
 }
 
+/// Check the error from attempting a payment.
+#[macro_export]
 macro_rules! unwrap_send_err {
        ($res: expr, $all_failed: expr, $type: pat, $check: expr) => {
                match &$res {
@@ -327,6 +341,8 @@ macro_rules! unwrap_send_err {
        }
 }
 
+/// Check whether N channel monitor(s) have been added.
+#[macro_export]
 macro_rules! check_added_monitors {
        ($node: expr, $count: expr) => {
                {
@@ -553,6 +569,9 @@ macro_rules! get_closing_signed_broadcast {
        }
 }
 
+/// Check that a channel's closing channel update has been broadcasted, and optionally
+/// check whether an error message event has occurred.
+#[macro_export]
 macro_rules! check_closed_broadcast {
        ($node: expr, $with_error_msg: expr) => {{
                let events = $node.node.get_and_clear_pending_msg_events();
@@ -750,6 +769,8 @@ macro_rules! commitment_signed_dance {
        }
 }
 
+/// Get a payment preimage and hash.
+#[macro_export]
 macro_rules! get_payment_preimage_hash {
        ($node: expr) => {
                {
@@ -779,6 +800,7 @@ macro_rules! expect_pending_htlcs_forwardable {
        }}
 }
 
+#[cfg(test)]
 macro_rules! expect_payment_received {
        ($node: expr, $expected_payment_hash: expr, $expected_recv_value: expr) => {
                let events = $node.node.get_and_clear_pending_events();
@@ -807,6 +829,7 @@ macro_rules! expect_payment_sent {
        }
 }
 
+#[cfg(test)]
 macro_rules! expect_payment_failed {
        ($node: expr, $expected_payment_hash: expr, $rejected_by_dest: expr $(, $expected_error_code: expr, $expected_error_data: expr)*) => {
                let events = $node.node.get_and_clear_pending_events();
@@ -1105,7 +1128,8 @@ pub fn create_chanmon_cfgs(node_count: usize) -> Vec<TestChanMonCfg> {
                let fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 };
                let chain_source = test_utils::TestChainSource::new(Network::Testnet);
                let logger = test_utils::TestLogger::with_id(format!("node {}", i));
-               chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger });
+               let persister = test_utils::TestPersister::new();
+               chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger, persister });
        }
 
        chan_mon_cfgs
@@ -1117,7 +1141,7 @@ pub fn create_node_cfgs<'a>(node_count: usize, chanmon_cfgs: &'a Vec<TestChanMon
        for i in 0..node_count {
                let seed = [i as u8; 32];
                let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
-               let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[i].chain_source), &chanmon_cfgs[i].tx_broadcaster, &chanmon_cfgs[i].logger, &chanmon_cfgs[i].fee_estimator);
+               let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[i].chain_source), &chanmon_cfgs[i].tx_broadcaster, &chanmon_cfgs[i].logger, &chanmon_cfgs[i].fee_estimator, &chanmon_cfgs[i].persister);
                nodes.push(NodeCfg { chain_source: &chanmon_cfgs[i].chain_source, logger: &chanmon_cfgs[i].logger, tx_broadcaster: &chanmon_cfgs[i].tx_broadcaster, fee_estimator: &chanmon_cfgs[i].fee_estimator, chain_monitor, keys_manager, node_seed: seed });
        }
 
@@ -1131,7 +1155,7 @@ pub fn create_node_chanmgrs<'a, 'b>(node_count: usize, cfgs: &'a Vec<NodeCfg<'b>
                default_config.channel_options.announced_channel = true;
                default_config.peer_channel_config_limits.force_announced_channel_preference = false;
                default_config.own_channel_config.our_htlc_minimum_msat = 1000; // sanitization being done by the sender, to exerce receiver logic we need to lift of limit
-               let node = ChannelManager::new(Network::Testnet, cfgs[i].fee_estimator, &cfgs[i].chain_monitor, cfgs[i].tx_broadcaster, cfgs[i].logger.clone(), &cfgs[i].keys_manager, if node_config[i].is_some() { node_config[i].clone().unwrap() } else { default_config }, 0);
+               let node = ChannelManager::new(Network::Testnet, cfgs[i].fee_estimator, &cfgs[i].chain_monitor, cfgs[i].tx_broadcaster, cfgs[i].logger, &cfgs[i].keys_manager, if node_config[i].is_some() { node_config[i].clone().unwrap() } else { default_config }, 0);
                chanmgrs.push(node);
        }
 
@@ -1284,6 +1308,7 @@ pub fn get_announce_close_broadcast_events<'a, 'b, 'c>(nodes: &Vec<Node<'a, 'b,
        }
 }
 
+#[cfg(test)]
 macro_rules! get_channel_value_stat {
        ($node: expr, $channel_id: expr) => {{
                let chan_lock = $node.node.channel_state.lock().unwrap();
index c94b220ce8702877b80a9912e39d9fbf9aa0b3f3..a2e12504031780718736fd9086f57a656da826b3 100644 (file)
@@ -4307,6 +4307,7 @@ fn test_no_txn_manager_serialize_deserialize() {
        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
        let logger: test_utils::TestLogger;
        let fee_estimator: test_utils::TestFeeEstimator;
+       let persister: test_utils::TestPersister;
        let new_chain_monitor: test_utils::TestChainMonitor;
        let keys_manager: test_utils::TestKeysInterface;
        let nodes_0_deserialized: ChannelManager<EnforcingChannelKeys, &test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
@@ -4318,11 +4319,12 @@ fn test_no_txn_manager_serialize_deserialize() {
 
        let nodes_0_serialized = nodes[0].node.encode();
        let mut chan_0_monitor_serialized = test_utils::TestVecWriter(Vec::new());
-       nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.write_for_disk(&mut chan_0_monitor_serialized).unwrap();
+       nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.serialize_for_disk(&mut chan_0_monitor_serialized).unwrap();
 
        logger = test_utils::TestLogger::new();
        fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 };
-       new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator);
+       persister = test_utils::TestPersister::new();
+       new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister);
        nodes[0].chain_monitor = &new_chain_monitor;
        let mut chan_0_monitor_read = &chan_0_monitor_serialized.0[..];
        let (_, mut chan_0_monitor) = <(BlockHash, ChannelMonitor<EnforcingChannelKeys>)>::read(&mut chan_0_monitor_read).unwrap();
@@ -4380,6 +4382,7 @@ fn test_manager_serialize_deserialize_events() {
        let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
        let fee_estimator: test_utils::TestFeeEstimator;
+       let persister: test_utils::TestPersister;
        let logger: test_utils::TestLogger;
        let new_chain_monitor: test_utils::TestChainMonitor;
        let keys_manager: test_utils::TestKeysInterface;
@@ -4425,11 +4428,12 @@ fn test_manager_serialize_deserialize_events() {
        // Start the de/seriailization process mid-channel creation to check that the channel manager will hold onto events that are serialized
        let nodes_0_serialized = nodes[0].node.encode();
        let mut chan_0_monitor_serialized = test_utils::TestVecWriter(Vec::new());
-       nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.write_for_disk(&mut chan_0_monitor_serialized).unwrap();
+       nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.serialize_for_disk(&mut chan_0_monitor_serialized).unwrap();
 
        fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 };
        logger = test_utils::TestLogger::new();
-       new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator);
+       persister = test_utils::TestPersister::new();
+       new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister);
        nodes[0].chain_monitor = &new_chain_monitor;
        let mut chan_0_monitor_read = &chan_0_monitor_serialized.0[..];
        let (_, mut chan_0_monitor) = <(BlockHash, ChannelMonitor<EnforcingChannelKeys>)>::read(&mut chan_0_monitor_read).unwrap();
@@ -4502,6 +4506,7 @@ fn test_simple_manager_serialize_deserialize() {
        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
        let logger: test_utils::TestLogger;
        let fee_estimator: test_utils::TestFeeEstimator;
+       let persister: test_utils::TestPersister;
        let new_chain_monitor: test_utils::TestChainMonitor;
        let keys_manager: test_utils::TestKeysInterface;
        let nodes_0_deserialized: ChannelManager<EnforcingChannelKeys, &test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
@@ -4515,11 +4520,12 @@ fn test_simple_manager_serialize_deserialize() {
 
        let nodes_0_serialized = nodes[0].node.encode();
        let mut chan_0_monitor_serialized = test_utils::TestVecWriter(Vec::new());
-       nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.write_for_disk(&mut chan_0_monitor_serialized).unwrap();
+       nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.serialize_for_disk(&mut chan_0_monitor_serialized).unwrap();
 
        logger = test_utils::TestLogger::new();
        fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 };
-       new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator);
+       persister = test_utils::TestPersister::new();
+       new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister);
        nodes[0].chain_monitor = &new_chain_monitor;
        let mut chan_0_monitor_read = &chan_0_monitor_serialized.0[..];
        let (_, mut chan_0_monitor) = <(BlockHash, ChannelMonitor<EnforcingChannelKeys>)>::read(&mut chan_0_monitor_read).unwrap();
@@ -4561,6 +4567,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() {
        let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]);
        let logger: test_utils::TestLogger;
        let fee_estimator: test_utils::TestFeeEstimator;
+       let persister: test_utils::TestPersister;
        let new_chain_monitor: test_utils::TestChainMonitor;
        let keys_manager: test_utils::TestKeysInterface;
        let nodes_0_deserialized: ChannelManager<EnforcingChannelKeys, &test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestLogger>;
@@ -4572,7 +4579,7 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() {
        let mut node_0_stale_monitors_serialized = Vec::new();
        for monitor in nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter() {
                let mut writer = test_utils::TestVecWriter(Vec::new());
-               monitor.1.write_for_disk(&mut writer).unwrap();
+               monitor.1.serialize_for_disk(&mut writer).unwrap();
                node_0_stale_monitors_serialized.push(writer.0);
        }
 
@@ -4591,13 +4598,14 @@ fn test_manager_serialize_deserialize_inconsistent_monitor() {
        let mut node_0_monitors_serialized = Vec::new();
        for monitor in nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter() {
                let mut writer = test_utils::TestVecWriter(Vec::new());
-               monitor.1.write_for_disk(&mut writer).unwrap();
+               monitor.1.serialize_for_disk(&mut writer).unwrap();
                node_0_monitors_serialized.push(writer.0);
        }
 
        logger = test_utils::TestLogger::new();
        fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 };
-       new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator);
+       persister = test_utils::TestPersister::new();
+       new_chain_monitor = test_utils::TestChainMonitor::new(Some(nodes[0].chain_source), nodes[0].tx_broadcaster.clone(), &logger, &fee_estimator, &persister);
        nodes[0].chain_monitor = &new_chain_monitor;
 
        let mut node_0_stale_monitors = Vec::new();
@@ -5742,7 +5750,7 @@ fn test_key_derivation_params() {
        // We manually create the node configuration to backup the seed.
        let seed = [42; 32];
        let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
-       let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator);
+       let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &chanmon_cfgs[0].persister);
        let node = NodeCfg { chain_source: &chanmon_cfgs[0].chain_source, logger: &chanmon_cfgs[0].logger, tx_broadcaster: &chanmon_cfgs[0].tx_broadcaster, fee_estimator: &chanmon_cfgs[0].fee_estimator, chain_monitor, keys_manager, node_seed: seed };
        let mut node_cfgs = create_node_cfgs(3, &chanmon_cfgs);
        node_cfgs.remove(0);
@@ -7157,60 +7165,6 @@ fn test_failure_delay_dust_htlc_local_commitment() {
        do_test_failure_delay_dust_htlc_local_commitment(false);
 }
 
-#[test]
-fn test_no_failure_dust_htlc_local_commitment() {
-       // Transaction filters for failing back dust htlc based on local commitment txn infos has been
-       // prone to error, we test here that a dummy transaction don't fail them.
-
-       let chanmon_cfgs = create_chanmon_cfgs(2);
-       let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
-       let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
-       let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
-       let chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
-
-       // Rebalance a bit
-       send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000, 8_000_000);
-
-       let as_dust_limit = nodes[0].node.channel_state.lock().unwrap().by_id.get(&chan.2).unwrap().holder_dust_limit_satoshis;
-       let bs_dust_limit = nodes[1].node.channel_state.lock().unwrap().by_id.get(&chan.2).unwrap().holder_dust_limit_satoshis;
-
-       // We route 2 dust-HTLCs between A and B
-       let (preimage_1, _) = route_payment(&nodes[0], &[&nodes[1]], bs_dust_limit*1000);
-       let (preimage_2, _) = route_payment(&nodes[1], &[&nodes[0]], as_dust_limit*1000);
-
-       // Build a dummy invalid transaction trying to spend a commitment tx
-       let input = TxIn {
-               previous_output: BitcoinOutPoint { txid: chan.3.txid(), vout: 0 },
-               script_sig: Script::new(),
-               sequence: 0,
-               witness: Vec::new(),
-       };
-
-       let outp = TxOut {
-               script_pubkey: Builder::new().push_opcode(opcodes::all::OP_RETURN).into_script(),
-               value: 10000,
-       };
-
-       let dummy_tx = Transaction {
-               version: 2,
-               lock_time: 0,
-               input: vec![input],
-               output: vec![outp]
-       };
-
-       let header = BlockHeader { version: 0x20000000, prev_blockhash: Default::default(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 };
-       nodes[0].chain_monitor.chain_monitor.block_connected(&header, &[(0, &dummy_tx)], 1);
-       assert_eq!(nodes[0].node.get_and_clear_pending_events().len(), 0);
-       assert_eq!(nodes[0].node.get_and_clear_pending_msg_events().len(), 0);
-       // We broadcast a few more block to check everything is all right
-       connect_blocks(&nodes[0], 20, 1, true, header.block_hash());
-       assert_eq!(nodes[0].node.get_and_clear_pending_events().len(), 0);
-       assert_eq!(nodes[0].node.get_and_clear_pending_msg_events().len(), 0);
-
-       claim_payment(&nodes[0], &vec!(&nodes[1])[..], preimage_1, bs_dust_limit*1000);
-       claim_payment(&nodes[1], &vec!(&nodes[0])[..], preimage_2, as_dust_limit*1000);
-}
-
 fn do_test_sweep_outbound_htlc_failure_update(revoked: bool, local: bool) {
        // Outbound HTLC-failure updates must be cancelled if we get a reorg before we reach ANTI_REORG_DELAY.
        // Broadcast of revoked remote commitment tx, trigger failure-update of dust/non-dust HTLCs
@@ -7461,6 +7415,7 @@ fn test_data_loss_protect() {
        // * we close channel in case of detecting other being fallen behind
        // * we are able to claim our own outputs thanks to to_remote being static
        let keys_manager;
+       let persister;
        let logger;
        let fee_estimator;
        let tx_broadcaster;
@@ -7477,7 +7432,7 @@ fn test_data_loss_protect() {
        // Cache node A state before any channel update
        let previous_node_state = nodes[0].node.encode();
        let mut previous_chain_monitor_state = test_utils::TestVecWriter(Vec::new());
-       nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.write_for_disk(&mut previous_chain_monitor_state).unwrap();
+       nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap().iter().next().unwrap().1.serialize_for_disk(&mut previous_chain_monitor_state).unwrap();
 
        send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000, 8_000_000);
        send_payment(&nodes[0], &vec!(&nodes[1])[..], 8000000, 8_000_000);
@@ -7492,7 +7447,8 @@ fn test_data_loss_protect() {
        tx_broadcaster = test_utils::TestBroadcaster{txn_broadcasted: Mutex::new(Vec::new())};
        fee_estimator = test_utils::TestFeeEstimator { sat_per_kw: 253 };
        keys_manager = test_utils::TestKeysInterface::new(&nodes[0].node_seed, Network::Testnet);
-       monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &fee_estimator);
+       persister = test_utils::TestPersister::new();
+       monitor = test_utils::TestChainMonitor::new(Some(&chain_source), &tx_broadcaster, &logger, &fee_estimator, &persister);
        node_state_0 = {
                let mut channel_monitors = HashMap::new();
                channel_monitors.insert(OutPoint { txid: chan.3.txid(), index: 0 }, &mut chain_monitor);
@@ -8353,15 +8309,16 @@ fn test_update_err_monitor_lockdown() {
        // Copy ChainMonitor to simulate a watchtower and update block height of node 0 until its ChannelMonitor timeout HTLC onchain
        let chain_source = test_utils::TestChainSource::new(Network::Testnet);
        let logger = test_utils::TestLogger::with_id(format!("node {}", 0));
+       let persister = test_utils::TestPersister::new();
        let watchtower = {
                let monitors = nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap();
                let monitor = monitors.get(&outpoint).unwrap();
                let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write_for_disk(&mut w).unwrap();
+               monitor.serialize_for_disk(&mut w).unwrap();
                let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::read(
                                &mut ::std::io::Cursor::new(&w.0)).unwrap().1;
                assert!(new_monitor == *monitor);
-               let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator);
+               let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister);
                assert!(watchtower.watch_channel(outpoint, new_monitor).is_ok());
                watchtower
        };
@@ -8411,15 +8368,16 @@ fn test_concurrent_monitor_claim() {
        // Copy ChainMonitor to simulate watchtower Alice and update block height her ChannelMonitor timeout HTLC onchain
        let chain_source = test_utils::TestChainSource::new(Network::Testnet);
        let logger = test_utils::TestLogger::with_id(format!("node {}", "Alice"));
+       let persister = test_utils::TestPersister::new();
        let watchtower_alice = {
                let monitors = nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap();
                let monitor = monitors.get(&outpoint).unwrap();
                let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write_for_disk(&mut w).unwrap();
+               monitor.serialize_for_disk(&mut w).unwrap();
                let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::read(
                                &mut ::std::io::Cursor::new(&w.0)).unwrap().1;
                assert!(new_monitor == *monitor);
-               let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator);
+               let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister);
                assert!(watchtower.watch_channel(outpoint, new_monitor).is_ok());
                watchtower
        };
@@ -8436,15 +8394,16 @@ fn test_concurrent_monitor_claim() {
        // Copy ChainMonitor to simulate watchtower Bob and make it receive a commitment update first.
        let chain_source = test_utils::TestChainSource::new(Network::Testnet);
        let logger = test_utils::TestLogger::with_id(format!("node {}", "Bob"));
+       let persister = test_utils::TestPersister::new();
        let watchtower_bob = {
                let monitors = nodes[0].chain_monitor.chain_monitor.monitors.lock().unwrap();
                let monitor = monitors.get(&outpoint).unwrap();
                let mut w = test_utils::TestVecWriter(Vec::new());
-               monitor.write_for_disk(&mut w).unwrap();
+               monitor.serialize_for_disk(&mut w).unwrap();
                let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::read(
                                &mut ::std::io::Cursor::new(&w.0)).unwrap().1;
                assert!(new_monitor == *monitor);
-               let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator);
+               let watchtower = test_utils::TestChainMonitor::new(Some(&chain_source), &chanmon_cfgs[0].tx_broadcaster, &logger, &chanmon_cfgs[0].fee_estimator, &persister);
                assert!(watchtower.watch_channel(outpoint, new_monitor).is_ok());
                watchtower
        };
@@ -8497,3 +8456,50 @@ fn test_concurrent_monitor_claim() {
                check_spends!(htlc_txn[1], bob_state_y);
        }
 }
+
+#[test]
+fn test_htlc_no_detection() {
+       // This test is a mutation to underscore the detection logic bug we had
+        // before #653. HTLC value routed is above the remaining balance, thus
+        // inverting HTLC and `to_remote` output. HTLC will come second and
+        // it wouldn't be seen by pre-#653 detection as we were enumerate()'ing
+        // on a watched outputs vector (Vec<TxOut>) thus implicitly relying on
+        // outputs order detection for correct spending children filtring.
+
+        let chanmon_cfgs = create_chanmon_cfgs(2);
+        let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+        let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
+        let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+        // Create some initial channels
+        let chan_1 = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 100000, 10001, InitFeatures::known(), InitFeatures::known());
+
+        send_payment(&nodes[0], &vec!(&nodes[1])[..], 1_000_000, 1_000_000);
+        let (_, our_payment_hash) = route_payment(&nodes[0], &vec!(&nodes[1])[..], 2_000_000);
+        let local_txn = get_local_commitment_txn!(nodes[0], chan_1.2);
+        assert_eq!(local_txn[0].input.len(), 1);
+        assert_eq!(local_txn[0].output.len(), 3);
+        check_spends!(local_txn[0], chan_1.3);
+
+        // Timeout HTLC on A's chain and so it can generate a HTLC-Timeout tx
+        let header = BlockHeader { version: 0x20000000, prev_blockhash: Default::default(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 };
+        connect_block(&nodes[0], &Block { header, txdata: vec![local_txn[0].clone()] }, 200);
+       // We deliberately connect the local tx twice as this should provoke a failure calling
+       // this test before #653 fix.
+        connect_block(&nodes[0], &Block { header, txdata: vec![local_txn[0].clone()] }, 200);
+        check_closed_broadcast!(nodes[0], false);
+        check_added_monitors!(nodes[0], 1);
+
+        let htlc_timeout = {
+                let node_txn = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap();
+                assert_eq!(node_txn[0].input.len(), 1);
+                assert_eq!(node_txn[0].input[0].witness.last().unwrap().len(), OFFERED_HTLC_SCRIPT_WEIGHT);
+                check_spends!(node_txn[0], local_txn[0]);
+                node_txn[0].clone()
+        };
+
+        let header_201 = BlockHeader { version: 0x20000000, prev_blockhash: header.block_hash(), merkle_root: Default::default(), time: 42, bits: 42, nonce: 42 };
+        connect_block(&nodes[0], &Block { header: header_201, txdata: vec![htlc_timeout.clone()] }, 201);
+        connect_blocks(&nodes[0], ANTI_REORG_DELAY - 1, 201, true, header_201.block_hash());
+        expect_payment_failed!(nodes[0], our_payment_hash, true);
+}
index cd959a74dc53cc0e176a25e175735011aea6a8ec..1afcb3530fe16b995a4671c563da98d8e8e8d115 100644 (file)
@@ -38,9 +38,9 @@ mod wire;
 // without the node parameter being mut. This is incorrect, and thus newer rustcs will complain
 // about an unnecessary mut. Thus, we silence the unused_mut warning in two test modules below.
 
-#[cfg(test)]
+#[cfg(any(test, feature = "_test_utils"))]
 #[macro_use]
-pub(crate) mod functional_test_utils;
+pub mod functional_test_utils;
 #[cfg(test)]
 #[allow(unused_mut)]
 mod functional_tests;
index 84c8e6c14daa841c2d01772b3cbe15f69bbe0ea7..a3cfaa9804e6e9f58c3521e6ab2d07c3862fd0b6 100644 (file)
@@ -32,10 +32,10 @@ pub(crate) mod macro_logger;
 pub mod logger;
 pub mod config;
 
-#[cfg(test)]
-pub(crate) mod test_utils;
+#[cfg(any(test, feature = "_test_utils"))]
+pub mod test_utils;
 
 /// impls of traits that add exra enforcement on the way they're called. Useful for detecting state
 /// machine errors and used in fuzz targets and tests.
-#[cfg(any(test, feature = "fuzztarget"))]
+#[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))]
 pub mod enforcing_trait_impls;
index 0370c0e1a402150bbdd749d892e10f895573a1a6..6c3552d5df0853f1b8b2e79592573373c25f7c99 100644 (file)
@@ -63,19 +63,19 @@ impl chaininterface::FeeEstimator for TestFeeEstimator {
 pub struct TestChainMonitor<'a> {
        pub added_monitors: Mutex<Vec<(OutPoint, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>>,
        pub latest_monitor_update_id: Mutex<HashMap<[u8; 32], (OutPoint, u64)>>,
-       pub chain_monitor: chainmonitor::ChainMonitor<EnforcingChannelKeys, &'a TestChainSource, &'a chaininterface::BroadcasterInterface, &'a TestFeeEstimator, &'a TestLogger>,
-       pub update_ret: Mutex<Result<(), channelmonitor::ChannelMonitorUpdateErr>>,
+       pub chain_monitor: chainmonitor::ChainMonitor<EnforcingChannelKeys, &'a TestChainSource, &'a chaininterface::BroadcasterInterface, &'a TestFeeEstimator, &'a TestLogger, &'a channelmonitor::Persist<EnforcingChannelKeys>>,
+       pub update_ret: Mutex<Option<Result<(), channelmonitor::ChannelMonitorUpdateErr>>>,
        // If this is set to Some(), after the next return, we'll always return this until update_ret
        // is changed:
        pub next_update_ret: Mutex<Option<Result<(), channelmonitor::ChannelMonitorUpdateErr>>>,
 }
 impl<'a> TestChainMonitor<'a> {
-       pub fn new(chain_source: Option<&'a TestChainSource>, broadcaster: &'a chaininterface::BroadcasterInterface, logger: &'a TestLogger, fee_estimator: &'a TestFeeEstimator) -> Self {
+       pub fn new(chain_source: Option<&'a TestChainSource>, broadcaster: &'a chaininterface::BroadcasterInterface, logger: &'a TestLogger, fee_estimator: &'a TestFeeEstimator, persister: &'a channelmonitor::Persist<EnforcingChannelKeys>) -> Self {
                Self {
                        added_monitors: Mutex::new(Vec::new()),
                        latest_monitor_update_id: Mutex::new(HashMap::new()),
-                       chain_monitor: chainmonitor::ChainMonitor::new(chain_source, broadcaster, logger, fee_estimator),
-                       update_ret: Mutex::new(Ok(())),
+                       chain_monitor: chainmonitor::ChainMonitor::new(chain_source, broadcaster, logger, fee_estimator, persister),
+                       update_ret: Mutex::new(None),
                        next_update_ret: Mutex::new(None),
                }
        }
@@ -87,19 +87,23 @@ impl<'a> chain::Watch for TestChainMonitor<'a> {
                // At every point where we get a monitor update, we should be able to send a useful monitor
                // to a watchtower and disk...
                let mut w = TestVecWriter(Vec::new());
-               monitor.write_for_disk(&mut w).unwrap();
+               monitor.serialize_for_disk(&mut w).unwrap();
                let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::read(
                        &mut ::std::io::Cursor::new(&w.0)).unwrap().1;
                assert!(new_monitor == monitor);
                self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, monitor.get_latest_update_id()));
                self.added_monitors.lock().unwrap().push((funding_txo, monitor));
-               assert!(self.chain_monitor.watch_channel(funding_txo, new_monitor).is_ok());
+               let watch_res = self.chain_monitor.watch_channel(funding_txo, new_monitor);
 
                let ret = self.update_ret.lock().unwrap().clone();
                if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
-                       *self.update_ret.lock().unwrap() = next_ret;
+                       *self.update_ret.lock().unwrap() = Some(next_ret);
                }
-               ret
+               if ret.is_some() {
+                       assert!(watch_res.is_ok());
+                       return ret.unwrap();
+               }
+               watch_res
        }
 
        fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
@@ -110,23 +114,27 @@ impl<'a> chain::Watch for TestChainMonitor<'a> {
                                &mut ::std::io::Cursor::new(&w.0)).unwrap() == update);
 
                self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, update.update_id));
-               assert!(self.chain_monitor.update_channel(funding_txo, update).is_ok());
+               let update_res = self.chain_monitor.update_channel(funding_txo, update);
                // At every point where we get a monitor update, we should be able to send a useful monitor
                // to a watchtower and disk...
                let monitors = self.chain_monitor.monitors.lock().unwrap();
                let monitor = monitors.get(&funding_txo).unwrap();
                w.0.clear();
-               monitor.write_for_disk(&mut w).unwrap();
+               monitor.serialize_for_disk(&mut w).unwrap();
                let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::read(
-                               &mut ::std::io::Cursor::new(&w.0)).unwrap().1;
+                       &mut ::std::io::Cursor::new(&w.0)).unwrap().1;
                assert!(new_monitor == *monitor);
                self.added_monitors.lock().unwrap().push((funding_txo, new_monitor));
 
                let ret = self.update_ret.lock().unwrap().clone();
                if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
-                       *self.update_ret.lock().unwrap() = next_ret;
+                       *self.update_ret.lock().unwrap() = Some(next_ret);
                }
-               ret
+               if ret.is_some() {
+                       assert!(update_res.is_ok());
+                       return ret.unwrap();
+               }
+               update_res
        }
 
        fn release_pending_monitor_events(&self) -> Vec<MonitorEvent> {
@@ -134,6 +142,30 @@ impl<'a> chain::Watch for TestChainMonitor<'a> {
        }
 }
 
+pub struct TestPersister {
+       pub update_ret: Mutex<Result<(), channelmonitor::ChannelMonitorUpdateErr>>
+}
+impl TestPersister {
+       pub fn new() -> Self {
+               Self {
+                       update_ret: Mutex::new(Ok(()))
+               }
+       }
+
+       pub fn set_update_ret(&self, ret: Result<(), channelmonitor::ChannelMonitorUpdateErr>) {
+               *self.update_ret.lock().unwrap() = ret;
+       }
+}
+impl channelmonitor::Persist<EnforcingChannelKeys> for TestPersister {
+       fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor<EnforcingChannelKeys>) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
+               self.update_ret.lock().unwrap().clone()
+       }
+
+       fn update_persisted_channel(&self, _funding_txo: OutPoint, _update: &channelmonitor::ChannelMonitorUpdate, _data: &channelmonitor::ChannelMonitor<EnforcingChannelKeys>) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
+               self.update_ret.lock().unwrap().clone()
+       }
+}
+
 pub struct TestBroadcaster {
        pub txn_broadcasted: Mutex<Vec<Transaction>>,
 }