]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Add source and target nodes to routing::Score
authorJeffrey Czyz <jkczyz@gmail.com>
Mon, 18 Oct 2021 23:36:35 +0000 (18:36 -0500)
committerJeffrey Czyz <jkczyz@gmail.com>
Tue, 19 Oct 2021 18:03:02 +0000 (13:03 -0500)
Expand routing::Score::channel_penalty_msat to include the source and
target node ids of the channel. This allows scorers to avoid certain
nodes altogether if desired.

lightning/src/routing/mod.rs
lightning/src/routing/router.rs
lightning/src/routing/scorer.rs

index 1cf73c2689eecc9029114d7ab6d24c4b19b66557..51ffd91b50483d58bc2016440bdb98d9f82886e5 100644 (file)
@@ -13,10 +13,13 @@ pub mod network_graph;
 pub mod router;
 pub mod scorer;
 
+use routing::network_graph::NodeId;
+
 /// An interface used to score payment channels for path finding.
 ///
 ///    Scoring is in terms of fees willing to be paid in order to avoid routing through a channel.
 pub trait Score {
-       /// Returns the fee in msats willing to be paid to avoid routing through the given channel.
-       fn channel_penalty_msat(&self, short_channel_id: u64) -> u64;
+       /// Returns the fee in msats willing to be paid to avoid routing through the given channel
+       /// in the direction from `source` to `target`.
+       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64;
 }
index f33ef56a16b5ad52544852b355137ae507b4c515..b617eebd42d834996279c9346e07a743bee7c1de 100644 (file)
@@ -748,7 +748,7 @@ where L::Target: Logger {
                                                        }
 
                                                        let path_penalty_msat = $next_hops_path_penalty_msat
-                                                               .checked_add(scorer.channel_penalty_msat($chan_id.clone()))
+                                                               .checked_add(scorer.channel_penalty_msat($chan_id.clone(), &$src_node_id, &$dest_node_id))
                                                                .unwrap_or_else(|| u64::max_value());
                                                        let new_graph_node = RouteGraphNode {
                                                                node_id: $src_node_id,
@@ -973,15 +973,17 @@ where L::Target: Logger {
                                                _ => aggregate_next_hops_fee_msat.checked_add(999).unwrap_or(u64::max_value())
                                        }) { Some( val / 1000 ) } else { break; }; // converting from msat or breaking if max ~ infinity
 
+                                       let src_node_id = NodeId::from_pubkey(&hop.src_node_id);
+                                       let dest_node_id = NodeId::from_pubkey(&prev_hop_id);
                                        aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
-                                               .checked_add(scorer.channel_penalty_msat(hop.short_channel_id))
+                                               .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, &src_node_id, &dest_node_id))
                                                .unwrap_or_else(|| u64::max_value());
 
                                        // We assume that the recipient only included route hints for routes which had
                                        // sufficient value to route `final_value_msat`. Note that in the case of "0-value"
                                        // invoices where the invoice does not specify value this may not be the case, but
                                        // better to include the hints than not.
-                                       if !add_entry!(hop.short_channel_id, NodeId::from_pubkey(&hop.src_node_id), NodeId::from_pubkey(&prev_hop_id), directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
+                                       if !add_entry!(hop.short_channel_id, src_node_id, dest_node_id, directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
                                                // If this hop was not used then there is no use checking the preceding hops
                                                // in the RouteHint. We can break by just searching for a direct channel between
                                                // last checked hop and first_hop_targets
@@ -1322,7 +1324,8 @@ where L::Target: Logger {
 
 #[cfg(test)]
 mod tests {
-       use routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
+       use routing;
+       use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
        use routing::router::{get_route, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees};
        use routing::scorer::Scorer;
        use chain::transaction::OutPoint;
@@ -4377,6 +4380,68 @@ mod tests {
                assert_eq!(path, vec![2, 4, 7, 10]);
        }
 
+       struct BadChannelScorer {
+               short_channel_id: u64,
+       }
+
+       impl routing::Score for BadChannelScorer {
+               fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 {
+                       if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
+               }
+       }
+
+       struct BadNodeScorer {
+               node_id: NodeId,
+       }
+
+       impl routing::Score for BadNodeScorer {
+               fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 {
+                       if *target == self.node_id { u64::max_value() } else { 0 }
+               }
+       }
+
+       #[test]
+       fn avoids_routing_through_bad_channels_and_nodes() {
+               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
+
+               // A path to nodes[6] exists when no penalties are applied to any channel.
+               let scorer = Scorer::new(0);
+               let route = get_route(
+                       &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
+                       &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
+               ).unwrap();
+               let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
+
+               assert_eq!(route.get_total_fees(), 100);
+               assert_eq!(route.get_total_amount(), 100);
+               assert_eq!(path, vec![2, 4, 6, 11, 8]);
+
+               // A different path to nodes[6] exists if channel 6 cannot be routed over.
+               let scorer = BadChannelScorer { short_channel_id: 6 };
+               let route = get_route(
+                       &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
+                       &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
+               ).unwrap();
+               let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
+
+               assert_eq!(route.get_total_fees(), 300);
+               assert_eq!(route.get_total_amount(), 100);
+               assert_eq!(path, vec![2, 4, 7, 10]);
+
+               // A path to nodes[6] does not exist if nodes[2] cannot be routed through.
+               let scorer = BadNodeScorer { node_id: NodeId::from_pubkey(&nodes[2]) };
+               match get_route(
+                       &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
+                       &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
+               ) {
+                       Err(LightningError { err, .. } ) => {
+                               assert_eq!(err, "Failed to find a path to the given destination");
+                       },
+                       Ok(_) => panic!("Expected error"),
+               }
+       }
+
        #[test]
        fn total_fees_single_path() {
                let route = Route {
index f58da652096349bc2094fb2ce27b4d8be7234113..0f43c3d79283ae8e458b506a783aa72d03f90347 100644 (file)
@@ -44,6 +44,8 @@
 
 use routing;
 
+use routing::network_graph::NodeId;
+
 /// [`routing::Score`] implementation that provides reasonable default behavior.
 ///
 /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
@@ -71,5 +73,9 @@ impl Default for Scorer {
 }
 
 impl routing::Score for Scorer {
-       fn channel_penalty_msat(&self, _short_channel_id: u64) -> u64 { self.base_penalty_msat }
+       fn channel_penalty_msat(
+               &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+       ) -> u64 {
+               self.base_penalty_msat
+       }
 }