Merge pull request #2323 from ariard/2023-05-remove-ariard-pgp-key
[rust-lightning] / lightning-background-processor / src / lib.rs
index 9d13facad8110490ddbbf12bb40bc7747881d5a5..4d270286d5027c780d36a16a58711d11e72dcf47 100644 (file)
@@ -108,7 +108,7 @@ const PING_TIMER: u64 = 1;
 const NETWORK_PRUNE_TIMER: u64 = 60 * 60;
 
 #[cfg(not(test))]
-const SCORER_PERSIST_TIMER: u64 = 30;
+const SCORER_PERSIST_TIMER: u64 = 60 * 60;
 #[cfg(test)]
 const SCORER_PERSIST_TIMER: u64 = 1;
 
@@ -236,9 +236,11 @@ fn handle_network_graph_update<L: Deref>(
        }
 }
 
+/// Updates scorer based on event and returns whether an update occurred so we can decide whether
+/// to persist.
 fn update_scorer<'a, S: 'static + Deref<Target = SC> + Send + Sync, SC: 'a + WriteableScore<'a>>(
        scorer: &'a S, event: &Event
-) {
+) -> bool {
        let mut score = scorer.lock();
        match event {
                Event::PaymentPathFailed { ref path, short_channel_id: Some(scid), .. } => {
@@ -258,8 +260,9 @@ fn update_scorer<'a, S: 'static + Deref<Target = SC> + Send + Sync, SC: 'a + Wri
                Event::ProbeFailed { path, short_channel_id: Some(scid), .. } => {
                        score.probe_failed(path, *scid);
                },
-               _ => {},
+               _ => return false,
        }
+       true
 }
 
 macro_rules! define_run_body {
@@ -352,9 +355,15 @@ macro_rules! define_run_body {
                        // Note that we want to run a graph prune once not long after startup before
                        // falling back to our usual hourly prunes. This avoids short-lived clients never
                        // pruning their network graph. We run once 60 seconds after startup before
-                       // continuing our normal cadence.
+                       // continuing our normal cadence. For RGS, since 60 seconds is likely too long,
+                       // we prune after an initial sync completes.
                        let prune_timer = if have_pruned { NETWORK_PRUNE_TIMER } else { FIRST_NETWORK_PRUNE_TIMER };
-                       if $timer_elapsed(&mut last_prune_call, prune_timer) {
+                       let prune_timer_elapsed = $timer_elapsed(&mut last_prune_call, prune_timer);
+                       let should_prune = match $gossip_sync {
+                               GossipSync::Rapid(_) => !have_pruned || prune_timer_elapsed,
+                               _ => prune_timer_elapsed,
+                       };
+                       if should_prune {
                                // The network graph must not be pruned while rapid sync completion is pending
                                if let Some(network_graph) = $gossip_sync.prunable_network_graph() {
                                        #[cfg(feature = "std")] {
@@ -616,12 +625,19 @@ where
                let network_graph = gossip_sync.network_graph();
                let event_handler = &event_handler;
                let scorer = &scorer;
+               let logger = &logger;
+               let persister = &persister;
                async move {
                        if let Some(network_graph) = network_graph {
                                handle_network_graph_update(network_graph, &event)
                        }
                        if let Some(ref scorer) = scorer {
-                               update_scorer(scorer, &event);
+                               if update_scorer(scorer, &event) {
+                                       log_trace!(logger, "Persisting scorer after update");
+                                       if let Err(e) = persister.persist_scorer(&scorer) {
+                                               log_error!(logger, "Error: Failed to persist scorer, check your disk and permissions {}", e)
+                                       }
+                               }
                        }
                        event_handler(event).await;
                }
@@ -751,7 +767,12 @@ impl BackgroundProcessor {
                                        handle_network_graph_update(network_graph, &event)
                                }
                                if let Some(ref scorer) = scorer {
-                                       update_scorer(scorer, &event);
+                                       if update_scorer(scorer, &event) {
+                                               log_trace!(logger, "Persisting scorer after update");
+                                               if let Err(e) = persister.persist_scorer(&scorer) {
+                                                       log_error!(logger, "Error: Failed to persist scorer, check your disk and permissions {}", e)
+                                               }
+                                       }
                                }
                                event_handler.handle_event(event);
                        };
@@ -817,7 +838,6 @@ impl Drop for BackgroundProcessor {
 
 #[cfg(all(feature = "std", test))]
 mod tests {
-       use bitcoin::blockdata::block::BlockHeader;
        use bitcoin::blockdata::constants::genesis_block;
        use bitcoin::blockdata::locktime::PackedLockTime;
        use bitcoin::blockdata::transaction::{Transaction, TxOut};
@@ -833,6 +853,7 @@ mod tests {
        use lightning::ln::channelmanager;
        use lightning::ln::channelmanager::{BREAKDOWN_TIMEOUT, ChainParameters, MIN_CLTV_EXPIRY_DELTA, PaymentId};
        use lightning::ln::features::{ChannelFeatures, NodeFeatures};
+       use lightning::ln::functional_test_utils::*;
        use lightning::ln::msgs::{ChannelMessageHandler, Init};
        use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
        use lightning::routing::gossip::{NetworkGraph, NodeId, P2PGossipSync};
@@ -849,8 +870,6 @@ mod tests {
        use std::sync::{Arc, Mutex};
        use std::sync::mpsc::SyncSender;
        use std::time::Duration;
-       use bitcoin::hashes::Hash;
-       use bitcoin::TxMerkleNode;
        use lightning_rapid_gossip_sync::RapidGossipSync;
        use super::{BackgroundProcessor, GossipSync, FRESHNESS_TIMER};
 
@@ -866,7 +885,7 @@ mod tests {
                fn disconnect_socket(&mut self) {}
        }
 
-       type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter< Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>>>, Arc<test_utils::TestLogger>>;
+       type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter<Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>, (), TestScorer>>, Arc<test_utils::TestLogger>>;
 
        type ChainMonitor = chainmonitor::ChainMonitor<InMemorySigner, Arc<test_utils::TestChainSource>, Arc<test_utils::TestBroadcaster>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>, Arc<FilesystemPersister>>;
 
@@ -1000,8 +1019,9 @@ mod tests {
        }
 
        impl Score for TestScorer {
+               type ScoreParams = ();
                fn channel_penalty_msat(
-                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, _usage: ChannelUsage
+                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, _usage: ChannelUsage, _score_params: &Self::ScoreParams
                ) -> u64 { unimplemented!(); }
 
                fn payment_path_failed(&mut self, actual_path: &Path, actual_short_channel_id: u64) {
@@ -1114,7 +1134,7 @@ mod tests {
                        let network_graph = Arc::new(NetworkGraph::new(network, logger.clone()));
                        let scorer = Arc::new(Mutex::new(TestScorer::new()));
                        let seed = [i as u8; 32];
-                       let router = Arc::new(DefaultRouter::new(network_graph.clone(), logger.clone(), seed, scorer.clone()));
+                       let router = Arc::new(DefaultRouter::new(network_graph.clone(), logger.clone(), seed, scorer.clone(), ()));
                        let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
                        let persister = Arc::new(FilesystemPersister::new(format!("{}_persister_{}", &persist_dir, i)));
                        let now = Duration::from_secs(genesis_block.header.time as u64);
@@ -1189,7 +1209,7 @@ mod tests {
                for i in 1..=depth {
                        let prev_blockhash = node.best_block.block_hash();
                        let height = node.best_block.height() + 1;
-                       let header = BlockHeader { version: 0x20000000, prev_blockhash, merkle_root: TxMerkleNode::all_zeros(), time: height, bits: 42, nonce: 42 };
+                       let header = create_dummy_header(prev_blockhash, height);
                        let txdata = vec![(0, tx)];
                        node.best_block = BestBlock::new(header.block_hash(), height);
                        match i {
@@ -1709,6 +1729,10 @@ mod tests {
                if !std::thread::panicking() {
                        bg_processor.stop().unwrap();
                }
+
+               let log_entries = nodes[0].logger.lines.lock().unwrap();
+               let expected_log = "Persisting scorer after update".to_string();
+               assert_eq!(*log_entries.get(&("lightning_background_processor".to_string(), expected_log)).unwrap(), 5);
        }
 
        #[tokio::test]
@@ -1751,6 +1775,10 @@ mod tests {
                let t2 = tokio::spawn(async move {
                        do_test_payment_path_scoring!(nodes, receiver.recv().await);
                        exit_sender.send(()).unwrap();
+
+                       let log_entries = nodes[0].logger.lines.lock().unwrap();
+                       let expected_log = "Persisting scorer after update".to_string();
+                       assert_eq!(*log_entries.get(&("lightning_background_processor".to_string(), expected_log)).unwrap(), 5);
                });
 
                let (r1, r2) = tokio::join!(t1, t2);