Merge pull request #2120 from valentinewallace/2023-03-blinded-pathfinding
[rust-lightning] / lightning-background-processor / src / lib.rs
index a9a69de7a3a002e7a15eea111b9646e67c7e3180..0cfa9801badb3a8d31a30cbdce81b93cbaab1f5a 100644 (file)
@@ -25,7 +25,7 @@ extern crate lightning_rapid_gossip_sync;
 use lightning::chain;
 use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
 use lightning::chain::chainmonitor::{ChainMonitor, Persist};
-use lightning::chain::keysinterface::{EntropySource, NodeSigner, SignerProvider};
+use lightning::sign::{EntropySource, NodeSigner, SignerProvider};
 use lightning::events::{Event, PathFailure};
 #[cfg(feature = "std")]
 use lightning::events::{EventHandler, EventsProvider};
@@ -108,7 +108,7 @@ const PING_TIMER: u64 = 1;
 const NETWORK_PRUNE_TIMER: u64 = 60 * 60;
 
 #[cfg(not(test))]
-const SCORER_PERSIST_TIMER: u64 = 30;
+const SCORER_PERSIST_TIMER: u64 = 60 * 60;
 #[cfg(test)]
 const SCORER_PERSIST_TIMER: u64 = 1;
 
@@ -236,9 +236,11 @@ fn handle_network_graph_update<L: Deref>(
        }
 }
 
+/// Updates scorer based on event and returns whether an update occurred so we can decide whether
+/// to persist.
 fn update_scorer<'a, S: 'static + Deref<Target = SC> + Send + Sync, SC: 'a + WriteableScore<'a>>(
        scorer: &'a S, event: &Event
-) {
+) -> bool {
        let mut score = scorer.lock();
        match event {
                Event::PaymentPathFailed { ref path, short_channel_id: Some(scid), .. } => {
@@ -258,8 +260,9 @@ fn update_scorer<'a, S: 'static + Deref<Target = SC> + Send + Sync, SC: 'a + Wri
                Event::ProbeFailed { path, short_channel_id: Some(scid), .. } => {
                        score.probe_failed(path, *scid);
                },
-               _ => {},
+               _ => return false,
        }
+       true
 }
 
 macro_rules! define_run_body {
@@ -352,9 +355,15 @@ macro_rules! define_run_body {
                        // Note that we want to run a graph prune once not long after startup before
                        // falling back to our usual hourly prunes. This avoids short-lived clients never
                        // pruning their network graph. We run once 60 seconds after startup before
-                       // continuing our normal cadence.
+                       // continuing our normal cadence. For RGS, since 60 seconds is likely too long,
+                       // we prune after an initial sync completes.
                        let prune_timer = if have_pruned { NETWORK_PRUNE_TIMER } else { FIRST_NETWORK_PRUNE_TIMER };
-                       if $timer_elapsed(&mut last_prune_call, prune_timer) {
+                       let prune_timer_elapsed = $timer_elapsed(&mut last_prune_call, prune_timer);
+                       let should_prune = match $gossip_sync {
+                               GossipSync::Rapid(_) => !have_pruned || prune_timer_elapsed,
+                               _ => prune_timer_elapsed,
+                       };
+                       if should_prune {
                                // The network graph must not be pruned while rapid sync completion is pending
                                if let Some(network_graph) = $gossip_sync.prunable_network_graph() {
                                        #[cfg(feature = "std")] {
@@ -506,12 +515,13 @@ use core::task;
 /// # use lightning_background_processor::{process_events_async, GossipSync};
 /// # type MyBroadcaster = dyn lightning::chain::chaininterface::BroadcasterInterface + Send + Sync;
 /// # type MyFeeEstimator = dyn lightning::chain::chaininterface::FeeEstimator + Send + Sync;
-/// # type MyNodeSigner = dyn lightning::chain::keysinterface::NodeSigner + Send + Sync;
+/// # type MyNodeSigner = dyn lightning::sign::NodeSigner + Send + Sync;
 /// # type MyUtxoLookup = dyn lightning::routing::utxo::UtxoLookup + Send + Sync;
 /// # type MyFilter = dyn lightning::chain::Filter + Send + Sync;
 /// # type MyLogger = dyn lightning::util::logger::Logger + Send + Sync;
-/// # type MyChainMonitor = lightning::chain::chainmonitor::ChainMonitor<lightning::chain::keysinterface::InMemorySigner, Arc<MyFilter>, Arc<MyBroadcaster>, Arc<MyFeeEstimator>, Arc<MyLogger>, Arc<MyPersister>>;
-/// # type MyPeerManager = lightning::ln::peer_handler::SimpleArcPeerManager<MySocketDescriptor, MyChainMonitor, MyBroadcaster, MyFeeEstimator, MyUtxoLookup, MyLogger>;
+/// # type MyMessageRouter = dyn lightning::onion_message::MessageRouter + Send + Sync;
+/// # type MyChainMonitor = lightning::chain::chainmonitor::ChainMonitor<lightning::sign::InMemorySigner, Arc<MyFilter>, Arc<MyBroadcaster>, Arc<MyFeeEstimator>, Arc<MyLogger>, Arc<MyPersister>>;
+/// # type MyPeerManager = lightning::ln::peer_handler::SimpleArcPeerManager<MySocketDescriptor, MyChainMonitor, MyBroadcaster, MyFeeEstimator, MyUtxoLookup, MyLogger, MyMessageRouter>;
 /// # type MyNetworkGraph = lightning::routing::gossip::NetworkGraph<Arc<MyLogger>>;
 /// # type MyGossipSync = lightning::routing::gossip::P2PGossipSync<Arc<MyNetworkGraph>, Arc<MyUtxoLookup>, Arc<MyLogger>>;
 /// # type MyChannelManager = lightning::ln::channelmanager::SimpleArcChannelManager<MyChainMonitor, MyBroadcaster, MyFeeEstimator, MyLogger>;
@@ -616,12 +626,19 @@ where
                let network_graph = gossip_sync.network_graph();
                let event_handler = &event_handler;
                let scorer = &scorer;
+               let logger = &logger;
+               let persister = &persister;
                async move {
                        if let Some(network_graph) = network_graph {
                                handle_network_graph_update(network_graph, &event)
                        }
                        if let Some(ref scorer) = scorer {
-                               update_scorer(scorer, &event);
+                               if update_scorer(scorer, &event) {
+                                       log_trace!(logger, "Persisting scorer after update");
+                                       if let Err(e) = persister.persist_scorer(&scorer) {
+                                               log_error!(logger, "Error: Failed to persist scorer, check your disk and permissions {}", e)
+                                       }
+                               }
                        }
                        event_handler(event).await;
                }
@@ -751,7 +768,12 @@ impl BackgroundProcessor {
                                        handle_network_graph_update(network_graph, &event)
                                }
                                if let Some(ref scorer) = scorer {
-                                       update_scorer(scorer, &event);
+                                       if update_scorer(scorer, &event) {
+                                               log_trace!(logger, "Persisting scorer after update");
+                                               if let Err(e) = persister.persist_scorer(&scorer) {
+                                                       log_error!(logger, "Error: Failed to persist scorer, check your disk and permissions {}", e)
+                                               }
+                                       }
                                }
                                event_handler.handle_event(event);
                        };
@@ -817,15 +839,14 @@ impl Drop for BackgroundProcessor {
 
 #[cfg(all(feature = "std", test))]
 mod tests {
-       use bitcoin::blockdata::block::BlockHeader;
-       use bitcoin::blockdata::constants::genesis_block;
+       use bitcoin::blockdata::constants::{genesis_block, ChainHash};
        use bitcoin::blockdata::locktime::PackedLockTime;
        use bitcoin::blockdata::transaction::{Transaction, TxOut};
        use bitcoin::network::constants::Network;
        use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1};
        use lightning::chain::{BestBlock, Confirm, chainmonitor};
        use lightning::chain::channelmonitor::ANTI_REORG_DELAY;
-       use lightning::chain::keysinterface::{InMemorySigner, KeysManager};
+       use lightning::sign::{InMemorySigner, KeysManager};
        use lightning::chain::transaction::OutPoint;
        use lightning::events::{Event, PathFailure, MessageSendEventsProvider, MessageSendEvent};
        use lightning::{get_event_msg, get_event};
@@ -833,6 +854,7 @@ mod tests {
        use lightning::ln::channelmanager;
        use lightning::ln::channelmanager::{BREAKDOWN_TIMEOUT, ChainParameters, MIN_CLTV_EXPIRY_DELTA, PaymentId};
        use lightning::ln::features::{ChannelFeatures, NodeFeatures};
+       use lightning::ln::functional_test_utils::*;
        use lightning::ln::msgs::{ChannelMessageHandler, Init};
        use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
        use lightning::routing::gossip::{NetworkGraph, NodeId, P2PGossipSync};
@@ -849,8 +871,6 @@ mod tests {
        use std::sync::{Arc, Mutex};
        use std::sync::mpsc::SyncSender;
        use std::time::Duration;
-       use bitcoin::hashes::Hash;
-       use bitcoin::TxMerkleNode;
        use lightning_rapid_gossip_sync::RapidGossipSync;
        use super::{BackgroundProcessor, GossipSync, FRESHNESS_TIMER};
 
@@ -866,7 +886,7 @@ mod tests {
                fn disconnect_socket(&mut self) {}
        }
 
-       type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter< Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>>>, Arc<test_utils::TestLogger>>;
+       type ChannelManager = channelmanager::ChannelManager<Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<KeysManager>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<DefaultRouter<Arc<NetworkGraph<Arc<test_utils::TestLogger>>>, Arc<test_utils::TestLogger>, Arc<Mutex<TestScorer>>, (), TestScorer>>, Arc<test_utils::TestLogger>>;
 
        type ChainMonitor = chainmonitor::ChainMonitor<InMemorySigner, Arc<test_utils::TestChainSource>, Arc<test_utils::TestBroadcaster>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>, Arc<FilesystemPersister>>;
 
@@ -1000,8 +1020,9 @@ mod tests {
        }
 
        impl Score for TestScorer {
+               type ScoreParams = ();
                fn channel_penalty_msat(
-                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, _usage: ChannelUsage
+                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId, _usage: ChannelUsage, _score_params: &Self::ScoreParams
                ) -> u64 { unimplemented!(); }
 
                fn payment_path_failed(&mut self, actual_path: &Path, actual_short_channel_id: u64) {
@@ -1104,7 +1125,7 @@ mod tests {
        fn create_nodes(num_nodes: usize, persist_dir: &str) -> (String, Vec<Node>) {
                let persist_temp_path = env::temp_dir().join(persist_dir);
                let persist_dir = persist_temp_path.to_string_lossy().to_string();
-               let network = Network::Testnet;
+               let network = Network::Bitcoin;
                let mut nodes = Vec::new();
                for i in 0..num_nodes {
                        let tx_broadcaster = Arc::new(test_utils::TestBroadcaster::new(network));
@@ -1114,8 +1135,8 @@ mod tests {
                        let network_graph = Arc::new(NetworkGraph::new(network, logger.clone()));
                        let scorer = Arc::new(Mutex::new(TestScorer::new()));
                        let seed = [i as u8; 32];
-                       let router = Arc::new(DefaultRouter::new(network_graph.clone(), logger.clone(), seed, scorer.clone()));
-                       let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
+                       let router = Arc::new(DefaultRouter::new(network_graph.clone(), logger.clone(), seed, scorer.clone(), ()));
+                       let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Bitcoin));
                        let persister = Arc::new(FilesystemPersister::new(format!("{}_persister_{}", &persist_dir, i)));
                        let now = Duration::from_secs(genesis_block.header.time as u64);
                        let keys_manager = Arc::new(KeysManager::new(&seed, now.as_secs(), now.subsec_nanos()));
@@ -1126,7 +1147,7 @@ mod tests {
                        let p2p_gossip_sync = Arc::new(P2PGossipSync::new(network_graph.clone(), Some(chain_source.clone()), logger.clone()));
                        let rapid_gossip_sync = Arc::new(RapidGossipSync::new(network_graph.clone(), logger.clone()));
                        let msg_handler = MessageHandler {
-                               chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new()),
+                               chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new(ChainHash::using_genesis_block(Network::Testnet))),
                                route_handler: Arc::new(test_utils::TestRoutingMessageHandler::new()),
                                onion_message_handler: IgnoringMessageHandler{}, custom_message_handler: IgnoringMessageHandler{}
                        };
@@ -1137,8 +1158,12 @@ mod tests {
 
                for i in 0..num_nodes {
                        for j in (i+1)..num_nodes {
-                               nodes[i].node.peer_connected(&nodes[j].node.get_our_node_id(), &Init { features: nodes[j].node.init_features(), remote_network_address: None }, true).unwrap();
-                               nodes[j].node.peer_connected(&nodes[i].node.get_our_node_id(), &Init { features: nodes[i].node.init_features(), remote_network_address: None }, false).unwrap();
+                               nodes[i].node.peer_connected(&nodes[j].node.get_our_node_id(), &Init {
+                                       features: nodes[j].node.init_features(), networks: None, remote_network_address: None
+                               }, true).unwrap();
+                               nodes[j].node.peer_connected(&nodes[i].node.get_our_node_id(), &Init {
+                                       features: nodes[i].node.init_features(), networks: None, remote_network_address: None
+                               }, false).unwrap();
                        }
                }
 
@@ -1189,7 +1214,7 @@ mod tests {
                for i in 1..=depth {
                        let prev_blockhash = node.best_block.block_hash();
                        let height = node.best_block.height() + 1;
-                       let header = BlockHeader { version: 0x20000000, prev_blockhash, merkle_root: TxMerkleNode::all_zeros(), time: height, bits: 42, nonce: 42 };
+                       let header = create_dummy_header(prev_blockhash, height);
                        let txdata = vec![(0, tx)];
                        node.best_block = BestBlock::new(header.block_hash(), height);
                        match i {
@@ -1709,6 +1734,10 @@ mod tests {
                if !std::thread::panicking() {
                        bg_processor.stop().unwrap();
                }
+
+               let log_entries = nodes[0].logger.lines.lock().unwrap();
+               let expected_log = "Persisting scorer after update".to_string();
+               assert_eq!(*log_entries.get(&("lightning_background_processor".to_string(), expected_log)).unwrap(), 5);
        }
 
        #[tokio::test]
@@ -1751,6 +1780,10 @@ mod tests {
                let t2 = tokio::spawn(async move {
                        do_test_payment_path_scoring!(nodes, receiver.recv().await);
                        exit_sender.send(()).unwrap();
+
+                       let log_entries = nodes[0].logger.lines.lock().unwrap();
+                       let expected_log = "Persisting scorer after update".to_string();
+                       assert_eq!(*log_entries.get(&("lightning_background_processor".to_string(), expected_log)).unwrap(), 5);
                });
 
                let (r1, r2) = tokio::join!(t1, t2);