X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning-invoice%2Fsrc%2Futils.rs;h=306714b0789ceefc7d23cd41ba3d76bc0decd784;hb=c353c3ed7c40e689a3b9fb6730c6dabbd3c92cc5;hp=c79e590450a2dbda227db5aedc6645607c9b555c;hpb=d6feb1c63b9474d4db633774763eb9684e7381a0;p=rust-lightning diff --git a/lightning-invoice/src/utils.rs b/lightning-invoice/src/utils.rs index c79e5904..306714b0 100644 --- a/lightning-invoice/src/utils.rs +++ b/lightning-invoice/src/utils.rs @@ -1,7 +1,7 @@ //! Convenient utilities to create an invoice. use {CreationError, Currency, Invoice, InvoiceBuilder, SignOrCreationError}; -use payment::{Payer, Router}; +use payment::{InFlightHtlcs, Payer, Router}; use crate::{prelude::*, Description, InvoiceDescription, Sha256}; use bech32::ToBase32; @@ -15,9 +15,9 @@ use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, P use lightning::ln::channelmanager::{PhantomRouteHints, MIN_CLTV_EXPIRY_DELTA}; use lightning::ln::inbound_payment::{create, create_from_hash, ExpandedKey}; use lightning::ln::msgs::LightningError; -use lightning::routing::gossip::{NetworkGraph, RoutingFees}; -use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route}; -use lightning::routing::scoring::Score; +use lightning::routing::gossip::{NetworkGraph, NodeId, RoutingFees}; +use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route, RouteHop}; +use lightning::routing::scoring::{ChannelUsage, LockableScore, Score}; use lightning::util::logger::Logger; use secp256k1::PublicKey; use core::ops::Deref; @@ -440,34 +440,63 @@ fn filter_channels(channels: Vec, min_inbound_capacity_msat: Opt } /// A [`Router`] implemented using [`find_route`]. -pub struct DefaultRouter>, L: Deref> where L::Target: Logger { +pub struct DefaultRouter>, L: Deref, S: Deref> where + L::Target: Logger, + S::Target: for <'a> LockableScore<'a>, +{ network_graph: G, logger: L, random_seed_bytes: Mutex<[u8; 32]>, + scorer: S } -impl>, L: Deref> DefaultRouter where L::Target: Logger { +impl>, L: Deref, S: Deref> DefaultRouter where + L::Target: Logger, + S::Target: for <'a> LockableScore<'a>, +{ /// Creates a new router using the given [`NetworkGraph`], a [`Logger`], and a randomness source /// `random_seed_bytes`. - pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32]) -> Self { + pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32], scorer: S) -> Self { let random_seed_bytes = Mutex::new(random_seed_bytes); - Self { network_graph, logger, random_seed_bytes } + Self { network_graph, logger, random_seed_bytes, scorer } } } -impl>, L: Deref, S: Score> Router for DefaultRouter -where L::Target: Logger { +impl>, L: Deref, S: Deref> Router for DefaultRouter where + L::Target: Logger, + S::Target: for <'a> LockableScore<'a>, +{ fn find_route( &self, payer: &PublicKey, params: &RouteParameters, _payment_hash: &PaymentHash, - first_hops: Option<&[&ChannelDetails]>, scorer: &S + first_hops: Option<&[&ChannelDetails]>, inflight_htlcs: InFlightHtlcs ) -> Result { - let network_graph = self.network_graph.read_only(); let random_seed_bytes = { let mut locked_random_seed_bytes = self.random_seed_bytes.lock().unwrap(); *locked_random_seed_bytes = sha256::Hash::hash(&*locked_random_seed_bytes).into_inner(); *locked_random_seed_bytes }; - find_route(payer, params, &network_graph, first_hops, &*self.logger, scorer, &random_seed_bytes) + + find_route( + payer, params, &self.network_graph, first_hops, &*self.logger, + &ScorerAccountingForInFlightHtlcs::new(&mut self.scorer.lock(), inflight_htlcs), + &random_seed_bytes + ) + } + + fn notify_payment_path_failed(&self, path: Vec<&RouteHop>, short_channel_id: u64) { + self.scorer.lock().payment_path_failed(&path, short_channel_id); + } + + fn notify_payment_path_successful(&self, path: Vec<&RouteHop>) { + self.scorer.lock().payment_path_successful(&path); + } + + fn notify_payment_probe_successful(&self, path: Vec<&RouteHop>) { + self.scorer.lock().probe_successful(&path); + } + + fn notify_payment_probe_failed(&self, path: Vec<&RouteHop>, short_channel_id: u64) { + self.scorer.lock().probe_failed(&path, short_channel_id); } } @@ -511,6 +540,54 @@ where } } + +/// Used to store information about all the HTLCs that are inflight across all payment attempts. +pub(crate) struct ScorerAccountingForInFlightHtlcs<'a, S: Score> { + scorer: &'a mut S, + /// Maps a channel's short channel id and its direction to the liquidity used up. + inflight_htlcs: InFlightHtlcs, +} + +impl<'a, S: Score> ScorerAccountingForInFlightHtlcs<'a, S> { + pub(crate) fn new(scorer: &'a mut S, inflight_htlcs: InFlightHtlcs) -> Self { + ScorerAccountingForInFlightHtlcs { + scorer, + inflight_htlcs + } + } +} + +#[cfg(c_bindings)] +impl<'a, S:Score> lightning::util::ser::Writeable for ScorerAccountingForInFlightHtlcs<'a, S> { + fn write(&self, writer: &mut W) -> Result<(), std::io::Error> { self.scorer.write(writer) } +} + +impl<'a, S: Score> Score for ScorerAccountingForInFlightHtlcs<'a, S> { + fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId, usage: ChannelUsage) -> u64 { + if let Some(used_liqudity) = self.inflight_htlcs.used_liquidity_msat( + source, target, short_channel_id + ) { + let usage = ChannelUsage { + inflight_htlc_msat: usage.inflight_htlc_msat + used_liqudity, + ..usage + }; + + self.scorer.channel_penalty_msat(short_channel_id, source, target, usage) + } else { + self.scorer.channel_penalty_msat(short_channel_id, source, target, usage) + } + } + + fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) { unreachable!() } + + fn payment_path_successful(&mut self, _path: &[&RouteHop]) { unreachable!() } + + fn probe_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) { unreachable!() } + + fn probe_successful(&mut self, _path: &[&RouteHop]) { unreachable!() } +} + + #[cfg(test)] mod test { use core::time::Duration; @@ -572,7 +649,7 @@ mod test { let scorer = test_utils::TestScorer::with_penalty(0); let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes(); let route = find_route( - &nodes[0].node.get_our_node_id(), &route_params, &network_graph.read_only(), + &nodes[0].node.get_our_node_id(), &route_params, &network_graph, Some(&first_hops.iter().collect::>()), &logger, &scorer, &random_seed_bytes ).unwrap(); @@ -659,7 +736,7 @@ mod test { // `msgs::ChannelUpdate` is never handled for the node(s). As the `msgs::ChannelUpdate` // is never handled, the `channel.counterparty.forwarding_info` is never assigned. let mut private_chan_cfg = UserConfig::default(); - private_chan_cfg.own_channel_config.announced_channel = false; + private_chan_cfg.channel_handshake_config.announced_channel = false; let temporary_channel_id = nodes[2].node.create_channel(nodes[0].node.get_our_node_id(), 1_000_000, 500_000_000, 42, Some(private_chan_cfg)).unwrap(); let open_channel = get_event_msg!(nodes[2], MessageSendEvent::SendOpenChannel, nodes[0].node.get_our_node_id()); nodes[0].node.handle_open_channel(&nodes[2].node.get_our_node_id(), InitFeatures::known(), &open_channel); @@ -848,7 +925,7 @@ mod test { let scorer = test_utils::TestScorer::with_penalty(0); let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes(); let route = find_route( - &nodes[0].node.get_our_node_id(), ¶ms, &network_graph.read_only(), + &nodes[0].node.get_our_node_id(), ¶ms, &network_graph, Some(&first_hops.iter().collect::>()), &logger, &scorer, &random_seed_bytes ).unwrap(); let (payment_event, fwd_idx) = { @@ -1046,7 +1123,7 @@ mod test { // `msgs::ChannelUpdate` is never handled for the node(s). As the `msgs::ChannelUpdate` // is never handled, the `channel.counterparty.forwarding_info` is never assigned. let mut private_chan_cfg = UserConfig::default(); - private_chan_cfg.own_channel_config.announced_channel = false; + private_chan_cfg.channel_handshake_config.announced_channel = false; let temporary_channel_id = nodes[1].node.create_channel(nodes[3].node.get_our_node_id(), 1_000_000, 500_000_000, 42, Some(private_chan_cfg)).unwrap(); let open_channel = get_event_msg!(nodes[1], MessageSendEvent::SendOpenChannel, nodes[3].node.get_our_node_id()); nodes[3].node.handle_open_channel(&nodes[1].node.get_our_node_id(), InitFeatures::known(), &open_channel);