Merge pull request #1504 from TheBlueMatt/2022-05-pub-io
[rust-lightning] / lightning / src / routing / router.rs
index b48c1e6ab7ab090a97032dd95587ec18255e3462..8731e66da9bd866bd82d00bc42c8922cd25508e8 100644 (file)
@@ -17,9 +17,9 @@ use bitcoin::secp256k1::PublicKey;
 use ln::channelmanager::ChannelDetails;
 use ln::features::{ChannelFeatures, InvoiceFeatures, NodeFeatures};
 use ln::msgs::{DecodeError, ErrorAction, LightningError, MAX_VALUE_MSAT};
-use routing::scoring::Score;
+use routing::scoring::{ChannelUsage, Score};
 use routing::network_graph::{DirectedChannelInfoWithUpdate, EffectiveCapacity, NetworkGraph, ReadOnlyNetworkGraph, NodeId, RoutingFees};
-use util::ser::{Writeable, Readable};
+use util::ser::{Writeable, Readable, Writer};
 use util::logger::{Level, Logger};
 use util::chacha20::ChaCha20;
 
@@ -151,8 +151,8 @@ impl Readable for Route {
 
 /// Parameters needed to find a [`Route`].
 ///
-/// Passed to [`find_route`] and also provided in [`Event::PaymentPathFailed`] for retrying a failed
-/// payment path.
+/// Passed to [`find_route`] and [`build_route_from_hops`], but also provided in
+/// [`Event::PaymentPathFailed`] for retrying a failed payment path.
 ///
 /// [`Event::PaymentPathFailed`]: crate::util::events::Event::PaymentPathFailed
 #[derive(Clone, Debug)]
@@ -383,7 +383,7 @@ enum CandidateRouteHop<'a> {
 impl<'a> CandidateRouteHop<'a> {
        fn short_channel_id(&self) -> u64 {
                match self {
-                       CandidateRouteHop::FirstHop { details } => details.short_channel_id.unwrap(),
+                       CandidateRouteHop::FirstHop { details } => details.get_outbound_payment_scid().unwrap(),
                        CandidateRouteHop::PublicHop { short_channel_id, .. } => *short_channel_id,
                        CandidateRouteHop::PrivateHop { hint } => hint.short_channel_id,
                }
@@ -414,6 +414,16 @@ impl<'a> CandidateRouteHop<'a> {
                }
        }
 
+       fn htlc_maximum_msat(&self) -> u64 {
+               match self {
+                       CandidateRouteHop::FirstHop { details } => details.next_outbound_htlc_limit_msat,
+                       CandidateRouteHop::PublicHop { info, .. } => info.htlc_maximum_msat(),
+                       CandidateRouteHop::PrivateHop { hint } => {
+                               hint.htlc_maximum_msat.unwrap_or(u64::max_value())
+                       },
+               }
+       }
+
        fn fees(&self) -> RoutingFees {
                match self {
                        CandidateRouteHop::FirstHop { .. } => RoutingFees {
@@ -666,16 +676,11 @@ pub fn find_route<L: Deref, S: Score>(
 ) -> Result<Route, LightningError>
 where L::Target: Logger {
        let network_graph = network.read_only();
-       match get_route(
-               our_node_pubkey, &route_params.payment_params, &network_graph, first_hops, route_params.final_value_msat,
-               route_params.final_cltv_expiry_delta, logger, scorer, random_seed_bytes
-       ) {
-               Ok(mut route) => {
-                       add_random_cltv_offset(&mut route, &route_params.payment_params, &network_graph, random_seed_bytes);
-                       Ok(route)
-               },
-               Err(err) => Err(err),
-       }
+       let mut route = get_route(our_node_pubkey, &route_params.payment_params, &network_graph, first_hops,
+               route_params.final_value_msat, route_params.final_cltv_expiry_delta, logger, scorer,
+               random_seed_bytes)?;
+       add_random_cltv_offset(&mut route, &route_params.payment_params, &network_graph, random_seed_bytes);
+       Ok(route)
 }
 
 pub(crate) fn get_route<L: Deref, S: Score>(
@@ -796,7 +801,7 @@ where L::Target: Logger {
                HashMap::with_capacity(if first_hops.is_some() { first_hops.as_ref().unwrap().len() } else { 0 });
        if let Some(hops) = first_hops {
                for chan in hops {
-                       if chan.short_channel_id.is_none() {
+                       if chan.get_outbound_payment_scid().is_none() {
                                panic!("first_hops should be filled in with usable channels, not pending ones");
                        }
                        if chan.counterparty.node_id == *our_node_pubkey {
@@ -834,12 +839,12 @@ where L::Target: Logger {
        let recommended_value_msat = final_value_msat * ROUTE_CAPACITY_PROVISION_FACTOR as u64;
        let mut path_value_msat = final_value_msat;
 
-       // We don't want multiple paths (as per MPP) share liquidity of the same channels.
-       // This map allows paths to be aware of the channel use by other paths in the same call.
-       // This would help to make a better path finding decisions and not "overbook" channels.
-       // It is unaware of the directions (except for `next_outbound_htlc_limit_msat` in
-       // `first_hops`).
-       let mut bookkept_channels_liquidity_available_msat = HashMap::with_capacity(network_nodes.len());
+       // Keep track of how much liquidity has been used in selected channels. Used to determine
+       // if the channel can be used by additional MPP paths or to inform path finding decisions. It is
+       // aware of direction *only* to ensure that the correct htlc_maximum_msat value is used. Hence,
+       // liquidity used in one direction will not offset any used in the opposite direction.
+       let mut used_channel_liquidities: HashMap<(u64, bool), u64> =
+               HashMap::with_capacity(network_nodes.len());
 
        // Keeping track of how much value we already collected across other paths. Helps to decide:
        // - how much a new path should be transferring (upper bound);
@@ -889,9 +894,7 @@ where L::Target: Logger {
                        // - for first and last hops early in get_route
                        if $src_node_id != $dest_node_id {
                                let short_channel_id = $candidate.short_channel_id();
-                               let available_liquidity_msat = bookkept_channels_liquidity_available_msat
-                                       .entry(short_channel_id)
-                                       .or_insert_with(|| $candidate.effective_capacity().as_msat());
+                               let htlc_maximum_msat = $candidate.htlc_maximum_msat();
 
                                // It is tricky to subtract $next_hops_fee_msat from available liquidity here.
                                // It may be misleading because we might later choose to reduce the value transferred
@@ -900,7 +903,14 @@ where L::Target: Logger {
                                // fees caused by one expensive channel, but then this channel could have been used
                                // if the amount being transferred over this path is lower.
                                // We do this for now, but this is a subject for removal.
-                               if let Some(available_value_contribution_msat) = available_liquidity_msat.checked_sub($next_hops_fee_msat) {
+                               if let Some(mut available_value_contribution_msat) = htlc_maximum_msat.checked_sub($next_hops_fee_msat) {
+                                       let used_liquidity_msat = used_channel_liquidities
+                                               .get(&(short_channel_id, $src_node_id < $dest_node_id))
+                                               .map_or(0, |used_liquidity_msat| {
+                                                       available_value_contribution_msat = available_value_contribution_msat
+                                                               .saturating_sub(*used_liquidity_msat);
+                                                       *used_liquidity_msat
+                                               });
 
                                        // Routing Fragmentation Mitigation heuristic:
                                        //
@@ -1051,9 +1061,16 @@ where L::Target: Logger {
                                                                }
                                                        }
 
-                                                       let path_penalty_msat = $next_hops_path_penalty_msat.saturating_add(
-                                                               scorer.channel_penalty_msat(short_channel_id, amount_to_transfer_over_msat,
-                                                                       *available_liquidity_msat, &$src_node_id, &$dest_node_id));
+                                                       let channel_usage = ChannelUsage {
+                                                               amount_msat: amount_to_transfer_over_msat,
+                                                               inflight_htlc_msat: used_liquidity_msat,
+                                                               effective_capacity: $candidate.effective_capacity(),
+                                                       };
+                                                       let channel_penalty_msat = scorer.channel_penalty_msat(
+                                                               short_channel_id, &$src_node_id, &$dest_node_id, channel_usage
+                                                       );
+                                                       let path_penalty_msat = $next_hops_path_penalty_msat
+                                                               .saturating_add(channel_penalty_msat);
                                                        let new_graph_node = RouteGraphNode {
                                                                node_id: $src_node_id,
                                                                lowest_fee_to_peer_through_node: total_fee_msat,
@@ -1211,9 +1228,8 @@ where L::Target: Logger {
 
        // TODO: diversify by nodes (so that all paths aren't doomed if one node is offline).
        'paths_collection: loop {
-               // For every new path, start from scratch, except
-               // bookkept_channels_liquidity_available_msat, which will improve
-               // the further iterations of path finding. Also don't erase first_hop_targets.
+               // For every new path, start from scratch, except for used_channel_liquidities, which
+               // helps to avoid reusing previously selected paths in future iterations.
                targets.clear();
                dist.clear();
                hit_minimum_limit = false;
@@ -1280,16 +1296,6 @@ where L::Target: Logger {
                                                        short_channel_id: hop.short_channel_id,
                                                })
                                                .unwrap_or_else(|| CandidateRouteHop::PrivateHop { hint: hop });
-                                       let capacity_msat = candidate.effective_capacity().as_msat();
-                                       aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
-                                               .saturating_add(scorer.channel_penalty_msat(hop.short_channel_id,
-                                                       final_value_msat, capacity_msat, &source, &target));
-
-                                       aggregate_next_hops_cltv_delta = aggregate_next_hops_cltv_delta
-                                               .saturating_add(hop.cltv_expiry_delta as u32);
-
-                                       aggregate_next_hops_path_length = aggregate_next_hops_path_length
-                                               .saturating_add(1);
 
                                        if !add_entry!(candidate, source, target, aggregate_next_hops_fee_msat,
                                                                path_value_msat, aggregate_next_hops_path_htlc_minimum_msat,
@@ -1301,6 +1307,25 @@ where L::Target: Logger {
                                                hop_used = false;
                                        }
 
+                                       let used_liquidity_msat = used_channel_liquidities
+                                               .get(&(hop.short_channel_id, source < target)).copied().unwrap_or(0);
+                                       let channel_usage = ChannelUsage {
+                                               amount_msat: final_value_msat + aggregate_next_hops_fee_msat,
+                                               inflight_htlc_msat: used_liquidity_msat,
+                                               effective_capacity: candidate.effective_capacity(),
+                                       };
+                                       let channel_penalty_msat = scorer.channel_penalty_msat(
+                                               hop.short_channel_id, &source, &target, channel_usage
+                                       );
+                                       aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
+                                               .saturating_add(channel_penalty_msat);
+
+                                       aggregate_next_hops_cltv_delta = aggregate_next_hops_cltv_delta
+                                               .saturating_add(hop.cltv_expiry_delta as u32);
+
+                                       aggregate_next_hops_path_length = aggregate_next_hops_path_length
+                                               .saturating_add(1);
+
                                        // Searching for a direct channel between last checked hop and first_hop_targets
                                        if let Some(first_channels) = first_hop_targets.get(&NodeId::from_pubkey(&prev_hop_id)) {
                                                for details in first_channels {
@@ -1386,7 +1411,7 @@ where L::Target: Logger {
                                        let mut features_set = false;
                                        if let Some(first_channels) = first_hop_targets.get(&ordered_hops.last().unwrap().0.node_id) {
                                                for details in first_channels {
-                                                       if details.short_channel_id.unwrap() == ordered_hops.last().unwrap().0.candidate.short_channel_id() {
+                                                       if details.get_outbound_payment_scid().unwrap() == ordered_hops.last().unwrap().0.candidate.short_channel_id() {
                                                                ordered_hops.last_mut().unwrap().1 = details.counterparty.features.to_context();
                                                                features_set = true;
                                                                break;
@@ -1452,26 +1477,30 @@ where L::Target: Logger {
                                // Remember that we used these channels so that we don't rely
                                // on the same liquidity in future paths.
                                let mut prevented_redundant_path_selection = false;
-                               for (payment_hop, _) in payment_path.hops.iter() {
-                                       let channel_liquidity_available_msat = bookkept_channels_liquidity_available_msat.get_mut(&payment_hop.candidate.short_channel_id()).unwrap();
-                                       let mut spent_on_hop_msat = value_contribution_msat;
-                                       let next_hops_fee_msat = payment_hop.next_hops_fee_msat;
-                                       spent_on_hop_msat += next_hops_fee_msat;
-                                       if spent_on_hop_msat == *channel_liquidity_available_msat {
+                               let prev_hop_iter = core::iter::once(&our_node_id)
+                                       .chain(payment_path.hops.iter().map(|(hop, _)| &hop.node_id));
+                               for (prev_hop, (hop, _)) in prev_hop_iter.zip(payment_path.hops.iter()) {
+                                       let spent_on_hop_msat = value_contribution_msat + hop.next_hops_fee_msat;
+                                       let used_liquidity_msat = used_channel_liquidities
+                                               .entry((hop.candidate.short_channel_id(), *prev_hop < hop.node_id))
+                                               .and_modify(|used_liquidity_msat| *used_liquidity_msat += spent_on_hop_msat)
+                                               .or_insert(spent_on_hop_msat);
+                                       if *used_liquidity_msat == hop.candidate.htlc_maximum_msat() {
                                                // If this path used all of this channel's available liquidity, we know
                                                // this path will not be selected again in the next loop iteration.
                                                prevented_redundant_path_selection = true;
                                        }
-                                       *channel_liquidity_available_msat -= spent_on_hop_msat;
+                                       debug_assert!(*used_liquidity_msat <= hop.candidate.htlc_maximum_msat());
                                }
                                if !prevented_redundant_path_selection {
                                        // If we weren't capped by hitting a liquidity limit on a channel in the path,
                                        // we'll probably end up picking the same path again on the next iteration.
                                        // Decrease the available liquidity of a hop in the middle of the path.
                                        let victim_scid = payment_path.hops[(payment_path.hops.len()) / 2].0.candidate.short_channel_id();
+                                       let exhausted = u64::max_value();
                                        log_trace!(logger, "Disabling channel {} for future path building iterations to avoid duplicates.", victim_scid);
-                                       let victim_liquidity = bookkept_channels_liquidity_available_msat.get_mut(&victim_scid).unwrap();
-                                       *victim_liquidity = 0;
+                                       *used_channel_liquidities.entry((victim_scid, false)).or_default() = exhausted;
+                                       *used_channel_liquidities.entry((victim_scid, true)).or_default() = exhausted;
                                }
 
                                // Track the total amount all our collected paths allow to send so that we:
@@ -1669,7 +1698,9 @@ where L::Target: Logger {
 // destination, if the remaining CLTV expiry delta exactly matches a feasible path in the network
 // graph. In order to improve privacy, this method obfuscates the CLTV expiry deltas along the
 // payment path by adding a randomized 'shadow route' offset to the final hop.
-fn add_random_cltv_offset(route: &mut Route, payment_params: &PaymentParameters, network_graph: &ReadOnlyNetworkGraph, random_seed_bytes: &[u8; 32]) {
+fn add_random_cltv_offset(route: &mut Route, payment_params: &PaymentParameters,
+       network_graph: &ReadOnlyNetworkGraph, random_seed_bytes: &[u8; 32]
+) {
        let network_channels = network_graph.channels();
        let network_nodes = network_graph.nodes();
 
@@ -1751,13 +1782,87 @@ fn add_random_cltv_offset(route: &mut Route, payment_params: &PaymentParameters,
        }
 }
 
+/// Construct a route from us (payer) to the target node (payee) via the given hops (which should
+/// exclude the payer, but include the payee). This may be useful, e.g., for probing the chosen path.
+///
+/// Re-uses logic from `find_route`, so the restrictions described there also apply here.
+pub fn build_route_from_hops<L: Deref>(
+       our_node_pubkey: &PublicKey, hops: &[PublicKey], route_params: &RouteParameters, network: &NetworkGraph,
+       logger: L, random_seed_bytes: &[u8; 32]
+) -> Result<Route, LightningError>
+where L::Target: Logger {
+       let network_graph = network.read_only();
+       let mut route = build_route_from_hops_internal(
+               our_node_pubkey, hops, &route_params.payment_params, &network_graph,
+               route_params.final_value_msat, route_params.final_cltv_expiry_delta, logger, random_seed_bytes)?;
+       add_random_cltv_offset(&mut route, &route_params.payment_params, &network_graph, random_seed_bytes);
+       Ok(route)
+}
+
+fn build_route_from_hops_internal<L: Deref>(
+       our_node_pubkey: &PublicKey, hops: &[PublicKey], payment_params: &PaymentParameters,
+       network_graph: &ReadOnlyNetworkGraph, final_value_msat: u64, final_cltv_expiry_delta: u32,
+       logger: L, random_seed_bytes: &[u8; 32]
+) -> Result<Route, LightningError> where L::Target: Logger {
+
+       struct HopScorer {
+               our_node_id: NodeId,
+               hop_ids: [Option<NodeId>; MAX_PATH_LENGTH_ESTIMATE as usize],
+       }
+
+       impl Score for HopScorer {
+               fn channel_penalty_msat(&self, _short_channel_id: u64, source: &NodeId, target: &NodeId,
+                       _usage: ChannelUsage) -> u64
+               {
+                       let mut cur_id = self.our_node_id;
+                       for i in 0..self.hop_ids.len() {
+                               if let Some(next_id) = self.hop_ids[i] {
+                                       if cur_id == *source && next_id == *target {
+                                               return 0;
+                                       }
+                                       cur_id = next_id;
+                               } else {
+                                       break;
+                               }
+                       }
+                       u64::max_value()
+               }
+
+               fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}
+
+               fn payment_path_successful(&mut self, _path: &[&RouteHop]) {}
+       }
+
+       impl<'a> Writeable for HopScorer {
+               #[inline]
+               fn write<W: Writer>(&self, _w: &mut W) -> Result<(), io::Error> {
+                       unreachable!();
+               }
+       }
+
+       if hops.len() > MAX_PATH_LENGTH_ESTIMATE.into() {
+               return Err(LightningError{err: "Cannot build a route exceeding the maximum path length.".to_owned(), action: ErrorAction::IgnoreError});
+       }
+
+       let our_node_id = NodeId::from_pubkey(our_node_pubkey);
+       let mut hop_ids = [None; MAX_PATH_LENGTH_ESTIMATE as usize];
+       for i in 0..hops.len() {
+               hop_ids[i] = Some(NodeId::from_pubkey(&hops[i]));
+       }
+
+       let scorer = HopScorer { our_node_id, hop_ids };
+
+       get_route(our_node_pubkey, payment_params, network_graph, None, final_value_msat,
+               final_cltv_expiry_delta, logger, &scorer, random_seed_bytes)
+}
+
 #[cfg(test)]
 mod tests {
        use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
-       use routing::router::{get_route, add_random_cltv_offset, default_node_features,
+       use routing::router::{get_route, build_route_from_hops_internal, add_random_cltv_offset, default_node_features,
                PaymentParameters, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees,
                DEFAULT_MAX_TOTAL_CLTV_EXPIRY_DELTA, MAX_PATH_LENGTH_ESTIMATE};
-       use routing::scoring::Score;
+       use routing::scoring::{ChannelUsage, Score};
        use chain::transaction::OutPoint;
        use chain::keysinterface::KeysInterface;
        use ln::features::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
@@ -1801,6 +1906,7 @@ mod tests {
                        funding_txo: Some(OutPoint { txid: bitcoin::Txid::from_slice(&[0; 32]).unwrap(), index: 0 }),
                        channel_type: None,
                        short_channel_id,
+                       outbound_scid_alias: None,
                        inbound_scid_alias: None,
                        channel_value_satoshis: 0,
                        user_channel_id: 0,
@@ -1811,7 +1917,7 @@ mod tests {
                        unspendable_punishment_reserve: None,
                        confirmations_required: None,
                        force_close_spend_delay: None,
-                       is_outbound: true, is_funding_locked: true,
+                       is_outbound: true, is_channel_ready: true,
                        is_usable: true, is_public: true,
                        inbound_htlc_minimum_msat: None,
                        inbound_htlc_maximum_msat: None,
@@ -5149,7 +5255,7 @@ mod tests {
                fn write<W: Writer>(&self, _w: &mut W) -> Result<(), ::io::Error> { unimplemented!() }
        }
        impl Score for BadChannelScorer {
-               fn channel_penalty_msat(&self, short_channel_id: u64, _send_amt: u64, _capacity_msat: u64, _source: &NodeId, _target: &NodeId) -> u64 {
+               fn channel_penalty_msat(&self, short_channel_id: u64, _: &NodeId, _: &NodeId, _: ChannelUsage) -> u64 {
                        if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
                }
 
@@ -5167,7 +5273,7 @@ mod tests {
        }
 
        impl Score for BadNodeScorer {
-               fn channel_penalty_msat(&self, _short_channel_id: u64, _send_amt: u64, _capacity_msat: u64, _source: &NodeId, target: &NodeId) -> u64 {
+               fn channel_penalty_msat(&self, _: u64, _: &NodeId, target: &NodeId, _: ChannelUsage) -> u64 {
                        if *target == self.node_id { u64::max_value() } else { 0 }
                }
 
@@ -5452,6 +5558,26 @@ mod tests {
                assert!(path_plausibility.iter().all(|x| *x));
        }
 
+       #[test]
+       fn builds_correct_path_from_hops() {
+               let (secp_ctx, network, _, _, logger) = build_graph();
+               let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
+               let network_graph = network.read_only();
+
+               let keys_manager = test_utils::TestKeysInterface::new(&[0u8; 32], Network::Testnet);
+               let random_seed_bytes = keys_manager.get_secure_random_bytes();
+
+               let payment_params = PaymentParameters::from_node_id(nodes[3]);
+               let hops = [nodes[1], nodes[2], nodes[4], nodes[3]];
+               let route = build_route_from_hops_internal(&our_id, &hops, &payment_params,
+                        &network_graph, 100, 0, Arc::clone(&logger), &random_seed_bytes).unwrap();
+               let route_hop_pubkeys = route.paths[0].iter().map(|hop| hop.pubkey).collect::<Vec<_>>();
+               assert_eq!(hops.len(), route.paths[0].len());
+               for (idx, hop_pubkey) in hops.iter().enumerate() {
+                       assert!(*hop_pubkey == route_hop_pubkeys[idx]);
+               }
+       }
+
        #[cfg(not(feature = "no-std"))]
        pub(super) fn random_init_seed() -> u64 {
                // Because the default HashMap in std pulls OS randomness, we can use it as a (bad) RNG.
@@ -5613,6 +5739,7 @@ mod benches {
                        channel_type: None,
                        short_channel_id: Some(1),
                        inbound_scid_alias: None,
+                       outbound_scid_alias: None,
                        channel_value_satoshis: 10_000_000,
                        user_channel_id: 0,
                        balance_msat: 10_000_000,
@@ -5623,7 +5750,7 @@ mod benches {
                        confirmations_required: None,
                        force_close_spend_delay: None,
                        is_outbound: true,
-                       is_funding_locked: true,
+                       is_channel_ready: true,
                        is_usable: true,
                        is_public: true,
                        inbound_htlc_minimum_msat: None,