Use Persister to return errors in tests not chain::Watch
[rust-lightning] / lightning / src / util / test_utils.rs
index 836e638fd92dc38ff1b513fc6e99effaca7da8c6..15eaa7d4661cf8e6dd088d732370613d70c96ab4 100644 (file)
@@ -91,10 +91,6 @@ pub struct TestChainMonitor<'a> {
        pub latest_monitor_update_id: Mutex<HashMap<[u8; 32], (OutPoint, u64)>>,
        pub chain_monitor: chainmonitor::ChainMonitor<EnforcingSigner, &'a TestChainSource, &'a chaininterface::BroadcasterInterface, &'a TestFeeEstimator, &'a TestLogger, &'a chainmonitor::Persist<EnforcingSigner>>,
        pub keys_manager: &'a TestKeysInterface,
-       pub update_ret: Mutex<Option<Result<(), chain::ChannelMonitorUpdateErr>>>,
-       /// If this is set to Some(), after the next return, we'll always return this until update_ret
-       /// is changed:
-       pub next_update_ret: Mutex<Option<Result<(), chain::ChannelMonitorUpdateErr>>>,
        /// If this is set to Some(), the next update_channel call (not watch_channel) must be a
        /// ChannelForceClosed event for the given channel_id with should_broadcast set to the given
        /// boolean.
@@ -107,8 +103,6 @@ impl<'a> TestChainMonitor<'a> {
                        latest_monitor_update_id: Mutex::new(HashMap::new()),
                        chain_monitor: chainmonitor::ChainMonitor::new(chain_source, broadcaster, logger, fee_estimator, persister),
                        keys_manager,
-                       update_ret: Mutex::new(None),
-                       next_update_ret: Mutex::new(None),
                        expect_channel_force_closed: Mutex::new(None),
                }
        }
@@ -124,17 +118,7 @@ impl<'a> chain::Watch<EnforcingSigner> for TestChainMonitor<'a> {
                assert!(new_monitor == monitor);
                self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, monitor.get_latest_update_id()));
                self.added_monitors.lock().unwrap().push((funding_txo, monitor));
-               let watch_res = self.chain_monitor.watch_channel(funding_txo, new_monitor);
-
-               let ret = self.update_ret.lock().unwrap().clone();
-               if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
-                       *self.update_ret.lock().unwrap() = Some(next_ret);
-               }
-               if ret.is_some() {
-                       assert!(watch_res.is_ok());
-                       return ret.unwrap();
-               }
-               watch_res
+               self.chain_monitor.watch_channel(funding_txo, new_monitor)
        }
 
        fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> Result<(), chain::ChannelMonitorUpdateErr> {
@@ -163,15 +147,6 @@ impl<'a> chain::Watch<EnforcingSigner> for TestChainMonitor<'a> {
                        &mut io::Cursor::new(&w.0), self.keys_manager).unwrap().1;
                assert!(new_monitor == *monitor);
                self.added_monitors.lock().unwrap().push((funding_txo, new_monitor));
-
-               let ret = self.update_ret.lock().unwrap().clone();
-               if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
-                       *self.update_ret.lock().unwrap() = Some(next_ret);
-               }
-               if ret.is_some() {
-                       assert!(update_res.is_ok());
-                       return ret.unwrap();
-               }
                update_res
        }
 
@@ -181,26 +156,43 @@ impl<'a> chain::Watch<EnforcingSigner> for TestChainMonitor<'a> {
 }
 
 pub struct TestPersister {
-       pub update_ret: Mutex<Result<(), chain::ChannelMonitorUpdateErr>>
+       pub update_ret: Mutex<Result<(), chain::ChannelMonitorUpdateErr>>,
+       /// If this is set to Some(), after the next return, we'll always return this until update_ret
+       /// is changed:
+       pub next_update_ret: Mutex<Option<Result<(), chain::ChannelMonitorUpdateErr>>>,
+
 }
 impl TestPersister {
        pub fn new() -> Self {
                Self {
-                       update_ret: Mutex::new(Ok(()))
+                       update_ret: Mutex::new(Ok(())),
+                       next_update_ret: Mutex::new(None),
                }
        }
 
        pub fn set_update_ret(&self, ret: Result<(), chain::ChannelMonitorUpdateErr>) {
                *self.update_ret.lock().unwrap() = ret;
        }
+
+       pub fn set_next_update_ret(&self, next_ret: Option<Result<(), chain::ChannelMonitorUpdateErr>>) {
+               *self.next_update_ret.lock().unwrap() = next_ret;
+       }
 }
 impl<Signer: keysinterface::Sign> chainmonitor::Persist<Signer> for TestPersister {
        fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor<Signer>) -> Result<(), chain::ChannelMonitorUpdateErr> {
-               self.update_ret.lock().unwrap().clone()
+               let ret = self.update_ret.lock().unwrap().clone();
+               if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
+                       *self.update_ret.lock().unwrap() = next_ret;
+               }
+               ret
        }
 
        fn update_persisted_channel(&self, _funding_txo: OutPoint, _update: &channelmonitor::ChannelMonitorUpdate, _data: &channelmonitor::ChannelMonitor<Signer>) -> Result<(), chain::ChannelMonitorUpdateErr> {
-               self.update_ret.lock().unwrap().clone()
+               let ret = self.update_ret.lock().unwrap().clone();
+               if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
+                       *self.update_ret.lock().unwrap() = next_ret;
+               }
+               ret
        }
 }