]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Consider many first-hop paths to the same counterparty in routing
authorMatt Corallo <git@bluematt.me>
Wed, 29 Sep 2021 19:15:16 +0000 (19:15 +0000)
committerMatt Corallo <git@bluematt.me>
Wed, 29 Sep 2021 19:15:16 +0000 (19:15 +0000)
Previously we'd simply overwritten "the" first hop path to each
counterparty when routing, however this results in us ignoring all
channels except the last one in the `ChannelDetails` list per
counterparty.

lightning/src/routing/router.rs

index 7679125cf5fe7ff91dea35bc4f327da5c234c9ea..38dc918d48c8cb7311f09fbfccbf217e2bd60f8b 100644 (file)
@@ -487,12 +487,14 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                        node_info.features.supports_basic_mpp()
                } else { false }
        } else { false };
+       log_trace!(logger, "Searching for a route from payer {} to payee {} {} MPP", our_node_id, payee,
+               if allow_mpp { "with" } else { "without" });
 
        // Step (1).
        // Prepare the data we'll use for payee-to-payer search by
        // inserting first hops suggested by the caller as targets.
        // Our search will then attempt to reach them while traversing from the payee node.
-       let mut first_hop_targets: HashMap<_, (_, ChannelFeatures, _, NodeFeatures)> =
+       let mut first_hop_targets: HashMap<_, Vec<(_, ChannelFeatures, _, NodeFeatures)>> =
                HashMap::with_capacity(if first_hops.is_some() { first_hops.as_ref().unwrap().len() } else { 0 });
        if let Some(hops) = first_hops {
                for chan in hops {
@@ -500,7 +502,8 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                        if chan.counterparty.node_id == *our_node_id {
                                return Err(LightningError{err: "First hop cannot have our_node_id as a destination.".to_owned(), action: ErrorAction::IgnoreError});
                        }
-                       first_hop_targets.insert(chan.counterparty.node_id, (short_channel_id, chan.counterparty.features.to_context(), chan.outbound_capacity_msat, chan.counterparty.features.to_context()));
+                       first_hop_targets.entry(chan.counterparty.node_id).or_insert(Vec::new())
+                               .push((short_channel_id, chan.counterparty.features.to_context(), chan.outbound_capacity_msat, chan.counterparty.features.to_context()));
                }
                if first_hop_targets.is_empty() {
                        return Err(LightningError{err: "Cannot route when there are no outbound routes away from us".to_owned(), action: ErrorAction::IgnoreError});
@@ -824,8 +827,8 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                        };
 
                        if !skip_node {
-                               if first_hops.is_some() {
-                                       if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&$node_id) {
+                               if let Some(first_channels) = first_hop_targets.get(&$node_id) {
+                                       for (ref first_hop, ref features, ref outbound_capacity_msat, _) in first_channels {
                                                add_entry!(first_hop, *our_node_id, $node_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, $fee_to_target_msat, $next_hops_value_contribution, $next_hops_path_htlc_minimum_msat);
                                        }
                                }
@@ -878,9 +881,10 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
 
                // If first hop is a private channel and the only way to reach the payee, this is the only
                // place where it could be added.
-               if first_hops.is_some() {
-                       if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&payee) {
-                               add_entry!(first_hop, *our_node_id, payee, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, 0, path_value_msat, 0);
+               if let Some(first_channels) = first_hop_targets.get(&payee) {
+                       for (ref first_hop, ref features, ref outbound_capacity_msat, _) in first_channels {
+                               let added = add_entry!(first_hop, *our_node_id, payee, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, 0, path_value_msat, 0);
+                               log_trace!(logger, "{} direct route to payee via SCID {}", if added { "Added" } else { "Skipped" }, first_hop);
                        }
                }
 
@@ -949,8 +953,10 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                                        }
 
                                        // Searching for a direct channel between last checked hop and first_hop_targets
-                                       if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&prev_hop_id) {
-                                               add_entry!(first_hop, *our_node_id , prev_hop_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat);
+                                       if let Some(first_channels) = first_hop_targets.get(&prev_hop_id) {
+                                               for (ref first_hop, ref features, ref outbound_capacity_msat, _) in first_channels {
+                                                       add_entry!(first_hop, *our_node_id , prev_hop_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat);
+                                               }
                                        }
 
                                        if !hop_used {
@@ -981,8 +987,10 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                                                // Note that we *must* check if the last hop was added as `add_entry`
                                                // always assumes that the third argument is a node to which we have a
                                                // path.
-                                               if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&hop.src_node_id) {
-                                                       add_entry!(first_hop, *our_node_id , hop.src_node_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat);
+                                               if let Some(first_channels) = first_hop_targets.get(&hop.src_node_id) {
+                                                       for (ref first_hop, ref features, ref outbound_capacity_msat, _) in first_channels {
+                                                               add_entry!(first_hop, *our_node_id , hop.src_node_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat);
+                                                       }
                                                }
                                        }
                                }
@@ -1013,8 +1021,17 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
                                let mut ordered_hops = vec!((new_entry.clone(), NodeFeatures::empty()));
 
                                'path_walk: loop {
-                                       if let Some(&(_, _, _, ref features)) = first_hop_targets.get(&ordered_hops.last().unwrap().0.pubkey) {
-                                               ordered_hops.last_mut().unwrap().1 = features.clone();
+                                       let mut features_set = false;
+                                       if let Some(first_channels) = first_hop_targets.get(&ordered_hops.last().unwrap().0.pubkey) {
+                                               for (scid, _, _, ref features) in first_channels {
+                                                       if *scid == ordered_hops.last().unwrap().0.short_channel_id {
+                                                               ordered_hops.last_mut().unwrap().1 = features.clone();
+                                                               features_set = true;
+                                                               break;
+                                                       }
+                                               }
+                                       }
+                                       if features_set {
                                        } else if let Some(node) = network_nodes.get(&ordered_hops.last().unwrap().0.pubkey) {
                                                if let Some(node_info) = node.announcement_info.as_ref() {
                                                        ordered_hops.last_mut().unwrap().1 = node_info.features.clone();
@@ -4220,6 +4237,51 @@ mod tests {
                }
        }
 
+       #[test]
+       fn multiple_direct_first_hops() {
+               // Previously we'd only ever considered one first hop path per counterparty.
+               // However, as we don't restrict users to one channel per peer, we really need to support
+               // looking at all first hop paths.
+               // Here we test that we do not ignore all-but-the-last first hop paths per counterparty (as
+               // we used to do by overwriting the `first_hop_targets` hashmap entry) and that we can MPP
+               // route over multiple channels with the same first hop.
+               let secp_ctx = Secp256k1::new();
+               let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
+               let logger = Arc::new(test_utils::TestLogger::new());
+               let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
+
+               {
+                       let route = get_route(&our_id, &network_graph, &nodes[0], Some(InvoiceFeatures::known()), Some(&[
+                               &get_channel_details(Some(3), nodes[0], InitFeatures::known(), 200_000),
+                               &get_channel_details(Some(2), nodes[0], InitFeatures::known(), 10_000),
+                       ]), &[], 100_000, 42, Arc::clone(&logger)).unwrap();
+                       assert_eq!(route.paths.len(), 1);
+                       assert_eq!(route.paths[0].len(), 1);
+
+                       assert_eq!(route.paths[0][0].pubkey, nodes[0]);
+                       assert_eq!(route.paths[0][0].short_channel_id, 3);
+                       assert_eq!(route.paths[0][0].fee_msat, 100_000);
+               }
+               {
+                       let route = get_route(&our_id, &network_graph, &nodes[0], Some(InvoiceFeatures::known()), Some(&[
+                               &get_channel_details(Some(3), nodes[0], InitFeatures::known(), 50_000),
+                               &get_channel_details(Some(2), nodes[0], InitFeatures::known(), 50_000),
+                       ]), &[], 100_000, 42, Arc::clone(&logger)).unwrap();
+                       assert_eq!(route.paths.len(), 2);
+                       assert_eq!(route.paths[0].len(), 1);
+                       assert_eq!(route.paths[1].len(), 1);
+
+                       assert_eq!(route.paths[0][0].pubkey, nodes[0]);
+                       assert_eq!(route.paths[0][0].short_channel_id, 3);
+                       assert_eq!(route.paths[0][0].fee_msat, 50_000);
+
+                       assert_eq!(route.paths[1][0].pubkey, nodes[0]);
+                       assert_eq!(route.paths[1][0].short_channel_id, 2);
+                       assert_eq!(route.paths[1][0].fee_msat, 50_000);
+               }
+
+       }
+
        #[test]
        fn total_fees_single_path() {
                let route = Route {