Individually lock NetworkGraph fields
[rust-lightning] / lightning / src / routing / router.rs
index 5030f6aaacbf9569a0ff5e2d26a20a31104c0c9b..3df942d733a64da157533ec8afbff45b4f786f3f 100644 (file)
@@ -443,6 +443,8 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
        // to use as the A* heuristic beyond just the cost to get one node further than the current
        // one.
 
+       let network_channels = network.get_channels();
+       let network_nodes = network.get_nodes();
        let dummy_directional_info = DummyDirectionalChannelInfo { // used for first_hops routes
                cltv_expiry_delta: 0,
                htlc_minimum_msat: 0,
@@ -458,7 +460,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
        // work reliably.
        let allow_mpp = if let Some(features) = &payee_features {
                features.supports_basic_mpp()
-       } else if let Some(node) = network.get_nodes().get(&payee) {
+       } else if let Some(node) = network_nodes.get(&payee) {
                if let Some(node_info) = node.announcement_info.as_ref() {
                        node_info.features.supports_basic_mpp()
                } else { false }
@@ -492,7 +494,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
 
        // Map from node_id to information about the best current path to that node, including feerate
        // information.
-       let mut dist = HashMap::with_capacity(network.get_nodes().len());
+       let mut dist = HashMap::with_capacity(network_nodes.len());
 
        // During routing, if we ignore a path due to an htlc_minimum_msat limit, we set this,
        // indicating that we may wish to try again with a higher value, potentially paying to meet an
@@ -511,7 +513,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
        // This map allows paths to be aware of the channel use by other paths in the same call.
        // This would help to make a better path finding decisions and not "overbook" channels.
        // It is unaware of the directions (except for `outbound_capacity_msat` in `first_hops`).
-       let mut bookkeeped_channels_liquidity_available_msat = HashMap::with_capacity(network.get_nodes().len());
+       let mut bookkeeped_channels_liquidity_available_msat = HashMap::with_capacity(network_nodes.len());
 
        // Keeping track of how much value we already collected across other paths. Helps to decide:
        // - how much a new path should be transferring (upper bound);
@@ -629,7 +631,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                                                        // as a way to reach the $dest_node_id.
                                                        let mut fee_base_msat = u32::max_value();
                                                        let mut fee_proportional_millionths = u32::max_value();
-                                                       if let Some(Some(fees)) = network.get_nodes().get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) {
+                                                       if let Some(Some(fees)) = network_nodes.get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) {
                                                                fee_base_msat = fees.base_msat;
                                                                fee_proportional_millionths = fees.proportional_millionths;
                                                        }
@@ -814,7 +816,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
 
                                if !features.requires_unknown_bits() {
                                        for chan_id in $node.channels.iter() {
-                                               let chan = network.get_channels().get(chan_id).unwrap();
+                                               let chan = network_channels.get(chan_id).unwrap();
                                                if !chan.features.requires_unknown_bits() {
                                                        if chan.node_one == *$node_id {
                                                                // ie $node is one, ie next hop in A* is two, via the two_to_one channel
@@ -862,7 +864,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
 
                // Add the payee as a target, so that the payee-to-payer
                // search algorithm knows what to start with.
-               match network.get_nodes().get(payee) {
+               match network_nodes.get(payee) {
                        // The payee is not in our network graph, so nothing to add here.
                        // There is still a chance of reaching them via last_hops though,
                        // so don't yet fail the payment here.
@@ -884,7 +886,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                                // we have a direct channel to the first hop or the first hop is
                                // in the regular network graph.
                                first_hop_targets.get(&first_hop_in_route.src_node_id).is_some() ||
-                               network.get_nodes().get(&first_hop_in_route.src_node_id).is_some();
+                               network_nodes.get(&first_hop_in_route.src_node_id).is_some();
                        if have_hop_src_in_graph {
                                // We start building the path from reverse, i.e., from payee
                                // to the first RouteHintHop in the path.
@@ -991,7 +993,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                                'path_walk: loop {
                                        if let Some(&(_, _, _, ref features)) = first_hop_targets.get(&ordered_hops.last().unwrap().0.pubkey) {
                                                ordered_hops.last_mut().unwrap().1 = features.clone();
-                                       } else if let Some(node) = network.get_nodes().get(&ordered_hops.last().unwrap().0.pubkey) {
+                                       } else if let Some(node) = network_nodes.get(&ordered_hops.last().unwrap().0.pubkey) {
                                                if let Some(node_info) = node.announcement_info.as_ref() {
                                                        ordered_hops.last_mut().unwrap().1 = node_info.features.clone();
                                                } else {
@@ -1093,7 +1095,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                        // Otherwise, since the current target node is not us,
                        // keep "unrolling" the payment graph from payee to payer by
                        // finding a way to reach the current target from the payer side.
-                       match network.get_nodes().get(&pubkey) {
+                       match network_nodes.get(&pubkey) {
                                None => {},
                                Some(node) => {
                                        add_entries_to_cheapest_to_target_node!(node, &pubkey, lowest_fee_to_node, value_contribution_msat, path_htlc_minimum_msat);
@@ -4211,12 +4213,13 @@ mod tests {
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut seed = random_init_seed() as usize;
+               let nodes = graph.get_nodes();
                'load_endpoints: for _ in 0..10 {
                        loop {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+                               let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+                               let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
                                let amt = seed as u64 % 200_000_000;
                                if get_route(src, &graph, dst, None, None, &[], amt, 42, &test_utils::TestLogger::new()).is_ok() {
                                        continue 'load_endpoints;
@@ -4239,12 +4242,13 @@ mod tests {
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut seed = random_init_seed() as usize;
+               let nodes = graph.get_nodes();
                'load_endpoints: for _ in 0..10 {
                        loop {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+                               let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+                               let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
                                let amt = seed as u64 % 200_000_000;
                                if get_route(src, &graph, dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &test_utils::TestLogger::new()).is_ok() {
                                        continue 'load_endpoints;
@@ -4297,6 +4301,7 @@ mod benches {
        fn generate_routes(bench: &mut Bencher) {
                let mut d = test_utils::get_route_file().unwrap();
                let graph = NetworkGraph::read(&mut d).unwrap();
+               let nodes = graph.get_nodes();
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut path_endpoints = Vec::new();
@@ -4304,9 +4309,9 @@ mod benches {
                'load_endpoints: for _ in 0..100 {
                        loop {
                                seed *= 0xdeadbeef;
-                               let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+                               let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
                                seed *= 0xdeadbeef;
-                               let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+                               let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
                                let amt = seed as u64 % 1_000_000;
                                if get_route(src, &graph, dst, None, None, &[], amt, 42, &DummyLogger{}).is_ok() {
                                        path_endpoints.push((src, dst, amt));
@@ -4328,6 +4333,7 @@ mod benches {
        fn generate_mpp_routes(bench: &mut Bencher) {
                let mut d = test_utils::get_route_file().unwrap();
                let graph = NetworkGraph::read(&mut d).unwrap();
+               let nodes = graph.get_nodes();
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut path_endpoints = Vec::new();
@@ -4335,9 +4341,9 @@ mod benches {
                'load_endpoints: for _ in 0..100 {
                        loop {
                                seed *= 0xdeadbeef;
-                               let src = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+                               let src = nodes.keys().skip(seed % nodes.len()).next().unwrap();
                                seed *= 0xdeadbeef;
-                               let dst = graph.get_nodes().keys().skip(seed % graph.get_nodes().len()).next().unwrap();
+                               let dst = nodes.keys().skip(seed % nodes.len()).next().unwrap();
                                let amt = seed as u64 % 1_000_000;
                                if get_route(src, &graph, dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &DummyLogger{}).is_ok() {
                                        path_endpoints.push((src, dst, amt));