]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Include any recipient overpayment amounts in the route fee limit
authorMatt Corallo <git@bluematt.me>
Thu, 28 Sep 2023 18:19:36 +0000 (18:19 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 28 Sep 2023 20:39:36 +0000 (20:39 +0000)
If the user told us to limit their total fee exposure, we should
do so including any potential overpayment to the recipient, which
is ultimately a part of the "fee" as far as the user is concerned.

lightning/src/routing/router.rs

index bca681dfc9b32f5760d3d9584ca65d802cd29381..19cf4e05d950d2b7c7b05cabdfd33a7377c5526a 100644 (file)
@@ -2628,14 +2628,6 @@ where L::Target: Logger {
        // Make sure we would never create a route with more paths than we allow.
        debug_assert!(paths.len() <= payment_params.max_path_count.into());
 
-       // Make sure we would never create a route whose total fees exceed max_total_routing_fee_msat.
-       if let Some(max_total_routing_fee_msat) = route_params.max_total_routing_fee_msat {
-               if paths.iter().map(|p| p.fee_msat()).sum::<u64>() > max_total_routing_fee_msat {
-                       return Err(LightningError{err: format!("Failed to find route that adheres to the maximum total fee limit of {}msat",
-                               max_total_routing_fee_msat), action: ErrorAction::IgnoreError});
-               }
-       }
-
        if let Some(node_features) = payment_params.payee.node_features() {
                for path in paths.iter_mut() {
                        path.hops.last_mut().unwrap().node_features = node_features.clone();
@@ -2643,6 +2635,15 @@ where L::Target: Logger {
        }
 
        let route = Route { paths, route_params: Some(route_params.clone()) };
+
+       // Make sure we would never create a route whose total fees exceed max_total_routing_fee_msat.
+       if let Some(max_total_routing_fee_msat) = route_params.max_total_routing_fee_msat {
+               if route.get_total_fees() > max_total_routing_fee_msat {
+                       return Err(LightningError{err: format!("Failed to find route that adheres to the maximum total fee limit of {}msat",
+                               max_total_routing_fee_msat), action: ErrorAction::IgnoreError});
+               }
+       }
+
        log_info!(logger, "Got route: {}", log_route!(route));
        Ok(route)
 }
@@ -3266,11 +3267,22 @@ mod tests {
                        excess_data: Vec::new()
                });
 
-               // Now check that we'll find a path if the htlc_minimum is overrun substantially.
+               // Now check that we'll fail to find a path if we fail to find a path if the htlc_minimum
+               // is overrun. Note that the fees are actually calculated on 3*payment amount as that's
+               // what we try to find a route for, so this test only just happens to work out to exactly
+               // the fee limit.
                let mut route_params = RouteParameters::from_payment_params_and_value(
                        payment_params.clone(), 5_000);
-               // TODO: This can even overrun the fee limit set by the recipient!
                route_params.max_total_routing_fee_msat = Some(9_999);
+               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id,
+                       &route_params, &network_graph.read_only(), None, Arc::clone(&logger), &scorer,
+                       &Default::default(), &random_seed_bytes) {
+                               assert_eq!(err, "Failed to find route that adheres to the maximum total fee limit of 9999msat");
+               } else { panic!(); }
+
+               let mut route_params = RouteParameters::from_payment_params_and_value(
+                       payment_params.clone(), 5_000);
+               route_params.max_total_routing_fee_msat = Some(10_000);
                let route = get_route(&our_id, &route_params, &network_graph.read_only(), None,
                        Arc::clone(&logger), &scorer, &Default::default(), &random_seed_bytes).unwrap();
                assert_eq!(route.get_total_fees(), 10_000);