Indicate ongoing rapid sync to background processor.
[rust-lightning] / lightning-rapid-gossip-sync / src / processing.rs
index ceb8b82295336406c142a376f325070bc9255f3e..6ffc6f58ea88bf6adbf1fc9e7fbd2184a0ed78e4 100644 (file)
@@ -1,6 +1,8 @@
 use std::cmp::max;
 use std::io;
 use std::io::Read;
+use std::ops::Deref;
+use std::sync::atomic::Ordering;
 
 use bitcoin::BlockHash;
 use bitcoin::secp256k1::PublicKey;
@@ -8,10 +10,11 @@ use bitcoin::secp256k1::PublicKey;
 use lightning::ln::msgs::{
        DecodeError, ErrorAction, LightningError, OptionalField, UnsignedChannelUpdate,
 };
-use lightning::routing::network_graph;
+use lightning::routing::network_graph::NetworkGraph;
 use lightning::util::ser::{BigSize, Readable};
 
 use crate::error::GraphSyncError;
+use crate::RapidGossipSync;
 
 /// The purpose of this prefix is to identify the serialization format, should other rapid gossip
 /// sync formats arise in the future.
@@ -23,203 +26,207 @@ const GOSSIP_PREFIX: [u8; 4] = [76, 68, 75, 1];
 /// avoid malicious updates being able to trigger excessive memory allocation.
 const MAX_INITIAL_NODE_ID_VECTOR_CAPACITY: u32 = 50_000;
 
-/// Update network graph from binary data.
-/// Returns the last sync timestamp to be used the next time rapid sync data is queried.
-///
-/// `network_graph`: network graph to be updated
-///
-/// `update_data`: `&[u8]` binary stream that comprises the update data
-pub fn update_network_graph(
-       network_graph: &network_graph::NetworkGraph,
-       update_data: &[u8],
-) -> Result<u32, GraphSyncError> {
-       let mut read_cursor = io::Cursor::new(update_data);
-       update_network_graph_from_byte_stream(&network_graph, &mut read_cursor)
-}
-
-pub(crate) fn update_network_graph_from_byte_stream<R: Read>(
-       network_graph: &network_graph::NetworkGraph,
-       mut read_cursor: &mut R,
-) -> Result<u32, GraphSyncError> {
-       let mut prefix = [0u8; 4];
-       read_cursor.read_exact(&mut prefix)?;
-
-       match prefix {
-               GOSSIP_PREFIX => {},
-               _ => {
-                       return Err(DecodeError::UnknownVersion.into());
-               }
-       };
-
-       let chain_hash: BlockHash = Readable::read(read_cursor)?;
-       let latest_seen_timestamp: u32 = Readable::read(read_cursor)?;
-       // backdate the applied timestamp by a week
-       let backdated_timestamp = latest_seen_timestamp.saturating_sub(24 * 3600 * 7);
-
-       let node_id_count: u32 = Readable::read(read_cursor)?;
-       let mut node_ids: Vec<PublicKey> = Vec::with_capacity(std::cmp::min(
-               node_id_count,
-               MAX_INITIAL_NODE_ID_VECTOR_CAPACITY,
-       ) as usize);
-       for _ in 0..node_id_count {
-               let current_node_id = Readable::read(read_cursor)?;
-               node_ids.push(current_node_id);
+impl<NG: Deref<Target=NetworkGraph>> RapidGossipSync<NG> {
+       /// Update network graph from binary data.
+       /// Returns the last sync timestamp to be used the next time rapid sync data is queried.
+       ///
+       /// `network_graph`: network graph to be updated
+       ///
+       /// `update_data`: `&[u8]` binary stream that comprises the update data
+       pub fn update_network_graph(&self, update_data: &[u8]) -> Result<u32, GraphSyncError> {
+               let mut read_cursor = io::Cursor::new(update_data);
+               self.update_network_graph_from_byte_stream(&mut read_cursor)
        }
 
-       let mut previous_scid: u64 = 0;
-       let announcement_count: u32 = Readable::read(read_cursor)?;
-       for _ in 0..announcement_count {
-               let features = Readable::read(read_cursor)?;
-
-               // handle SCID
-               let scid_delta: BigSize = Readable::read(read_cursor)?;
-               let short_channel_id = previous_scid
-                       .checked_add(scid_delta.0)
-                       .ok_or(DecodeError::InvalidValue)?;
-               previous_scid = short_channel_id;
-
-               let node_id_1_index: BigSize = Readable::read(read_cursor)?;
-               let node_id_2_index: BigSize = Readable::read(read_cursor)?;
-               if max(node_id_1_index.0, node_id_2_index.0) >= node_id_count as u64 {
-                       return Err(DecodeError::InvalidValue.into());
-               };
-               let node_id_1 = node_ids[node_id_1_index.0 as usize];
-               let node_id_2 = node_ids[node_id_2_index.0 as usize];
-
-               let announcement_result = network_graph.add_channel_from_partial_announcement(
-                       short_channel_id,
-                       backdated_timestamp as u64,
-                       features,
-                       node_id_1,
-                       node_id_2,
-               );
-               if let Err(lightning_error) = announcement_result {
-                       if let ErrorAction::IgnoreDuplicateGossip = lightning_error.action {
-                               // everything is fine, just a duplicate channel announcement
-                       } else {
-                               return Err(lightning_error.into());
+
+       pub(crate) fn update_network_graph_from_byte_stream<R: Read>(
+               &self,
+               mut read_cursor: &mut R,
+       ) -> Result<u32, GraphSyncError> {
+               let mut prefix = [0u8; 4];
+               read_cursor.read_exact(&mut prefix)?;
+
+               match prefix {
+                       GOSSIP_PREFIX => {}
+                       _ => {
+                               return Err(DecodeError::UnknownVersion.into());
                        }
+               };
+
+               let chain_hash: BlockHash = Readable::read(read_cursor)?;
+               let latest_seen_timestamp: u32 = Readable::read(read_cursor)?;
+               // backdate the applied timestamp by a week
+               let backdated_timestamp = latest_seen_timestamp.saturating_sub(24 * 3600 * 7);
+
+               let node_id_count: u32 = Readable::read(read_cursor)?;
+               let mut node_ids: Vec<PublicKey> = Vec::with_capacity(std::cmp::min(
+                       node_id_count,
+                       MAX_INITIAL_NODE_ID_VECTOR_CAPACITY,
+               ) as usize);
+               for _ in 0..node_id_count {
+                       let current_node_id = Readable::read(read_cursor)?;
+                       node_ids.push(current_node_id);
                }
-       }
 
-       previous_scid = 0; // updates start at a new scid
+               let network_graph = &self.network_graph;
 
-       let update_count: u32 = Readable::read(read_cursor)?;
-       if update_count == 0 {
-               return Ok(latest_seen_timestamp);
-       }
+               let mut previous_scid: u64 = 0;
+               let announcement_count: u32 = Readable::read(read_cursor)?;
+               for _ in 0..announcement_count {
+                       let features = Readable::read(read_cursor)?;
 
-       // obtain default values for non-incremental updates
-       let default_cltv_expiry_delta: u16 = Readable::read(&mut read_cursor)?;
-       let default_htlc_minimum_msat: u64 = Readable::read(&mut read_cursor)?;
-       let default_fee_base_msat: u32 = Readable::read(&mut read_cursor)?;
-       let default_fee_proportional_millionths: u32 = Readable::read(&mut read_cursor)?;
-       let tentative_default_htlc_maximum_msat: u64 = Readable::read(&mut read_cursor)?;
-       let default_htlc_maximum_msat = if tentative_default_htlc_maximum_msat == u64::max_value() {
-               OptionalField::Absent
-       } else {
-               OptionalField::Present(tentative_default_htlc_maximum_msat)
-       };
-
-       for _ in 0..update_count {
-               let scid_delta: BigSize = Readable::read(read_cursor)?;
-               let short_channel_id = previous_scid
-                       .checked_add(scid_delta.0)
-                       .ok_or(DecodeError::InvalidValue)?;
-               previous_scid = short_channel_id;
-
-               let channel_flags: u8 = Readable::read(read_cursor)?;
-
-               // flags are always sent in full, and hence always need updating
-               let standard_channel_flags = channel_flags & 0b_0000_0011;
-
-               let mut synthetic_update = if channel_flags & 0b_1000_0000 == 0 {
-                       // full update, field flags will indicate deviations from the default
-                       UnsignedChannelUpdate {
-                               chain_hash,
-                               short_channel_id,
-                               timestamp: backdated_timestamp,
-                               flags: standard_channel_flags,
-                               cltv_expiry_delta: default_cltv_expiry_delta,
-                               htlc_minimum_msat: default_htlc_minimum_msat,
-                               htlc_maximum_msat: default_htlc_maximum_msat.clone(),
-                               fee_base_msat: default_fee_base_msat,
-                               fee_proportional_millionths: default_fee_proportional_millionths,
-                               excess_data: vec![],
-                       }
-               } else {
-                       // incremental update, field flags will indicate mutated values
-                       let read_only_network_graph = network_graph.read_only();
-                       let channel = read_only_network_graph
-                               .channels()
-                               .get(&short_channel_id)
-                               .ok_or(LightningError {
-                                       err: "Couldn't find channel for update".to_owned(),
-                                       action: ErrorAction::IgnoreError,
-                               })?;
-
-                       let directional_info = channel
-                               .get_directional_info(channel_flags)
-                               .ok_or(LightningError {
-                                       err: "Couldn't find previous directional data for update".to_owned(),
-                                       action: ErrorAction::IgnoreError,
-                               })?;
-
-                       let htlc_maximum_msat =
-                               if let Some(htlc_maximum_msat) = directional_info.htlc_maximum_msat {
-                                       OptionalField::Present(htlc_maximum_msat)
-                               } else {
-                                       OptionalField::Absent
-                               };
+                       // handle SCID
+                       let scid_delta: BigSize = Readable::read(read_cursor)?;
+                       let short_channel_id = previous_scid
+                               .checked_add(scid_delta.0)
+                               .ok_or(DecodeError::InvalidValue)?;
+                       previous_scid = short_channel_id;
+
+                       let node_id_1_index: BigSize = Readable::read(read_cursor)?;
+                       let node_id_2_index: BigSize = Readable::read(read_cursor)?;
+                       if max(node_id_1_index.0, node_id_2_index.0) >= node_id_count as u64 {
+                               return Err(DecodeError::InvalidValue.into());
+                       };
+                       let node_id_1 = node_ids[node_id_1_index.0 as usize];
+                       let node_id_2 = node_ids[node_id_2_index.0 as usize];
 
-                       UnsignedChannelUpdate {
-                               chain_hash,
+                       let announcement_result = network_graph.add_channel_from_partial_announcement(
                                short_channel_id,
-                               timestamp: backdated_timestamp,
-                               flags: standard_channel_flags,
-                               cltv_expiry_delta: directional_info.cltv_expiry_delta,
-                               htlc_minimum_msat: directional_info.htlc_minimum_msat,
-                               htlc_maximum_msat,
-                               fee_base_msat: directional_info.fees.base_msat,
-                               fee_proportional_millionths: directional_info.fees.proportional_millionths,
-                               excess_data: vec![],
+                               backdated_timestamp as u64,
+                               features,
+                               node_id_1,
+                               node_id_2,
+                       );
+                       if let Err(lightning_error) = announcement_result {
+                               if let ErrorAction::IgnoreDuplicateGossip = lightning_error.action {
+                                       // everything is fine, just a duplicate channel announcement
+                               } else {
+                                       return Err(lightning_error.into());
+                               }
                        }
-               };
-
-               if channel_flags & 0b_0100_0000 > 0 {
-                       let cltv_expiry_delta: u16 = Readable::read(read_cursor)?;
-                       synthetic_update.cltv_expiry_delta = cltv_expiry_delta;
                }
 
-               if channel_flags & 0b_0010_0000 > 0 {
-                       let htlc_minimum_msat: u64 = Readable::read(read_cursor)?;
-                       synthetic_update.htlc_minimum_msat = htlc_minimum_msat;
-               }
+               previous_scid = 0; // updates start at a new scid
 
-               if channel_flags & 0b_0001_0000 > 0 {
-                       let fee_base_msat: u32 = Readable::read(read_cursor)?;
-                       synthetic_update.fee_base_msat = fee_base_msat;
+               let update_count: u32 = Readable::read(read_cursor)?;
+               if update_count == 0 {
+                       return Ok(latest_seen_timestamp);
                }
 
-               if channel_flags & 0b_0000_1000 > 0 {
-                       let fee_proportional_millionths: u32 = Readable::read(read_cursor)?;
-                       synthetic_update.fee_proportional_millionths = fee_proportional_millionths;
-               }
+               // obtain default values for non-incremental updates
+               let default_cltv_expiry_delta: u16 = Readable::read(&mut read_cursor)?;
+               let default_htlc_minimum_msat: u64 = Readable::read(&mut read_cursor)?;
+               let default_fee_base_msat: u32 = Readable::read(&mut read_cursor)?;
+               let default_fee_proportional_millionths: u32 = Readable::read(&mut read_cursor)?;
+               let tentative_default_htlc_maximum_msat: u64 = Readable::read(&mut read_cursor)?;
+               let default_htlc_maximum_msat = if tentative_default_htlc_maximum_msat == u64::max_value() {
+                       OptionalField::Absent
+               } else {
+                       OptionalField::Present(tentative_default_htlc_maximum_msat)
+               };
 
-               if channel_flags & 0b_0000_0100 > 0 {
-                       let tentative_htlc_maximum_msat: u64 = Readable::read(read_cursor)?;
-                       synthetic_update.htlc_maximum_msat = if tentative_htlc_maximum_msat == u64::max_value()
-                       {
-                               OptionalField::Absent
+               for _ in 0..update_count {
+                       let scid_delta: BigSize = Readable::read(read_cursor)?;
+                       let short_channel_id = previous_scid
+                               .checked_add(scid_delta.0)
+                               .ok_or(DecodeError::InvalidValue)?;
+                       previous_scid = short_channel_id;
+
+                       let channel_flags: u8 = Readable::read(read_cursor)?;
+
+                       // flags are always sent in full, and hence always need updating
+                       let standard_channel_flags = channel_flags & 0b_0000_0011;
+
+                       let mut synthetic_update = if channel_flags & 0b_1000_0000 == 0 {
+                               // full update, field flags will indicate deviations from the default
+                               UnsignedChannelUpdate {
+                                       chain_hash,
+                                       short_channel_id,
+                                       timestamp: backdated_timestamp,
+                                       flags: standard_channel_flags,
+                                       cltv_expiry_delta: default_cltv_expiry_delta,
+                                       htlc_minimum_msat: default_htlc_minimum_msat,
+                                       htlc_maximum_msat: default_htlc_maximum_msat.clone(),
+                                       fee_base_msat: default_fee_base_msat,
+                                       fee_proportional_millionths: default_fee_proportional_millionths,
+                                       excess_data: vec![],
+                               }
                        } else {
-                               OptionalField::Present(tentative_htlc_maximum_msat)
+                               // incremental update, field flags will indicate mutated values
+                               let read_only_network_graph = network_graph.read_only();
+                               let channel = read_only_network_graph
+                                       .channels()
+                                       .get(&short_channel_id)
+                                       .ok_or(LightningError {
+                                               err: "Couldn't find channel for update".to_owned(),
+                                               action: ErrorAction::IgnoreError,
+                                       })?;
+
+                               let directional_info = channel
+                                       .get_directional_info(channel_flags)
+                                       .ok_or(LightningError {
+                                               err: "Couldn't find previous directional data for update".to_owned(),
+                                               action: ErrorAction::IgnoreError,
+                                       })?;
+
+                               let htlc_maximum_msat =
+                                       if let Some(htlc_maximum_msat) = directional_info.htlc_maximum_msat {
+                                               OptionalField::Present(htlc_maximum_msat)
+                                       } else {
+                                               OptionalField::Absent
+                                       };
+
+                               UnsignedChannelUpdate {
+                                       chain_hash,
+                                       short_channel_id,
+                                       timestamp: backdated_timestamp,
+                                       flags: standard_channel_flags,
+                                       cltv_expiry_delta: directional_info.cltv_expiry_delta,
+                                       htlc_minimum_msat: directional_info.htlc_minimum_msat,
+                                       htlc_maximum_msat,
+                                       fee_base_msat: directional_info.fees.base_msat,
+                                       fee_proportional_millionths: directional_info.fees.proportional_millionths,
+                                       excess_data: vec![],
+                               }
                        };
+
+                       if channel_flags & 0b_0100_0000 > 0 {
+                               let cltv_expiry_delta: u16 = Readable::read(read_cursor)?;
+                               synthetic_update.cltv_expiry_delta = cltv_expiry_delta;
+                       }
+
+                       if channel_flags & 0b_0010_0000 > 0 {
+                               let htlc_minimum_msat: u64 = Readable::read(read_cursor)?;
+                               synthetic_update.htlc_minimum_msat = htlc_minimum_msat;
+                       }
+
+                       if channel_flags & 0b_0001_0000 > 0 {
+                               let fee_base_msat: u32 = Readable::read(read_cursor)?;
+                               synthetic_update.fee_base_msat = fee_base_msat;
+                       }
+
+                       if channel_flags & 0b_0000_1000 > 0 {
+                               let fee_proportional_millionths: u32 = Readable::read(read_cursor)?;
+                               synthetic_update.fee_proportional_millionths = fee_proportional_millionths;
+                       }
+
+                       if channel_flags & 0b_0000_0100 > 0 {
+                               let tentative_htlc_maximum_msat: u64 = Readable::read(read_cursor)?;
+                               synthetic_update.htlc_maximum_msat = if tentative_htlc_maximum_msat == u64::max_value()
+                               {
+                                       OptionalField::Absent
+                               } else {
+                                       OptionalField::Present(tentative_htlc_maximum_msat)
+                               };
+                       }
+
+                       network_graph.update_channel_unsigned(&synthetic_update)?;
                }
 
-               network_graph.update_channel_unsigned(&synthetic_update)?;
+               self.network_graph.set_last_rapid_gossip_sync_timestamp(latest_seen_timestamp);
+               self.is_initial_sync_complete.store(true, Ordering::Release);
+               Ok(latest_seen_timestamp)
        }
-
-       Ok(latest_seen_timestamp)
 }
 
 #[cfg(test)]
@@ -231,7 +238,7 @@ mod tests {
        use lightning::routing::network_graph::NetworkGraph;
 
        use crate::error::GraphSyncError;
-       use crate::processing::update_network_graph;
+       use crate::RapidGossipSync;
 
        #[test]
        fn network_graph_fails_to_update_from_clipped_input() {
@@ -254,7 +261,8 @@ mod tests {
                        0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 58, 85, 116, 216, 255, 2, 68, 226, 0, 6, 11, 0, 1, 24, 0,
                        0, 3, 232, 0, 0, 0,
                ];
-               let update_result = update_network_graph(&network_graph, &example_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let update_result = rapid_sync.update_network_graph(&example_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::DecodeError(DecodeError::ShortRead)) = update_result {
                        // this is the expected error type
@@ -278,7 +286,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let update_result = update_network_graph(&network_graph, &incremental_update_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let update_result = rapid_sync.update_network_graph(&incremental_update_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
                        assert_eq!(lightning_error.err, "Couldn't find channel for update");
@@ -310,7 +319,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let update_result = update_network_graph(&network_graph, &announced_update_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let update_result = rapid_sync.update_network_graph(&announced_update_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
                        assert_eq!(
@@ -345,7 +355,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let initialization_result = update_network_graph(&network_graph, &initialization_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                if initialization_result.is_err() {
                        panic!(
                                "Unexpected initialization result: {:?}",
@@ -373,10 +384,7 @@ mod tests {
                        0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2,
                        68, 226, 0, 6, 11, 0, 1, 128,
                ];
-               let update_result = update_network_graph(
-                       &network_graph,
-                       &opposite_direction_incremental_update_input[..],
-               );
+               let update_result = rapid_sync.update_network_graph(&opposite_direction_incremental_update_input[..]);
                assert!(update_result.is_err());
                if let Err(GraphSyncError::LightningError(lightning_error)) = update_result {
                        assert_eq!(
@@ -413,7 +421,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let initialization_result = update_network_graph(&network_graph, &initialization_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let initialization_result = rapid_sync.update_network_graph(&initialization_input[..]);
                assert!(initialization_result.is_ok());
 
                let single_direction_incremental_update_input = vec![
@@ -423,10 +432,7 @@ mod tests {
                        0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 8, 153, 192, 0, 2, 27, 0, 0, 136, 0, 0, 0, 221, 255, 2,
                        68, 226, 0, 6, 11, 0, 1, 128,
                ];
-               let update_result = update_network_graph(
-                       &network_graph,
-                       &single_direction_incremental_update_input[..],
-               );
+               let update_result = rapid_sync.update_network_graph(&single_direction_incremental_update_input[..]);
                if update_result.is_err() {
                        panic!("Unexpected update result: {:?}", update_result)
                }
@@ -474,7 +480,8 @@ mod tests {
 
                assert_eq!(network_graph.read_only().channels().len(), 0);
 
-               let update_result = update_network_graph(&network_graph, &valid_input[..]);
+               let rapid_sync = RapidGossipSync::new(&network_graph);
+               let update_result = rapid_sync.update_network_graph(&valid_input[..]);
                if update_result.is_err() {
                        panic!("Unexpected update result: {:?}", update_result)
                }