X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Futil%2Ftest_utils.rs;h=15eaa7d4661cf8e6dd088d732370613d70c96ab4;hb=4500270488e6ed918c5f6e07310eb4a384eb6e21;hp=836e638fd92dc38ff1b513fc6e99effaca7da8c6;hpb=79541b11e8b6e62de0fc613f416e30bf1de5f3d9;p=rust-lightning diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index 836e638f..15eaa7d4 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -91,10 +91,6 @@ pub struct TestChainMonitor<'a> { pub latest_monitor_update_id: Mutex>, pub chain_monitor: chainmonitor::ChainMonitor>, pub keys_manager: &'a TestKeysInterface, - pub update_ret: Mutex>>, - /// If this is set to Some(), after the next return, we'll always return this until update_ret - /// is changed: - pub next_update_ret: Mutex>>, /// If this is set to Some(), the next update_channel call (not watch_channel) must be a /// ChannelForceClosed event for the given channel_id with should_broadcast set to the given /// boolean. @@ -107,8 +103,6 @@ impl<'a> TestChainMonitor<'a> { latest_monitor_update_id: Mutex::new(HashMap::new()), chain_monitor: chainmonitor::ChainMonitor::new(chain_source, broadcaster, logger, fee_estimator, persister), keys_manager, - update_ret: Mutex::new(None), - next_update_ret: Mutex::new(None), expect_channel_force_closed: Mutex::new(None), } } @@ -124,17 +118,7 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { assert!(new_monitor == monitor); self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, monitor.get_latest_update_id())); self.added_monitors.lock().unwrap().push((funding_txo, monitor)); - let watch_res = self.chain_monitor.watch_channel(funding_txo, new_monitor); - - let ret = self.update_ret.lock().unwrap().clone(); - if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { - *self.update_ret.lock().unwrap() = Some(next_ret); - } - if ret.is_some() { - assert!(watch_res.is_ok()); - return ret.unwrap(); - } - watch_res + self.chain_monitor.watch_channel(funding_txo, new_monitor) } fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> Result<(), chain::ChannelMonitorUpdateErr> { @@ -163,15 +147,6 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { &mut io::Cursor::new(&w.0), self.keys_manager).unwrap().1; assert!(new_monitor == *monitor); self.added_monitors.lock().unwrap().push((funding_txo, new_monitor)); - - let ret = self.update_ret.lock().unwrap().clone(); - if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { - *self.update_ret.lock().unwrap() = Some(next_ret); - } - if ret.is_some() { - assert!(update_res.is_ok()); - return ret.unwrap(); - } update_res } @@ -181,26 +156,43 @@ impl<'a> chain::Watch for TestChainMonitor<'a> { } pub struct TestPersister { - pub update_ret: Mutex> + pub update_ret: Mutex>, + /// If this is set to Some(), after the next return, we'll always return this until update_ret + /// is changed: + pub next_update_ret: Mutex>>, + } impl TestPersister { pub fn new() -> Self { Self { - update_ret: Mutex::new(Ok(())) + update_ret: Mutex::new(Ok(())), + next_update_ret: Mutex::new(None), } } pub fn set_update_ret(&self, ret: Result<(), chain::ChannelMonitorUpdateErr>) { *self.update_ret.lock().unwrap() = ret; } + + pub fn set_next_update_ret(&self, next_ret: Option>) { + *self.next_update_ret.lock().unwrap() = next_ret; + } } impl chainmonitor::Persist for TestPersister { fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor) -> Result<(), chain::ChannelMonitorUpdateErr> { - self.update_ret.lock().unwrap().clone() + let ret = self.update_ret.lock().unwrap().clone(); + if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { + *self.update_ret.lock().unwrap() = next_ret; + } + ret } fn update_persisted_channel(&self, _funding_txo: OutPoint, _update: &channelmonitor::ChannelMonitorUpdate, _data: &channelmonitor::ChannelMonitor) -> Result<(), chain::ChannelMonitorUpdateErr> { - self.update_ret.lock().unwrap().clone() + let ret = self.update_ret.lock().unwrap().clone(); + if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() { + *self.update_ret.lock().unwrap() = next_ret; + } + ret } }