]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Indicate ongoing rapid sync to background processor.
authorArik Sosman <git@arik.io>
Wed, 1 Jun 2022 22:26:07 +0000 (15:26 -0700)
committerArik Sosman <git@arik.io>
Thu, 2 Jun 2022 17:14:08 +0000 (10:14 -0700)
Create a wrapper struct for rapid gossip sync that can be passed to
BackgroundProcessor's start method, allowing it to only start pruning
the network graph upon rapid gossip sync's completion.

fuzz/src/process_network_graph.rs
lightning-background-processor/Cargo.toml
lightning-background-processor/src/lib.rs
lightning-rapid-gossip-sync/src/lib.rs
lightning-rapid-gossip-sync/src/processing.rs
lightning/src/routing/network_graph.rs

index 3f30335844644fec000b4c4b7c9c95959c255cf9..a71ae0e221b294cb302960d34bc4f6e966fa8124 100644 (file)
@@ -1,11 +1,13 @@
-// Import that needs to be added manually
+// Imports that need to be added manually
+use lightning_rapid_gossip_sync::RapidGossipSync;
 use utils::test_logger;
 
 /// Actual fuzz test, method signature and name are fixed
 fn do_test(data: &[u8]) {
        let block_hash = bitcoin::BlockHash::default();
        let network_graph = lightning::routing::network_graph::NetworkGraph::new(block_hash);
-       lightning_rapid_gossip_sync::processing::update_network_graph(&network_graph, data);
+       let rapid_sync = RapidGossipSync::new(&network_graph);
+       let _ = rapid_sync.update_network_graph(data);
 }
 
 /// Method that needs to be added manually, {name}_test
index 00061ee6e5e852122e490b42420db81ca53c5859..5558eaaa67c43f49248fba1a4838c6a6143bdf41 100644 (file)
@@ -16,6 +16,7 @@ rustdoc-args = ["--cfg", "docsrs"]
 [dependencies]
 bitcoin = "0.28.1"
 lightning = { version = "0.0.106", path = "../lightning", features = ["std"] }
+lightning-rapid-gossip-sync = { version = "0.0.106", path = "../lightning-rapid-gossip-sync" }
 
 [dev-dependencies]
 lightning = { version = "0.0.106", path = "../lightning", features = ["_test_utils"] }
index 95c753bca80521ef9bf203b00f94d477881ec386..603dc545eb86129946506a0d43c4d0ad02af8101 100644 (file)
@@ -9,6 +9,7 @@
 #![cfg_attr(docsrs, feature(doc_auto_cfg))]
 
 #[macro_use] extern crate lightning;
+extern crate lightning_rapid_gossip_sync;
 
 use lightning::chain;
 use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
@@ -22,6 +23,7 @@ use lightning::routing::scoring::WriteableScore;
 use lightning::util::events::{Event, EventHandler, EventsProvider};
 use lightning::util::logger::Logger;
 use lightning::util::persist::Persister;
+use lightning_rapid_gossip_sync::RapidGossipSync;
 use std::sync::Arc;
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::thread;
@@ -142,6 +144,12 @@ impl BackgroundProcessor {
        /// functionality implemented by other handlers.
        /// * [`NetGraphMsgHandler`] if given will update the [`NetworkGraph`] based on payment failures.
        ///
+       /// # Rapid Gossip Sync
+       ///
+       /// If rapid gossip sync is meant to run at startup, pass an optional [`RapidGossipSync`]
+       /// to `rapid_gossip_sync` to indicate to [`BackgroundProcessor`] not to prune the
+       /// [`NetworkGraph`] instance until the [`RapidGossipSync`] instance completes its first sync.
+       ///
        /// [top-level documentation]: BackgroundProcessor
        /// [`join`]: Self::join
        /// [`stop`]: Self::stop
@@ -175,9 +183,11 @@ impl BackgroundProcessor {
                PM: 'static + Deref<Target = PeerManager<Descriptor, CMH, RMH, L, UMH>> + Send + Sync,
                S: 'static + Deref<Target = SC> + Send + Sync,
                SC: WriteableScore<'a>,
+               RGS: 'static + Deref<Target = RapidGossipSync<G>> + Send
        >(
                persister: PS, event_handler: EH, chain_monitor: M, channel_manager: CM,
-               net_graph_msg_handler: Option<NG>, peer_manager: PM, logger: L, scorer: Option<S>
+               net_graph_msg_handler: Option<NG>, peer_manager: PM, logger: L, scorer: Option<S>,
+               rapid_gossip_sync: Option<RGS>
        ) -> Self
        where
                CA::Target: 'static + chain::Access,
@@ -272,12 +282,30 @@ impl BackgroundProcessor {
                                // pruning their network graph. We run once 60 seconds after startup before
                                // continuing our normal cadence.
                                if last_prune_call.elapsed().as_secs() > if have_pruned { NETWORK_PRUNE_TIMER } else { FIRST_NETWORK_PRUNE_TIMER } {
-                                       if let Some(ref handler) = net_graph_msg_handler {
-                                               log_trace!(logger, "Pruning network graph of stale entries");
-                                               handler.network_graph().remove_stale_channels();
-                                               if let Err(e) = persister.persist_graph(handler.network_graph()) {
+                                       // The network graph must not be pruned while rapid sync completion is pending
+                                       log_trace!(logger, "Assessing prunability of network graph");
+                                       let graph_to_prune = match rapid_gossip_sync.as_ref() {
+                                               Some(rapid_sync) => {
+                                                       if rapid_sync.is_initial_sync_complete() {
+                                                               Some(rapid_sync.network_graph())
+                                                       } else {
+                                                               None
+                                                       }
+                                               },
+                                               None => net_graph_msg_handler.as_ref().map(|handler| handler.network_graph())
+                                       };
+
+                                       if let Some(network_graph_reference) = graph_to_prune {
+                                               network_graph_reference.remove_stale_channels();
+
+                                               if let Err(e) = persister.persist_graph(network_graph_reference) {
                                                        log_error!(logger, "Error: Failed to persist network graph, check your disk and permissions {}", e)
                                                }
+
+                                               last_prune_call = Instant::now();
+                                               have_pruned = true;
+                                       } else {
+                                               log_trace!(logger, "Not pruning network graph, either due to pending rapid gossip sync or absence of a prunable graph.");
                                        }
                                        if let Some(ref scorer) = scorer {
                                                log_trace!(logger, "Persisting scorer");
@@ -285,9 +313,6 @@ impl BackgroundProcessor {
                                                        log_error!(logger, "Error: Failed to persist scorer, check your disk and permissions {}", e)
                                                }
                                        }
-
-                                       last_prune_call = Instant::now();
-                                       have_pruned = true;
                                }
                        }
 
@@ -370,7 +395,7 @@ mod tests {
        use lightning::chain::transaction::OutPoint;
        use lightning::get_event_msg;
        use lightning::ln::channelmanager::{BREAKDOWN_TIMEOUT, ChainParameters, ChannelManager, SimpleArcChannelManager};
-       use lightning::ln::features::InitFeatures;
+       use lightning::ln::features::{ChannelFeatures, InitFeatures};
        use lightning::ln::msgs::{ChannelMessageHandler, Init};
        use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
        use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
@@ -385,8 +410,10 @@ mod tests {
        use std::fs;
        use std::path::PathBuf;
        use std::sync::{Arc, Mutex};
+       use std::sync::mpsc::SyncSender;
        use std::time::Duration;
        use lightning::routing::scoring::{FixedPenaltyScorer};
+       use lightning_rapid_gossip_sync::RapidGossipSync;
        use super::{BackgroundProcessor, FRESHNESS_TIMER};
 
        const EVENT_DEADLINE: u64 = 5 * FRESHNESS_TIMER;
@@ -414,6 +441,7 @@ mod tests {
                logger: Arc<test_utils::TestLogger>,
                best_block: BestBlock,
                scorer: Arc<Mutex<FixedPenaltyScorer>>,
+               rapid_gossip_sync: Option<Arc<RapidGossipSync<Arc<NetworkGraph>>>>
        }
 
        impl Drop for Node {
@@ -428,6 +456,7 @@ mod tests {
 
        struct Persister {
                graph_error: Option<(std::io::ErrorKind, &'static str)>,
+               graph_persistence_notifier: Option<SyncSender<()>>,
                manager_error: Option<(std::io::ErrorKind, &'static str)>,
                scorer_error: Option<(std::io::ErrorKind, &'static str)>,
                filesystem_persister: FilesystemPersister,
@@ -436,13 +465,17 @@ mod tests {
        impl Persister {
                fn new(data_dir: String) -> Self {
                        let filesystem_persister = FilesystemPersister::new(data_dir.clone());
-                       Self { graph_error: None, manager_error: None, scorer_error: None, filesystem_persister }
+                       Self { graph_error: None, graph_persistence_notifier: None, manager_error: None, scorer_error: None, filesystem_persister }
                }
 
                fn with_graph_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
                        Self { graph_error: Some((error, message)), ..self }
                }
 
+               fn with_graph_persistence_notifier(self, sender: SyncSender<()>) -> Self {
+                       Self { graph_persistence_notifier: Some(sender), ..self }
+               }
+
                fn with_manager_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
                        Self { manager_error: Some((error, message)), ..self }
                }
@@ -461,6 +494,10 @@ mod tests {
                        }
 
                        if key == "network_graph" {
+                               if let Some(sender) = &self.graph_persistence_notifier {
+                                       sender.send(()).unwrap();
+                               };
+
                                if let Some((error, message)) = self.graph_error {
                                        return Err(std::io::Error::new(error, message))
                                }
@@ -504,7 +541,8 @@ mod tests {
                        let msg_handler = MessageHandler { chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new()), route_handler: Arc::new(test_utils::TestRoutingMessageHandler::new() )};
                        let peer_manager = Arc::new(PeerManager::new(msg_handler, keys_manager.get_node_secret(Recipient::Node).unwrap(), &seed, logger.clone(), IgnoringMessageHandler{}));
                        let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
-                       let node = Node { node: manager, net_graph_msg_handler, peer_manager, chain_monitor, persister, tx_broadcaster, network_graph, logger, best_block, scorer };
+                       let rapid_gossip_sync = None;
+                       let node = Node { node: manager, net_graph_msg_handler, peer_manager, chain_monitor, persister, tx_broadcaster, network_graph, logger, best_block, scorer, rapid_gossip_sync };
                        nodes.push(node);
                }
 
@@ -602,7 +640,7 @@ mod tests {
                let data_dir = nodes[0].persister.get_data_dir();
                let persister = Arc::new(Persister::new(data_dir));
                let event_handler = |_: &_| {};
-               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(),  nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), nodes[0].rapid_gossip_sync.clone());
 
                macro_rules! check_persisted_data {
                        ($node: expr, $filepath: expr) => {
@@ -667,7 +705,7 @@ mod tests {
                let data_dir = nodes[0].persister.get_data_dir();
                let persister = Arc::new(Persister::new(data_dir));
                let event_handler = |_: &_| {};
-               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), nodes[0].rapid_gossip_sync.clone());
                loop {
                        let log_entries = nodes[0].logger.lines.lock().unwrap();
                        let desired_log = "Calling ChannelManager's timer_tick_occurred".to_string();
@@ -690,7 +728,7 @@ mod tests {
                let data_dir = nodes[0].persister.get_data_dir();
                let persister = Arc::new(Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test"));
                let event_handler = |_: &_| {};
-               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), nodes[0].rapid_gossip_sync.clone());
                match bg_processor.join() {
                        Ok(_) => panic!("Expected error persisting manager"),
                        Err(e) => {
@@ -707,7 +745,7 @@ mod tests {
                let data_dir = nodes[0].persister.get_data_dir();
                let persister = Arc::new(Persister::new(data_dir).with_graph_error(std::io::ErrorKind::Other, "test"));
                let event_handler = |_: &_| {};
-               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), nodes[0].rapid_gossip_sync.clone());
 
                match bg_processor.stop() {
                        Ok(_) => panic!("Expected error persisting network graph"),
@@ -725,7 +763,7 @@ mod tests {
                let data_dir = nodes[0].persister.get_data_dir();
                let persister = Arc::new(Persister::new(data_dir).with_scorer_error(std::io::ErrorKind::Other, "test"));
                let event_handler = |_: &_| {};
-               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(),  nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), nodes[0].rapid_gossip_sync.clone());
 
                match bg_processor.stop() {
                        Ok(_) => panic!("Expected error persisting scorer"),
@@ -748,7 +786,7 @@ mod tests {
                let event_handler = move |event: &Event| {
                        sender.send(handle_funding_generation_ready!(event, channel_value)).unwrap();
                };
-               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), nodes[0].rapid_gossip_sync.clone());
 
                // Open a channel and check that the FundingGenerationReady event was handled.
                begin_open_channel!(nodes[0], nodes[1], channel_value);
@@ -773,7 +811,7 @@ mod tests {
                let (sender, receiver) = std::sync::mpsc::sync_channel(1);
                let event_handler = move |event: &Event| sender.send(event.clone()).unwrap();
                let persister = Arc::new(Persister::new(data_dir));
-               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), nodes[0].rapid_gossip_sync.clone());
 
                // Force close the channel and check that the SpendableOutputs event was handled.
                nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id, &nodes[1].node.get_our_node_id()).unwrap();
@@ -791,6 +829,83 @@ mod tests {
                assert!(bg_processor.stop().is_ok());
        }
 
+       #[test]
+       fn test_scorer_persistence() {
+               let nodes = create_nodes(2, "test_scorer_persistence".to_string());
+               let data_dir = nodes[0].persister.get_data_dir();
+               let persister = Arc::new(Persister::new(data_dir));
+               let event_handler = |_: &_| {};
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(),  nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), nodes[0].rapid_gossip_sync.clone());
+
+               loop {
+                       let log_entries = nodes[0].logger.lines.lock().unwrap();
+                       let expected_log = "Persisting scorer".to_string();
+                       if log_entries.get(&("lightning_background_processor".to_string(), expected_log)).is_some() {
+                               break
+                       }
+               }
+
+               assert!(bg_processor.stop().is_ok());
+       }
+
+       #[test]
+       fn test_not_pruning_network_graph_until_graph_sync_completion() {
+               let nodes = create_nodes(2, "test_not_pruning_network_graph_until_graph_sync_completion".to_string());
+               let data_dir = nodes[0].persister.get_data_dir();
+               let (sender, receiver) = std::sync::mpsc::sync_channel(1);
+               let persister = Arc::new(Persister::new(data_dir.clone()).with_graph_persistence_notifier(sender));
+               let network_graph = nodes[0].network_graph.clone();
+               let rapid_sync = Arc::new(RapidGossipSync::new(network_graph.clone()));
+               let features = ChannelFeatures::empty();
+               network_graph.add_channel_from_partial_announcement(42, 53, features, nodes[0].node.get_our_node_id(), nodes[1].node.get_our_node_id())
+                       .expect("Failed to update channel from partial announcement");
+               let original_graph_description = network_graph.to_string();
+               assert!(original_graph_description.contains("42: features: 0000, node_one:"));
+               assert_eq!(network_graph.read_only().channels().len(), 1);
+
+               let event_handler = |_: &_| {};
+               let background_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), Some(rapid_sync.clone()));
+
+               loop {
+                       let log_entries = nodes[0].logger.lines.lock().unwrap();
+                       let expected_log_a = "Assessing prunability of network graph".to_string();
+                       let expected_log_b = "Not pruning network graph, either due to pending rapid gossip sync or absence of a prunable graph.".to_string();
+                       if log_entries.get(&("lightning_background_processor".to_string(), expected_log_a)).is_some() &&
+                               log_entries.get(&("lightning_background_processor".to_string(), expected_log_b)).is_some() {
+                               break
+                       }
+               }
+
+               let initialization_input = vec![
+                       76, 68, 75, 1, 111, 226, 140, 10, 182, 241, 179, 114, 193, 166, 162, 70, 174, 99, 247,
+                       79, 147, 30, 131, 101, 225, 90, 8, 156, 104, 214, 25, 0, 0, 0, 0, 0, 97, 227, 98, 218,
+                       0, 0, 0, 4, 2, 22, 7, 207, 206, 25, 164, 197, 231, 230, 231, 56, 102, 61, 250, 251,
+                       187, 172, 38, 46, 79, 247, 108, 44, 155, 48, 219, 238, 252, 53, 192, 6, 67, 2, 36, 125,
+                       157, 176, 223, 175, 234, 116, 94, 248, 201, 225, 97, 235, 50, 47, 115, 172, 63, 136,
+                       88, 216, 115, 11, 111, 217, 114, 84, 116, 124, 231, 107, 2, 158, 1, 242, 121, 152, 106,
+                       204, 131, 186, 35, 93, 70, 216, 10, 237, 224, 183, 89, 95, 65, 3, 83, 185, 58, 138,
+                       181, 64, 187, 103, 127, 68, 50, 2, 201, 19, 17, 138, 136, 149, 185, 226, 156, 137, 175,
+                       110, 32, 237, 0, 217, 90, 31, 100, 228, 149, 46, 219, 175, 168, 77, 4, 143, 38, 128,
+                       76, 97, 0, 0, 0, 2, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 0, 1, 0, 0, 255, 2, 68,
+                       226, 0, 6, 11, 0, 1, 2, 3, 0, 0, 0, 2, 0, 40, 0, 0, 0, 0, 0, 0, 3, 232, 0, 0, 3, 232,
+                       0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 8, 153, 192, 0, 2, 27, 0, 0, 25, 0, 0,
+                       0, 1, 0, 0, 0, 125, 255, 2, 68, 226, 0, 6, 11, 0, 1, 5, 0, 0, 0, 0, 29, 129, 25, 192,
+               ];
+               rapid_sync.update_network_graph(&initialization_input[..]).unwrap();
+
+               // this should have added two channels
+               assert_eq!(network_graph.read_only().channels().len(), 3);
+
+               let _ = receiver
+                       .recv_timeout(Duration::from_secs(super::FIRST_NETWORK_PRUNE_TIMER * 5))
+                       .expect("Network graph not pruned within deadline");
+
+               background_processor.stop().unwrap();
+
+               // all channels should now be pruned
+               assert_eq!(network_graph.read_only().channels().len(), 0);
+       }
+
        #[test]
        fn test_invoice_payer() {
                let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
@@ -803,7 +918,7 @@ mod tests {
                let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger), random_seed_bytes);
                let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, Arc::clone(&nodes[0].scorer), Arc::clone(&nodes[0].logger), |_: &_| {}, Retry::Attempts(2)));
                let event_handler = Arc::clone(&invoice_payer);
-               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()));
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), Some(nodes[0].scorer.clone()), nodes[0].rapid_gossip_sync.clone());
                assert!(bg_processor.stop().is_ok());
        }
 }
index 123f3238ed0755facc3da65f83a410c74629c0ae..e2e7807e398b48b327b63f6b5fec49a5724464df 100644 (file)
 //! use bitcoin::blockdata::constants::genesis_block;
 //! use bitcoin::Network;
 //! use lightning::routing::network_graph::NetworkGraph;
+//! use lightning_rapid_gossip_sync::RapidGossipSync;
 //!
 //! let block_hash = genesis_block(Network::Bitcoin).header.block_hash();
 //! let network_graph = NetworkGraph::new(block_hash);
-//! let new_last_sync_timestamp_result = lightning_rapid_gossip_sync::sync_network_graph_with_file_path(&network_graph, "./rapid_sync.lngossip");
+//! let rapid_sync = RapidGossipSync::new(&network_graph);
+//! let new_last_sync_timestamp_result = rapid_sync.sync_network_graph_with_file_path("./rapid_sync.lngossip");
 //! ```
 //!
 //! The primary benefit this syncing mechanism provides is that given a trusted server, a
 extern crate test;
 
 use std::fs::File;
+use std::ops::Deref;
+use std::sync::atomic::{AtomicBool, Ordering};
 
-use lightning::routing::network_graph;
+use lightning::routing::network_graph::NetworkGraph;
 
 use crate::error::GraphSyncError;
 
@@ -68,19 +72,51 @@ pub mod error;
 /// Core functionality of this crate
 pub mod processing;
 
-/// Sync gossip data from a file
-/// Returns the last sync timestamp to be used the next time rapid sync data is queried.
+/// Rapid Gossip Sync struct
+/// See [crate-level documentation] for usage.
 ///
-/// `network_graph`: The network graph to apply the updates to
-///
-/// `sync_path`: Path to the file where the gossip update data is located
-///
-pub fn sync_network_graph_with_file_path(
-       network_graph: &network_graph::NetworkGraph,
-       sync_path: &str,
-) -> Result<u32, GraphSyncError> {
-       let mut file = File::open(sync_path)?;
-       processing::update_network_graph_from_byte_stream(&network_graph, &mut file)
+/// [crate-level documentation]: crate
+pub struct RapidGossipSync<NG: Deref<Target=NetworkGraph>> {
+       network_graph: NG,
+       is_initial_sync_complete: AtomicBool
+}
+
+impl<NG: Deref<Target=NetworkGraph>> RapidGossipSync<NG> {
+       /// Instantiate a new [`RapidGossipSync`] instance
+       pub fn new(network_graph: NG) -> Self {
+               Self {
+                       network_graph,
+                       is_initial_sync_complete: AtomicBool::new(false)
+               }
+       }
+
+       /// Sync gossip data from a file
+       /// Returns the last sync timestamp to be used the next time rapid sync data is queried.
+       ///
+       /// `network_graph`: The network graph to apply the updates to
+       ///
+       /// `sync_path`: Path to the file where the gossip update data is located
+       ///
+       pub fn sync_network_graph_with_file_path(
+               &self,
+               sync_path: &str,
+       ) -> Result<u32, GraphSyncError> {
+               let mut file = File::open(sync_path)?;
+               self.update_network_graph_from_byte_stream(&mut file)
+       }
+
+       /// Gets a reference to the underlying [`NetworkGraph`] which was provided in
+       /// [`RapidGossipSync::new`].
+       ///
+       /// (C-not exported) as bindings don't support a reference-to-a-reference yet
+       pub fn network_graph(&self) -> &NG {
+               &self.network_graph
+       }
+
+       /// Returns whether a rapid gossip sync has completed at least once
+       pub fn is_initial_sync_complete(&self) -> bool {
+               self.is_initial_sync_complete.load(Ordering::Acquire)
+       }
 }
 
 #[cfg(test)]
@@ -92,8 +128,7 @@ mod tests {
 
        use lightning::ln::msgs::DecodeError;
        use lightning::routing::network_graph::NetworkGraph;
-
-       use crate::sync_network_graph_with_file_path;
+       use crate::RapidGossipSync;
 
        #[test]
        fn test_sync_from_file() {
@@ -156,7 +191,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let sync_result = sync_network_graph_with_file_path(&network_graph, &graph_sync_test_file);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let sync_result = rapid_sync.sync_network_graph_with_file_path(&graph_sync_test_file);
 
                if sync_result.is_err() {
                        panic!("Unexpected sync result: {:?}", sync_result)
@@ -187,11 +223,12 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
+               let rapid_sync = RapidGossipSync::new(&network_graph);
                let start = std::time::Instant::now();
-               let sync_result =
-                       sync_network_graph_with_file_path(&network_graph, "./res/full_graph.lngossip");
+               let sync_result = rapid_sync
+                       .sync_network_graph_with_file_path("./res/full_graph.lngossip");
                if let Err(crate::error::GraphSyncError::DecodeError(DecodeError::Io(io_error))) = &sync_result {
-                       let error_string = format!("Input file lightning-graph-sync/res/full_graph.lngossip is missing! Download it from https://bitcoin.ninja/ldk-compressed_graph-bc08df7542-2022-05-05.bin\n\n{:?}", io_error);
+                       let error_string = format!("Input file lightning-rapid-gossip-sync/res/full_graph.lngossip is missing! Download it from https://bitcoin.ninja/ldk-compressed_graph-bc08df7542-2022-05-05.bin\n\n{:?}", io_error);
                        #[cfg(not(require_route_graph_test))]
                        {
                                println!("{}", error_string);
@@ -218,19 +255,17 @@ pub mod bench {
        use lightning::ln::msgs::DecodeError;
        use lightning::routing::network_graph::NetworkGraph;
 
-       use crate::sync_network_graph_with_file_path;
+       use crate::RapidGossipSync;
 
        #[bench]
        fn bench_reading_full_graph_from_file(b: &mut Bencher) {
                let block_hash = genesis_block(Network::Bitcoin).block_hash();
                b.iter(|| {
                        let network_graph = NetworkGraph::new(block_hash);
-                       let sync_result = sync_network_graph_with_file_path(
-                               &network_graph,
-                               "./res/full_graph.lngossip",
-                       );
+                       let rapid_sync = RapidGossipSync::new(&network_graph);
+                       let sync_result = rapid_sync.sync_network_graph_with_file_path("./res/full_graph.lngossip");
                        if let Err(crate::error::GraphSyncError::DecodeError(DecodeError::Io(io_error))) = &sync_result {
-                               let error_string = format!("Input file lightning-graph-sync/res/full_graph.lngossip is missing! Download it from https://bitcoin.ninja/ldk-compressed_graph-bc08df7542-2022-05-05.bin\n\n{:?}", io_error);
+                               let error_string = format!("Input file lightning-rapid-gossip-sync/res/full_graph.lngossip is missing! Download it from https://bitcoin.ninja/ldk-compressed_graph-bc08df7542-2022-05-05.bin\n\n{:?}", io_error);
                                #[cfg(not(require_route_graph_test))]
                                {
                                        println!("{}", error_string);
index ceb8b82295336406c142a376f325070bc9255f3e..6ffc6f58ea88bf6adbf1fc9e7fbd2184a0ed78e4 100644 (file)
@@ -1,6 +1,8 @@
 use std::cmp::max;
 use std::io;
 use std::io::Read;
+use std::ops::Deref;
+use std::sync::atomic::Ordering;
 
 use bitcoin::BlockHash;
 use bitcoin::secp256k1::PublicKey;
@@ -8,10 +10,11 @@ use bitcoin::secp256k1::PublicKey;
 use lightning::ln::msgs::{
        DecodeError, ErrorAction, LightningError, OptionalField, UnsignedChannelUpdate,
 };
-use lightning::routing::network_graph;
+use lightning::routing::network_graph::NetworkGraph;
 use lightning::util::ser::{BigSize, Readable};
 
 use crate::error::GraphSyncError;
+use crate::RapidGossipSync;
 
 /// The purpose of this prefix is to identify the serialization format, should other rapid gossip
 /// sync formats arise in the future.
@@ -23,203 +26,207 @@ const GOSSIP_PREFIX: [u8; 4] = [76, 68, 75, 1];
 /// avoid malicious updates being able to trigger excessive memory allocation.
 const MAX_INITIAL_NODE_ID_VECTOR_CAPACITY: u32 = 50_000;
 
-/// Update network graph from binary data.
-/// Returns the last sync timestamp to be used the next time rapid sync data is queried.
-///
-/// `network_graph`: network graph to be updated
-///
-/// `update_data`: `&[u8]` binary stream that comprises the update data
-pub fn update_network_graph(
-       network_graph: &network_graph::NetworkGraph,
-       update_data: &[u8],
-) -> Result<u32, GraphSyncError> {
-       let mut read_cursor = io::Cursor::new(update_data);
-       update_network_graph_from_byte_stream(&network_graph, &mut read_cursor)
-}
-
-pub(crate) fn update_network_graph_from_byte_stream<R: Read>(
-       network_graph: &network_graph::NetworkGraph,
-       mut read_cursor: &mut R,
-) -> Result<u32, GraphSyncError> {
-       let mut prefix = [0u8; 4];
-       read_cursor.read_exact(&mut prefix)?;
-
-       match prefix {
-               GOSSIP_PREFIX => {},
-               _ => {
-                       return Err(DecodeError::UnknownVersion.into());
-               }
-       };
-
-       let chain_hash: BlockHash = Readable::read(read_cursor)?;
-       let latest_seen_timestamp: u32 = Readable::read(read_cursor)?;
-       // backdate the applied timestamp by a week
-       let backdated_timestamp = latest_seen_timestamp.saturating_sub(24 * 3600 * 7);
-
-       let node_id_count: u32 = Readable::read(read_cursor)?;
-       let mut node_ids: Vec<PublicKey> = Vec::with_capacity(std::cmp::min(
-               node_id_count,
-               MAX_INITIAL_NODE_ID_VECTOR_CAPACITY,
-       ) as usize);
-       for _ in 0..node_id_count {
-               let current_node_id = Readable::read(read_cursor)?;
-               node_ids.push(current_node_id);
+impl<NG: Deref<Target=NetworkGraph>> RapidGossipSync<NG> {
+       /// Update network graph from binary data.
+       /// Returns the last sync timestamp to be used the next time rapid sync data is queried.
+       ///
+       /// `network_graph`: network graph to be updated
+       ///
+       /// `update_data`: `&[u8]` binary stream that comprises the update data
+       pub fn update_network_graph(&self, update_data: &[u8]) -> Result<u32, GraphSyncError> {
+               let mut read_cursor = io::Cursor::new(update_data);
+               self.update_network_graph_from_byte_stream(&mut read_cursor)
        }
 
-       let mut previous_scid: u64 = 0;
-       let announcement_count: u32 = Readable::read(read_cursor)?;
-       for _ in 0..announcement_count {
-               let features = Readable::read(read_cursor)?;
-
-               // handle SCID
-               let scid_delta: BigSize = Readable::read(read_cursor)?;
-               let short_channel_id = previous_scid
-                       .checked_add(scid_delta.0)
-                       .ok_or(DecodeError::InvalidValue)?;
-               previous_scid = short_channel_id;
-
-               let node_id_1_index: BigSize = Readable::read(read_cursor)?;
-               let node_id_2_index: BigSize = Readable::read(read_cursor)?;
-               if max(node_id_1_index.0, node_id_2_index.0) >= node_id_count as u64 {
-                       return Err(DecodeError::InvalidValue.into());
-               };
-               let node_id_1 = node_ids[node_id_1_index.0 as usize];
-               let node_id_2 = node_ids[node_id_2_index.0 as usize];
-
-               let announcement_result = network_graph.add_channel_from_partial_announcement(
-                       short_channel_id,
-                       backdated_timestamp as u64,
-                       features,
-                       node_id_1,
-                       node_id_2,
-               );
-               if let Err(lightning_error) = announcement_result {
-                       if let ErrorAction::IgnoreDuplicateGossip = lightning_error.action {
-                               // everything is fine, just a duplicate channel announcement
-                       } else {
-                               return Err(lightning_error.into());
+
+       pub(crate) fn update_network_graph_from_byte_stream<R: Read>(
+               &self,
+               mut read_cursor: &mut R,
+       ) -> Result<u32, GraphSyncError> {
+               let mut prefix = [0u8; 4];
+               read_cursor.read_exact(&mut prefix)?;
+
+               match prefix {
+                       GOSSIP_PREFIX => {}
+                       _ => {
+                               return Err(DecodeError::UnknownVersion.into());
                        }
+               };
+
+               let chain_hash: BlockHash = Readable::read(read_cursor)?;
+               let latest_seen_timestamp: u32 = Readable::read(read_cursor)?;
+               // backdate the applied timestamp by a week
+               let backdated_timestamp = latest_seen_timestamp.saturating_sub(24 * 3600 * 7);
+
+               let node_id_count: u32 = Readable::read(read_cursor)?;
+               let mut node_ids: Vec<PublicKey> = Vec::with_capacity(std::cmp::min(
+                       node_id_count,
+                       MAX_INITIAL_NODE_ID_VECTOR_CAPACITY,
+               ) as usize);
+               for _ in 0..node_id_count {
+                       let current_node_id = Readable::read(read_cursor)?;
+                       node_ids.push(current_node_id);
                }
-       }
 
-       previous_scid = 0; // updates start at a new scid
+               let network_graph = &self.network_graph;
 
-       let update_count: u32 = Readable::read(read_cursor)?;
-       if update_count == 0 {
-               return Ok(latest_seen_timestamp);
-       }
+               let mut previous_scid: u64 = 0;
+               let announcement_count: u32 = Readable::read(read_cursor)?;
+               for _ in 0..announcement_count {
+                       let features = Readable::read(read_cursor)?;
 
-       // obtain default values for non-incremental updates
-       let default_cltv_expiry_delta: u16 = Readable::read(&mut read_cursor)?;
-       let default_htlc_minimum_msat: u64 = Readable::read(&mut read_cursor)?;
-       let default_fee_base_msat: u32 = Readable::read(&mut read_cursor)?;
-       let default_fee_proportional_millionths: u32 = Readable::read(&mut read_cursor)?;
-       let tentative_default_htlc_maximum_msat: u64 = Readable::read(&mut read_cursor)?;
-       let default_htlc_maximum_msat = if tentative_default_htlc_maximum_msat == u64::max_value() {
-               OptionalField::Absent
-       } else {
-               OptionalField::Present(tentative_default_htlc_maximum_msat)
-       };
-
-       for _ in 0..update_count {
-               let scid_delta: BigSize = Readable::read(read_cursor)?;
-               let short_channel_id = previous_scid
-                       .checked_add(scid_delta.0)
-                       .ok_or(DecodeError::InvalidValue)?;
-               previous_scid = short_channel_id;
-
-               let channel_flags: u8 = Readable::read(read_cursor)?;
-
-               // flags are always sent in full, and hence always need updating
-               let standard_channel_flags = channel_flags & 0b_0000_0011;
-
-               let mut synthetic_update = if channel_flags & 0b_1000_0000 == 0 {
-                       // full update, field flags will indicate deviations from the default
-                       UnsignedChannelUpdate {
-                               chain_hash,
-                               short_channel_id,
-                               timestamp: backdated_timestamp,
-                               flags: standard_channel_flags,
-                               cltv_expiry_delta: default_cltv_expiry_delta,
-                               htlc_minimum_msat: default_htlc_minimum_msat,
-                               htlc_maximum_msat: default_htlc_maximum_msat.clone(),
-                               fee_base_msat: default_fee_base_msat,
-                               fee_proportional_millionths: default_fee_proportional_millionths,
-                               excess_data: vec![],
-                       }
-               } else {
-                       // incremental update, field flags will indicate mutated values
-                       let read_only_network_graph = network_graph.read_only();
-                       let channel = read_only_network_graph
-                               .channels()
-                               .get(&short_channel_id)
-                               .ok_or(LightningError {
-                                       err: "Couldn't find channel for update".to_owned(),
-                                       action: ErrorAction::IgnoreError,
-                               })?;
-
-                       let directional_info = channel
-                               .get_directional_info(channel_flags)
-                               .ok_or(LightningError {
-                                       err: "Couldn't find previous directional data for update".to_owned(),
-                                       action: ErrorAction::IgnoreError,
-                               })?;
-
-                       let htlc_maximum_msat =
-                               if let Some(htlc_maximum_msat) = directional_info.htlc_maximum_msat {
-                                       OptionalField::Present(htlc_maximum_msat)
-                               } else {
-                                       OptionalField::Absent
-                               };
+                       // handle SCID
+                       let scid_delta: BigSize = Readable::read(read_cursor)?;
+                       let short_channel_id = previous_scid
+                               .checked_add(scid_delta.0)
+                               .ok_or(DecodeError::InvalidValue)?;
+                       previous_scid = short_channel_id;
+
+                       let node_id_1_index: BigSize = Readable::read(read_cursor)?;
+                       let node_id_2_index: BigSize = Readable::read(read_cursor)?;
+                       if max(node_id_1_index.0, node_id_2_index.0) >= node_id_count as u64 {
+                               return Err(DecodeError::InvalidValue.into());
+                       };
+                       let node_id_1 = node_ids[node_id_1_index.0 as usize];
+                       let node_id_2 = node_ids[node_id_2_index.0 as usize];
 
-                       UnsignedChannelUpdate {
-                               chain_hash,
+                       let announcement_result = network_graph.add_channel_from_partial_announcement(
                                short_channel_id,
-                               timestamp: backdated_timestamp,
-                               flags: standard_channel_flags,
-                               cltv_expiry_delta: directional_info.cltv_expiry_delta,
-                               htlc_minimum_msat: directional_info.htlc_minimum_msat,
-                               htlc_maximum_msat,
-                               fee_base_msat: directional_info.fees.base_msat,
-                               fee_proportional_millionths: directional_info.fees.proportional_millionths,
-                               excess_data: vec![],
+                               backdated_timestamp as u64,
+                               features,
+                               node_id_1,
+                               node_id_2,
+                       );
+                       if let Err(lightning_error) = announcement_result {
+                               if let ErrorAction::IgnoreDuplicateGossip = lightning_error.action {
+                                       // everything is fine, just a duplicate channel announcement
+                               } else {
+                                       return Err(lightning_error.into());
+                               }
                        }
-               };
-
-               if channel_flags & 0b_0100_0000 > 0 {
-                       let cltv_expiry_delta: u16 = Readable::read(read_cursor)?;
-                       synthetic_update.cltv_expiry_delta = cltv_expiry_delta;
                }
 
-               if channel_flags & 0b_0010_0000 > 0 {
-                       let htlc_minimum_msat: u64 = Readable::read(read_cursor)?;
-                       synthetic_update.htlc_minimum_msat = htlc_minimum_msat;
-               }
+               previous_scid = 0; // updates start at a new scid
 
-               if channel_flags & 0b_0001_0000 > 0 {
-                       let fee_base_msat: u32 = Readable::read(read_cursor)?;
-                       synthetic_update.fee_base_msat = fee_base_msat;
+               let update_count: u32 = Readable::read(read_cursor)?;
+               if update_count == 0 {
+                       return Ok(latest_seen_timestamp);
                }
 
-               if channel_flags & 0b_0000_1000 > 0 {
-                       let fee_proportional_millionths: u32 = Readable::read(read_cursor)?;
-                       synthetic_update.fee_proportional_millionths = fee_proportional_millionths;
-               }
+               // obtain default values for non-incremental updates
+               let default_cltv_expiry_delta: u16 = Readable::read(&mut read_cursor)?;
+               let default_htlc_minimum_msat: u64 = Readable::read(&mut read_cursor)?;
+               let default_fee_base_msat: u32 = Readable::read(&mut read_cursor)?;
+               let default_fee_proportional_millionths: u32 = Readable::read(&mut read_cursor)?;
+               let tentative_default_htlc_maximum_msat: u64 = Readable::read(&mut read_cursor)?;
+               let default_htlc_maximum_msat = if tentative_default_htlc_maximum_msat == u64::max_value() {
+                       OptionalField::Absent
+               } else {
+                       OptionalField::Present(tentative_default_htlc_maximum_msat)
+               };
 
-               if channel_flags & 0b_0000_0100 > 0 {
-                       let tentative_htlc_maximum_msat: u64 = Readable::read(read_cursor)?;
-                       synthetic_update.htlc_maximum_msat = if tentative_htlc_maximum_msat == u64::max_value()
-                       {
-                               OptionalField::Absent
+               for _ in 0..update_count {
+                       let scid_delta: BigSize = Readable::read(read_cursor)?;
+                       let short_channel_id = previous_scid
+                               .checked_add(scid_delta.0)
+                               .ok_or(DecodeError::InvalidValue)?;
+                       previous_scid = short_channel_id;
+
+                       let channel_flags: u8 = Readable::read(read_cursor)?;
+
+                       // flags are always sent in full, and hence always need updating
+                       let standard_channel_flags = channel_flags & 0b_0000_0011;
+
+                       let mut synthetic_update = if channel_flags & 0b_1000_0000 == 0 {
+                               // full update, field flags will indicate deviations from the default
+                               UnsignedChannelUpdate {
+                                       chain_hash,
+                                       short_channel_id,
+                                       timestamp: backdated_timestamp,
+                                       flags: standard_channel_flags,
+                                       cltv_expiry_delta: default_cltv_expiry_delta,
+                                       htlc_minimum_msat: default_htlc_minimum_msat,
+                                       htlc_maximum_msat: default_htlc_maximum_msat.clone(),
+                                       fee_base_msat: default_fee_base_msat,
+                                       fee_proportional_millionths: default_fee_proportional_millionths,
+                                       excess_data: vec![],
+                               }
                        } else {
-                               OptionalField::Present(tentative_htlc_maximum_msat)
+                               // incremental update, field flags will indicate mutated values
+                               let read_only_network_graph = network_graph.read_only();
+                               let channel = read_only_network_graph
+                                       .channels()
+                                       .get(&short_channel_id)
+                                       .ok_or(LightningError {
+                                               err: "Couldn't find channel for update".to_owned(),
+                                               action: ErrorAction::IgnoreError,
+                                       })?;
+
+                               let directional_info = channel
+                                       .get_directional_info(channel_flags)
+                                       .ok_or(LightningError {
+                                               err: "Couldn't find previous directional data for update".to_owned(),
+                                               action: ErrorAction::IgnoreError,
+                                       })?;
+
+                               let htlc_maximum_msat =
+                                       if let Some(htlc_maximum_msat) = directional_info.htlc_maximum_msat {
+                                               OptionalField::Present(htlc_maximum_msat)
+                                       } else {
+                                               OptionalField::Absent
+                                       };
+
+                               UnsignedChannelUpdate {
+                                       chain_hash,
+                                       short_channel_id,
+                                       timestamp: backdated_timestamp,
+                                       flags: standard_channel_flags,
+                                       cltv_expiry_delta: directional_info.cltv_expiry_delta,
+                                       htlc_minimum_msat: directional_info.htlc_minimum_msat,
+                                       htlc_maximum_msat,
+                                       fee_base_msat: directional_info.fees.base_msat,
+                                       fee_proportional_millionths: directional_info.fees.proportional_millionths,
+                                       excess_data: vec![],
+                               }
                        };
+
+                       if channel_flags & 0b_0100_0000 > 0 {
+                               let cltv_expiry_delta: u16 = Readable::read(read_cursor)?;
+                               synthetic_update.cltv_expiry_delta = cltv_expiry_delta;
+                       }
+
+                       if channel_flags & 0b_0010_0000 > 0 {
+                               let htlc_minimum_msat: u64 = Readable::read(read_cursor)?;
+                               synthetic_update.htlc_minimum_msat = htlc_minimum_msat;
+                       }
+
+                       if channel_flags & 0b_0001_0000 > 0 {
+                               let fee_base_msat: u32 = Readable::read(read_cursor)?;
+                               synthetic_update.fee_base_msat = fee_base_msat;
+                       }
+
+                       if channel_flags & 0b_0000_1000 > 0 {
+                               let fee_proportional_millionths: u32 = Readable::read(read_cursor)?;
+                               synthetic_update.fee_proportional_millionths = fee_proportional_millionths;
+                       }
+
+                       if channel_flags & 0b_0000_0100 > 0 {
+                               let tentative_htlc_maximum_msat: u64 = Readable::read(read_cursor)?;
+                               synthetic_update.htlc_maximum_msat = if tentative_htlc_maximum_msat == u64::max_value()
+                               {
+                                       OptionalField::Absent
+                               } else {
+                                       OptionalField::Present(tentative_htlc_maximum_msat)
+                               };
+                       }
+
+                       network_graph.update_channel_unsigned(&synthetic_update)?;
                }
 
-               network_graph.update_channel_unsigned(&synthetic_update)?;
+               self.network_graph.set_last_rapid_gossip_sync_timestamp(latest_seen_timestamp);
+               self.is_initial_sync_complete.store(true, Ordering::Release);
+               Ok(latest_seen_timestamp)
        }
-
-       Ok(latest_seen_timestamp)
 }
 
 #[cfg(test)]
@@ -231,7 +238,7 @@ mod tests {
        use lightning::routing::network_graph::NetworkGraph;
 
        use crate::error::GraphSyncError;
-       use crate::processing::update_network_graph;
+       use crate::RapidGossipSync;
 
        #[test]
        fn network_graph_fails_to_update_from_clipped_input() {
@@ -254,7 +261,8 @@ mod tests {
                        0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 2, 68, 226, 0, 6, 11, 0, 1, 24, 0,
                        0, 3, 232, 0, 0, 0,
                ];
-               let update_result = update_network_graph(&network_graph, &example_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let update_result = rapid_sync.update_network_graph(&example_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::DecodeError(DecodeError::ShortRead)) = update_result {
                        // this is the expected error type
@@ -278,7 +286,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let update_result = update_network_graph(&network_graph, &incremental_update_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let update_result = rapid_sync.update_network_graph(&incremental_update_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
                        assert_eq!(lightning_error.err, "Couldn't find channel for update");
@@ -310,7 +319,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let update_result = update_network_graph(&network_graph, &announced_update_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let update_result = rapid_sync.update_network_graph(&announced_update_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
                        assert_eq!(
@@ -345,7 +355,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let initialization_result = update_network_graph(&network_graph, &initialization_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                if initialization_result.is_err() {
                        panic!(
                                "Unexpected initialization result: {:?}",
@@ -373,10 +384,7 @@ mod tests {
                        0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2,
                        68, 226, 0, 6, 11, 0, 1, 128,
                ];
-               let update_result = update_network_graph(
-                       &network_graph,
-                       &opposite_direction_incremental_update_input[..],
-               );
+               let update_result = rapid_sync.update_network_graph(&opposite_direction_incremental_update_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
                        assert_eq!(
@@ -413,7 +421,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let initialization_result = update_network_graph(&network_graph, &initialization_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                assert!(initialization_result.is_ok());
 
                let single_direction_incremental_update_input = vec![
@@ -423,10 +432,7 @@ mod tests {
                        0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2,
                        68, 226, 0, 6, 11, 0, 1, 128,
                ];
-               let update_result = update_network_graph(
-                       &network_graph,
-                       &single_direction_incremental_update_input[..],
-               );
+               let update_result = rapid_sync.update_network_graph(&single_direction_incremental_update_input[..]);
                if update_result.is_err() {
                        panic!("Unexpected update result: {:?}", update_result)
                }
@@ -474,7 +480,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let update_result = update_network_graph(&network_graph, &valid_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let update_result = rapid_sync.update_network_graph(&valid_input[..]);
                if update_result.is_err() {
                        panic!("Unexpected update result: {:?}", update_result)
                }
index 668e70969a85ae45db174a1ff68ec745bc6c377c..18b846cbf8895ce1e10f7ff5ba6411b1a1e91011 100644 (file)
@@ -123,9 +123,7 @@ impl Readable for NodeId {
 
 /// Represents the network as nodes and channels between them
 pub struct NetworkGraph {
-       /// The unix timestamp in UTC provided by the most recent rapid gossip sync
-       /// It will be set by the rapid sync process after every sync completion
-       pub last_rapid_gossip_sync_timestamp: Option<u32>,
+       last_rapid_gossip_sync_timestamp: Mutex<Option<u32>>,
        genesis_hash: BlockHash,
        // Lock order: channels -> nodes
        channels: RwLock<BTreeMap<u64, ChannelInfo>>,
@@ -136,11 +134,12 @@ impl Clone for NetworkGraph {
        fn clone(&self) -> Self {
                let channels = self.channels.read().unwrap();
                let nodes = self.nodes.read().unwrap();
+               let last_rapid_gossip_sync_timestamp = self.get_last_rapid_gossip_sync_timestamp();
                Self {
                        genesis_hash: self.genesis_hash.clone(),
                        channels: RwLock::new(channels.clone()),
                        nodes: RwLock::new(nodes.clone()),
-                       last_rapid_gossip_sync_timestamp: self.last_rapid_gossip_sync_timestamp.clone(),
+                       last_rapid_gossip_sync_timestamp: Mutex::new(last_rapid_gossip_sync_timestamp)
                }
        }
 }
@@ -994,8 +993,9 @@ impl Writeable for NetworkGraph {
                        node_info.write(writer)?;
                }
 
+               let last_rapid_gossip_sync_timestamp = self.get_last_rapid_gossip_sync_timestamp();
                write_tlv_fields!(writer, {
-                       (1, self.last_rapid_gossip_sync_timestamp, option),
+                       (1, last_rapid_gossip_sync_timestamp, option),
                });
                Ok(())
        }
@@ -1030,7 +1030,7 @@ impl Readable for NetworkGraph {
                        genesis_hash,
                        channels: RwLock::new(channels),
                        nodes: RwLock::new(nodes),
-                       last_rapid_gossip_sync_timestamp,
+                       last_rapid_gossip_sync_timestamp: Mutex::new(last_rapid_gossip_sync_timestamp),
                })
        }
 }
@@ -1064,7 +1064,7 @@ impl NetworkGraph {
                        genesis_hash,
                        channels: RwLock::new(BTreeMap::new()),
                        nodes: RwLock::new(BTreeMap::new()),
-                       last_rapid_gossip_sync_timestamp: None,
+                       last_rapid_gossip_sync_timestamp: Mutex::new(None),
                }
        }
 
@@ -1078,6 +1078,18 @@ impl NetworkGraph {
                }
        }
 
+       /// The unix timestamp provided by the most recent rapid gossip sync.
+       /// It will be set by the rapid sync process after every sync completion.
+       pub fn get_last_rapid_gossip_sync_timestamp(&self) -> Option<u32> {
+               self.last_rapid_gossip_sync_timestamp.lock().unwrap().clone()
+       }
+
+       /// Update the unix timestamp provided by the most recent rapid gossip sync.
+       /// This should be done automatically by the rapid sync process after every sync completion.
+       pub fn set_last_rapid_gossip_sync_timestamp(&self, last_rapid_gossip_sync_timestamp: u32) {
+               self.last_rapid_gossip_sync_timestamp.lock().unwrap().replace(last_rapid_gossip_sync_timestamp);
+       }
+
        /// Clears the `NodeAnnouncementInfo` field for all nodes in the `NetworkGraph` for testing
        /// purposes.
        #[cfg(test)]
@@ -2374,13 +2386,13 @@ mod tests {
        #[test]
        fn network_graph_tlv_serialization() {
                let mut network_graph = create_network_graph();
-               network_graph.last_rapid_gossip_sync_timestamp.replace(42);
+               network_graph.set_last_rapid_gossip_sync_timestamp(42);
 
                let mut w = test_utils::TestVecWriter(Vec::new());
                network_graph.write(&mut w).unwrap();
                let reassembled_network_graph: NetworkGraph = Readable::read(&mut io::Cursor::new(&w.0)).unwrap();
                assert!(reassembled_network_graph == network_graph);
-               assert_eq!(reassembled_network_graph.last_rapid_gossip_sync_timestamp.unwrap(), 42);
+               assert_eq!(reassembled_network_graph.get_last_rapid_gossip_sync_timestamp().unwrap(), 42);
        }
 
        #[test]