Merge pull request #1144 from jkczyz/2021-10-invoice-payer-scoring
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Fri, 29 Oct 2021 20:16:36 +0000 (20:16 +0000)
committerGitHub <noreply@github.com>
Fri, 29 Oct 2021 20:16:36 +0000 (20:16 +0000)
Penalize failed channels

13 files changed:
fuzz/src/full_stack.rs
fuzz/src/router.rs
lightning-background-processor/Cargo.toml
lightning-background-processor/src/lib.rs
lightning-invoice/src/payment.rs
lightning-invoice/src/utils.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs
lightning/src/ln/functional_tests.rs
lightning/src/ln/shutdown_tests.rs
lightning/src/routing/mod.rs
lightning/src/routing/router.rs
lightning/src/routing/scorer.rs

index b01506871ec10ab9677590bdc22cd925596c067c..2e447dac6da1f2de57d5ea7a5ff602414b23830a 100644 (file)
@@ -382,7 +382,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
        let our_id = PublicKey::from_secret_key(&Secp256k1::signing_only(), &keys_manager.get_node_secret());
        let network_graph = NetworkGraph::new(genesis_block(network).block_hash());
        let net_graph_msg_handler = Arc::new(NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger)));
-       let scorer = Scorer::new(0);
+       let scorer = Scorer::with_fixed_penalty(0);
 
        let peers = RefCell::new([false; 256]);
        let mut loss_detector = MoneyLossDetector::new(&peers, channelmanager.clone(), monitor.clone(), PeerManager::new(MessageHandler {
index 7f7d9585cc279df031ad14304272fae6782fdfca..abd83fa58c669265d5096a9e6ac526fc1316324d 100644 (file)
@@ -248,7 +248,7 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                                }]));
                                        }
                                }
-                               let scorer = Scorer::new(0);
+                               let scorer = Scorer::with_fixed_penalty(0);
                                for target in node_pks.iter() {
                                        let params = RouteParameters {
                                                payee: Payee::new(*target).with_route_hints(last_hops.clone()),
index 4e45bb2a83c11d58b6528c59cfef3208d07e86ac..d868f14db74883bce4c1997cfbfd4da170fd8018 100644 (file)
@@ -16,4 +16,4 @@ lightning-persister = { version = "0.0.102", path = "../lightning-persister" }
 
 [dev-dependencies]
 lightning = { version = "0.0.102", path = "../lightning", features = ["_test_utils"] }
-
+lightning-invoice = { version = "0.10.0", path = "../lightning-invoice" }
index 902bef6af0a60433503f7d0d88de27ac2ace4395..95866d7d9f7e08eb268d5b677a53b30839659e1c 100644 (file)
@@ -173,7 +173,7 @@ impl BackgroundProcessor {
                Descriptor: 'static + SocketDescriptor + Send + Sync,
                CMH: 'static + Deref + Send + Sync,
                RMH: 'static + Deref + Send + Sync,
-               EH: 'static + EventHandler + Send + Sync,
+               EH: 'static + EventHandler + Send,
                CMP: 'static + Send + ChannelManagerPersister<Signer, CW, T, K, F, L>,
                M: 'static + Deref<Target = ChainMonitor<Signer, CF, T, F, L, P>> + Send + Sync,
                CM: 'static + Deref<Target = ChannelManager<Signer, CW, T, K, F, L>> + Send + Sync,
@@ -309,11 +309,14 @@ mod tests {
        use lightning::ln::features::InitFeatures;
        use lightning::ln::msgs::{ChannelMessageHandler, Init};
        use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
+       use lightning::routing::scorer::Scorer;
        use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
        use lightning::util::config::UserConfig;
        use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent};
        use lightning::util::ser::Writeable;
        use lightning::util::test_utils;
+       use lightning_invoice::payment::{InvoicePayer, RetryAttempts};
+       use lightning_invoice::utils::DefaultRouter;
        use lightning_persister::FilesystemPersister;
        use std::fs;
        use std::path::PathBuf;
@@ -619,4 +622,20 @@ mod tests {
 
                assert!(bg_processor.stop().is_ok());
        }
+
+       #[test]
+       fn test_invoice_payer() {
+               let nodes = create_nodes(2, "test_invoice_payer".to_string());
+
+               // Initiate the background processors to watch each node.
+               let data_dir = nodes[0].persister.get_data_dir();
+               let persister = move |node: &ChannelManager<InMemorySigner, Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>>| FilesystemPersister::persist_manager(data_dir.clone(), node);
+               let network_graph = Arc::new(NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()));
+               let router = DefaultRouter::new(network_graph, Arc::clone(&nodes[0].logger));
+               let scorer = Arc::new(Mutex::new(Scorer::default()));
+               let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2)));
+               let event_handler = Arc::clone(&invoice_payer);
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
+               assert!(bg_processor.stop().is_ok());
+       }
 }
index 4603f2bbb21f60750ce8cf2e55f9d80dfcd20674..eefae3fc1a69e711696634c4bae2aaaa435766fc 100644 (file)
 //! # use lightning::ln::{PaymentHash, PaymentSecret};
 //! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 //! # use lightning::ln::msgs::LightningError;
-//! # use lightning::routing::router::{Route, RouteParameters};
+//! # use lightning::routing;
+//! # use lightning::routing::network_graph::NodeId;
+//! # use lightning::routing::router::{Route, RouteHop, RouteParameters};
 //! # use lightning::util::events::{Event, EventHandler, EventsProvider};
 //! # use lightning::util::logger::{Logger, Record};
 //! # use lightning_invoice::Invoice;
 //! # use lightning_invoice::payment::{InvoicePayer, Payer, RetryAttempts, Router};
 //! # use secp256k1::key::PublicKey;
+//! # use std::cell::RefCell;
 //! # use std::ops::Deref;
 //! #
 //! # struct FakeEventProvider {}
 //! # }
 //! #
 //! # struct FakeRouter {};
-//! # impl Router for FakeRouter {
+//! # impl<S: routing::Score> Router<S> for FakeRouter {
 //! #     fn find_route(
 //! #         &self, payer: &PublicKey, params: &RouteParameters,
-//! #         first_hops: Option<&[&ChannelDetails]>
+//! #         first_hops: Option<&[&ChannelDetails]>, scorer: &S
 //! #     ) -> Result<Route, LightningError> { unimplemented!() }
 //! # }
 //! #
+//! # struct FakeScorer {};
+//! # impl routing::Score for FakeScorer {
+//! #     fn channel_penalty_msat(
+//! #         &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+//! #     ) -> u64 { 0 }
+//! #     fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, _short_channel_id: u64) {}
+//! # }
+//! #
 //! # struct FakeLogger {};
 //! # impl Logger for FakeLogger {
 //! #     fn log(&self, record: &Record) { unimplemented!() }
@@ -78,8 +89,9 @@
 //! };
 //! # let payer = FakePayer {};
 //! # let router = FakeRouter {};
+//! # let scorer = RefCell::new(FakeScorer {});
 //! # let logger = FakeLogger {};
-//! let invoice_payer = InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+//! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 //!
 //! let invoice = "...";
 //! let invoice = invoice.parse::<Invoice>().unwrap();
@@ -105,6 +117,8 @@ use bitcoin_hashes::Hash;
 use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
 use lightning::ln::msgs::LightningError;
+use lightning::routing;
+use lightning::routing::{LockableScore, Score};
 use lightning::routing::router::{Payee, Route, RouteParameters};
 use lightning::util::events::{Event, EventHandler};
 use lightning::util::logger::Logger;
@@ -117,15 +131,17 @@ use std::sync::Mutex;
 use std::time::{Duration, SystemTime};
 
 /// A utility for paying [`Invoice]`s.
-pub struct InvoicePayer<P: Deref, R, L: Deref, E>
+pub struct InvoicePayer<P: Deref, R, S: Deref, L: Deref, E>
 where
        P::Target: Payer,
-       R: Router,
+       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
+       S::Target: for <'a> routing::LockableScore<'a>,
        L::Target: Logger,
        E: EventHandler,
 {
        payer: P,
        router: R,
+       scorer: S,
        logger: L,
        event_handler: E,
        payment_cache: Mutex<HashMap<PaymentHash, usize>>,
@@ -150,10 +166,11 @@ pub trait Payer {
 }
 
 /// A trait defining behavior for routing an [`Invoice`] payment.
-pub trait Router {
+pub trait Router<S: routing::Score> {
        /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values.
        fn find_route(
-               &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>
+               &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+               scorer: &S
        ) -> Result<Route, LightningError>;
 }
 
@@ -172,10 +189,11 @@ pub enum PaymentError {
        Sending(PaymentSendFailure),
 }
 
-impl<P: Deref, R, L: Deref, E> InvoicePayer<P, R, L, E>
+impl<P: Deref, R, S: Deref, L: Deref, E> InvoicePayer<P, R, S, L, E>
 where
        P::Target: Payer,
-       R: Router,
+       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
+       S::Target: for <'a> routing::LockableScore<'a>,
        L::Target: Logger,
        E: EventHandler,
 {
@@ -184,11 +202,12 @@ where
        /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once
        /// `retry_attempts` has been exceeded for a given [`Invoice`].
        pub fn new(
-               payer: P, router: R, logger: L, event_handler: E, retry_attempts: RetryAttempts
+               payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts
        ) -> Self {
                Self {
                        payer,
                        router,
+                       scorer,
                        logger,
                        event_handler,
                        payment_cache: Mutex::new(HashMap::new()),
@@ -242,6 +261,7 @@ where
                                        &payer,
                                        &params,
                                        Some(&first_hops.iter().collect::<Vec<_>>()),
+                                       &self.scorer.lock(),
                                ).map_err(|e| PaymentError::Routing(e))?;
 
                                let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
@@ -261,7 +281,8 @@ where
                let payer = self.payer.node_id();
                let first_hops = self.payer.first_hops();
                let route = self.router.find_route(
-                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>())
+                       &payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()),
+                       &self.scorer.lock()
                ).map_err(|e| PaymentError::Routing(e))?;
                self.payer.retry_payment(&route, payment_id).map_err(|e| PaymentError::Sending(e))
        }
@@ -284,16 +305,23 @@ fn has_expired(params: &RouteParameters) -> bool {
        Invoice::is_expired_from_epoch(&SystemTime::UNIX_EPOCH, expiry_time)
 }
 
-impl<P: Deref, R, L: Deref, E> EventHandler for InvoicePayer<P, R, L, E>
+impl<P: Deref, R, S: Deref, L: Deref, E> EventHandler for InvoicePayer<P, R, S, L, E>
 where
        P::Target: Payer,
-       R: Router,
+       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
+       S::Target: for <'a> routing::LockableScore<'a>,
        L::Target: Logger,
        E: EventHandler,
 {
        fn handle_event(&self, event: &Event) {
                match event {
-                       Event::PaymentPathFailed { payment_id, payment_hash, rejected_by_dest, retry, .. } => {
+                       Event::PaymentPathFailed {
+                               payment_id, payment_hash, rejected_by_dest, path, short_channel_id, retry, ..
+                       } => {
+                               if let Some(short_channel_id) = short_channel_id {
+                                       self.scorer.lock().payment_path_failed(path, *short_channel_id);
+                               }
+
                                let mut payment_cache = self.payment_cache.lock().unwrap();
                                let entry = loop {
                                        let entry = payment_cache.entry(*payment_hash);
@@ -354,11 +382,13 @@ mod tests {
        use lightning::ln::PaymentPreimage;
        use lightning::ln::features::{ChannelFeatures, NodeFeatures};
        use lightning::ln::msgs::{ErrorAction, LightningError};
-       use lightning::routing::router::{Route, RouteHop};
+       use lightning::routing::network_graph::NodeId;
+       use lightning::routing::router::{Payee, Route, RouteHop};
        use lightning::util::test_utils::TestLogger;
        use lightning::util::errors::APIError;
        use lightning::util::events::Event;
        use secp256k1::{SecretKey, PublicKey, Secp256k1};
+       use std::cell::RefCell;
        use std::time::{SystemTime, Duration};
 
        fn invoice(payment_preimage: PaymentPreimage) -> Invoice {
@@ -422,9 +452,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -450,9 +481,10 @@ mod tests {
                        .expect_value_msat(final_value_msat)
                        .expect_value_msat(final_value_msat / 2);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -490,9 +522,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(PaymentId([1; 32]));
                let event = Event::PaymentPathFailed {
@@ -534,9 +567,10 @@ mod tests {
                        .expect_value_msat(final_value_msat / 2)
                        .expect_value_msat(final_value_msat / 2);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -583,9 +617,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -614,9 +649,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = expired_invoice(payment_preimage);
@@ -651,9 +687,10 @@ mod tests {
                        .fails_on_attempt(2)
                        .expect_value_msat(final_value_msat);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
                assert_eq!(*payer.attempts.borrow(), 1);
@@ -680,9 +717,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -711,9 +749,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -751,9 +790,10 @@ mod tests {
        fn fails_paying_invoice_with_routing_errors() {
                let payer = TestPayer::new();
                let router = FailingRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -768,9 +808,10 @@ mod tests {
        fn fails_paying_invoice_with_sending_errors() {
                let payer = TestPayer::new().fails_on_attempt(1);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -793,9 +834,10 @@ mod tests {
 
                let payer = TestPayer::new().expect_value_msat(final_value_msat);
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_id =
                        Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap());
@@ -815,9 +857,10 @@ mod tests {
 
                let payer = TestPayer::new();
                let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
                let logger = TestLogger::new();
                let invoice_payer =
-                       InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0));
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
 
                let payment_preimage = PaymentPreimage([1; 32]);
                let invoice = invoice(payment_preimage);
@@ -830,6 +873,40 @@ mod tests {
                }
        }
 
+       #[test]
+       fn scores_failed_channel() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+               let path = TestRouter::path_for_value(final_value_msat);
+               let short_channel_id = Some(path[0].short_channel_id);
+
+               // Expect that scorer is given short_channel_id upon handling the event.
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap()));
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash,
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path,
+                       short_channel_id,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+       }
+
        struct TestRouter;
 
        impl TestRouter {
@@ -873,12 +950,13 @@ mod tests {
                }
        }
 
-       impl Router for TestRouter {
+       impl<S: routing::Score> Router<S> for TestRouter {
                fn find_route(
                        &self,
                        _payer: &PublicKey,
                        params: &RouteParameters,
                        _first_hops: Option<&[&ChannelDetails]>,
+                       _scorer: &S,
                ) -> Result<Route, LightningError> {
                        Ok(Route {
                                payee: Some(params.payee.clone()), ..Self::route_for_value(params.final_value_msat)
@@ -888,17 +966,59 @@ mod tests {
 
        struct FailingRouter;
 
-       impl Router for FailingRouter {
+       impl<S: routing::Score> Router<S> for FailingRouter {
                fn find_route(
                        &self,
                        _payer: &PublicKey,
                        _params: &RouteParameters,
                        _first_hops: Option<&[&ChannelDetails]>,
+                       _scorer: &S,
                ) -> Result<Route, LightningError> {
                        Err(LightningError { err: String::new(), action: ErrorAction::IgnoreError })
                }
        }
 
+       struct TestScorer {
+               expectations: std::collections::VecDeque<u64>,
+       }
+
+       impl TestScorer {
+               fn new() -> Self {
+                       Self {
+                               expectations: std::collections::VecDeque::new(),
+                       }
+               }
+
+               fn expect_channel_failure(mut self, short_channel_id: u64) -> Self {
+                       self.expectations.push_back(short_channel_id);
+                       self
+               }
+       }
+
+       impl routing::Score for TestScorer {
+               fn channel_penalty_msat(
+                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+               ) -> u64 { 0 }
+
+               fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, short_channel_id: u64) {
+                       if let Some(expected_short_channel_id) = self.expectations.pop_front() {
+                               assert_eq!(short_channel_id, expected_short_channel_id);
+                       }
+               }
+       }
+
+       impl Drop for TestScorer {
+               fn drop(&mut self) {
+                       if std::thread::panicking() {
+                               return;
+                       }
+
+                       if !self.expectations.is_empty() {
+                               panic!("Unsatisfied channel failure expectations: {:?}", self.expectations);
+                       }
+               }
+       }
+
        struct TestPayer {
                expectations: core::cell::RefCell<std::collections::VecDeque<u64>>,
                attempts: core::cell::RefCell<usize>,
index ef885f20381fb626fa3ca34e1a61d3dacdbf1198..8da9994a3f7729ea1df9c407dd567fc811782d6a 100644 (file)
@@ -11,9 +11,9 @@ use lightning::chain::keysinterface::{Sign, KeysInterface};
 use lightning::ln::{PaymentHash, PaymentSecret};
 use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, PaymentSendFailure, MIN_FINAL_CLTV_EXPIRY};
 use lightning::ln::msgs::LightningError;
+use lightning::routing;
 use lightning::routing::network_graph::{NetworkGraph, RoutingFees};
 use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route};
-use lightning::routing::scorer::Scorer;
 use lightning::util::logger::Logger;
 use secp256k1::key::PublicKey;
 use std::convert::TryInto;
@@ -109,13 +109,13 @@ impl<G, L: Deref> DefaultRouter<G, L> where G: Deref<Target = NetworkGraph>, L::
        }
 }
 
-impl<G, L: Deref> Router for DefaultRouter<G, L>
+impl<G, L: Deref, S: routing::Score> Router<S> for DefaultRouter<G, L>
 where G: Deref<Target = NetworkGraph>, L::Target: Logger {
        fn find_route(
                &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+               scorer: &S
        ) -> Result<Route, LightningError> {
-               let scorer = Scorer::default();
-               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, &scorer)
+               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, scorer)
        }
 }
 
@@ -183,7 +183,7 @@ mod test {
                let first_hops = nodes[0].node.list_usable_channels();
                let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
                let logger = test_utils::TestLogger::new();
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let route = find_route(
                        &nodes[0].node.get_our_node_id(), &params, network_graph,
                        Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer,
index 763a293fa5bcae790e21c7e2b5206c1e0d31ad76..698ef1a975a965626af085c2d9edba3fd74811c7 100644 (file)
@@ -6299,7 +6299,7 @@ mod tests {
                let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
                let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
                create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // To start (1), send a regular payment but don't claim it.
                let expected_route = [&nodes[1]];
@@ -6404,7 +6404,7 @@ mod tests {
                };
                let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
                let first_hops = nodes[0].node.list_usable_channels();
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let route = find_route(
                        &payer_pubkey, &params, network_graph, Some(&first_hops.iter().collect::<Vec<_>>()),
                        nodes[0].logger, &scorer
@@ -6447,7 +6447,7 @@ mod tests {
                };
                let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
                let first_hops = nodes[0].node.list_usable_channels();
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let route = find_route(
                        &payer_pubkey, &params, network_graph, Some(&first_hops.iter().collect::<Vec<_>>()),
                        nodes[0].logger, &scorer
@@ -6622,7 +6622,7 @@ pub mod bench {
                                let usable_channels = $node_a.list_usable_channels();
                                let payee = Payee::new($node_b.get_our_node_id())
                                        .with_features(InvoiceFeatures::known());
-                               let scorer = Scorer::new(0);
+                               let scorer = Scorer::with_fixed_penalty(0);
                                let route = get_route(&$node_a.get_our_node_id(), &payee, &dummy_graph,
                                        Some(&usable_channels.iter().map(|r| r).collect::<Vec<_>>()), 10_000, TEST_FINAL_CLTV, &logger_a, &scorer).unwrap();
 
index 92377236401de78e64ec08aa1f9af84c09801536..bcb1ac1f1ad24cd7c07ee523d22d4cdaa97843f1 100644 (file)
@@ -1015,7 +1015,7 @@ macro_rules! get_route_and_payment_hash {
                        .with_features($crate::ln::features::InvoiceFeatures::known())
                        .with_route_hints($last_hops);
                let net_graph_msg_handler = &$send_node.net_graph_msg_handler;
-               let scorer = ::routing::scorer::Scorer::new(0);
+               let scorer = ::routing::scorer::Scorer::with_fixed_penalty(0);
                let route = ::routing::router::get_route(
                        &$send_node.node.get_our_node_id(), &payee, &net_graph_msg_handler.network_graph,
                        Some(&$send_node.node.list_usable_channels().iter().collect::<Vec<_>>()),
@@ -1353,7 +1353,7 @@ pub fn route_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route:
        let payee = Payee::new(expected_route.last().unwrap().node.get_our_node_id())
                .with_features(InvoiceFeatures::known());
        let net_graph_msg_handler = &origin_node.net_graph_msg_handler;
-       let scorer = Scorer::new(0);
+       let scorer = Scorer::with_fixed_penalty(0);
        let route = get_route(
                &origin_node.node.get_our_node_id(), &payee, &net_graph_msg_handler.network_graph,
                Some(&origin_node.node.list_usable_channels().iter().collect::<Vec<_>>()),
@@ -1372,7 +1372,7 @@ pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_rou
        let payee = Payee::new(expected_route.last().unwrap().node.get_our_node_id())
                .with_features(InvoiceFeatures::known());
        let net_graph_msg_handler = &origin_node.net_graph_msg_handler;
-       let scorer = Scorer::new(0);
+       let scorer = Scorer::with_fixed_penalty(0);
        let route = get_route(&origin_node.node.get_our_node_id(), &payee, &net_graph_msg_handler.network_graph, None, recv_value, TEST_FINAL_CLTV, origin_node.logger, &scorer).unwrap();
        assert_eq!(route.paths.len(), 1);
        assert_eq!(route.paths[0].len(), expected_route.len());
index e8367cd55ca1a6b7dd4ccb787d1ebe150de28676..77d3f1d36a683c4eee90bcdf6e43075a31caddc2 100644 (file)
@@ -7161,7 +7161,7 @@ fn test_check_htlc_underpaying() {
        // Create some initial channels
        create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
 
-       let scorer = Scorer::new(0);
+       let scorer = Scorer::with_fixed_penalty(0);
        let payee = Payee::new(nodes[1].node.get_our_node_id()).with_features(InvoiceFeatures::known());
        let route = get_route(&nodes[0].node.get_our_node_id(), &payee, &nodes[0].net_graph_msg_handler.network_graph, None, 10_000, TEST_FINAL_CLTV, nodes[0].logger, &scorer).unwrap();
        let (_, our_payment_hash, _) = get_payment_preimage_hash!(nodes[0]);
@@ -7561,7 +7561,7 @@ fn test_bump_penalty_txn_on_revoked_htlcs() {
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1000000, 59000000, InitFeatures::known(), InitFeatures::known());
        // Lock HTLC in both directions (using a slightly lower CLTV delay to provide timely RBF bumps)
        let payee = Payee::new(nodes[1].node.get_our_node_id()).with_features(InvoiceFeatures::known());
-       let scorer = Scorer::new(0);
+       let scorer = Scorer::with_fixed_penalty(0);
        let route = get_route(&nodes[0].node.get_our_node_id(), &payee, &nodes[0].net_graph_msg_handler.network_graph,
                None, 3_000_000, 50, nodes[0].logger, &scorer).unwrap();
        let payment_preimage = send_along_route(&nodes[0], route, &[&nodes[1]], 3_000_000).0;
@@ -9061,7 +9061,7 @@ fn test_keysend_payments_to_public_node() {
                final_value_msat: 10000,
                final_cltv_expiry_delta: 40,
        };
-       let scorer = Scorer::new(0);
+       let scorer = Scorer::with_fixed_penalty(0);
        let route = find_route(&payer_pubkey, &params, &network_graph, None, nodes[0].logger, &scorer).unwrap();
 
        let test_preimage = PaymentPreimage([42; 32]);
@@ -9095,7 +9095,7 @@ fn test_keysend_payments_to_private_node() {
        };
        let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
        let first_hops = nodes[0].node.list_usable_channels();
-       let scorer = Scorer::new(0);
+       let scorer = Scorer::with_fixed_penalty(0);
        let route = find_route(
                &payer_pubkey, &params, &network_graph, Some(&first_hops.iter().collect::<Vec<_>>()),
                nodes[0].logger, &scorer
index 7a089112454779572d8052410b8818ca68dadbb5..d4280bb71516ad7ca2db6666936512e0a2dfe2a8 100644 (file)
@@ -82,7 +82,7 @@ fn updates_shutdown_wait() {
        let chan_1 = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
        let chan_2 = create_announced_chan_between_nodes(&nodes, 1, 2, InitFeatures::known(), InitFeatures::known());
        let logger = test_utils::TestLogger::new();
-       let scorer = Scorer::new(0);
+       let scorer = Scorer::with_fixed_penalty(0);
 
        let (our_payment_preimage, our_payment_hash, _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 100000);
 
index 51ffd91b50483d58bc2016440bdb98d9f82886e5..d6c016468ce79f696483ccafcee03e0269b6267f 100644 (file)
@@ -14,6 +14,12 @@ pub mod router;
 pub mod scorer;
 
 use routing::network_graph::NodeId;
+use routing::router::RouteHop;
+
+use prelude::*;
+use core::cell::{RefCell, RefMut};
+use core::ops::DerefMut;
+use sync::{Mutex, MutexGuard};
 
 /// An interface used to score payment channels for path finding.
 ///
@@ -22,4 +28,49 @@ pub trait Score {
        /// Returns the fee in msats willing to be paid to avoid routing through the given channel
        /// in the direction from `source` to `target`.
        fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64;
+
+       /// Handles updating channel penalties after failing to route through a channel.
+       fn payment_path_failed(&mut self, path: &Vec<RouteHop>, short_channel_id: u64);
+}
+
+/// A scorer that is accessed under a lock.
+///
+/// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while
+/// having shared ownership of a scorer but without requiring internal locking in [`Score`]
+/// implementations. Internal locking would be detrimental to route finding performance and could
+/// result in [`Score::channel_penalty_msat`] returning a different value for the same channel.
+///
+/// [`find_route`]: crate::routing::router::find_route
+pub trait LockableScore<'a> {
+       /// The locked [`Score`] type.
+       type Locked: 'a + Score;
+
+       /// Returns the locked scorer.
+       fn lock(&'a self) -> Self::Locked;
+}
+
+impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
+       type Locked = MutexGuard<'a, T>;
+
+       fn lock(&'a self) -> MutexGuard<'a, T> {
+               Mutex::lock(self).unwrap()
+       }
+}
+
+impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
+       type Locked = RefMut<'a, T>;
+
+       fn lock(&'a self) -> RefMut<'a, T> {
+               self.borrow_mut()
+       }
+}
+
+impl<S: Score, T: DerefMut<Target=S>> Score for T {
+       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 {
+               self.deref().channel_penalty_msat(short_channel_id, source, target)
+       }
+
+       fn payment_path_failed(&mut self, path: &Vec<RouteHop>, short_channel_id: u64) {
+               self.deref_mut().payment_path_failed(path, short_channel_id)
+       }
 }
index 0d11dc1a61a1c4567a9a5fdb3e092f38696b4000..c035b4547cc78f2e8448fe45747154304e060404 100644 (file)
@@ -1937,7 +1937,7 @@ mod tests {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let payee = Payee::new(nodes[2]);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // Simple route to 2 via 1
 
@@ -1968,7 +1968,7 @@ mod tests {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let payee = Payee::new(nodes[2]);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // Simple route to 2 via 1
 
@@ -1987,7 +1987,7 @@ mod tests {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let payee = Payee::new(nodes[2]);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // Simple route to 2 via 1
 
@@ -2112,7 +2112,7 @@ mod tests {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known());
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // A route to node#2 via two paths.
                // One path allows transferring 35-40 sats, another one also allows 35-40 sats.
@@ -2248,7 +2248,7 @@ mod tests {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let payee = Payee::new(nodes[2]);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // // Disable channels 4 and 12 by flags=2
                update_channel(&net_graph_msg_handler, &secp_ctx, &privkeys[1], UnsignedChannelUpdate {
@@ -2306,7 +2306,7 @@ mod tests {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let payee = Payee::new(nodes[2]);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // Disable nodes 1, 2, and 8 by requiring unknown feature bits
                let unknown_features = NodeFeatures::known().set_unknown_feature_required();
@@ -2347,7 +2347,7 @@ mod tests {
        fn our_chans_test() {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // Route to 1 via 2 and 3 because our channel to 1 is disabled
                let payee = Payee::new(nodes[0]);
@@ -2476,7 +2476,7 @@ mod tests {
        fn partial_route_hint_test() {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // Simple test across 2, 3, 5, and 4 via a last_hop channel
                // Tests the behaviour when the RouteHint contains a suboptimal hop.
@@ -2575,7 +2575,7 @@ mod tests {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let payee = Payee::new(nodes[6]).with_route_hints(empty_last_hop(&nodes));
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // Test handling of an empty RouteHint passed in Invoice.
 
@@ -2657,7 +2657,7 @@ mod tests {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
                let payee = Payee::new(nodes[6]).with_route_hints(multi_hint_last_hops(&nodes));
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                // Test through channels 2, 3, 5, 8.
                // Test shows that multiple hop hints are considered.
 
@@ -2763,7 +2763,7 @@ mod tests {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let payee = Payee::new(nodes[6]).with_route_hints(last_hops_with_public_channel(&nodes));
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                // This test shows that public routes can be present in the invoice
                // which would be handled in the same manner.
 
@@ -2812,7 +2812,7 @@ mod tests {
        fn our_chans_last_hop_connect_test() {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // Simple test with outbound channel to 4 to test that last_hops and first_hops connect
                let our_chans = vec![get_channel_details(Some(42), nodes[3].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)];
@@ -2933,7 +2933,7 @@ mod tests {
                }]);
                let payee = Payee::new(target_node_id).with_route_hints(vec![last_hops]);
                let our_chans = vec![get_channel_details(Some(42), middle_node_id, InitFeatures::from_le_bytes(vec![0b11]), outbound_capacity_msat)];
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                get_route(&source_node_id, &payee, &NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()), Some(&our_chans.iter().collect::<Vec<_>>()), route_val, 42, &test_utils::TestLogger::new(), &scorer)
        }
 
@@ -2987,7 +2987,7 @@ mod tests {
 
                let (secp_ctx, mut net_graph_msg_handler, chain_monitor, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known());
 
                // We will use a simple single-path route from
@@ -3259,7 +3259,7 @@ mod tests {
                // one of the latter hops is limited.
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[3]).with_features(InvoiceFeatures::known());
 
                // Path via {node7, node2, node4} is channels {12, 13, 6, 11}.
@@ -3382,7 +3382,7 @@ mod tests {
        fn ignore_fee_first_hop_test() {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[2]);
 
                // Path via node0 is channels {1, 3}. Limit them to 100 and 50 sats (total limit 50).
@@ -3428,7 +3428,7 @@ mod tests {
        fn simple_mpp_route_test() {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known());
 
                // We need a route consisting of 3 paths:
@@ -3559,7 +3559,7 @@ mod tests {
        fn long_mpp_route_test() {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[3]).with_features(InvoiceFeatures::known());
 
                // We need a route consisting of 3 paths:
@@ -3721,7 +3721,7 @@ mod tests {
        fn mpp_cheaper_route_test() {
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[3]).with_features(InvoiceFeatures::known());
 
                // This test checks that if we have two cheaper paths and one more expensive path,
@@ -3888,7 +3888,7 @@ mod tests {
                // if the fee is not properly accounted for, the behavior is different.
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[3]).with_features(InvoiceFeatures::known());
 
                // We need a route consisting of 2 paths:
@@ -4057,7 +4057,7 @@ mod tests {
                // path finding we realize that we found more capacity than we need.
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known());
 
                // We need a route consisting of 3 paths:
@@ -4214,7 +4214,7 @@ mod tests {
                let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
                let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger));
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[6]);
 
                add_channel(&net_graph_msg_handler, &secp_ctx, &our_privkey, &privkeys[1], ChannelFeatures::from_le_bytes(id_to_feature_flags(6)), 6);
@@ -4343,7 +4343,7 @@ mod tests {
                // we calculated fees on a higher value, resulting in us ignoring such paths.
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[2]);
 
                // We modify the graph to set the htlc_maximum of channel 2 to below the value we wish to
@@ -4405,7 +4405,7 @@ mod tests {
                // resulting in us thinking there is no possible path, even if other paths exist.
                let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known());
 
                // We modify the graph to set the htlc_minimum of channel 2 and 4 as needed - channel 2
@@ -4472,7 +4472,7 @@ mod tests {
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let logger = Arc::new(test_utils::TestLogger::new());
                let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let payee = Payee::new(nodes[0]).with_features(InvoiceFeatures::known());
 
                {
@@ -4513,7 +4513,7 @@ mod tests {
                let payee = Payee::new(nodes[6]).with_route_hints(last_hops(&nodes));
 
                // Without penalizing each hop 100 msats, a longer path with lower fees is chosen.
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let route = get_route(
                        &our_id, &payee, &net_graph_msg_handler.network_graph, None, 100, 42,
                        Arc::clone(&logger), &scorer
@@ -4526,7 +4526,7 @@ mod tests {
 
                // Applying a 100 msat penalty to each hop results in taking channels 7 and 10 to nodes[6]
                // from nodes[2] rather than channel 6, 11, and 8, even though the longer path is cheaper.
-               let scorer = Scorer::new(100);
+               let scorer = Scorer::with_fixed_penalty(100);
                let route = get_route(
                        &our_id, &payee, &net_graph_msg_handler.network_graph, None, 100, 42,
                        Arc::clone(&logger), &scorer
@@ -4546,6 +4546,8 @@ mod tests {
                fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 {
                        if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
                }
+
+               fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, _short_channel_id: u64) {}
        }
 
        struct BadNodeScorer {
@@ -4556,6 +4558,8 @@ mod tests {
                fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 {
                        if *target == self.node_id { u64::max_value() } else { 0 }
                }
+
+               fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, _short_channel_id: u64) {}
        }
 
        #[test]
@@ -4565,7 +4569,7 @@ mod tests {
                let payee = Payee::new(nodes[6]).with_route_hints(last_hops(&nodes));
 
                // A path to nodes[6] exists when no penalties are applied to any channel.
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
                let route = get_route(
                        &our_id, &payee, &net_graph_msg_handler.network_graph, None, 100, 42,
                        Arc::clone(&logger), &scorer
@@ -4694,7 +4698,7 @@ mod tests {
                        },
                };
                let graph = NetworkGraph::read(&mut d).unwrap();
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut seed = random_init_seed() as usize;
@@ -4725,7 +4729,7 @@ mod tests {
                        },
                };
                let graph = NetworkGraph::read(&mut d).unwrap();
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut seed = random_init_seed() as usize;
@@ -4791,7 +4795,7 @@ mod benches {
                let mut d = test_utils::get_route_file().unwrap();
                let graph = NetworkGraph::read(&mut d).unwrap();
                let nodes = graph.read_only().nodes().clone();
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut path_endpoints = Vec::new();
@@ -4826,7 +4830,7 @@ mod benches {
                let mut d = test_utils::get_route_file().unwrap();
                let graph = NetworkGraph::read(&mut d).unwrap();
                let nodes = graph.read_only().nodes().clone();
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut path_endpoints = Vec::new();
index e3f5c8679b68d6b57ca418b70befdc0ff34e35d7..d2b167675e083748c813a65928a3b71859292b4b 100644 (file)
@@ -19,7 +19,7 @@
 //! #
 //! # use lightning::routing::network_graph::NetworkGraph;
 //! # use lightning::routing::router::{RouteParameters, find_route};
-//! # use lightning::routing::scorer::Scorer;
+//! # use lightning::routing::scorer::{Scorer, ScoringParameters};
 //! # use lightning::util::logger::{Logger, Record};
 //! # use secp256k1::key::PublicKey;
 //! #
 //! # fn find_scored_route(payer: PublicKey, params: RouteParameters, network_graph: NetworkGraph) {
 //! # let logger = FakeLogger {};
 //! #
-//! // Use the default channel penalty.
+//! // Use the default channel penalties.
 //! let scorer = Scorer::default();
 //!
-//! // Or use a custom channel penalty.
-//! let scorer = Scorer::new(1_000);
+//! // Or use custom channel penalties.
+//! let scorer = Scorer::new(ScoringParameters {
+//!     base_penalty_msat: 1000,
+//!     failure_penalty_msat: 2 * 1024 * 1000,
+//!     ..ScoringParameters::default()
+//! });
 //!
 //! let route = find_route(&payer, &params, &network_graph, None, &logger, &scorer);
 //! # }
 use routing;
 
 use routing::network_graph::NodeId;
+use routing::router::RouteHop;
+
+use prelude::*;
+#[cfg(not(feature = "no-std"))]
+use core::time::Duration;
+#[cfg(not(feature = "no-std"))]
+use std::time::Instant;
 
 /// [`routing::Score`] implementation that provides reasonable default behavior.
 ///
 /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
-/// slightly higher fees are available.
+/// slightly higher fees are available. May also further penalize failed channels.
 ///
 /// See [module-level documentation] for usage.
 ///
 /// [module-level documentation]: crate::routing::scorer
 pub struct Scorer {
-       base_penalty_msat: u64,
+       params: ScoringParameters,
+       #[cfg(not(feature = "no-std"))]
+       channel_failures: HashMap<u64, (u64, Instant)>,
+       #[cfg(feature = "no-std")]
+       channel_failures: HashMap<u64, u64>,
+}
+
+/// Parameters for configuring [`Scorer`].
+pub struct ScoringParameters {
+       /// A fixed penalty in msats to apply to each channel.
+       pub base_penalty_msat: u64,
+
+       /// A penalty in msats to apply to a channel upon failure.
+       ///
+       /// This may be reduced over time based on [`failure_penalty_half_life`].
+       ///
+       /// [`failure_penalty_half_life`]: Self::failure_penalty_half_life
+       pub failure_penalty_msat: u64,
+
+       /// The time needed before any accumulated channel failure penalties are cut in half.
+       #[cfg(not(feature = "no-std"))]
+       pub failure_penalty_half_life: Duration,
 }
 
 impl Scorer {
-       /// Creates a new scorer using `base_penalty_msat` as the channel penalty.
-       pub fn new(base_penalty_msat: u64) -> Self {
-               Self { base_penalty_msat }
+       /// Creates a new scorer using the given scoring parameters.
+       pub fn new(params: ScoringParameters) -> Self {
+               Self {
+                       params,
+                       channel_failures: HashMap::new(),
+               }
+       }
+
+       /// Creates a new scorer using `penalty_msat` as a fixed channel penalty.
+       #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))]
+       pub fn with_fixed_penalty(penalty_msat: u64) -> Self {
+               Self::new(ScoringParameters {
+                       base_penalty_msat: penalty_msat,
+                       failure_penalty_msat: 0,
+                       #[cfg(not(feature = "no-std"))]
+                       failure_penalty_half_life: Duration::from_secs(0),
+               })
+       }
+
+       #[cfg(not(feature = "no-std"))]
+       fn decay_from(&self, penalty_msat: u64, last_failure: &Instant) -> u64 {
+               decay_from(penalty_msat, last_failure, self.params.failure_penalty_half_life)
        }
 }
 
 impl Default for Scorer {
-       /// Creates a new scorer using 500 msat as the channel penalty.
        fn default() -> Self {
-               Scorer::new(500)
+               Scorer::new(ScoringParameters::default())
+       }
+}
+
+impl Default for ScoringParameters {
+       fn default() -> Self {
+               Self {
+                       base_penalty_msat: 500,
+                       failure_penalty_msat: 1024 * 1000,
+                       #[cfg(not(feature = "no-std"))]
+                       failure_penalty_half_life: Duration::from_secs(3600),
+               }
        }
 }
 
 impl routing::Score for Scorer {
        fn channel_penalty_msat(
-               &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+               &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId
        ) -> u64 {
-               self.base_penalty_msat
+               #[cfg(not(feature = "no-std"))]
+               let failure_penalty_msat = match self.channel_failures.get(&short_channel_id) {
+                       Some((penalty_msat, last_failure)) => self.decay_from(*penalty_msat, last_failure),
+                       None => 0,
+               };
+               #[cfg(feature = "no-std")]
+               let failure_penalty_msat =
+                       self.channel_failures.get(&short_channel_id).copied().unwrap_or(0);
+
+               self.params.base_penalty_msat + failure_penalty_msat
+       }
+
+       fn payment_path_failed(&mut self, _path: &Vec<RouteHop>, short_channel_id: u64) {
+               let failure_penalty_msat = self.params.failure_penalty_msat;
+               #[cfg(not(feature = "no-std"))]
+               {
+                       let half_life = self.params.failure_penalty_half_life;
+                       self.channel_failures
+                               .entry(short_channel_id)
+                               .and_modify(|(penalty_msat, last_failure)| {
+                                       let decayed_penalty = decay_from(*penalty_msat, last_failure, half_life);
+                                       *penalty_msat = decayed_penalty + failure_penalty_msat;
+                                       *last_failure = Instant::now();
+                               })
+                               .or_insert_with(|| (failure_penalty_msat, Instant::now()));
+               }
+               #[cfg(feature = "no-std")]
+               self.channel_failures
+                       .entry(short_channel_id)
+                       .and_modify(|penalty_msat| *penalty_msat += failure_penalty_msat)
+                       .or_insert(failure_penalty_msat);
+       }
+}
+
+#[cfg(not(feature = "no-std"))]
+fn decay_from(penalty_msat: u64, last_failure: &Instant, half_life: Duration) -> u64 {
+       let decays = last_failure.elapsed().as_secs().checked_div(half_life.as_secs());
+       match decays {
+               Some(decays) => penalty_msat >> decays,
+               None => 0,
        }
 }