Make `ChainMonitor::monitors` private and expose monitor via getter
[rust-lightning] / lightning / src / chain / chainmonitor.rs
index 8126e4d52276efee0f8466297fea0842d31e5d5a..2de5540059c38e630919a3db76bb3215b6a5399f 100644 (file)
@@ -38,7 +38,7 @@ use util::events::EventHandler;
 use ln::channelmanager::ChannelDetails;
 
 use prelude::*;
-use sync::RwLock;
+use sync::{RwLock, RwLockReadGuard};
 use core::ops::Deref;
 
 /// `Persist` defines behavior for persisting channel monitors: this could mean
@@ -92,6 +92,26 @@ pub trait Persist<ChannelSigner: Sign> {
        fn update_persisted_channel(&self, id: OutPoint, update: &ChannelMonitorUpdate, data: &ChannelMonitor<ChannelSigner>) -> Result<(), ChannelMonitorUpdateErr>;
 }
 
+struct MonitorHolder<ChannelSigner: Sign> {
+       monitor: ChannelMonitor<ChannelSigner>,
+}
+
+/// A read-only reference to a current ChannelMonitor.
+///
+/// Note that this holds a mutex in [`ChainMonitor`] and may block other events until it is
+/// released.
+pub struct LockedChannelMonitor<'a, ChannelSigner: Sign> {
+       lock: RwLockReadGuard<'a, HashMap<OutPoint, MonitorHolder<ChannelSigner>>>,
+       funding_txo: OutPoint,
+}
+
+impl<ChannelSigner: Sign> Deref for LockedChannelMonitor<'_, ChannelSigner> {
+       type Target = ChannelMonitor<ChannelSigner>;
+       fn deref(&self) -> &ChannelMonitor<ChannelSigner> {
+               &self.lock.get(&self.funding_txo).expect("Checked at construction").monitor
+       }
+}
+
 /// An implementation of [`chain::Watch`] for monitoring channels.
 ///
 /// Connected and disconnected blocks must be provided to `ChainMonitor` as documented by
@@ -108,8 +128,7 @@ pub struct ChainMonitor<ChannelSigner: Sign, C: Deref, T: Deref, F: Deref, L: De
         L::Target: Logger,
         P::Target: Persist<ChannelSigner>,
 {
-       /// The monitors
-       pub monitors: RwLock<HashMap<OutPoint, ChannelMonitor<ChannelSigner>>>,
+       monitors: RwLock<HashMap<OutPoint, MonitorHolder<ChannelSigner>>>,
        chain_source: Option<C>,
        broadcaster: T,
        logger: L,
@@ -138,9 +157,9 @@ where C::Target: chain::Filter,
                FN: Fn(&ChannelMonitor<ChannelSigner>, &TransactionData) -> Vec<TransactionOutputs>
        {
                let mut dependent_txdata = Vec::new();
-               let monitors = self.monitors.read().unwrap();
-               for monitor in monitors.values() {
-                       let mut txn_outputs = process(monitor, txdata);
+               let monitor_states = self.monitors.read().unwrap();
+               for monitor_state in monitor_states.values() {
+                       let mut txn_outputs = process(&monitor_state.monitor, txdata);
 
                        // Register any new outputs with the chain source for filtering, storing any dependent
                        // transactions from within the block that previously had not been included in txdata.
@@ -202,8 +221,8 @@ where C::Target: chain::Filter,
        /// inclusion in the return value.
        pub fn get_claimable_balances(&self, ignored_channels: &[&ChannelDetails]) -> Vec<Balance> {
                let mut ret = Vec::new();
-               let monitors = self.monitors.read().unwrap();
-               for (_, monitor) in monitors.iter().filter(|(funding_outpoint, _)| {
+               let monitor_states = self.monitors.read().unwrap();
+               for (_, monitor_state) in monitor_states.iter().filter(|(funding_outpoint, _)| {
                        for chan in ignored_channels {
                                if chan.funding_txo.as_ref() == Some(funding_outpoint) {
                                        return false;
@@ -211,11 +230,38 @@ where C::Target: chain::Filter,
                        }
                        true
                }) {
-                       ret.append(&mut monitor.get_claimable_balances());
+                       ret.append(&mut monitor_state.monitor.get_claimable_balances());
                }
                ret
        }
 
+       /// Gets the [`LockedChannelMonitor`] for a given funding outpoint, returning an `Err` if no
+       /// such [`ChannelMonitor`] is currently being monitored for.
+       ///
+       /// Note that the result holds a mutex over our monitor set, and should not be held
+       /// indefinitely.
+       pub fn get_monitor(&self, funding_txo: OutPoint) -> Result<LockedChannelMonitor<'_, ChannelSigner>, ()> {
+               let lock = self.monitors.read().unwrap();
+               if lock.get(&funding_txo).is_some() {
+                       Ok(LockedChannelMonitor { lock, funding_txo })
+               } else {
+                       Err(())
+               }
+       }
+
+       /// Lists the funding outpoint of each [`ChannelMonitor`] being monitored.
+       ///
+       /// Note that [`ChannelMonitor`]s are not removed when a channel is closed as they are always
+       /// monitoring for on-chain state resolutions.
+       pub fn list_monitors(&self) -> Vec<OutPoint> {
+               self.monitors.read().unwrap().keys().map(|outpoint| *outpoint).collect()
+       }
+
+       #[cfg(test)]
+       pub fn remove_monitor(&self, funding_txo: &OutPoint) -> ChannelMonitor<ChannelSigner> {
+               self.monitors.write().unwrap().remove(funding_txo).unwrap().monitor
+       }
+
        #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))]
        pub fn get_and_clear_pending_events(&self) -> Vec<events::Event> {
                use util::events::EventsProvider;
@@ -246,10 +292,10 @@ where
        }
 
        fn block_disconnected(&self, header: &BlockHeader, height: u32) {
-               let monitors = self.monitors.read().unwrap();
+               let monitor_states = self.monitors.read().unwrap();
                log_debug!(self.logger, "Latest block {} at height {} removed via block_disconnected", header.block_hash(), height);
-               for monitor in monitors.values() {
-                       monitor.block_disconnected(
+               for monitor_state in monitor_states.values() {
+                       monitor_state.monitor.block_disconnected(
                                header, height, &*self.broadcaster, &*self.fee_estimator, &*self.logger);
                }
        }
@@ -274,9 +320,9 @@ where
 
        fn transaction_unconfirmed(&self, txid: &Txid) {
                log_debug!(self.logger, "Transaction {} reorganized out of chain", txid);
-               let monitors = self.monitors.read().unwrap();
-               for monitor in monitors.values() {
-                       monitor.transaction_unconfirmed(txid, &*self.broadcaster, &*self.fee_estimator, &*self.logger);
+               let monitor_states = self.monitors.read().unwrap();
+               for monitor_state in monitor_states.values() {
+                       monitor_state.monitor.transaction_unconfirmed(txid, &*self.broadcaster, &*self.fee_estimator, &*self.logger);
                }
        }
 
@@ -293,9 +339,9 @@ where
 
        fn get_relevant_txids(&self) -> Vec<Txid> {
                let mut txids = Vec::new();
-               let monitors = self.monitors.read().unwrap();
-               for monitor in monitors.values() {
-                       txids.append(&mut monitor.get_relevant_txids());
+               let monitor_states = self.monitors.read().unwrap();
+               for monitor_state in monitor_states.values() {
+                       txids.append(&mut monitor_state.monitor.get_relevant_txids());
                }
 
                txids.sort_unstable();
@@ -338,7 +384,7 @@ where C::Target: chain::Filter,
                                monitor.load_outputs_to_watch(chain_source);
                        }
                }
-               entry.insert(monitor);
+               entry.insert(MonitorHolder { monitor });
                Ok(())
        }
 
@@ -359,7 +405,8 @@ where C::Target: chain::Filter,
                                #[cfg(not(any(test, feature = "fuzztarget")))]
                                Err(ChannelMonitorUpdateErr::PermanentFailure)
                        },
-                       Some(monitor) => {
+                       Some(monitor_state) => {
+                               let monitor = &monitor_state.monitor;
                                log_trace!(self.logger, "Updating Channel Monitor for channel {}", log_funding_info!(monitor));
                                let update_res = monitor.update_monitor(&update, &self.broadcaster, &self.fee_estimator, &self.logger);
                                if let Err(e) = &update_res {
@@ -382,8 +429,8 @@ where C::Target: chain::Filter,
 
        fn release_pending_monitor_events(&self) -> Vec<MonitorEvent> {
                let mut pending_monitor_events = Vec::new();
-               for monitor in self.monitors.read().unwrap().values() {
-                       pending_monitor_events.append(&mut monitor.get_and_clear_pending_monitor_events());
+               for monitor_state in self.monitors.read().unwrap().values() {
+                       pending_monitor_events.append(&mut monitor_state.monitor.get_and_clear_pending_monitor_events());
                }
                pending_monitor_events
        }
@@ -404,8 +451,8 @@ impl<ChannelSigner: Sign, C: Deref, T: Deref, F: Deref, L: Deref, P: Deref> even
        /// [`SpendableOutputs`]: events::Event::SpendableOutputs
        fn process_pending_events<H: Deref>(&self, handler: H) where H::Target: EventHandler {
                let mut pending_events = Vec::new();
-               for monitor in self.monitors.read().unwrap().values() {
-                       pending_events.append(&mut monitor.get_and_clear_pending_events());
+               for monitor_state in self.monitors.read().unwrap().values() {
+                       pending_events.append(&mut monitor_state.monitor.get_and_clear_pending_events());
                }
                for event in pending_events.drain(..) {
                        handler.handle_event(&event);