Stupid massive reduction in struct size in routing lol
[rust-lightning] / lightning / src / routing / router.rs
index c7d57fabca5cbf5364dbdf8d79d70e98d3a29b87..798e35f6e3db1c91c961332221fa873002d633ee 100644 (file)
@@ -306,9 +306,18 @@ impl_writeable_tlv_based!(RouteHintHop, {
        (6, cltv_expiry_delta, required),
 });
 
+#[derive(Eq)]
+struct NodeIdPtr<'a>(&'a NodeId);
+impl<'a> core::hash::Hash for NodeIdPtr<'a> {
+    fn hash<H: core::hash::Hasher>(&self, h: &mut H) { self.0.as_slice()[1..].hash(h) }
+}
+impl<'a> PartialEq for NodeIdPtr<'a> {
+    fn eq(&self, o: &Self) -> bool { core::ptr::eq(self.0, o.0) || self.0 == o.0 }
+}
+
 #[derive(Eq, PartialEq)]
-struct RouteGraphNode {
-       node_id: NodeId,
+struct RouteGraphNode<'a> {
+       node_id: NodeIdPtr<'a>,
        lowest_fee_to_peer_through_node: u64,
        lowest_fee_to_node: u64,
        total_cltv_delta: u32,
@@ -326,20 +335,20 @@ struct RouteGraphNode {
        path_penalty_msat: u64,
 }
 
-impl cmp::Ord for RouteGraphNode {
-       fn cmp(&self, other: &RouteGraphNode) -> cmp::Ordering {
+impl<'a> cmp::Ord for RouteGraphNode<'a> {
+       fn cmp(&self, other: &RouteGraphNode<'a>) -> cmp::Ordering {
                let other_score = cmp::max(other.lowest_fee_to_peer_through_node, other.path_htlc_minimum_msat)
                        .checked_add(other.path_penalty_msat)
                        .unwrap_or_else(|| u64::max_value());
                let self_score = cmp::max(self.lowest_fee_to_peer_through_node, self.path_htlc_minimum_msat)
                        .checked_add(self.path_penalty_msat)
                        .unwrap_or_else(|| u64::max_value());
-               other_score.cmp(&self_score).then_with(|| other.node_id.cmp(&self.node_id))
+               other_score.cmp(&self_score).then_with(|| (other.node_id.0 as *const NodeId).cmp(&(self.node_id.0 as *const NodeId)))
        }
 }
 
-impl cmp::PartialOrd for RouteGraphNode {
-       fn partial_cmp(&self, other: &RouteGraphNode) -> Option<cmp::Ordering> {
+impl<'a> cmp::PartialOrd for RouteGraphNode<'a> {
+       fn partial_cmp(&self, other: &RouteGraphNode<'a>) -> Option<cmp::Ordering> {
                Some(self.cmp(other))
        }
 }
@@ -361,6 +370,7 @@ enum CandidateRouteHop<'a> {
        /// A hop to the payee found in the payment invoice, though not necessarily a direct channel.
        PrivateHop {
                hint: &'a RouteHintHop,
+               dst_node_id: &'a PublicKey,
        }
 }
 
@@ -369,7 +379,7 @@ impl<'a> CandidateRouteHop<'a> {
                match self {
                        CandidateRouteHop::FirstHop { details } => details.short_channel_id.unwrap(),
                        CandidateRouteHop::PublicHop { short_channel_id, .. } => *short_channel_id,
-                       CandidateRouteHop::PrivateHop { hint } => hint.short_channel_id,
+                       CandidateRouteHop::PrivateHop { hint, .. } => hint.short_channel_id,
                }
        }
 
@@ -386,7 +396,7 @@ impl<'a> CandidateRouteHop<'a> {
                match self {
                        CandidateRouteHop::FirstHop { .. } => 0,
                        CandidateRouteHop::PublicHop { info, .. } => info.direction().cltv_expiry_delta as u32,
-                       CandidateRouteHop::PrivateHop { hint } => hint.cltv_expiry_delta as u32,
+                       CandidateRouteHop::PrivateHop { hint, .. } => hint.cltv_expiry_delta as u32,
                }
        }
 
@@ -394,7 +404,7 @@ impl<'a> CandidateRouteHop<'a> {
                match self {
                        CandidateRouteHop::FirstHop { .. } => 0,
                        CandidateRouteHop::PublicHop { info, .. } => info.direction().htlc_minimum_msat,
-                       CandidateRouteHop::PrivateHop { hint } => hint.htlc_minimum_msat.unwrap_or(0),
+                       CandidateRouteHop::PrivateHop { hint, .. } => hint.htlc_minimum_msat.unwrap_or(0),
                }
        }
 
@@ -404,7 +414,7 @@ impl<'a> CandidateRouteHop<'a> {
                                base_msat: 0, proportional_millionths: 0,
                        },
                        CandidateRouteHop::PublicHop { info, .. } => info.direction().fees,
-                       CandidateRouteHop::PrivateHop { hint } => hint.fees,
+                       CandidateRouteHop::PrivateHop { hint, .. } => hint.fees,
                }
        }
 
@@ -417,6 +427,15 @@ impl<'a> CandidateRouteHop<'a> {
                        CandidateRouteHop::PrivateHop { .. } => EffectiveCapacity::Infinite,
                }
        }
+
+    // NOTE: THIS IS EXPENSIVE!
+       fn dst_node_id(&self) -> NodeId {
+               match self {
+                       CandidateRouteHop::FirstHop { details } => NodeId::from_pubkey(&details.counterparty.node_id),
+                       CandidateRouteHop::PublicHop { info, .. } => info.dest_node_id().clone(),
+                       CandidateRouteHop::PrivateHop { dst_node_id, .. } => NodeId::from_pubkey(&dst_node_id),
+               }
+       }
 }
 
 /// It's useful to keep track of the hops associated with the fees required to use them,
@@ -425,9 +444,6 @@ impl<'a> CandidateRouteHop<'a> {
 /// These fee values are useful to choose hops as we traverse the graph "payee-to-payer".
 #[derive(Clone, Debug)]
 struct PathBuildingHop<'a> {
-       // Note that this should be dropped in favor of loading it from CandidateRouteHop, but doing so
-       // is a larger refactor and will require careful performance analysis.
-       node_id: NodeId,
        candidate: CandidateRouteHop<'a>,
        fee_msat: u64,
 
@@ -753,7 +769,7 @@ where L::Target: Logger {
 
        // Map from node_id to information about the best current path to that node, including feerate
        // information.
-       let mut dist = HashMap::with_capacity(network_nodes.len());
+       let mut dist: HashMap<NodeIdPtr, _> = HashMap::with_capacity(network_nodes.len());
 
        // During routing, if we ignore a path due to an htlc_minimum_msat limit, we set this,
        // indicating that we may wish to try again with a higher value, potentially paying to meet an
@@ -869,7 +885,7 @@ where L::Target: Logger {
                                                        .and_then(|fee_msat| fee_msat.checked_add($next_hops_path_htlc_minimum_msat))
                                                        .map(|fee_msat| cmp::max(fee_msat, $candidate.htlc_minimum_msat()))
                                                        .unwrap_or_else(|| u64::max_value());
-                                               let hm_entry = dist.entry($src_node_id);
+                                               let hm_entry = dist.entry(NodeIdPtr(&$src_node_id));
                                                let old_entry = hm_entry.or_insert_with(|| {
                                                        // If there was previously no known way to access the source node
                                                        // (recall it goes payee-to-payer) of short_channel_id, first add a
@@ -883,7 +899,6 @@ where L::Target: Logger {
                                                                fee_proportional_millionths = fees.proportional_millionths;
                                                        }
                                                        PathBuildingHop {
-                                                               node_id: $dest_node_id.clone(),
                                                                candidate: $candidate.clone(),
                                                                fee_msat: 0,
                                                                src_lowest_inbound_fees: RoutingFees {
@@ -950,7 +965,7 @@ where L::Target: Logger {
                                                                scorer.channel_penalty_msat(short_channel_id, amount_to_transfer_over_msat, *available_liquidity_msat,
                                                                        &$src_node_id, &$dest_node_id)).unwrap_or_else(|| u64::max_value());
                                                        let new_graph_node = RouteGraphNode {
-                                                               node_id: $src_node_id,
+                                                               node_id: NodeIdPtr(&$src_node_id),
                                                                lowest_fee_to_peer_through_node: total_fee_msat,
                                                                lowest_fee_to_node: $next_hops_fee_msat as u64 + hop_use_fee_msat,
                                                                total_cltv_delta: hop_total_cltv_delta,
@@ -987,7 +1002,6 @@ where L::Target: Logger {
                                                                old_entry.next_hops_fee_msat = $next_hops_fee_msat;
                                                                old_entry.hop_use_fee_msat = hop_use_fee_msat;
                                                                old_entry.total_fee_msat = total_fee_msat;
-                                                               old_entry.node_id = $dest_node_id.clone();
                                                                old_entry.candidate = $candidate.clone();
                                                                old_entry.fee_msat = 0; // This value will be later filled with hop_use_fee_msat of the following channel
                                                                old_entry.path_htlc_minimum_msat = path_htlc_minimum_msat;
@@ -1042,7 +1056,7 @@ where L::Target: Logger {
        // This data can later be helpful to optimize routing (pay lower fees).
        macro_rules! add_entries_to_cheapest_to_target_node {
                ( $node: expr, $node_id: expr, $fee_to_target_msat: expr, $next_hops_value_contribution: expr, $next_hops_path_htlc_minimum_msat: expr, $next_hops_path_penalty_msat: expr, $next_hops_cltv_delta: expr ) => {
-                       let skip_node = if let Some(elem) = dist.get_mut(&$node_id) {
+                       let skip_node = if let Some(elem) = dist.get_mut(&NodeIdPtr(&$node_id)) {
                                let was_processed = elem.was_processed;
                                elem.was_processed = true;
                                was_processed
@@ -1093,6 +1107,24 @@ where L::Target: Logger {
        }
 
        let mut payment_paths = Vec::<PaymentPath>::new();
+let mut first_hops_node_ids = Vec::new();
+for route in payment_params.route_hints.iter().filter(|route| !route.0.is_empty()) {
+                       let first_hop_in_route = &(route.0)[0];
+    let have_hop_src_in_graph =
+        // Only add the hops in this route to our candidate set if either
+        // we have a direct channel to the first hop or the first hop is
+        // in the regular network graph.
+        first_hop_targets.get(&NodeId::from_pubkey(&first_hop_in_route.src_node_id)).is_some() ||
+        network_nodes.get(&NodeId::from_pubkey(&first_hop_in_route.src_node_id)).is_some();
+    if have_hop_src_in_graph {
+        let hop_iter = route.0.iter().rev();
+        let prev_hop_iter = core::iter::once(&payment_params.payee_pubkey).chain(
+            route.0.iter().skip(1).rev().map(|hop| &hop.src_node_id));
+               for (idx, (hop, prev_hop_id)) in hop_iter.zip(prev_hop_iter).enumerate() {
+            first_hops_node_ids.push(NodeId::from_pubkey(&hop.src_node_id));
+        }
+    }
+}
 
        // TODO: diversify by nodes (so that all paths aren't doomed if one node is offline).
        'paths_collection: loop {
@@ -1130,6 +1162,7 @@ where L::Target: Logger {
                // If a caller provided us with last hops, add them to routing targets. Since this happens
                // earlier than general path finding, they will be somewhat prioritized, although currently
                // it matters only if the fees are exactly the same.
+let mut a_thing_idx = 0;
                for route in payment_params.route_hints.iter().filter(|route| !route.0.is_empty()) {
                        let first_hop_in_route = &(route.0)[0];
                        let have_hop_src_in_graph =
@@ -1151,7 +1184,9 @@ where L::Target: Logger {
                                let mut aggregate_next_hops_cltv_delta: u32 = 0;
 
                                for (idx, (hop, prev_hop_id)) in hop_iter.zip(prev_hop_iter).enumerate() {
-                                       let source = NodeId::from_pubkey(&hop.src_node_id);
+let source = &first_hops_node_ids[a_thing_idx];
+a_thing_idx += 1;
+                                       //let source = NodeId::from_pubkey(&hop.src_node_id);
                                        let target = NodeId::from_pubkey(&prev_hop_id);
                                        let candidate = network_channels
                                                .get(&hop.short_channel_id)
@@ -1161,7 +1196,7 @@ where L::Target: Logger {
                                                        info,
                                                        short_channel_id: hop.short_channel_id,
                                                })
-                                               .unwrap_or_else(|| CandidateRouteHop::PrivateHop { hint: hop });
+                                               .unwrap_or_else(|| CandidateRouteHop::PrivateHop { hint: hop, dst_node_id: &prev_hop_id });
                                        let capacity_msat = candidate.effective_capacity().as_msat();
                                        aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
                                                .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, final_value_msat, capacity_msat, &source, &target))
@@ -1171,7 +1206,7 @@ where L::Target: Logger {
                                                .checked_add(hop.cltv_expiry_delta as u32)
                                                .unwrap_or_else(|| u32::max_value());
 
-                                       if !add_entry!(candidate, source, target, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, aggregate_next_hops_cltv_delta) {
+                                       if !add_entry!(candidate, *source, target, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, aggregate_next_hops_cltv_delta) {
                                                // If this hop was not used then there is no use checking the preceding hops
                                                // in the RouteHint. We can break by just searching for a direct channel between
                                                // last checked hop and first_hop_targets
@@ -1245,23 +1280,20 @@ where L::Target: Logger {
 
                        // Since we're going payee-to-payer, hitting our node as a target means we should stop
                        // traversing the graph and arrange the path out of what we found.
-                       if node_id == our_node_id {
-                               let mut new_entry = dist.remove(&our_node_id).unwrap();
+                       if *node_id.0 == our_node_id {
+                               let mut new_entry = dist.remove(&NodeIdPtr(&our_node_id)).unwrap();
                                let mut ordered_hops = vec!((new_entry.clone(), NodeFeatures::empty()));
 
                                'path_walk: loop {
                                        let mut features_set = false;
-                                       if let Some(first_channels) = first_hop_targets.get(&ordered_hops.last().unwrap().0.node_id) {
-                                               for details in first_channels {
-                                                       if details.short_channel_id.unwrap() == ordered_hops.last().unwrap().0.candidate.short_channel_id() {
-                                                               ordered_hops.last_mut().unwrap().1 = details.counterparty.features.to_context();
-                                                               features_set = true;
-                                                               break;
-                                                       }
-                                               }
+                                       if let CandidateRouteHop::FirstHop { details } = ordered_hops.last().unwrap().0.candidate {//let Some(first_channels) = first_hop_targets.get(&ordered_hops.last().unwrap().0.node_id) {
+                                               debug_assert_eq!(details.short_channel_id.unwrap(), ordered_hops.last().unwrap().0.candidate.short_channel_id());
+                                               ordered_hops.last_mut().unwrap().1 = details.counterparty.features.to_context();
+                                               features_set = true;
                                        }
+                                       let dst_node_id = ordered_hops.last().unwrap().0.candidate.dst_node_id();
                                        if !features_set {
-                                               if let Some(node) = network_nodes.get(&ordered_hops.last().unwrap().0.node_id) {
+                                               if let Some(node) = network_nodes.get(&dst_node_id) {
                                                        if let Some(node_info) = node.announcement_info.as_ref() {
                                                                ordered_hops.last_mut().unwrap().1 = node_info.features.clone();
                                                        } else {
@@ -1272,7 +1304,7 @@ where L::Target: Logger {
                                                        // hop, if the last hop was provided via a BOLT 11 invoice (though we
                                                        // should be able to extend it further as BOLT 11 does have feature
                                                        // flags for the last hop node itself).
-                                                       assert!(ordered_hops.last().unwrap().0.node_id == payee_node_id);
+                                                       assert!(dst_node_id == payee_node_id);
                                                }
                                        }
 
@@ -1280,11 +1312,11 @@ where L::Target: Logger {
                                        // save this path for the payment route. Also, update the liquidity
                                        // remaining on the used hops, so that we take them into account
                                        // while looking for more paths.
-                                       if ordered_hops.last().unwrap().0.node_id == payee_node_id {
+                                       if dst_node_id == payee_node_id {
                                                break 'path_walk;
                                        }
 
-                                       new_entry = match dist.remove(&ordered_hops.last().unwrap().0.node_id) {
+                                       new_entry = match dist.remove(&NodeIdPtr(unsafe { &*(&dst_node_id as *const NodeId) })) {
                                                Some(payment_hop) => payment_hop,
                                                // We can't arrive at None because, if we ever add an entry to targets,
                                                // we also fill in the entry in dist (see add_entry!).
@@ -1357,15 +1389,15 @@ where L::Target: Logger {
                        // If we found a path back to the payee, we shouldn't try to process it again. This is
                        // the equivalent of the `elem.was_processed` check in
                        // add_entries_to_cheapest_to_target_node!() (see comment there for more info).
-                       if node_id == payee_node_id { continue 'path_construction; }
+                       if *node_id.0 == payee_node_id { continue 'path_construction; }
 
                        // Otherwise, since the current target node is not us,
                        // keep "unrolling" the payment graph from payee to payer by
                        // finding a way to reach the current target from the payer side.
-                       match network_nodes.get(&node_id) {
+                       match network_nodes.get(node_id.0) {
                                None => {},
                                Some(node) => {
-                                       add_entries_to_cheapest_to_target_node!(node, node_id, lowest_fee_to_node, value_contribution_msat, path_htlc_minimum_msat, path_penalty_msat, total_cltv_delta);
+                                       add_entries_to_cheapest_to_target_node!(node, *node_id.0, lowest_fee_to_node, value_contribution_msat, path_htlc_minimum_msat, path_penalty_msat, total_cltv_delta);
                                },
                        }
                }
@@ -1487,8 +1519,9 @@ where L::Target: Logger {
        let mut selected_paths = Vec::<Vec<Result<RouteHop, LightningError>>>::new();
        for payment_path in drawn_routes.first().unwrap() {
                let mut path = payment_path.hops.iter().map(|(payment_hop, node_features)| {
+                       let dst_node_id = payment_hop.candidate.dst_node_id();
                        Ok(RouteHop {
-                               pubkey: PublicKey::from_slice(payment_hop.node_id.as_slice()).map_err(|_| LightningError{err: format!("Public key {:?} is invalid", &payment_hop.node_id), action: ErrorAction::IgnoreAndLog(Level::Trace)})?,
+                               pubkey: PublicKey::from_slice(dst_node_id.as_slice()).map_err(|_| LightningError{err: format!("Public key {:?} is invalid", &dst_node_id), action: ErrorAction::IgnoreAndLog(Level::Trace)})?,
                                node_features: node_features.clone(),
                                short_channel_id: payment_hop.candidate.short_channel_id(),
                                channel_features: payment_hop.candidate.features(),