Implement non-strict forwarding
[rust-lightning] / lightning / src / routing / scoring.rs
index 646405c6287ac150d88199890137020e0350079f..77e1a13065201bd9e705db19e795fb6640fe2fed 100644 (file)
@@ -64,7 +64,6 @@ use crate::util::logger::Logger;
 
 use crate::prelude::*;
 use core::{cmp, fmt};
-use core::convert::TryInto;
 use core::ops::{Deref, DerefMut};
 use core::time::Duration;
 use crate::io::{self, Read};
@@ -251,7 +250,7 @@ impl<'a, T: Score + 'a> LockableScore<'a> for RefCell<T> {
        }
 }
 
-#[cfg(not(c_bindings))]
+#[cfg(any(not(c_bindings), feature = "_test_utils", test))]
 impl<'a, T: Score + 'a> LockableScore<'a> for RwLock<T> {
        type ScoreUpdate = T;
        type ScoreLookUp = T;
@@ -653,7 +652,7 @@ impl Default for ProbabilisticScoringFeeParameters {
                        base_penalty_amount_multiplier_msat: 8192,
                        liquidity_penalty_multiplier_msat: 30_000,
                        liquidity_penalty_amount_multiplier_msat: 192,
-                       manual_node_penalties: HashMap::new(),
+                       manual_node_penalties: new_hash_map(),
                        anti_probing_penalty_msat: 250,
                        considered_impossible_penalty_msat: 1_0000_0000_000,
                        historical_liquidity_penalty_multiplier_msat: 10_000,
@@ -695,7 +694,7 @@ impl ProbabilisticScoringFeeParameters {
 
        /// Clears the list of manual penalties that are applied during path finding.
        pub fn clear_manual_penalties(&mut self) {
-               self.manual_node_penalties = HashMap::new();
+               self.manual_node_penalties = new_hash_map();
        }
 }
 
@@ -709,7 +708,7 @@ impl ProbabilisticScoringFeeParameters {
                        liquidity_penalty_amount_multiplier_msat: 0,
                        historical_liquidity_penalty_multiplier_msat: 0,
                        historical_liquidity_penalty_amount_multiplier_msat: 0,
-                       manual_node_penalties: HashMap::new(),
+                       manual_node_penalties: new_hash_map(),
                        anti_probing_penalty_msat: 0,
                        considered_impossible_penalty_msat: 0,
                        linear_success_probability: true,
@@ -819,7 +818,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ProbabilisticScorer<G, L> whe
                        decay_params,
                        network_graph,
                        logger,
-                       channel_liquidities: HashMap::new(),
+                       channel_liquidities: new_hash_map(),
                }
        }
 
@@ -1330,7 +1329,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreLookUp for Probabilistic
                        _ => return 0,
                };
                let source = candidate.source();
-               if let Some(penalty) = score_params.manual_node_penalties.get(&target) {
+               if let Some(penalty) = score_params.manual_node_penalties.get(target) {
                        return *penalty;
                }
 
@@ -1360,7 +1359,7 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreLookUp for Probabilistic
                let amount_msat = usage.amount_msat.saturating_add(usage.inflight_htlc_msat);
                let capacity_msat = usage.effective_capacity.as_msat();
                self.channel_liquidities
-                       .get(&scid)
+                       .get(scid)
                        .unwrap_or(&ChannelLiquidity::new(Duration::ZERO))
                        .as_directed(&source, &target, capacity_msat)
                        .penalty_msat(amount_msat, score_params)
@@ -2073,7 +2072,7 @@ ReadableArgs<(ProbabilisticScoringDecayParameters, G, L)> for ProbabilisticScore
                r: &mut R, args: (ProbabilisticScoringDecayParameters, G, L)
        ) -> Result<Self, DecodeError> {
                let (decay_params, network_graph, logger) = args;
-               let mut channel_liquidities = HashMap::new();
+               let mut channel_liquidities = new_hash_map();
                read_tlv_fields!(r, {
                        (0, channel_liquidities, required),
                });
@@ -2153,7 +2152,7 @@ impl Readable for ChannelLiquidity {
 #[cfg(test)]
 mod tests {
        use super::{ChannelLiquidity, HistoricalBucketRangeTracker, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters, ProbabilisticScorer};
-       use crate::blinded_path::{BlindedHop, BlindedPath};
+       use crate::blinded_path::{BlindedHop, BlindedPath, IntroductionNode};
        use crate::util::config::UserConfig;
 
        use crate::ln::channelmanager;
@@ -2167,7 +2166,7 @@ mod tests {
        use bitcoin::blockdata::constants::ChainHash;
        use bitcoin::hashes::Hash;
        use bitcoin::hashes::sha256d::Hash as Sha256dHash;
-       use bitcoin::network::constants::Network;
+       use bitcoin::network::Network;
        use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
        use core::time::Duration;
        use crate::io;
@@ -3568,7 +3567,7 @@ mod tests {
                let mut path = payment_path_for_amount(768);
                let recipient_hop = path.hops.pop().unwrap();
                let blinded_path = BlindedPath {
-                       introduction_node_id: path.hops.last().as_ref().unwrap().pubkey,
+                       introduction_node: IntroductionNode::NodeId(path.hops.last().as_ref().unwrap().pubkey),
                        blinding_point: test_utils::pubkey(42),
                        blinded_hops: vec![
                                BlindedHop { blinded_node_id: test_utils::pubkey(44), encrypted_payload: Vec::new() }
@@ -3684,44 +3683,7 @@ pub mod benches {
 
        pub fn decay_100k_channel_bounds(bench: &mut Criterion) {
                let logger = TestLogger::new();
-               let network_graph = bench_utils::read_network_graph(&logger).unwrap();
-               let mut scorer = ProbabilisticScorer::new(Default::default(), &network_graph, &logger);
-               // Score a number of random channels
-               let mut seed: u64 = 0xdeadbeef;
-               for _ in 0..100_000 {
-                       seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0;
-                       let (victim, victim_dst, amt) = {
-                               let rong = network_graph.read_only();
-                               let channels = rong.channels();
-                               let chan = channels.unordered_iter()
-                                       .skip((seed as usize) % channels.len())
-                                       .next().unwrap();
-                               seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0;
-                               let amt = seed % chan.1.capacity_sats.map(|c| c * 1000)
-                                       .or(chan.1.one_to_two.as_ref().map(|info| info.htlc_maximum_msat))
-                                       .or(chan.1.two_to_one.as_ref().map(|info| info.htlc_maximum_msat))
-                                       .unwrap_or(1_000_000_000).saturating_add(1);
-                               (*chan.0, chan.1.node_two, amt)
-                       };
-                       let path = Path {
-                               hops: vec![RouteHop {
-                                       pubkey: victim_dst.as_pubkey().unwrap(),
-                                       node_features: NodeFeatures::empty(),
-                                       short_channel_id: victim,
-                                       channel_features: ChannelFeatures::empty(),
-                                       fee_msat: amt,
-                                       cltv_expiry_delta: 42,
-                                       maybe_announced_channel: true,
-                               }],
-                               blinded_tail: None
-                       };
-                       seed = seed.overflowing_mul(6364136223846793005).0.overflowing_add(1).0;
-                       if seed % 1 == 0 {
-                               scorer.probe_failed(&path, victim, Duration::ZERO);
-                       } else {
-                               scorer.probe_successful(&path, Duration::ZERO);
-                       }
-               }
+               let (network_graph, mut scorer) = bench_utils::read_graph_scorer(&logger).unwrap();
                let mut cur_time = Duration::ZERO;
                        cur_time += Duration::from_millis(1);
                        scorer.time_passed(cur_time);