Make it harder to forget to call CM::process_background_events
authorMatt Corallo <git@bluematt.me>
Mon, 11 Sep 2023 03:38:14 +0000 (03:38 +0000)
committerMatt Corallo <git@bluematt.me>
Tue, 12 Sep 2023 19:06:34 +0000 (19:06 +0000)
Prior to any actions which may generate a `ChannelMonitorUpdate`,
and in general after startup,
`ChannelManager::process_background_events` must be called. This is
mostly accomplished by doing so on taking the
`total_consistency_lock` via the `PersistenceNotifierGuard`. In
order to skip this call in block connection logic, the
`PersistenceNotifierGuard::optionally_notify` constructor did not
call the `process_background_events` method.

However, this is very easy to misuse - `optionally_notify` does not
convey to the reader that they need to call
`process_background_events` at all.

Here we fix this by adding a separate
`optionally_notify_skipping_background_events` method, making the
requirements much clearer to callers.

lightning/src/ln/channelmanager.rs

index 3d96db4bbc600c6ae5dd11e01f305873857e9b2a..f3dbd9c253fad5016c574671e39db238f75b131a 100644 (file)
@@ -1236,21 +1236,32 @@ struct PersistenceNotifierGuard<'a, F: Fn() -> NotifyOption> {
 
 impl<'a> PersistenceNotifierGuard<'a, fn() -> NotifyOption> { // We don't care what the concrete F is here, it's unused
        fn notify_on_drop<C: AChannelManager>(cm: &'a C) -> PersistenceNotifierGuard<'a, impl Fn() -> NotifyOption> {
+               Self::optionally_notify(cm, || -> NotifyOption { NotifyOption::DoPersist })
+       }
+
+       fn optionally_notify<F: Fn() -> NotifyOption, C: AChannelManager>(cm: &'a C, persist_check: F)
+       -> PersistenceNotifierGuard<'a, impl Fn() -> NotifyOption> {
                let read_guard = cm.get_cm().total_consistency_lock.read().unwrap();
-               let _ = cm.get_cm().process_background_events(); // We always persist
+               let force_notify = cm.get_cm().process_background_events();
 
                PersistenceNotifierGuard {
                        event_persist_notifier: &cm.get_cm().event_persist_notifier,
-                       should_persist: || -> NotifyOption { NotifyOption::DoPersist },
+                       should_persist: move || {
+                               // Pick the "most" action between `persist_check` and the background events
+                               // processing and return that.
+                               let notify = persist_check();
+                               if force_notify == NotifyOption::DoPersist { NotifyOption::DoPersist }
+                               else { notify }
+                       },
                        _read_guard: read_guard,
                }
-
        }
 
        /// Note that if any [`ChannelMonitorUpdate`]s are possibly generated,
-       /// [`ChannelManager::process_background_events`] MUST be called first.
-       fn optionally_notify<F: Fn() -> NotifyOption, C: AChannelManager>(cm: &'a C, persist_check: F)
-       -> PersistenceNotifierGuard<'a, F> {
+       /// [`ChannelManager::process_background_events`] MUST be called first (or
+       /// [`Self::optionally_notify`] used).
+       fn optionally_notify_skipping_background_events<F: Fn() -> NotifyOption, C: AChannelManager>
+       (cm: &'a C, persist_check: F) -> PersistenceNotifierGuard<'a, F> {
                let read_guard = cm.get_cm().total_consistency_lock.read().unwrap();
 
                PersistenceNotifierGuard {
@@ -4424,7 +4435,7 @@ where
        /// it wants to detect). Thus, we have a variant exposed here for its benefit.
        pub fn maybe_update_chan_fees(&self) {
                PersistenceNotifierGuard::optionally_notify(self, || {
-                       let mut should_persist = self.process_background_events();
+                       let mut should_persist = NotifyOption::SkipPersist;
 
                        let normal_feerate = self.fee_estimator.bounded_sat_per_1000_weight(ConfirmationTarget::Normal);
                        let min_mempool_feerate = self.fee_estimator.bounded_sat_per_1000_weight(ConfirmationTarget::MempoolMinimum);
@@ -4469,7 +4480,7 @@ where
        /// [`ChannelConfig`]: crate::util::config::ChannelConfig
        pub fn timer_tick_occurred(&self) {
                PersistenceNotifierGuard::optionally_notify(self, || {
-                       let mut should_persist = self.process_background_events();
+                       let mut should_persist = NotifyOption::SkipPersist;
 
                        let normal_feerate = self.fee_estimator.bounded_sat_per_1000_weight(ConfirmationTarget::Normal);
                        let min_mempool_feerate = self.fee_estimator.bounded_sat_per_1000_weight(ConfirmationTarget::MempoolMinimum);
@@ -7002,7 +7013,7 @@ where
        fn get_and_clear_pending_msg_events(&self) -> Vec<MessageSendEvent> {
                let events = RefCell::new(Vec::new());
                PersistenceNotifierGuard::optionally_notify(self, || {
-                       let mut result = self.process_background_events();
+                       let mut result = NotifyOption::SkipPersist;
 
                        // TODO: This behavior should be documented. It's unintuitive that we query
                        // ChannelMonitors when clearing other events.
@@ -7083,8 +7094,9 @@ where
        }
 
        fn block_disconnected(&self, header: &BlockHeader, height: u32) {
-               let _persistence_guard = PersistenceNotifierGuard::optionally_notify(self,
-                       || -> NotifyOption { NotifyOption::DoPersist });
+               let _persistence_guard =
+                       PersistenceNotifierGuard::optionally_notify_skipping_background_events(
+                               self, || -> NotifyOption { NotifyOption::DoPersist });
                let new_height = height - 1;
                {
                        let mut best_block = self.best_block.write().unwrap();
@@ -7118,8 +7130,9 @@ where
                let block_hash = header.block_hash();
                log_trace!(self.logger, "{} transactions included in block {} at height {} provided", txdata.len(), block_hash, height);
 
-               let _persistence_guard = PersistenceNotifierGuard::optionally_notify(self,
-                       || -> NotifyOption { NotifyOption::DoPersist });
+               let _persistence_guard =
+                       PersistenceNotifierGuard::optionally_notify_skipping_background_events(
+                               self, || -> NotifyOption { NotifyOption::DoPersist });
                self.do_chain_event(Some(height), |channel| channel.transactions_confirmed(&block_hash, height, txdata, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger)
                        .map(|(a, b)| (a, Vec::new(), b)));
 
@@ -7138,8 +7151,9 @@ where
                let block_hash = header.block_hash();
                log_trace!(self.logger, "New best block: {} at height {}", block_hash, height);
 
-               let _persistence_guard = PersistenceNotifierGuard::optionally_notify(self,
-                       || -> NotifyOption { NotifyOption::DoPersist });
+               let _persistence_guard =
+                       PersistenceNotifierGuard::optionally_notify_skipping_background_events(
+                               self, || -> NotifyOption { NotifyOption::DoPersist });
                *self.best_block.write().unwrap() = BestBlock::new(block_hash, height);
 
                self.do_chain_event(Some(height), |channel| channel.best_block_updated(height, header.time, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger));
@@ -7182,8 +7196,9 @@ where
        }
 
        fn transaction_unconfirmed(&self, txid: &Txid) {
-               let _persistence_guard = PersistenceNotifierGuard::optionally_notify(self,
-                       || -> NotifyOption { NotifyOption::DoPersist });
+               let _persistence_guard =
+                       PersistenceNotifierGuard::optionally_notify_skipping_background_events(
+                               self, || -> NotifyOption { NotifyOption::DoPersist });
                self.do_chain_event(None, |channel| {
                        if let Some(funding_txo) = channel.context.get_funding_txo() {
                                if funding_txo.txid == *txid {
@@ -7522,9 +7537,8 @@ where
 
        fn handle_channel_update(&self, counterparty_node_id: &PublicKey, msg: &msgs::ChannelUpdate) {
                PersistenceNotifierGuard::optionally_notify(self, || {
-                       let force_persist = self.process_background_events();
                        if let Ok(persist) = handle_error!(self, self.internal_channel_update(counterparty_node_id, msg), *counterparty_node_id) {
-                               if force_persist == NotifyOption::DoPersist { NotifyOption::DoPersist } else { persist }
+                               persist
                        } else {
                                NotifyOption::SkipPersist
                        }