From 777661ae520c9ca969e6359bff05e561011eb336 Mon Sep 17 00:00:00 2001 From: Jeffrey Czyz Date: Mon, 9 Aug 2021 22:24:41 -0500 Subject: [PATCH] Individually lock NetworkGraph fields In preparation for giving NetworkGraph shared ownership, wrap individual fields in RwLock. This allows removing the outer RwLock used in NetGraphMsgHandler. --- lightning/src/routing/network_graph.rs | 99 ++++++++++++++++---------- lightning/src/routing/router.rs | 40 ++++++----- 2 files changed, 84 insertions(+), 55 deletions(-) diff --git a/lightning/src/routing/network_graph.rs b/lightning/src/routing/network_graph.rs index 8accdab60..3b593351f 100644 --- a/lightning/src/routing/network_graph.rs +++ b/lightning/src/routing/network_graph.rs @@ -51,11 +51,11 @@ const MAX_EXCESS_BYTES_FOR_RELAY: usize = 1024; const MAX_SCIDS_PER_REPLY: usize = 8000; /// Represents the network as nodes and channels between them -#[derive(Clone, PartialEq)] pub struct NetworkGraph { genesis_hash: BlockHash, - channels: BTreeMap, - nodes: BTreeMap, + // Lock order: channels -> nodes + channels: RwLock>, + nodes: RwLock>, } /// A simple newtype for RwLockReadGuard<'a, NetworkGraph>. @@ -193,7 +193,8 @@ impl RoutingMessageHandler for NetGraphMsgHandler wh fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(ChannelAnnouncement, Option, Option)> { let network_graph = self.network_graph.read().unwrap(); let mut result = Vec::with_capacity(batch_amount as usize); - let mut iter = network_graph.get_channels().range(starting_point..); + let channels = network_graph.get_channels(); + let mut iter = channels.range(starting_point..); while result.len() < batch_amount as usize { if let Some((_, ref chan)) = iter.next() { if chan.announcement_message.is_some() { @@ -221,12 +222,13 @@ impl RoutingMessageHandler for NetGraphMsgHandler wh fn get_next_node_announcements(&self, starting_point: Option<&PublicKey>, batch_amount: u8) -> Vec { let network_graph = self.network_graph.read().unwrap(); let mut result = Vec::with_capacity(batch_amount as usize); + let nodes = network_graph.get_nodes(); let mut iter = if let Some(pubkey) = starting_point { - let mut iter = network_graph.get_nodes().range((*pubkey)..); + let mut iter = nodes.range((*pubkey)..); iter.next(); iter } else { - network_graph.get_nodes().range(..) + nodes.range(..) }; while result.len() < batch_amount as usize { if let Some((_, ref node)) = iter.next() { @@ -616,13 +618,15 @@ impl Writeable for NetworkGraph { write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION); self.genesis_hash.write(writer)?; - (self.channels.len() as u64).write(writer)?; - for (ref chan_id, ref chan_info) in self.channels.iter() { + let channels = self.channels.read().unwrap(); + (channels.len() as u64).write(writer)?; + for (ref chan_id, ref chan_info) in channels.iter() { (*chan_id).write(writer)?; chan_info.write(writer)?; } - (self.nodes.len() as u64).write(writer)?; - for (ref node_id, ref node_info) in self.nodes.iter() { + let nodes = self.nodes.read().unwrap(); + (nodes.len() as u64).write(writer)?; + for (ref node_id, ref node_info) in nodes.iter() { node_id.write(writer)?; node_info.write(writer)?; } @@ -655,8 +659,8 @@ impl Readable for NetworkGraph { Ok(NetworkGraph { genesis_hash, - channels, - nodes, + channels: RwLock::new(channels), + nodes: RwLock::new(nodes), }) } } @@ -664,36 +668,49 @@ impl Readable for NetworkGraph { impl fmt::Display for NetworkGraph { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { writeln!(f, "Network map\n[Channels]")?; - for (key, val) in self.channels.iter() { + for (key, val) in self.channels.read().unwrap().iter() { writeln!(f, " {}: {}", key, val)?; } writeln!(f, "[Nodes]")?; - for (key, val) in self.nodes.iter() { + for (key, val) in self.nodes.read().unwrap().iter() { writeln!(f, " {}: {}", log_pubkey!(key), val)?; } Ok(()) } } +impl PartialEq for NetworkGraph { + fn eq(&self, other: &Self) -> bool { + self.genesis_hash == other.genesis_hash && + *self.channels.read().unwrap() == *other.channels.read().unwrap() && + *self.nodes.read().unwrap() == *other.nodes.read().unwrap() + } +} + impl NetworkGraph { /// Returns all known valid channels' short ids along with announced channel info. /// /// (C-not exported) because we have no mapping for `BTreeMap`s - pub fn get_channels<'a>(&'a self) -> &'a BTreeMap { &self.channels } + pub fn get_channels(&self) -> RwLockReadGuard<'_, BTreeMap> { + self.channels.read().unwrap() + } + /// Returns all known nodes' public keys along with announced node info. /// /// (C-not exported) because we have no mapping for `BTreeMap`s - pub fn get_nodes<'a>(&'a self) -> &'a BTreeMap { &self.nodes } + pub fn get_nodes(&self) -> RwLockReadGuard<'_, BTreeMap> { + self.nodes.read().unwrap() + } /// Get network addresses by node id. /// Returns None if the requested node is completely unknown, /// or if node announcement for the node was never received. /// /// (C-not exported) as there is no practical way to track lifetimes of returned values. - pub fn get_addresses<'a>(&'a self, pubkey: &PublicKey) -> Option<&'a Vec> { - if let Some(node) = self.nodes.get(pubkey) { + pub fn get_addresses(&self, pubkey: &PublicKey) -> Option> { + if let Some(node) = self.nodes.read().unwrap().get(pubkey) { if let Some(node_info) = node.announcement_info.as_ref() { - return Some(&node_info.addresses) + return Some(node_info.addresses.clone()) } } None @@ -703,8 +720,8 @@ impl NetworkGraph { pub fn new(genesis_hash: BlockHash) -> NetworkGraph { Self { genesis_hash, - channels: BTreeMap::new(), - nodes: BTreeMap::new(), + channels: RwLock::new(BTreeMap::new()), + nodes: RwLock::new(BTreeMap::new()), } } @@ -729,7 +746,7 @@ impl NetworkGraph { } fn update_node_from_announcement_intern(&mut self, msg: &msgs::UnsignedNodeAnnouncement, full_msg: Option<&msgs::NodeAnnouncement>) -> Result<(), LightningError> { - match self.nodes.get_mut(&msg.node_id) { + match self.nodes.write().unwrap().get_mut(&msg.node_id) { None => Err(LightningError{err: "No existing channels for node_announcement".to_owned(), action: ErrorAction::IgnoreError}), Some(node) => { if let Some(node_info) = node.announcement_info.as_ref() { @@ -838,7 +855,9 @@ impl NetworkGraph { { full_msg.cloned() } else { None }, }; - match self.channels.entry(msg.short_channel_id) { + let mut channels = self.channels.write().unwrap(); + let mut nodes = self.nodes.write().unwrap(); + match channels.entry(msg.short_channel_id) { BtreeEntry::Occupied(mut entry) => { //TODO: because asking the blockchain if short_channel_id is valid is only optional //in the blockchain API, we need to handle it smartly here, though it's unclear @@ -852,7 +871,7 @@ impl NetworkGraph { // b) we don't track UTXOs of channels we know about and remove them if they // get reorg'd out. // c) it's unclear how to do so without exposing ourselves to massive DoS risk. - Self::remove_channel_in_nodes(&mut self.nodes, &entry.get(), msg.short_channel_id); + Self::remove_channel_in_nodes(&mut nodes, &entry.get(), msg.short_channel_id); *entry.get_mut() = chan_info; } else { return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreAndLog(Level::Trace)}) @@ -865,7 +884,7 @@ impl NetworkGraph { macro_rules! add_channel_to_node { ( $node_id: expr ) => { - match self.nodes.entry($node_id) { + match nodes.entry($node_id) { BtreeEntry::Occupied(node_entry) => { node_entry.into_mut().channels.push(msg.short_channel_id); }, @@ -891,12 +910,14 @@ impl NetworkGraph { /// May cause the removal of nodes too, if this was their last channel. /// If not permanent, makes channels unavailable for routing. pub fn close_channel_from_update(&mut self, short_channel_id: u64, is_permanent: bool) { + let mut channels = self.channels.write().unwrap(); if is_permanent { - if let Some(chan) = self.channels.remove(&short_channel_id) { - Self::remove_channel_in_nodes(&mut self.nodes, &chan, short_channel_id); + if let Some(chan) = channels.remove(&short_channel_id) { + let mut nodes = self.nodes.write().unwrap(); + Self::remove_channel_in_nodes(&mut nodes, &chan, short_channel_id); } } else { - if let Some(chan) = self.channels.get_mut(&short_channel_id) { + if let Some(chan) = channels.get_mut(&short_channel_id) { if let Some(one_to_two) = chan.one_to_two.as_mut() { one_to_two.enabled = false; } @@ -937,7 +958,8 @@ impl NetworkGraph { let chan_enabled = msg.flags & (1 << 1) != (1 << 1); let chan_was_enabled; - match self.channels.get_mut(&msg.short_channel_id) { + let mut channels = self.channels.write().unwrap(); + match channels.get_mut(&msg.short_channel_id) { None => return Err(LightningError{err: "Couldn't find channel for update".to_owned(), action: ErrorAction::IgnoreError}), Some(channel) => { if let OptionalField::Present(htlc_maximum_msat) = msg.htlc_maximum_msat { @@ -1000,8 +1022,9 @@ impl NetworkGraph { } } + let mut nodes = self.nodes.write().unwrap(); if chan_enabled { - let node = self.nodes.get_mut(&dest_node_id).unwrap(); + let node = nodes.get_mut(&dest_node_id).unwrap(); let mut base_msat = msg.fee_base_msat; let mut proportional_millionths = msg.fee_proportional_millionths; if let Some(fees) = node.lowest_inbound_channel_fees { @@ -1013,11 +1036,11 @@ impl NetworkGraph { proportional_millionths }); } else if chan_was_enabled { - let node = self.nodes.get_mut(&dest_node_id).unwrap(); + let node = nodes.get_mut(&dest_node_id).unwrap(); let mut lowest_inbound_channel_fees = None; for chan_id in node.channels.iter() { - let chan = self.channels.get(chan_id).unwrap(); + let chan = channels.get(chan_id).unwrap(); let chan_info_opt; if chan.node_one == dest_node_id { chan_info_opt = chan.two_to_one.as_ref(); @@ -1268,7 +1291,7 @@ mod tests { match network.get_channels().get(&unsigned_announcement.short_channel_id) { None => panic!(), Some(_) => () - } + }; } // If we receive announcement for the same channel (with UTXO lookups disabled), @@ -1320,7 +1343,7 @@ mod tests { match network.get_channels().get(&unsigned_announcement.short_channel_id) { None => panic!(), Some(_) => () - } + }; } // If we receive announcement for the same channel (but TX is not confirmed), @@ -1353,7 +1376,7 @@ mod tests { assert_eq!(channel_entry.features, ChannelFeatures::empty()); }, _ => panic!() - } + }; } // Don't relay valid channels with excess data @@ -1484,7 +1507,7 @@ mod tests { assert_eq!(channel_info.one_to_two.as_ref().unwrap().cltv_expiry_delta, 144); assert!(channel_info.two_to_one.is_none()); } - } + }; } unsigned_channel_update.timestamp += 100; @@ -1645,7 +1668,7 @@ mod tests { Some(channel_info) => { assert!(channel_info.one_to_two.is_some()); } - } + }; } let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed { @@ -1663,7 +1686,7 @@ mod tests { Some(channel_info) => { assert!(!channel_info.one_to_two.as_ref().unwrap().enabled); } - } + }; } let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed { diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 5030f6aaa..3df942d73 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -443,6 +443,8 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // to use as the A* heuristic beyond just the cost to get one node further than the current // one. + let network_channels = network.get_channels(); + let network_nodes = network.get_nodes(); let dummy_directional_info = DummyDirectionalChannelInfo { // used for first_hops routes cltv_expiry_delta: 0, htlc_minimum_msat: 0, @@ -458,7 +460,7 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // work reliably. let allow_mpp = if let Some(features) = &payee_features { features.supports_basic_mpp() - } else if let Some(node) = network.get_nodes().get(&payee) { + } else if let Some(node) = network_nodes.get(&payee) { if let Some(node_info) = node.announcement_info.as_ref() { node_info.features.supports_basic_mpp() } else { false } @@ -492,7 +494,7 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // Map from node_id to information about the best current path to that node, including feerate // information. - let mut dist = HashMap::with_capacity(network.get_nodes().len()); + let mut dist = HashMap::with_capacity(network_nodes.len()); // During routing, if we ignore a path due to an htlc_minimum_msat limit, we set this, // indicating that we may wish to try again with a higher value, potentially paying to meet an @@ -511,7 +513,7 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // This map allows paths to be aware of the channel use by other paths in the same call. // This would help to make a better path finding decisions and not "overbook" channels. // It is unaware of the directions (except for `outbound_capacity_msat` in `first_hops`). - let mut bookkeeped_channels_liquidity_available_msat = HashMap::with_capacity(network.get_nodes().len()); + let mut bookkeeped_channels_liquidity_available_msat = HashMap::with_capacity(network_nodes.len()); // Keeping track of how much value we already collected across other paths. Helps to decide: // - how much a new path should be transferring (upper bound); @@ -629,7 +631,7 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // as a way to reach the $dest_node_id. let mut fee_base_msat = u32::max_value(); let mut fee_proportional_millionths = u32::max_value(); - if let Some(Some(fees)) = network.get_nodes().get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) { + if let Some(Some(fees)) = network_nodes.get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) { fee_base_msat = fees.base_msat; fee_proportional_millionths = fees.proportional_millionths; } @@ -814,7 +816,7 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye if !features.requires_unknown_bits() { for chan_id in $node.channels.iter() { - let chan = network.get_channels().get(chan_id).unwrap(); + let chan = network_channels.get(chan_id).unwrap(); if !chan.features.requires_unknown_bits() { if chan.node_one == *$node_id { // ie $node is one, ie next hop in A* is two, via the two_to_one channel @@ -862,7 +864,7 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // Add the payee as a target, so that the payee-to-payer // search algorithm knows what to start with. - match network.get_nodes().get(payee) { + match network_nodes.get(payee) { // The payee is not in our network graph, so nothing to add here. // There is still a chance of reaching them via last_hops though, // so don't yet fail the payment here. @@ -884,7 +886,7 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // we have a direct channel to the first hop or the first hop is // in the regular network graph. first_hop_targets.get(&first_hop_in_route.src_node_id).is_some() || - network.get_nodes().get(&first_hop_in_route.src_node_id).is_some(); + network_nodes.get(&first_hop_in_route.src_node_id).is_some(); if have_hop_src_in_graph { // We start building the path from reverse, i.e., from payee // to the first RouteHintHop in the path. @@ -991,7 +993,7 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye 'path_walk: loop { if let Some(&(_, _, _, ref features)) = first_hop_targets.get(&ordered_hops.last().unwrap().0.pubkey) { ordered_hops.last_mut().unwrap().1 = features.clone(); - } else if let Some(node) = network.get_nodes().get(&ordered_hops.last().unwrap().0.pubkey) { + } else if let Some(node) = network_nodes.get(&ordered_hops.last().unwrap().0.pubkey) { if let Some(node_info) = node.announcement_info.as_ref() { ordered_hops.last_mut().unwrap().1 = node_info.features.clone(); } else { @@ -1093,7 +1095,7 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // Otherwise, since the current target node is not us, // keep "unrolling" the payment graph from payee to payer by // finding a way to reach the current target from the payer side. - match network.get_nodes().get(&pubkey) { + match network_nodes.get(&pubkey) { None => {}, Some(node) => { add_entries_to_cheapest_to_target_node!(node, &pubkey, lowest_fee_to_node, value_contribution_msat, path_htlc_minimum_msat); @@ -4211,12 +4213,13 @@ mod tests { // First, get 100 (source, destination) pairs for which route-getting actually succeeds... let mut seed = random_init_seed() as usize; + let nodes = graph.get_nodes(); 'load_endpoints: for _ in 0..10 { loop { seed = seed.overflowing_mul(0xdeadbeef).0; - let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap(); + let src = nodes.keys().skip(seed % nodes.len()).next().unwrap(); seed = seed.overflowing_mul(0xdeadbeef).0; - let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap(); + let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap(); let amt = seed as u64 % 200_000_000; if get_route(src, &graph, dst, None, None, &[], amt, 42, &test_utils::TestLogger::new()).is_ok() { continue 'load_endpoints; @@ -4239,12 +4242,13 @@ mod tests { // First, get 100 (source, destination) pairs for which route-getting actually succeeds... let mut seed = random_init_seed() as usize; + let nodes = graph.get_nodes(); 'load_endpoints: for _ in 0..10 { loop { seed = seed.overflowing_mul(0xdeadbeef).0; - let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap(); + let src = nodes.keys().skip(seed % nodes.len()).next().unwrap(); seed = seed.overflowing_mul(0xdeadbeef).0; - let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap(); + let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap(); let amt = seed as u64 % 200_000_000; if get_route(src, &graph, dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &test_utils::TestLogger::new()).is_ok() { continue 'load_endpoints; @@ -4297,6 +4301,7 @@ mod benches { fn generate_routes(bench: &mut Bencher) { let mut d = test_utils::get_route_file().unwrap(); let graph = NetworkGraph::read(&mut d).unwrap(); + let nodes = graph.get_nodes(); // First, get 100 (source, destination) pairs for which route-getting actually succeeds... let mut path_endpoints = Vec::new(); @@ -4304,9 +4309,9 @@ mod benches { 'load_endpoints: for _ in 0..100 { loop { seed *= 0xdeadbeef; - let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap(); + let src = nodes.keys().skip(seed % nodes.len()).next().unwrap(); seed *= 0xdeadbeef; - let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap(); + let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap(); let amt = seed as u64 % 1_000_000; if get_route(src, &graph, dst, None, None, &[], amt, 42, &DummyLogger{}).is_ok() { path_endpoints.push((src, dst, amt)); @@ -4328,6 +4333,7 @@ mod benches { fn generate_mpp_routes(bench: &mut Bencher) { let mut d = test_utils::get_route_file().unwrap(); let graph = NetworkGraph::read(&mut d).unwrap(); + let nodes = graph.get_nodes(); // First, get 100 (source, destination) pairs for which route-getting actually succeeds... let mut path_endpoints = Vec::new(); @@ -4335,9 +4341,9 @@ mod benches { 'load_endpoints: for _ in 0..100 { loop { seed *= 0xdeadbeef; - let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap(); + let src = nodes.keys().skip(seed % nodes.len()).next().unwrap(); seed *= 0xdeadbeef; - let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap(); + let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap(); let amt = seed as u64 % 1_000_000; if get_route(src, &graph, dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &DummyLogger{}).is_ok() { path_endpoints.push((src, dst, amt)); -- 2.39.5