Move Scorer requirement away from Router trait
[rust-lightning] / lightning-invoice / src / utils.rs
index c79e590450a2dbda227db5aedc6645607c9b555c..306714b0789ceefc7d23cd41ba3d76bc0decd784 100644 (file)
@@ -1,7 +1,7 @@
 //! Convenient utilities to create an invoice.
 
 use {CreationError, Currency, Invoice, InvoiceBuilder, SignOrCreationError};
-use payment::{Payer, Router};
+use payment::{InFlightHtlcs, Payer, Router};
 
 use crate::{prelude::*, Description, InvoiceDescription, Sha256};
 use bech32::ToBase32;
@@ -15,9 +15,9 @@ use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, P
 use lightning::ln::channelmanager::{PhantomRouteHints, MIN_CLTV_EXPIRY_DELTA};
 use lightning::ln::inbound_payment::{create, create_from_hash, ExpandedKey};
 use lightning::ln::msgs::LightningError;
-use lightning::routing::gossip::{NetworkGraph, RoutingFees};
-use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route};
-use lightning::routing::scoring::Score;
+use lightning::routing::gossip::{NetworkGraph, NodeId, RoutingFees};
+use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route, RouteHop};
+use lightning::routing::scoring::{ChannelUsage, LockableScore, Score};
 use lightning::util::logger::Logger;
 use secp256k1::PublicKey;
 use core::ops::Deref;
@@ -440,34 +440,63 @@ fn filter_channels(channels: Vec<ChannelDetails>, min_inbound_capacity_msat: Opt
 }
 
 /// A [`Router`] implemented using [`find_route`].
-pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref> where L::Target: Logger {
+pub struct DefaultRouter<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref> where
+       L::Target: Logger,
+       S::Target: for <'a> LockableScore<'a>,
+{
        network_graph: G,
        logger: L,
        random_seed_bytes: Mutex<[u8; 32]>,
+       scorer: S
 }
 
-impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> DefaultRouter<G, L> where L::Target: Logger {
+impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref> DefaultRouter<G, L, S> where
+       L::Target: Logger,
+       S::Target: for <'a> LockableScore<'a>,
+{
        /// Creates a new router using the given [`NetworkGraph`], a [`Logger`], and a randomness source
        /// `random_seed_bytes`.
-       pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32]) -> Self {
+       pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S) -> Self {
                let random_seed_bytes = Mutex::new(random_seed_bytes);
-               Self { network_graph, logger, random_seed_bytes }
+               Self { network_graph, logger, random_seed_bytes, scorer }
        }
 }
 
-impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Score> Router<S> for DefaultRouter<G, L>
-where L::Target: Logger {
+impl<G: Deref<Target = NetworkGraph<L>>, L: Deref, S: Deref> Router for DefaultRouter<G, L, S> where
+       L::Target: Logger,
+       S::Target: for <'a> LockableScore<'a>,
+{
        fn find_route(
                &self, payer: &PublicKey, params: &RouteParameters, _payment_hash: &PaymentHash,
-               first_hops: Option<&[&ChannelDetails]>, scorer: &S
+               first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs
        ) -> Result<Route, LightningError> {
-               let network_graph = self.network_graph.read_only();
                let random_seed_bytes = {
                        let mut locked_random_seed_bytes = self.random_seed_bytes.lock().unwrap();
                        *locked_random_seed_bytes = sha256::Hash::hash(&*locked_random_seed_bytes).into_inner();
                        *locked_random_seed_bytes
                };
-               find_route(payer, params, &network_graph, first_hops, &*self.logger, scorer, &random_seed_bytes)
+
+               find_route(
+                       payer, params, &self.network_graph, first_hops, &*self.logger,
+                       &ScorerAccountingForInFlightHtlcs::new(&mut self.scorer.lock(), inflight_htlcs),
+                       &random_seed_bytes
+               )
+       }
+
+       fn notify_payment_path_failed(&self, path: Vec<&RouteHop>, short_channel_id: u64) {
+               self.scorer.lock().payment_path_failed(&path, short_channel_id);
+       }
+
+       fn notify_payment_path_successful(&self, path: Vec<&RouteHop>) {
+               self.scorer.lock().payment_path_successful(&path);
+       }
+
+       fn notify_payment_probe_successful(&self, path: Vec<&RouteHop>) {
+               self.scorer.lock().probe_successful(&path);
+       }
+
+       fn notify_payment_probe_failed(&self, path: Vec<&RouteHop>, short_channel_id: u64) {
+               self.scorer.lock().probe_failed(&path, short_channel_id);
        }
 }
 
@@ -511,6 +540,54 @@ where
        }
 }
 
+
+/// Used to store information about all the HTLCs that are inflight across all payment attempts.
+pub(crate) struct ScorerAccountingForInFlightHtlcs<'a, S: Score> {
+       scorer: &'a mut S,
+       /// Maps a channel's short channel id and its direction to the liquidity used up.
+       inflight_htlcs: InFlightHtlcs,
+}
+
+impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> {
+       pub(crate) fn new(scorer: &'a mut S, inflight_htlcs: InFlightHtlcs) -> Self {
+               ScorerAccountingForInFlightHtlcs {
+                       scorer,
+                       inflight_htlcs
+               }
+       }
+}
+
+#[cfg(c_bindings)]
+impl<'a, S:Score> lightning::util::ser::Writeable for ScorerAccountingForInFlightHtlcs<'a, S> {
+       fn write<W: lightning::util::ser::Writer>(&self, writer: &mut W) -> Result<(), std::io::Error> { self.scorer.write(writer) }
+}
+
+impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> {
+       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage) -> u64 {
+               if let Some(used_liqudity) = self.inflight_htlcs.used_liquidity_msat(
+                       source, target, short_channel_id
+               ) {
+                       let usage = ChannelUsage {
+                               inflight_htlc_msat: usage.inflight_htlc_msat + used_liqudity,
+                               ..usage
+                       };
+
+                       self.scorer.channel_penalty_msat(short_channel_id, source, target, usage)
+               } else {
+                       self.scorer.channel_penalty_msat(short_channel_id, source, target, usage)
+               }
+       }
+
+       fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) { unreachable!() }
+
+       fn payment_path_successful(&mut self, _path: &[&RouteHop]) { unreachable!() }
+
+       fn probe_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) { unreachable!() }
+
+       fn probe_successful(&mut self, _path: &[&RouteHop]) { unreachable!() }
+}
+
+
 #[cfg(test)]
 mod test {
        use core::time::Duration;
@@ -572,7 +649,7 @@ mod test {
                let scorer = test_utils::TestScorer::with_penalty(0);
                let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
                let route = find_route(
-                       &nodes[0].node.get_our_node_id(), &route_params, &network_graph.read_only(),
+                       &nodes[0].node.get_our_node_id(), &route_params, &network_graph,
                        Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer, &random_seed_bytes
                ).unwrap();
 
@@ -659,7 +736,7 @@ mod test {
                // `msgs::ChannelUpdate` is never handled for the node(s). As the `msgs::ChannelUpdate`
                // is never handled, the `channel.counterparty.forwarding_info` is never assigned.
                let mut private_chan_cfg = UserConfig::default();
-               private_chan_cfg.own_channel_config.announced_channel = false;
+               private_chan_cfg.channel_handshake_config.announced_channel = false;
                let temporary_channel_id = nodes[2].node.create_channel(nodes[0].node.get_our_node_id(), 1_000_000, 500_000_000, 42, Some(private_chan_cfg)).unwrap();
                let open_channel = get_event_msg!(nodes[2], MessageSendEvent::SendOpenChannel, nodes[0].node.get_our_node_id());
                nodes[0].node.handle_open_channel(&nodes[2].node.get_our_node_id(), InitFeatures::known(), &open_channel);
@@ -848,7 +925,7 @@ mod test {
                let scorer = test_utils::TestScorer::with_penalty(0);
                let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
                let route = find_route(
-                       &nodes[0].node.get_our_node_id(), &params, &network_graph.read_only(),
+                       &nodes[0].node.get_our_node_id(), &params, &network_graph,
                        Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer, &random_seed_bytes
                ).unwrap();
                let (payment_event, fwd_idx) = {
@@ -1046,7 +1123,7 @@ mod test {
                // `msgs::ChannelUpdate` is never handled for the node(s). As the `msgs::ChannelUpdate`
                // is never handled, the `channel.counterparty.forwarding_info` is never assigned.
                let mut private_chan_cfg = UserConfig::default();
-               private_chan_cfg.own_channel_config.announced_channel = false;
+               private_chan_cfg.channel_handshake_config.announced_channel = false;
                let temporary_channel_id = nodes[1].node.create_channel(nodes[3].node.get_our_node_id(), 1_000_000, 500_000_000, 42, Some(private_chan_cfg)).unwrap();
                let open_channel = get_event_msg!(nodes[1], MessageSendEvent::SendOpenChannel, nodes[3].node.get_our_node_id());
                nodes[3].node.handle_open_channel(&nodes[1].node.get_our_node_id(), InitFeatures::known(), &open_channel);