Merge pull request #1226 from TheBlueMatt/2022-01-bindings-no-std
[rust-lightning] / lightning / src / routing / router.rs
index 24f042d68e4fe368f7fbef7128e865782042bcd8..e49844383bbceee091e9d42b7262416109b2d484 100644 (file)
@@ -17,7 +17,7 @@ use bitcoin::secp256k1::key::PublicKey;
 use ln::channelmanager::ChannelDetails;
 use ln::features::{ChannelFeatures, InvoiceFeatures, NodeFeatures};
 use ln::msgs::{DecodeError, ErrorAction, LightningError, MAX_VALUE_MSAT};
-use routing;
+use routing::scoring::Score;
 use routing::network_graph::{NetworkGraph, NodeId, RoutingFees};
 use util::ser::{Writeable, Readable};
 use util::logger::{Level, Logger};
@@ -529,7 +529,7 @@ fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option<u64> {
 ///
 /// [`ChannelManager::list_usable_channels`]: crate::ln::channelmanager::ChannelManager::list_usable_channels
 /// [`Event::PaymentPathFailed`]: crate::util::events::Event::PaymentPathFailed
-pub fn find_route<L: Deref, S: routing::Score>(
+pub fn find_route<L: Deref, S: Score>(
        our_node_pubkey: &PublicKey, params: &RouteParameters, network: &NetworkGraph,
        first_hops: Option<&[&ChannelDetails]>, logger: L, scorer: &S
 ) -> Result<Route, LightningError>
@@ -540,7 +540,7 @@ where L::Target: Logger {
        )
 }
 
-pub(crate) fn get_route<L: Deref, S: routing::Score>(
+pub(crate) fn get_route<L: Deref, S: Score>(
        our_node_pubkey: &PublicKey, payee: &Payee, network: &NetworkGraph,
        first_hops: Option<&[&ChannelDetails]>, final_value_msat: u64, final_cltv_expiry_delta: u32,
        logger: L, scorer: &S
@@ -892,9 +892,9 @@ where L::Target: Logger {
                                                                }
                                                        }
 
-                                                       let path_penalty_msat = $next_hops_path_penalty_msat
-                                                               .checked_add(scorer.channel_penalty_msat($chan_id.clone(), &$src_node_id, &$dest_node_id))
-                                                               .unwrap_or_else(|| u64::max_value());
+                                                       let path_penalty_msat = $next_hops_path_penalty_msat.checked_add(
+                                                               scorer.channel_penalty_msat($chan_id.clone(), amount_to_transfer_over_msat, Some(*available_liquidity_msat),
+                                                                       &$src_node_id, &$dest_node_id)).unwrap_or_else(|| u64::max_value());
                                                        let new_graph_node = RouteGraphNode {
                                                                node_id: $src_node_id,
                                                                lowest_fee_to_peer_through_node: total_fee_msat,
@@ -1127,7 +1127,7 @@ where L::Target: Logger {
                                        let src_node_id = NodeId::from_pubkey(&hop.src_node_id);
                                        let dest_node_id = NodeId::from_pubkey(&prev_hop_id);
                                        aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
-                                               .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, &src_node_id, &dest_node_id))
+                                               .checked_add(scorer.channel_penalty_msat(hop.short_channel_id, final_value_msat, None, &src_node_id, &dest_node_id))
                                                .unwrap_or_else(|| u64::max_value());
 
                                        // We assume that the recipient only included route hints for routes which had
@@ -1478,7 +1478,7 @@ where L::Target: Logger {
 
 #[cfg(test)]
 mod tests {
-       use routing;
+       use routing::scoring::Score;
        use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
        use routing::router::{get_route, Payee, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees};
        use chain::transaction::OutPoint;
@@ -1488,6 +1488,8 @@ mod tests {
        use ln::channelmanager;
        use util::test_utils;
        use util::ser::Writeable;
+       #[cfg(c_bindings)]
+       use util::ser::Writer;
 
        use bitcoin::hashes::sha256d::Hash as Sha256dHash;
        use bitcoin::hashes::Hash;
@@ -1519,6 +1521,7 @@ mod tests {
                        short_channel_id,
                        channel_value_satoshis: 0,
                        user_channel_id: 0,
+                       balance_msat: 0,
                        outbound_capacity_msat,
                        inbound_capacity_msat: 42,
                        unspendable_punishment_reserve: None,
@@ -4653,24 +4656,35 @@ mod tests {
                short_channel_id: u64,
        }
 
-       impl routing::Score for BadChannelScorer {
-               fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 {
+       #[cfg(c_bindings)]
+       impl Writeable for BadChannelScorer {
+               fn write<W: Writer>(&self, _w: &mut W) -> Result<(), ::io::Error> { unimplemented!() }
+       }
+       impl Score for BadChannelScorer {
+               fn channel_penalty_msat(&self, short_channel_id: u64, _send_amt: u64, _chan_amt: Option<u64>, _source: &NodeId, _target: &NodeId) -> u64 {
                        if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
                }
 
                fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}
+               fn payment_path_successful(&mut self, _path: &[&RouteHop]) {}
        }
 
        struct BadNodeScorer {
                node_id: NodeId,
        }
 
-       impl routing::Score for BadNodeScorer {
-               fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 {
+       #[cfg(c_bindings)]
+       impl Writeable for BadNodeScorer {
+               fn write<W: Writer>(&self, _w: &mut W) -> Result<(), ::io::Error> { unimplemented!() }
+       }
+
+       impl Score for BadNodeScorer {
+               fn channel_penalty_msat(&self, _short_channel_id: u64, _send_amt: u64, _chan_amt: Option<u64>, _source: &NodeId, target: &NodeId) -> u64 {
                        if *target == self.node_id { u64::max_value() } else { 0 }
                }
 
                fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}
+               fn payment_path_successful(&mut self, _path: &[&RouteHop]) {}
        }
 
        #[test]
@@ -4891,7 +4905,7 @@ pub(crate) mod test_utils {
 #[cfg(all(test, feature = "unstable", not(feature = "no-std")))]
 mod benches {
        use super::*;
-       use routing::scorer::Scorer;
+       use routing::scoring::Scorer;
        use util::logger::{Logger, Record};
 
        use test::Bencher;