const MAX_SCIDS_PER_REPLY: usize = 8000;
/// Represents the network as nodes and channels between them
-#[derive(Clone, PartialEq)]
pub struct NetworkGraph {
genesis_hash: BlockHash,
- channels: BTreeMap<u64, ChannelInfo>,
- nodes: BTreeMap<PublicKey, NodeInfo>,
+ // Lock order: channels -> nodes
+ channels: RwLock<BTreeMap<u64, ChannelInfo>>,
+ nodes: RwLock<BTreeMap<PublicKey, NodeInfo>>,
}
/// A simple newtype for RwLockReadGuard<'a, NetworkGraph>.
fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> {
let network_graph = self.network_graph.read().unwrap();
let mut result = Vec::with_capacity(batch_amount as usize);
- let mut iter = network_graph.get_channels().range(starting_point..);
+ let channels = network_graph.get_channels();
+ let mut iter = channels.range(starting_point..);
while result.len() < batch_amount as usize {
if let Some((_, ref chan)) = iter.next() {
if chan.announcement_message.is_some() {
fn get_next_node_announcements(&self, starting_point: Option<&PublicKey>, batch_amount: u8) -> Vec<NodeAnnouncement> {
let network_graph = self.network_graph.read().unwrap();
let mut result = Vec::with_capacity(batch_amount as usize);
+ let nodes = network_graph.get_nodes();
let mut iter = if let Some(pubkey) = starting_point {
- let mut iter = network_graph.get_nodes().range((*pubkey)..);
+ let mut iter = nodes.range((*pubkey)..);
iter.next();
iter
} else {
- network_graph.get_nodes().range(..)
+ nodes.range(..)
};
while result.len() < batch_amount as usize {
if let Some((_, ref node)) = iter.next() {
write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
self.genesis_hash.write(writer)?;
- (self.channels.len() as u64).write(writer)?;
- for (ref chan_id, ref chan_info) in self.channels.iter() {
+ let channels = self.channels.read().unwrap();
+ (channels.len() as u64).write(writer)?;
+ for (ref chan_id, ref chan_info) in channels.iter() {
(*chan_id).write(writer)?;
chan_info.write(writer)?;
}
- (self.nodes.len() as u64).write(writer)?;
- for (ref node_id, ref node_info) in self.nodes.iter() {
+ let nodes = self.nodes.read().unwrap();
+ (nodes.len() as u64).write(writer)?;
+ for (ref node_id, ref node_info) in nodes.iter() {
node_id.write(writer)?;
node_info.write(writer)?;
}
Ok(NetworkGraph {
genesis_hash,
- channels,
- nodes,
+ channels: RwLock::new(channels),
+ nodes: RwLock::new(nodes),
})
}
}
impl fmt::Display for NetworkGraph {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
writeln!(f, "Network map\n[Channels]")?;
- for (key, val) in self.channels.iter() {
+ for (key, val) in self.channels.read().unwrap().iter() {
writeln!(f, " {}: {}", key, val)?;
}
writeln!(f, "[Nodes]")?;
- for (key, val) in self.nodes.iter() {
+ for (key, val) in self.nodes.read().unwrap().iter() {
writeln!(f, " {}: {}", log_pubkey!(key), val)?;
}
Ok(())
}
}
+impl PartialEq for NetworkGraph {
+ fn eq(&self, other: &Self) -> bool {
+ self.genesis_hash == other.genesis_hash &&
+ *self.channels.read().unwrap() == *other.channels.read().unwrap() &&
+ *self.nodes.read().unwrap() == *other.nodes.read().unwrap()
+ }
+}
+
impl NetworkGraph {
/// Returns all known valid channels' short ids along with announced channel info.
///
/// (C-not exported) because we have no mapping for `BTreeMap`s
- pub fn get_channels<'a>(&'a self) -> &'a BTreeMap<u64, ChannelInfo> { &self.channels }
+ pub fn get_channels(&self) -> RwLockReadGuard<'_, BTreeMap<u64, ChannelInfo>> {
+ self.channels.read().unwrap()
+ }
+
/// Returns all known nodes' public keys along with announced node info.
///
/// (C-not exported) because we have no mapping for `BTreeMap`s
- pub fn get_nodes<'a>(&'a self) -> &'a BTreeMap<PublicKey, NodeInfo> { &self.nodes }
+ pub fn get_nodes(&self) -> RwLockReadGuard<'_, BTreeMap<PublicKey, NodeInfo>> {
+ self.nodes.read().unwrap()
+ }
/// Get network addresses by node id.
/// Returns None if the requested node is completely unknown,
/// or if node announcement for the node was never received.
///
/// (C-not exported) as there is no practical way to track lifetimes of returned values.
- pub fn get_addresses<'a>(&'a self, pubkey: &PublicKey) -> Option<&'a Vec<NetAddress>> {
- if let Some(node) = self.nodes.get(pubkey) {
+ pub fn get_addresses(&self, pubkey: &PublicKey) -> Option<Vec<NetAddress>> {
+ if let Some(node) = self.nodes.read().unwrap().get(pubkey) {
if let Some(node_info) = node.announcement_info.as_ref() {
- return Some(&node_info.addresses)
+ return Some(node_info.addresses.clone())
}
}
None
pub fn new(genesis_hash: BlockHash) -> NetworkGraph {
Self {
genesis_hash,
- channels: BTreeMap::new(),
- nodes: BTreeMap::new(),
+ channels: RwLock::new(BTreeMap::new()),
+ nodes: RwLock::new(BTreeMap::new()),
}
}
}
fn update_node_from_announcement_intern(&mut self, msg: &msgs::UnsignedNodeAnnouncement, full_msg: Option<&msgs::NodeAnnouncement>) -> Result<(), LightningError> {
- match self.nodes.get_mut(&msg.node_id) {
+ match self.nodes.write().unwrap().get_mut(&msg.node_id) {
None => Err(LightningError{err: "No existing channels for node_announcement".to_owned(), action: ErrorAction::IgnoreError}),
Some(node) => {
if let Some(node_info) = node.announcement_info.as_ref() {
{ full_msg.cloned() } else { None },
};
- match self.channels.entry(msg.short_channel_id) {
+ let mut channels = self.channels.write().unwrap();
+ let mut nodes = self.nodes.write().unwrap();
+ match channels.entry(msg.short_channel_id) {
BtreeEntry::Occupied(mut entry) => {
//TODO: because asking the blockchain if short_channel_id is valid is only optional
//in the blockchain API, we need to handle it smartly here, though it's unclear
// b) we don't track UTXOs of channels we know about and remove them if they
// get reorg'd out.
// c) it's unclear how to do so without exposing ourselves to massive DoS risk.
- Self::remove_channel_in_nodes(&mut self.nodes, &entry.get(), msg.short_channel_id);
+ Self::remove_channel_in_nodes(&mut nodes, &entry.get(), msg.short_channel_id);
*entry.get_mut() = chan_info;
} else {
return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreAndLog(Level::Trace)})
macro_rules! add_channel_to_node {
( $node_id: expr ) => {
- match self.nodes.entry($node_id) {
+ match nodes.entry($node_id) {
BtreeEntry::Occupied(node_entry) => {
node_entry.into_mut().channels.push(msg.short_channel_id);
},
/// May cause the removal of nodes too, if this was their last channel.
/// If not permanent, makes channels unavailable for routing.
pub fn close_channel_from_update(&mut self, short_channel_id: u64, is_permanent: bool) {
+ let mut channels = self.channels.write().unwrap();
if is_permanent {
- if let Some(chan) = self.channels.remove(&short_channel_id) {
- Self::remove_channel_in_nodes(&mut self.nodes, &chan, short_channel_id);
+ if let Some(chan) = channels.remove(&short_channel_id) {
+ let mut nodes = self.nodes.write().unwrap();
+ Self::remove_channel_in_nodes(&mut nodes, &chan, short_channel_id);
}
} else {
- if let Some(chan) = self.channels.get_mut(&short_channel_id) {
+ if let Some(chan) = channels.get_mut(&short_channel_id) {
if let Some(one_to_two) = chan.one_to_two.as_mut() {
one_to_two.enabled = false;
}
let chan_enabled = msg.flags & (1 << 1) != (1 << 1);
let chan_was_enabled;
- match self.channels.get_mut(&msg.short_channel_id) {
+ let mut channels = self.channels.write().unwrap();
+ match channels.get_mut(&msg.short_channel_id) {
None => return Err(LightningError{err: "Couldn't find channel for update".to_owned(), action: ErrorAction::IgnoreError}),
Some(channel) => {
if let OptionalField::Present(htlc_maximum_msat) = msg.htlc_maximum_msat {
}
}
+ let mut nodes = self.nodes.write().unwrap();
if chan_enabled {
- let node = self.nodes.get_mut(&dest_node_id).unwrap();
+ let node = nodes.get_mut(&dest_node_id).unwrap();
let mut base_msat = msg.fee_base_msat;
let mut proportional_millionths = msg.fee_proportional_millionths;
if let Some(fees) = node.lowest_inbound_channel_fees {
proportional_millionths
});
} else if chan_was_enabled {
- let node = self.nodes.get_mut(&dest_node_id).unwrap();
+ let node = nodes.get_mut(&dest_node_id).unwrap();
let mut lowest_inbound_channel_fees = None;
for chan_id in node.channels.iter() {
- let chan = self.channels.get(chan_id).unwrap();
+ let chan = channels.get(chan_id).unwrap();
let chan_info_opt;
if chan.node_one == dest_node_id {
chan_info_opt = chan.two_to_one.as_ref();
match network.get_channels().get(&unsigned_announcement.short_channel_id) {
None => panic!(),
Some(_) => ()
- }
+ };
}
// If we receive announcement for the same channel (with UTXO lookups disabled),
match network.get_channels().get(&unsigned_announcement.short_channel_id) {
None => panic!(),
Some(_) => ()
- }
+ };
}
// If we receive announcement for the same channel (but TX is not confirmed),
assert_eq!(channel_entry.features, ChannelFeatures::empty());
},
_ => panic!()
- }
+ };
}
// Don't relay valid channels with excess data
assert_eq!(channel_info.one_to_two.as_ref().unwrap().cltv_expiry_delta, 144);
assert!(channel_info.two_to_one.is_none());
}
- }
+ };
}
unsigned_channel_update.timestamp += 100;
Some(channel_info) => {
assert!(channel_info.one_to_two.is_some());
}
- }
+ };
}
let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed {
Some(channel_info) => {
assert!(!channel_info.one_to_two.as_ref().unwrap().enabled);
}
- }
+ };
}
let channel_close_msg = HTLCFailChannelUpdate::ChannelClosed {
// to use as the A* heuristic beyond just the cost to get one node further than the current
// one.
+ let network_channels = network.get_channels();
+ let network_nodes = network.get_nodes();
let dummy_directional_info = DummyDirectionalChannelInfo { // used for first_hops routes
cltv_expiry_delta: 0,
htlc_minimum_msat: 0,
// work reliably.
let allow_mpp = if let Some(features) = &payee_features {
features.supports_basic_mpp()
- } else if let Some(node) = network.get_nodes().get(&payee) {
+ } else if let Some(node) = network_nodes.get(&payee) {
if let Some(node_info) = node.announcement_info.as_ref() {
node_info.features.supports_basic_mpp()
} else { false }
// Map from node_id to information about the best current path to that node, including feerate
// information.
- let mut dist = HashMap::with_capacity(network.get_nodes().len());
+ let mut dist = HashMap::with_capacity(network_nodes.len());
// During routing, if we ignore a path due to an htlc_minimum_msat limit, we set this,
// indicating that we may wish to try again with a higher value, potentially paying to meet an
// This map allows paths to be aware of the channel use by other paths in the same call.
// This would help to make a better path finding decisions and not "overbook" channels.
// It is unaware of the directions (except for `outbound_capacity_msat` in `first_hops`).
- let mut bookkeeped_channels_liquidity_available_msat = HashMap::with_capacity(network.get_nodes().len());
+ let mut bookkeeped_channels_liquidity_available_msat = HashMap::with_capacity(network_nodes.len());
// Keeping track of how much value we already collected across other paths. Helps to decide:
// - how much a new path should be transferring (upper bound);
// as a way to reach the $dest_node_id.
let mut fee_base_msat = u32::max_value();
let mut fee_proportional_millionths = u32::max_value();
- if let Some(Some(fees)) = network.get_nodes().get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) {
+ if let Some(Some(fees)) = network_nodes.get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) {
fee_base_msat = fees.base_msat;
fee_proportional_millionths = fees.proportional_millionths;
}
if !features.requires_unknown_bits() {
for chan_id in $node.channels.iter() {
- let chan = network.get_channels().get(chan_id).unwrap();
+ let chan = network_channels.get(chan_id).unwrap();
if !chan.features.requires_unknown_bits() {
if chan.node_one == *$node_id {
// ie $node is one, ie next hop in A* is two, via the two_to_one channel
// Add the payee as a target, so that the payee-to-payer
// search algorithm knows what to start with.
- match network.get_nodes().get(payee) {
+ match network_nodes.get(payee) {
// The payee is not in our network graph, so nothing to add here.
// There is still a chance of reaching them via last_hops though,
// so don't yet fail the payment here.
// we have a direct channel to the first hop or the first hop is
// in the regular network graph.
first_hop_targets.get(&first_hop_in_route.src_node_id).is_some() ||
- network.get_nodes().get(&first_hop_in_route.src_node_id).is_some();
+ network_nodes.get(&first_hop_in_route.src_node_id).is_some();
if have_hop_src_in_graph {
// We start building the path from reverse, i.e., from payee
// to the first RouteHintHop in the path.
'path_walk: loop {
if let Some(&(_, _, _, ref features)) = first_hop_targets.get(&ordered_hops.last().unwrap().0.pubkey) {
ordered_hops.last_mut().unwrap().1 = features.clone();
- } else if let Some(node) = network.get_nodes().get(&ordered_hops.last().unwrap().0.pubkey) {
+ } else if let Some(node) = network_nodes.get(&ordered_hops.last().unwrap().0.pubkey) {
if let Some(node_info) = node.announcement_info.as_ref() {
ordered_hops.last_mut().unwrap().1 = node_info.features.clone();
} else {
// Otherwise, since the current target node is not us,
// keep "unrolling" the payment graph from payee to payer by
// finding a way to reach the current target from the payer side.
- match network.get_nodes().get(&pubkey) {
+ match network_nodes.get(&pubkey) {
None => {},
Some(node) => {
add_entries_to_cheapest_to_target_node!(node, &pubkey, lowest_fee_to_node, value_contribution_msat, path_htlc_minimum_msat);
// First, get 100 (source, destination) pairs for which route-getting actually succeeds...
let mut seed = random_init_seed() as usize;
+ let nodes = graph.get_nodes();
'load_endpoints: for _ in 0..10 {
loop {
seed = seed.overflowing_mul(0xdeadbeef).0;
- let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+ let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
seed = seed.overflowing_mul(0xdeadbeef).0;
- let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+ let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
let amt = seed as u64 % 200_000_000;
if get_route(src, &graph, dst, None, None, &[], amt, 42, &test_utils::TestLogger::new()).is_ok() {
continue 'load_endpoints;
// First, get 100 (source, destination) pairs for which route-getting actually succeeds...
let mut seed = random_init_seed() as usize;
+ let nodes = graph.get_nodes();
'load_endpoints: for _ in 0..10 {
loop {
seed = seed.overflowing_mul(0xdeadbeef).0;
- let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+ let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
seed = seed.overflowing_mul(0xdeadbeef).0;
- let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+ let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
let amt = seed as u64 % 200_000_000;
if get_route(src, &graph, dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &test_utils::TestLogger::new()).is_ok() {
continue 'load_endpoints;
fn generate_routes(bench: &mut Bencher) {
let mut d = test_utils::get_route_file().unwrap();
let graph = NetworkGraph::read(&mut d).unwrap();
+ let nodes = graph.get_nodes();
// First, get 100 (source, destination) pairs for which route-getting actually succeeds...
let mut path_endpoints = Vec::new();
'load_endpoints: for _ in 0..100 {
loop {
seed *= 0xdeadbeef;
- let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+ let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
seed *= 0xdeadbeef;
- let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+ let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
let amt = seed as u64 % 1_000_000;
if get_route(src, &graph, dst, None, None, &[], amt, 42, &DummyLogger{}).is_ok() {
path_endpoints.push((src, dst, amt));
fn generate_mpp_routes(bench: &mut Bencher) {
let mut d = test_utils::get_route_file().unwrap();
let graph = NetworkGraph::read(&mut d).unwrap();
+ let nodes = graph.get_nodes();
// First, get 100 (source, destination) pairs for which route-getting actually succeeds...
let mut path_endpoints = Vec::new();
'load_endpoints: for _ in 0..100 {
loop {
seed *= 0xdeadbeef;
- let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+ let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
seed *= 0xdeadbeef;
- let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+ let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
let amt = seed as u64 % 1_000_000;
if get_route(src, &graph, dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &DummyLogger{}).is_ok() {
path_endpoints.push((src, dst, amt));