Test we prefer first hops over route hints
[rust-lightning] / lightning / src / routing / router.rs
index bca681dfc9b32f5760d3d9584ca65d802cd29381..8a7343916d34cacd166762716cdb5ad566ba21fc 100644 (file)
@@ -1622,9 +1622,13 @@ where L::Target: Logger {
                        |info| info.features.supports_basic_mpp()))
        } else { false };
 
-       log_trace!(logger, "Searching for a route from payer {} to {} {} MPP and {} first hops {}overriding the network graph", our_node_pubkey,
-               LoggedPayeePubkey(payment_params.payee.node_id()), if allow_mpp { "with" } else { "without" },
-               first_hops.map(|hops| hops.len()).unwrap_or(0), if first_hops.is_some() { "" } else { "not " });
+       let max_total_routing_fee_msat = route_params.max_total_routing_fee_msat.unwrap_or(u64::max_value());
+
+       log_trace!(logger, "Searching for a route from payer {} to {} {} MPP and {} first hops {}overriding the network graph with a fee limit of {} msat",
+               our_node_pubkey, LoggedPayeePubkey(payment_params.payee.node_id()),
+               if allow_mpp { "with" } else { "without" },
+               first_hops.map(|hops| hops.len()).unwrap_or(0), if first_hops.is_some() { "" } else { "not " },
+               max_total_routing_fee_msat);
 
        // Step (1).
        // Prepare the data we'll use for payee-to-payer search by
@@ -1890,7 +1894,6 @@ where L::Target: Logger {
                                                        }
 
                                                        // Ignore hops if augmenting the current path to them would put us over `max_total_routing_fee_msat`
-                                                       let max_total_routing_fee_msat = route_params.max_total_routing_fee_msat.unwrap_or(u64::max_value());
                                                        if total_fee_msat > max_total_routing_fee_msat {
                                                                if should_log_candidate {
                                                                        log_trace!(logger, "Ignoring {} due to exceeding max total routing fee limit.", LoggedCandidateHop(&$candidate));
@@ -2628,14 +2631,6 @@ where L::Target: Logger {
        // Make sure we would never create a route with more paths than we allow.
        debug_assert!(paths.len() <= payment_params.max_path_count.into());
 
-       // Make sure we would never create a route whose total fees exceed max_total_routing_fee_msat.
-       if let Some(max_total_routing_fee_msat) = route_params.max_total_routing_fee_msat {
-               if paths.iter().map(|p| p.fee_msat()).sum::<u64>() > max_total_routing_fee_msat {
-                       return Err(LightningError{err: format!("Failed to find route that adheres to the maximum total fee limit of {}msat",
-                               max_total_routing_fee_msat), action: ErrorAction::IgnoreError});
-               }
-       }
-
        if let Some(node_features) = payment_params.payee.node_features() {
                for path in paths.iter_mut() {
                        path.hops.last_mut().unwrap().node_features = node_features.clone();
@@ -2643,6 +2638,15 @@ where L::Target: Logger {
        }
 
        let route = Route { paths, route_params: Some(route_params.clone()) };
+
+       // Make sure we would never create a route whose total fees exceed max_total_routing_fee_msat.
+       if let Some(max_total_routing_fee_msat) = route_params.max_total_routing_fee_msat {
+               if route.get_total_fees() > max_total_routing_fee_msat {
+                       return Err(LightningError{err: format!("Failed to find route that adheres to the maximum total fee limit of {}msat",
+                               max_total_routing_fee_msat), action: ErrorAction::IgnoreError});
+               }
+       }
+
        log_info!(logger, "Got route: {}", log_route!(route));
        Ok(route)
 }
@@ -3266,11 +3270,22 @@ mod tests {
                        excess_data: Vec::new()
                });
 
-               // Now check that we'll find a path if the htlc_minimum is overrun substantially.
+               // Now check that we'll fail to find a path if we fail to find a path if the htlc_minimum
+               // is overrun. Note that the fees are actually calculated on 3*payment amount as that's
+               // what we try to find a route for, so this test only just happens to work out to exactly
+               // the fee limit.
                let mut route_params = RouteParameters::from_payment_params_and_value(
                        payment_params.clone(), 5_000);
-               // TODO: This can even overrun the fee limit set by the recipient!
                route_params.max_total_routing_fee_msat = Some(9_999);
+               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id,
+                       &route_params, &network_graph.read_only(), None, Arc::clone(&logger), &scorer,
+                       &Default::default(), &random_seed_bytes) {
+                               assert_eq!(err, "Failed to find route that adheres to the maximum total fee limit of 9999msat");
+               } else { panic!(); }
+
+               let mut route_params = RouteParameters::from_payment_params_and_value(
+                       payment_params.clone(), 5_000);
+               route_params.max_total_routing_fee_msat = Some(10_000);
                let route = get_route(&our_id, &route_params, &network_graph.read_only(), None,
                        Arc::clone(&logger), &scorer, &Default::default(), &random_seed_bytes).unwrap();
                assert_eq!(route.get_total_fees(), 10_000);
@@ -7680,6 +7695,152 @@ mod tests {
                assert_eq!(route.paths.len(), 1);
                assert_eq!(route.get_total_amount(), amt_msat);
        }
+
+       #[test]
+       fn first_hop_preferred_over_hint() {
+               // Check that if we have a first hop to a peer we'd always prefer that over a route hint
+               // they gave us, but we'd still consider all subsequent hints if they are more attractive.
+               let secp_ctx = Secp256k1::new();
+               let logger = Arc::new(ln_test_utils::TestLogger::new());
+               let network_graph = Arc::new(NetworkGraph::new(Network::Testnet, Arc::clone(&logger)));
+               let gossip_sync = P2PGossipSync::new(Arc::clone(&network_graph), None, Arc::clone(&logger));
+               let scorer = ln_test_utils::TestScorer::new();
+               let keys_manager = ln_test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
+               let random_seed_bytes = keys_manager.get_secure_random_bytes();
+               let config = UserConfig::default();
+
+               let amt_msat = 1_000_000;
+               let (our_privkey, our_node_id, privkeys, nodes) = get_nodes(&secp_ctx);
+
+               add_channel(&gossip_sync, &secp_ctx, &our_privkey, &privkeys[0],
+                       ChannelFeatures::from_le_bytes(id_to_feature_flags(1)), 1);
+               update_channel(&gossip_sync, &secp_ctx, &our_privkey, UnsignedChannelUpdate {
+                       chain_hash: genesis_block(Network::Testnet).header.block_hash(),
+                       short_channel_id: 1,
+                       timestamp: 1,
+                       flags: 0,
+                       cltv_expiry_delta: 42,
+                       htlc_minimum_msat: 1_000,
+                       htlc_maximum_msat: 10_000_000,
+                       fee_base_msat: 800,
+                       fee_proportional_millionths: 0,
+                       excess_data: Vec::new()
+               });
+               update_channel(&gossip_sync, &secp_ctx, &privkeys[0], UnsignedChannelUpdate {
+                       chain_hash: genesis_block(Network::Testnet).header.block_hash(),
+                       short_channel_id: 1,
+                       timestamp: 1,
+                       flags: 1,
+                       cltv_expiry_delta: 42,
+                       htlc_minimum_msat: 1_000,
+                       htlc_maximum_msat: 10_000_000,
+                       fee_base_msat: 800,
+                       fee_proportional_millionths: 0,
+                       excess_data: Vec::new()
+               });
+
+               add_channel(&gossip_sync, &secp_ctx, &privkeys[0], &privkeys[1],
+                       ChannelFeatures::from_le_bytes(id_to_feature_flags(1)), 2);
+               update_channel(&gossip_sync, &secp_ctx, &privkeys[0], UnsignedChannelUpdate {
+                       chain_hash: genesis_block(Network::Testnet).header.block_hash(),
+                       short_channel_id: 2,
+                       timestamp: 2,
+                       flags: 0,
+                       cltv_expiry_delta: 42,
+                       htlc_minimum_msat: 1_000,
+                       htlc_maximum_msat: 10_000_000,
+                       fee_base_msat: 800,
+                       fee_proportional_millionths: 0,
+                       excess_data: Vec::new()
+               });
+               update_channel(&gossip_sync, &secp_ctx, &privkeys[1], UnsignedChannelUpdate {
+                       chain_hash: genesis_block(Network::Testnet).header.block_hash(),
+                       short_channel_id: 2,
+                       timestamp: 2,
+                       flags: 1,
+                       cltv_expiry_delta: 42,
+                       htlc_minimum_msat: 1_000,
+                       htlc_maximum_msat: 10_000_000,
+                       fee_base_msat: 800,
+                       fee_proportional_millionths: 0,
+                       excess_data: Vec::new()
+               });
+
+               let dest_node_id = nodes[2];
+
+               let route_hint = RouteHint(vec![RouteHintHop {
+                       src_node_id: our_node_id,
+                       short_channel_id: 44,
+                       fees: RoutingFees {
+                               base_msat: 234,
+                               proportional_millionths: 0,
+                       },
+                       cltv_expiry_delta: 10,
+                       htlc_minimum_msat: None,
+                       htlc_maximum_msat: Some(5_000_000),
+               },
+               RouteHintHop {
+                       src_node_id: nodes[0],
+                       short_channel_id: 45,
+                       fees: RoutingFees {
+                               base_msat: 123,
+                               proportional_millionths: 0,
+                       },
+                       cltv_expiry_delta: 10,
+                       htlc_minimum_msat: None,
+                       htlc_maximum_msat: None,
+               }]);
+
+               let payment_params = PaymentParameters::from_node_id(dest_node_id, 42)
+                       .with_route_hints(vec![route_hint]).unwrap()
+                       .with_bolt11_features(channelmanager::provided_invoice_features(&config)).unwrap();
+               let route_params = RouteParameters::from_payment_params_and_value(
+                       payment_params, amt_msat);
+
+               // First create an insufficient first hop for channel with SCID 1 and check we'd use the
+               // route hint.
+               let first_hop = get_channel_details(Some(1), nodes[0],
+                       channelmanager::provided_init_features(&config), 999_999);
+               let first_hops = vec![first_hop];
+
+               let route = get_route(&our_node_id, &route_params.clone(), &network_graph.read_only(),
+                       Some(&first_hops.iter().collect::<Vec<_>>()), Arc::clone(&logger), &scorer,
+                       &Default::default(), &random_seed_bytes).unwrap();
+               assert_eq!(route.paths.len(), 1);
+               assert_eq!(route.get_total_amount(), amt_msat);
+               assert_eq!(route.paths[0].hops.len(), 2);
+               assert_eq!(route.paths[0].hops[0].short_channel_id, 44);
+               assert_eq!(route.paths[0].hops[1].short_channel_id, 45);
+               assert_eq!(route.get_total_fees(), 123);
+
+               // Now check we would trust our first hop info, i.e., fail if we detect the route hint is
+               // for a first hop channel.
+               let mut first_hop = get_channel_details(Some(1), nodes[0], channelmanager::provided_init_features(&config), 999_999);
+               first_hop.outbound_scid_alias = Some(44);
+               let first_hops = vec![first_hop];
+
+               let route_res = get_route(&our_node_id, &route_params.clone(), &network_graph.read_only(),
+                       Some(&first_hops.iter().collect::<Vec<_>>()), Arc::clone(&logger), &scorer,
+                       &Default::default(), &random_seed_bytes);
+               assert!(route_res.is_err());
+
+               // Finally check we'd use the first hop if has sufficient outbound capacity. But we'd stil
+               // use the cheaper second hop of the route hint.
+               let mut first_hop = get_channel_details(Some(1), nodes[0],
+                       channelmanager::provided_init_features(&config), 10_000_000);
+               first_hop.outbound_scid_alias = Some(44);
+               let first_hops = vec![first_hop];
+
+               let route = get_route(&our_node_id, &route_params.clone(), &network_graph.read_only(),
+                       Some(&first_hops.iter().collect::<Vec<_>>()), Arc::clone(&logger), &scorer,
+                       &Default::default(), &random_seed_bytes).unwrap();
+               assert_eq!(route.paths.len(), 1);
+               assert_eq!(route.get_total_amount(), amt_msat);
+               assert_eq!(route.paths[0].hops.len(), 2);
+               assert_eq!(route.paths[0].hops[0].short_channel_id, 1);
+               assert_eq!(route.paths[0].hops[1].short_channel_id, 45);
+               assert_eq!(route.get_total_fees(), 123);
+       }
 }
 
 #[cfg(all(any(test, ldk_bench), not(feature = "no-std")))]