From 56c75f0bac02db2212f8ef5357befdd5425da55d Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Wed, 29 Sep 2021 19:15:16 +0000 Subject: [PATCH] Consider many first-hop paths to the same counterparty in routing Previously we'd simply overwritten "the" first hop path to each counterparty when routing, however this results in us ignoring all channels except the last one in the `ChannelDetails` list per counterparty. --- lightning/src/routing/router.rs | 88 ++++++++++++++++++++++++++++----- 1 file changed, 75 insertions(+), 13 deletions(-) diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 7679125cf..38dc918d4 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -487,12 +487,14 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye node_info.features.supports_basic_mpp() } else { false } } else { false }; + log_trace!(logger, "Searching for a route from payer {} to payee {} {} MPP", our_node_id, payee, + if allow_mpp { "with" } else { "without" }); // Step (1). // Prepare the data we'll use for payee-to-payer search by // inserting first hops suggested by the caller as targets. // Our search will then attempt to reach them while traversing from the payee node. - let mut first_hop_targets: HashMap<_, (_, ChannelFeatures, _, NodeFeatures)> = + let mut first_hop_targets: HashMap<_, Vec<(_, ChannelFeatures, _, NodeFeatures)>> = HashMap::with_capacity(if first_hops.is_some() { first_hops.as_ref().unwrap().len() } else { 0 }); if let Some(hops) = first_hops { for chan in hops { @@ -500,7 +502,8 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye if chan.counterparty.node_id == *our_node_id { return Err(LightningError{err: "First hop cannot have our_node_id as a destination.".to_owned(), action: ErrorAction::IgnoreError}); } - first_hop_targets.insert(chan.counterparty.node_id, (short_channel_id, chan.counterparty.features.to_context(), chan.outbound_capacity_msat, chan.counterparty.features.to_context())); + first_hop_targets.entry(chan.counterparty.node_id).or_insert(Vec::new()) + .push((short_channel_id, chan.counterparty.features.to_context(), chan.outbound_capacity_msat, chan.counterparty.features.to_context())); } if first_hop_targets.is_empty() { return Err(LightningError{err: "Cannot route when there are no outbound routes away from us".to_owned(), action: ErrorAction::IgnoreError}); @@ -824,8 +827,8 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye }; if !skip_node { - if first_hops.is_some() { - if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&$node_id) { + if let Some(first_channels) = first_hop_targets.get(&$node_id) { + for (ref first_hop, ref features, ref outbound_capacity_msat, _) in first_channels { add_entry!(first_hop, *our_node_id, $node_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, $fee_to_target_msat, $next_hops_value_contribution, $next_hops_path_htlc_minimum_msat); } } @@ -878,9 +881,10 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // If first hop is a private channel and the only way to reach the payee, this is the only // place where it could be added. - if first_hops.is_some() { - if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&payee) { - add_entry!(first_hop, *our_node_id, payee, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, 0, path_value_msat, 0); + if let Some(first_channels) = first_hop_targets.get(&payee) { + for (ref first_hop, ref features, ref outbound_capacity_msat, _) in first_channels { + let added = add_entry!(first_hop, *our_node_id, payee, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, 0, path_value_msat, 0); + log_trace!(logger, "{} direct route to payee via SCID {}", if added { "Added" } else { "Skipped" }, first_hop); } } @@ -949,8 +953,10 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye } // Searching for a direct channel between last checked hop and first_hop_targets - if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&prev_hop_id) { - add_entry!(first_hop, *our_node_id , prev_hop_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat); + if let Some(first_channels) = first_hop_targets.get(&prev_hop_id) { + for (ref first_hop, ref features, ref outbound_capacity_msat, _) in first_channels { + add_entry!(first_hop, *our_node_id , prev_hop_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat); + } } if !hop_used { @@ -981,8 +987,10 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye // Note that we *must* check if the last hop was added as `add_entry` // always assumes that the third argument is a node to which we have a // path. - if let Some(&(ref first_hop, ref features, ref outbound_capacity_msat, _)) = first_hop_targets.get(&hop.src_node_id) { - add_entry!(first_hop, *our_node_id , hop.src_node_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat); + if let Some(first_channels) = first_hop_targets.get(&hop.src_node_id) { + for (ref first_hop, ref features, ref outbound_capacity_msat, _) in first_channels { + add_entry!(first_hop, *our_node_id , hop.src_node_id, dummy_directional_info, Some(outbound_capacity_msat / 1000), features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat); + } } } } @@ -1013,8 +1021,17 @@ pub fn get_route(our_node_id: &PublicKey, network: &NetworkGraph, paye let mut ordered_hops = vec!((new_entry.clone(), NodeFeatures::empty())); 'path_walk: loop { - if let Some(&(_, _, _, ref features)) = first_hop_targets.get(&ordered_hops.last().unwrap().0.pubkey) { - ordered_hops.last_mut().unwrap().1 = features.clone(); + let mut features_set = false; + if let Some(first_channels) = first_hop_targets.get(&ordered_hops.last().unwrap().0.pubkey) { + for (scid, _, _, ref features) in first_channels { + if *scid == ordered_hops.last().unwrap().0.short_channel_id { + ordered_hops.last_mut().unwrap().1 = features.clone(); + features_set = true; + break; + } + } + } + if features_set { } else if let Some(node) = network_nodes.get(&ordered_hops.last().unwrap().0.pubkey) { if let Some(node_info) = node.announcement_info.as_ref() { ordered_hops.last_mut().unwrap().1 = node_info.features.clone(); @@ -4220,6 +4237,51 @@ mod tests { } } + #[test] + fn multiple_direct_first_hops() { + // Previously we'd only ever considered one first hop path per counterparty. + // However, as we don't restrict users to one channel per peer, we really need to support + // looking at all first hop paths. + // Here we test that we do not ignore all-but-the-last first hop paths per counterparty (as + // we used to do by overwriting the `first_hop_targets` hashmap entry) and that we can MPP + // route over multiple channels with the same first hop. + let secp_ctx = Secp256k1::new(); + let (_, our_id, _, nodes) = get_nodes(&secp_ctx); + let logger = Arc::new(test_utils::TestLogger::new()); + let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()); + + { + let route = get_route(&our_id, &network_graph, &nodes[0], Some(InvoiceFeatures::known()), Some(&[ + &get_channel_details(Some(3), nodes[0], InitFeatures::known(), 200_000), + &get_channel_details(Some(2), nodes[0], InitFeatures::known(), 10_000), + ]), &[], 100_000, 42, Arc::clone(&logger)).unwrap(); + assert_eq!(route.paths.len(), 1); + assert_eq!(route.paths[0].len(), 1); + + assert_eq!(route.paths[0][0].pubkey, nodes[0]); + assert_eq!(route.paths[0][0].short_channel_id, 3); + assert_eq!(route.paths[0][0].fee_msat, 100_000); + } + { + let route = get_route(&our_id, &network_graph, &nodes[0], Some(InvoiceFeatures::known()), Some(&[ + &get_channel_details(Some(3), nodes[0], InitFeatures::known(), 50_000), + &get_channel_details(Some(2), nodes[0], InitFeatures::known(), 50_000), + ]), &[], 100_000, 42, Arc::clone(&logger)).unwrap(); + assert_eq!(route.paths.len(), 2); + assert_eq!(route.paths[0].len(), 1); + assert_eq!(route.paths[1].len(), 1); + + assert_eq!(route.paths[0][0].pubkey, nodes[0]); + assert_eq!(route.paths[0][0].short_channel_id, 3); + assert_eq!(route.paths[0][0].fee_msat, 50_000); + + assert_eq!(route.paths[1][0].pubkey, nodes[0]); + assert_eq!(route.paths[1][0].short_channel_id, 2); + assert_eq!(route.paths[1][0].fee_msat, 50_000); + } + + } + #[test] fn total_fees_single_path() { let route = Route { -- 2.39.5