f - remove Probability enum
[rust-lightning] / lightning / src / routing / scoring.rs
index 8db2c3681146101c5027ee7db14a65744e14c354..8e9a002dd519a6bef57316da1d0ed6b7ebebe811 100644 (file)
 //! #
 //! // Use the default channel penalties.
 //! let params = ProbabilisticScoringParameters::default();
-//! let scorer = ProbabilisticScorer::new(params, payer, &network_graph);
+//! let scorer = ProbabilisticScorer::new(params, &payer, &network_graph);
 //!
 //! // Or use custom channel penalties.
 //! let params = ProbabilisticScoringParameters {
 //!     liquidity_penalty_multiplier_msat: 2 * 1000,
 //!     ..ProbabilisticScoringParameters::default()
 //! };
-//! let scorer = ProbabilisticScorer::new(params, payer, &network_graph);
+//! let scorer = ProbabilisticScorer::new(params, &payer, &network_graph);
 //!
 //! let route = find_route(&payer, &route_params, &network_graph, None, &logger, &scorer);
 //! # }
@@ -504,22 +504,15 @@ struct DirectedChannelLiquidity<L: Deref<Target = u64>> {
        capacity_msat: u64,
 }
 
-/// The likelihood of an event occurring.
-enum Probability {
-       Zero,
-       One,
-       Ratio { numerator: u64, denominator: u64 },
-}
-
 impl<G: Deref<Target = NetworkGraph>> ProbabilisticScorer<G> {
        /// Creates a new scorer using the given scoring parameters for sending payments from a node
        /// through a network graph.
        pub fn new(
-               params: ProbabilisticScoringParameters, node_pubkey: PublicKey, network_graph: G
+               params: ProbabilisticScoringParameters, node_pubkey: &PublicKey, network_graph: G
        ) -> Self {
                Self {
                        params,
-                       node_id: NodeId::from_pubkey(&node_pubkey),
+                       node_id: NodeId::from_pubkey(node_pubkey),
                        network_graph,
                        channel_liquidities: HashMap::new(),
                }
@@ -589,21 +582,17 @@ impl ChannelLiquidity {
 impl<L: Deref<Target = u64>> DirectedChannelLiquidity<L> {
        /// Returns the success probability of routing the given HTLC `amount_msat` through the channel
        /// in this direction.
-       fn success_probability(&self, amount_msat: u64) -> Probability {
+       fn success_probability(&self, amount_msat: u64) -> f64 {
                let max_liquidity_msat = self.max_liquidity_msat();
                let min_liquidity_msat = core::cmp::min(self.min_liquidity_msat(), max_liquidity_msat);
                if amount_msat > max_liquidity_msat {
-                       Probability::Zero
-               } else if amount_msat < min_liquidity_msat {
-                       Probability::One
+                       0.0
+               } else if amount_msat <= min_liquidity_msat {
+                       1.0
                } else {
                        let numerator = max_liquidity_msat + 1 - amount_msat;
                        let denominator = max_liquidity_msat + 1 - min_liquidity_msat;
-                       if numerator == denominator {
-                               Probability::One
-                       } else {
-                               Probability::Ratio { numerator, denominator }
-                       }
+                       numerator as f64 / denominator as f64
                }
        }
 
@@ -673,13 +662,12 @@ impl<G: Deref<Target = NetworkGraph>> Score for ProbabilisticScorer<G> {
                        .unwrap_or(&ChannelLiquidity::new())
                        .as_directed(source, target, capacity_msat)
                        .success_probability(amount_msat);
-               match success_probability {
-                       Probability::Zero => u64::max_value(),
-                       Probability::One => 0,
-                       Probability::Ratio { numerator, denominator } => {
-                               let success_probability = numerator as f64 / denominator as f64;
-                               (-(success_probability.log10()) * liquidity_penalty_multiplier_msat as f64) as u64
-                       },
+               if success_probability == 0.0 {
+                       u64::max_value()
+               } else if success_probability == 1.0 {
+                       0
+               } else {
+                       (-(success_probability.log10()) * liquidity_penalty_multiplier_msat as f64) as u64
                }
        }
 
@@ -747,20 +735,20 @@ impl<G: Deref<Target = NetworkGraph>> Writeable for ProbabilisticScorer<G> {
        #[inline]
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
                self.params.write(w)?;
-               self.node_id.write(w)?;
                self.channel_liquidities.write(w)?;
                write_tlv_fields!(w, {});
                Ok(())
        }
 }
 
-impl<G: Deref<Target = NetworkGraph>> ReadableArgs<G> for ProbabilisticScorer<G> {
+impl<G: Deref<Target = NetworkGraph>> ReadableArgs<(&PublicKey, G)> for ProbabilisticScorer<G> {
        #[inline]
-       fn read<R: Read>(r: &mut R, args: G) -> Result<Self, DecodeError> {
+       fn read<R: Read>(r: &mut R, args: (&PublicKey, G)) -> Result<Self, DecodeError> {
+               let (node_pubkey, network_graph) = args;
                let res = Ok(Self {
                        params: Readable::read(r)?,
-                       node_id: Readable::read(r)?,
-                       network_graph: args,
+                       node_id: NodeId::from_pubkey(node_pubkey),
+                       network_graph,
                        channel_liquidities: Readable::read(r)?,
                });
                read_tlv_fields!(r, {});
@@ -1334,7 +1322,7 @@ mod tests {
        fn liquidity_bounds_directed_from_lowest_node_id() {
                let network_graph = network_graph();
                let params = ProbabilisticScoringParameters::default();
-               let mut scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph)
+               let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
                        .with_channel(42,
                                ChannelLiquidity {
                                        min_liquidity_offset_msat: 700, max_liquidity_offset_msat: 100
@@ -1378,7 +1366,7 @@ mod tests {
        fn resets_liquidity_upper_bound_when_crossed_by_lower_bound() {
                let network_graph = network_graph();
                let params = ProbabilisticScoringParameters::default();
-               let mut scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph)
+               let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
                        .with_channel(42,
                                ChannelLiquidity {
                                        min_liquidity_offset_msat: 200, max_liquidity_offset_msat: 400
@@ -1433,7 +1421,7 @@ mod tests {
        fn resets_liquidity_lower_bound_when_crossed_by_upper_bound() {
                let network_graph = network_graph();
                let params = ProbabilisticScoringParameters::default();
-               let mut scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph)
+               let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
                        .with_channel(42,
                                ChannelLiquidity {
                                        min_liquidity_offset_msat: 200, max_liquidity_offset_msat: 400
@@ -1488,7 +1476,7 @@ mod tests {
        fn increased_penalty_nearing_liquidity_upper_bound() {
                let network_graph = network_graph();
                let params = ProbabilisticScoringParameters::default();
-               let scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph);
+               let scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
                let source = source_node_id();
                let target = target_node_id();
 
@@ -1510,7 +1498,7 @@ mod tests {
        fn constant_penalty_outside_liquidity_bounds() {
                let network_graph = network_graph();
                let params = ProbabilisticScoringParameters::default();
-               let scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph)
+               let scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph)
                        .with_channel(42,
                                ChannelLiquidity { min_liquidity_offset_msat: 40, max_liquidity_offset_msat: 40 });
                let source = source_node_id();
@@ -1526,7 +1514,7 @@ mod tests {
        fn does_not_penalize_own_channel() {
                let network_graph = network_graph();
                let params = ProbabilisticScoringParameters::default();
-               let mut scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph);
+               let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
                let sender = sender_node_id();
                let source = source_node_id();
                let failed_path = payment_path_for_amount(500);
@@ -1545,7 +1533,7 @@ mod tests {
        fn sets_liquidity_lower_bound_on_downstream_failure() {
                let network_graph = network_graph();
                let params = ProbabilisticScoringParameters::default();
-               let mut scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph);
+               let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
                let source = source_node_id();
                let target = target_node_id();
                let path = payment_path_for_amount(500);
@@ -1565,7 +1553,7 @@ mod tests {
        fn sets_liquidity_upper_bound_on_failure() {
                let network_graph = network_graph();
                let params = ProbabilisticScoringParameters::default();
-               let mut scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph);
+               let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
                let source = source_node_id();
                let target = target_node_id();
                let path = payment_path_for_amount(500);
@@ -1585,7 +1573,7 @@ mod tests {
        fn reduces_liquidity_upper_bound_along_path_on_success() {
                let network_graph = network_graph();
                let params = ProbabilisticScoringParameters::default();
-               let mut scorer = ProbabilisticScorer::new(params, sender_pubkey(), &network_graph);
+               let mut scorer = ProbabilisticScorer::new(params, &sender_pubkey(), &network_graph);
                let sender = sender_node_id();
                let source = source_node_id();
                let target = target_node_id();