Merge pull request #1078 from TheBlueMatt/2021-09-chan-types
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Wed, 3 Nov 2021 16:58:33 +0000 (16:58 +0000)
committerGitHub <noreply@github.com>
Wed, 3 Nov 2021 16:58:33 +0000 (16:58 +0000)
Implement channel_type negotiation

35 files changed:
CHANGELOG.md
fuzz/src/chanmon_consistency.rs
fuzz/src/full_stack.rs
fuzz/src/router.rs
lightning-background-processor/Cargo.toml
lightning-background-processor/src/lib.rs
lightning-block-sync/Cargo.toml
lightning-invoice/Cargo.toml
lightning-invoice/src/lib.rs
lightning-invoice/src/payment.rs [new file with mode: 0644]
lightning-invoice/src/utils.rs
lightning-net-tokio/Cargo.toml
lightning-persister/Cargo.toml
lightning/Cargo.toml
lightning/src/ln/chanmon_update_fail_tests.rs
lightning/src/ln/channel.rs
lightning/src/ln/channelmanager.rs
lightning/src/ln/features.rs
lightning/src/ln/functional_test_utils.rs
lightning/src/ln/functional_tests.rs
lightning/src/ln/onion_route_tests.rs
lightning/src/ln/onion_utils.rs
lightning/src/ln/payment_tests.rs
lightning/src/ln/peer_handler.rs
lightning/src/ln/shutdown_tests.rs
lightning/src/routing/mod.rs
lightning/src/routing/network_graph.rs
lightning/src/routing/router.rs
lightning/src/routing/scorer.rs
lightning/src/util/events.rs
lightning/src/util/logger.rs
lightning/src/util/macro_logger.rs
lightning/src/util/ser.rs
lightning/src/util/ser_macros.rs
lightning/src/util/test_utils.rs

index 518badd1ed21d56c6533f4cbb8dd53c649078af3..887b738ff3cbd5ad8ed69ae171ca8105d8db453c 100644 (file)
@@ -1,3 +1,69 @@
+# 0.0.103 - 2021-11-02
+
+## API Updates
+ * This release is almost entirely focused on a new API in the
+   `lightning-invoice` crate - the `InvoicePayer`. `InvoicePayer` is a
+   struct which takes a reference to a `ChannelManager` and a `NetworkGraph`
+   and retries payments as paths fail. It limits retries to a configurable
+   number, but is not serialized to disk and may retry additional times across
+   a serialization/load. In order to learn about failed payments, it must
+   receive `Event`s directly from the `ChannelManager`, wrapping a
+   user-provided `EventHandler` which it provides all unhandled events to
+   (#1059).
+ * `get_route` has been renamed `find_route` (#1059) and now takes a `Payee`
+   struct in replacement of a number of its long list of arguments (#1134).
+   `Payee` is further stored in the `Route` object returned and provided in the
+   `RouteParameters` contained in `Event::PaymentPathFailed` (#1059).
+ * `ChannelMonitor`s must now be persisted after calls which provide new block
+   data, prior to `MonitorEvent`s being passed back to `ChannelManager` for
+   processing. If you are using a `ChainMonitor` this is handled for you.
+   The `Persist` API has been updated to `Option`ally take the
+   `ChannelMonitorUpdate` as persistence events that result from chain data no
+   longer have a corresponding update (#1108).
+ * `routing::Score` now has a `payment_path_failed` method which it can use to
+   learn which channels often fail payments. It is automatically called by
+   `InvoicePayer` for failed payment paths (#1144).
+ * The default `Scorer` implementation is now a type alias to a type generic
+   across different clocks and supports serialization to persist scoring data
+   across restarts (#1146).
+ * `Event::PaymentSent` now includes the full fee which was spent across all
+   payment paths which were fulfilled or pending when the payment was fulfilled
+   (#1142).
+ * `NetGraphMsgHandler` now takes a `Deref` to the `NetworkGraph`, allowing for
+   shared references to the graph data to make serialization and references to
+   the graph data in the `InvoicePayer`'s `Router` simpler (#1149).
+ * `routing::Score::channel_penalty_msat` has been updated to provide the
+   `NodeId` of both the source and destination nodes of a channel (#1133).
+
+## Bug Fixes
+ * Delay disconnecting peers if we receive messages from them even if it takes
+   a while to receive a pong from them. Further, avoid sending too many gossip
+   messages between pings to ensure we should always receive pongs in a timely
+   manner. Together, these should significantly reduce instances of us failing
+   to remain connected to a peer during initial gossip sync (#1137).
+ * If a payment is sent, creating an outbound HTLC and sending it to our
+   counterparty (implying the `ChannelMonitor` was persisted on disk), but the
+   `ChannelManager` was not persisted prior to shutdown/crash, no
+   `Event::PaymentPathFailed` event will be generated if the HTLC is eventually
+   failed on chain (#1104).
+
+## Serialization Compatibility
+ * All above new Events/fields are ignored by prior clients. All above new
+   Events/fields are not present when reading objects serialized by prior
+   versions of the library.
+ * Payments for which a `Route` was generated using a previous version or for
+   which the payment was originally sent by a previous version of the library
+   will not be retried by an `InvoicePayer`.
+
+This release was singularly focused and some contributions by third parties
+were delayed.
+In total, this release features 38 files changed, 4414 insertions, and 969
+deletions in 71 commits from 2 authors, in alphabetical order:
+
+ * Jeffrey Czyz
+ * Matt Corallo
+
+
 # 0.0.102 - 2021-10-18
 
 ## API Updates
index 28bab5a28f8942550e26c9036984f57cd6a84f3b..acdf3cb3ca2dc3b2f902aa1a323e645b837d2368 100644 (file)
@@ -271,8 +271,8 @@ fn check_payment_err(send_err: PaymentSendFailure) {
                PaymentSendFailure::AllFailedRetrySafe(per_path_results) => {
                        for api_err in per_path_results { check_api_err(api_err); }
                },
-               PaymentSendFailure::PartialFailure(per_path_results) => {
-                       for res in per_path_results { if let Err(api_err) = res { check_api_err(api_err); } }
+               PaymentSendFailure::PartialFailure { results, .. } => {
+                       for res in results { if let Err(api_err) = res { check_api_err(api_err); } }
                },
        }
 }
@@ -305,6 +305,7 @@ fn send_payment(source: &ChanMan, dest: &ChanMan, dest_chan_id: u64, amt: u64, p
                        fee_msat: amt,
                        cltv_expiry_delta: 200,
                }]],
+               payee: None,
        }, payment_hash, &Some(payment_secret)) {
                check_payment_err(err);
                false
@@ -330,6 +331,7 @@ fn send_hop_payment(source: &ChanMan, middle: &ChanMan, middle_chan_id: u64, des
                        fee_msat: amt,
                        cltv_expiry_delta: 200,
                }]],
+               payee: None,
        }, payment_hash, &Some(payment_secret)) {
                check_payment_err(err);
                false
index 84ba7bab502ea4f44ec2bca37ee83a01b1a5e404..829ef20b0779ac680322e1352b1bf4cf167e2173 100644 (file)
@@ -38,7 +38,7 @@ use lightning::ln::peer_handler::{MessageHandler,PeerManager,SocketDescriptor,Ig
 use lightning::ln::msgs::DecodeError;
 use lightning::ln::script::ShutdownScript;
 use lightning::routing::network_graph::{NetGraphMsgHandler, NetworkGraph};
-use lightning::routing::router::get_route;
+use lightning::routing::router::{find_route, Payee, RouteParameters};
 use lightning::routing::scorer::Scorer;
 use lightning::util::config::UserConfig;
 use lightning::util::errors::APIError;
@@ -161,7 +161,7 @@ type ChannelMan = ChannelManager<
        EnforcingSigner,
        Arc<chainmonitor::ChainMonitor<EnforcingSigner, Arc<dyn chain::Filter>, Arc<TestBroadcaster>, Arc<FuzzEstimator>, Arc<dyn Logger>, Arc<TestPersister>>>,
        Arc<TestBroadcaster>, Arc<KeyProvider>, Arc<FuzzEstimator>, Arc<dyn Logger>>;
-type PeerMan<'a> = PeerManager<Peer<'a>, Arc<ChannelMan>, Arc<NetGraphMsgHandler<Arc<dyn chain::Access>, Arc<dyn Logger>>>, Arc<dyn Logger>, IgnoringMessageHandler>;
+type PeerMan<'a> = PeerManager<Peer<'a>, Arc<ChannelMan>, Arc<NetGraphMsgHandler<Arc<NetworkGraph>, Arc<dyn chain::Access>, Arc<dyn Logger>>>, Arc<dyn Logger>, IgnoringMessageHandler>;
 
 struct MoneyLossDetector<'a> {
        manager: Arc<ChannelMan>,
@@ -380,9 +380,9 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
        };
        let channelmanager = Arc::new(ChannelManager::new(fee_est.clone(), monitor.clone(), broadcast.clone(), Arc::clone(&logger), keys_manager.clone(), config, params));
        let our_id = PublicKey::from_secret_key(&Secp256k1::signing_only(), &keys_manager.get_node_secret());
-       let network_graph = NetworkGraph::new(genesis_block(network).block_hash());
-       let net_graph_msg_handler = Arc::new(NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger)));
-       let scorer = Scorer::new(0);
+       let network_graph = Arc::new(NetworkGraph::new(genesis_block(network).block_hash()));
+       let net_graph_msg_handler = Arc::new(NetGraphMsgHandler::new(Arc::clone(&network_graph), None, Arc::clone(&logger)));
+       let scorer = Scorer::with_fixed_penalty(0);
 
        let peers = RefCell::new([false; 256]);
        let mut loss_detector = MoneyLossDetector::new(&peers, channelmanager.clone(), monitor.clone(), PeerManager::new(MessageHandler {
@@ -437,8 +437,14 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                }
                        },
                        4 => {
-                               let value = slice_to_be24(get_slice!(3)) as u64;
-                               let route = match get_route(&our_id, &net_graph_msg_handler.network_graph, &get_pubkey!(), None, None, &Vec::new(), value, 42, Arc::clone(&logger), &scorer) {
+                               let final_value_msat = slice_to_be24(get_slice!(3)) as u64;
+                               let payee = Payee::from_node_id(get_pubkey!());
+                               let params = RouteParameters {
+                                       payee,
+                                       final_value_msat,
+                                       final_cltv_expiry_delta: 42,
+                               };
+                               let route = match find_route(&our_id, &params, &network_graph, None, Arc::clone(&logger), &scorer) {
                                        Ok(route) => route,
                                        Err(_) => return,
                                };
@@ -454,8 +460,14 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                }
                        },
                        15 => {
-                               let value = slice_to_be24(get_slice!(3)) as u64;
-                               let mut route = match get_route(&our_id, &net_graph_msg_handler.network_graph, &get_pubkey!(), None, None, &Vec::new(), value, 42, Arc::clone(&logger), &scorer) {
+                               let final_value_msat = slice_to_be24(get_slice!(3)) as u64;
+                               let payee = Payee::from_node_id(get_pubkey!());
+                               let params = RouteParameters {
+                                       payee,
+                                       final_value_msat,
+                                       final_cltv_expiry_delta: 42,
+                               };
+                               let mut route = match find_route(&our_id, &params, &network_graph, None, Arc::clone(&logger), &scorer) {
                                        Ok(route) => route,
                                        Err(_) => return,
                                };
index 6a207254ac9992e172d0fb194d7f6b76745e8d80..8c9b4b7d815698656fd79985977b296029919c93 100644 (file)
@@ -16,7 +16,7 @@ use lightning::chain::transaction::OutPoint;
 use lightning::ln::channelmanager::{ChannelDetails, ChannelCounterparty};
 use lightning::ln::features::InitFeatures;
 use lightning::ln::msgs;
-use lightning::routing::router::{get_route, RouteHint, RouteHintHop};
+use lightning::routing::router::{find_route, Payee, RouteHint, RouteHintHop, RouteParameters};
 use lightning::routing::scorer::Scorer;
 use lightning::util::logger::Logger;
 use lightning::util::ser::Readable;
@@ -248,12 +248,16 @@ pub fn do_test<Out: test_logger::Output>(data: &[u8], out: Out) {
                                                }]));
                                        }
                                }
-                               let scorer = Scorer::new(0);
+                               let scorer = Scorer::with_fixed_penalty(0);
                                for target in node_pks.iter() {
-                                       let _ = get_route(&our_pubkey, &net_graph, target, None,
+                                       let params = RouteParameters {
+                                               payee: Payee::from_node_id(*target).with_route_hints(last_hops.clone()),
+                                               final_value_msat: slice_to_be64(get_slice!(8)),
+                                               final_cltv_expiry_delta: slice_to_be32(get_slice!(4)),
+                                       };
+                                       let _ = find_route(&our_pubkey, &params, &net_graph,
                                                first_hops.map(|c| c.iter().collect::<Vec<_>>()).as_ref().map(|a| a.as_slice()),
-                                               &last_hops.iter().collect::<Vec<_>>(),
-                                               slice_to_be64(get_slice!(8)), slice_to_be32(get_slice!(4)), Arc::clone(&logger), &scorer);
+                                               Arc::clone(&logger), &scorer);
                                }
                        },
                }
index 4e45bb2a83c11d58b6528c59cfef3208d07e86ac..f5e90e1410d27ea226b1d67c82c48aab316a280a 100644 (file)
@@ -1,6 +1,6 @@
 [package]
 name = "lightning-background-processor"
-version = "0.0.102"
+version = "0.0.103"
 authors = ["Valentine Wallace <vwallace@protonmail.com>"]
 license = "MIT OR Apache-2.0"
 repository = "http://github.com/rust-bitcoin/rust-lightning"
@@ -11,9 +11,9 @@ edition = "2018"
 
 [dependencies]
 bitcoin = "0.27"
-lightning = { version = "0.0.102", path = "../lightning", features = ["allow_wallclock_use"] }
-lightning-persister = { version = "0.0.102", path = "../lightning-persister" }
+lightning = { version = "0.0.103", path = "../lightning", features = ["allow_wallclock_use"] }
+lightning-persister = { version = "0.0.103", path = "../lightning-persister" }
 
 [dev-dependencies]
-lightning = { version = "0.0.102", path = "../lightning", features = ["_test_utils"] }
-
+lightning = { version = "0.0.103", path = "../lightning", features = ["_test_utils"] }
+lightning-invoice = { version = "0.11.0", path = "../lightning-invoice" }
index 0fef55a95c2c5369408613cebfa574e9b8b95c05..593f84a90898534742f5dbd0c0458ea3f53cf059 100644 (file)
@@ -14,9 +14,8 @@ use lightning::chain::chainmonitor::{ChainMonitor, Persist};
 use lightning::chain::keysinterface::{Sign, KeysInterface};
 use lightning::ln::channelmanager::ChannelManager;
 use lightning::ln::msgs::{ChannelMessageHandler, RoutingMessageHandler};
-use lightning::ln::peer_handler::{PeerManager, SocketDescriptor};
-use lightning::ln::peer_handler::CustomMessageHandler;
-use lightning::routing::network_graph::NetGraphMsgHandler;
+use lightning::ln::peer_handler::{CustomMessageHandler, PeerManager, SocketDescriptor};
+use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
 use lightning::util::events::{Event, EventHandler, EventsProvider};
 use lightning::util::logger::Logger;
 use std::sync::Arc;
@@ -104,7 +103,8 @@ ChannelManagerPersister<Signer, M, T, K, F, L> for Fun where
 /// Decorates an [`EventHandler`] with common functionality provided by standard [`EventHandler`]s.
 struct DecoratingEventHandler<
        E: EventHandler,
-       N: Deref<Target = NetGraphMsgHandler<A, L>>,
+       N: Deref<Target = NetGraphMsgHandler<G, A, L>>,
+       G: Deref<Target = NetworkGraph>,
        A: Deref,
        L: Deref,
 >
@@ -115,10 +115,11 @@ where A::Target: chain::Access, L::Target: Logger {
 
 impl<
        E: EventHandler,
-       N: Deref<Target = NetGraphMsgHandler<A, L>>,
+       N: Deref<Target = NetGraphMsgHandler<G, A, L>>,
+       G: Deref<Target = NetworkGraph>,
        A: Deref,
        L: Deref,
-> EventHandler for DecoratingEventHandler<E, N, A, L>
+> EventHandler for DecoratingEventHandler<E, N, G, A, L>
 where A::Target: chain::Access, L::Target: Logger {
        fn handle_event(&self, event: &Event) {
                if let Some(event_handler) = &self.net_graph_msg_handler {
@@ -169,16 +170,17 @@ impl BackgroundProcessor {
                T: 'static + Deref + Send + Sync,
                K: 'static + Deref + Send + Sync,
                F: 'static + Deref + Send + Sync,
+               G: 'static + Deref<Target = NetworkGraph> + Send + Sync,
                L: 'static + Deref + Send + Sync,
                P: 'static + Deref + Send + Sync,
                Descriptor: 'static + SocketDescriptor + Send + Sync,
                CMH: 'static + Deref + Send + Sync,
                RMH: 'static + Deref + Send + Sync,
-               EH: 'static + EventHandler + Send + Sync,
+               EH: 'static + EventHandler + Send,
                CMP: 'static + Send + ChannelManagerPersister<Signer, CW, T, K, F, L>,
                M: 'static + Deref<Target = ChainMonitor<Signer, CF, T, F, L, P>> + Send + Sync,
                CM: 'static + Deref<Target = ChannelManager<Signer, CW, T, K, F, L>> + Send + Sync,
-               NG: 'static + Deref<Target = NetGraphMsgHandler<CA, L>> + Send + Sync,
+               NG: 'static + Deref<Target = NetGraphMsgHandler<G, CA, L>> + Send + Sync,
                UMH: 'static + Deref + Send + Sync,
                PM: 'static + Deref<Target = PeerManager<Descriptor, CMH, RMH, L, UMH>> + Send + Sync,
        >(
@@ -236,8 +238,7 @@ impl BackgroundProcessor {
                                        // timer, we should have disconnected all sockets by now (and they're probably
                                        // dead anyway), so disconnect them by calling `timer_tick_occurred()` twice.
                                        log_trace!(logger, "Awoke after more than double our ping timer, disconnecting peers.");
-                                       peer_manager.timer_tick_occurred();
-                                       peer_manager.timer_tick_occurred();
+                                       peer_manager.disconnect_all_peers();
                                        last_ping_call = Instant::now();
                                } else if last_ping_call.elapsed().as_secs() > PING_TIMER {
                                        log_trace!(logger, "Calling PeerManager's timer_tick_occurred");
@@ -316,6 +317,8 @@ mod tests {
        use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent};
        use lightning::util::ser::Writeable;
        use lightning::util::test_utils;
+       use lightning_invoice::payment::{InvoicePayer, RetryAttempts};
+       use lightning_invoice::utils::DefaultRouter;
        use lightning_persister::FilesystemPersister;
        use std::fs;
        use std::path::PathBuf;
@@ -339,11 +342,12 @@ mod tests {
 
        struct Node {
                node: Arc<SimpleArcChannelManager<ChainMonitor, test_utils::TestBroadcaster, test_utils::TestFeeEstimator, test_utils::TestLogger>>,
-               net_graph_msg_handler: Option<Arc<NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>>>,
+               net_graph_msg_handler: Option<Arc<NetGraphMsgHandler<Arc<NetworkGraph>, Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>>>,
                peer_manager: Arc<PeerManager<TestDescriptor, Arc<test_utils::TestChannelMessageHandler>, Arc<test_utils::TestRoutingMessageHandler>, Arc<test_utils::TestLogger>, IgnoringMessageHandler>>,
                chain_monitor: Arc<ChainMonitor>,
                persister: Arc<FilesystemPersister>,
                tx_broadcaster: Arc<test_utils::TestBroadcaster>,
+               network_graph: Arc<NetworkGraph>,
                logger: Arc<test_utils::TestLogger>,
                best_block: BestBlock,
        }
@@ -381,11 +385,11 @@ mod tests {
                        let best_block = BestBlock::from_genesis(network);
                        let params = ChainParameters { network, best_block };
                        let manager = Arc::new(ChannelManager::new(fee_estimator.clone(), chain_monitor.clone(), tx_broadcaster.clone(), logger.clone(), keys_manager.clone(), UserConfig::default(), params));
-                       let network_graph = NetworkGraph::new(genesis_block.header.block_hash());
-                       let net_graph_msg_handler = Some(Arc::new(NetGraphMsgHandler::new(network_graph, Some(chain_source.clone()), logger.clone())));
+                       let network_graph = Arc::new(NetworkGraph::new(genesis_block.header.block_hash()));
+                       let net_graph_msg_handler = Some(Arc::new(NetGraphMsgHandler::new(network_graph.clone(), Some(chain_source.clone()), logger.clone())));
                        let msg_handler = MessageHandler { chan_handler: Arc::new(test_utils::TestChannelMessageHandler::new()), route_handler: Arc::new(test_utils::TestRoutingMessageHandler::new() )};
                        let peer_manager = Arc::new(PeerManager::new(msg_handler, keys_manager.get_node_secret(), &seed, logger.clone(), IgnoringMessageHandler{}));
-                       let node = Node { node: manager, net_graph_msg_handler, peer_manager, chain_monitor, persister, tx_broadcaster, logger, best_block };
+                       let node = Node { node: manager, net_graph_msg_handler, peer_manager, chain_monitor, persister, tx_broadcaster, network_graph, logger, best_block };
                        nodes.push(node);
                }
 
@@ -621,4 +625,19 @@ mod tests {
 
                assert!(bg_processor.stop().is_ok());
        }
+
+       #[test]
+       fn test_invoice_payer() {
+               let nodes = create_nodes(2, "test_invoice_payer".to_string());
+
+               // Initiate the background processors to watch each node.
+               let data_dir = nodes[0].persister.get_data_dir();
+               let persister = move |node: &ChannelManager<InMemorySigner, Arc<ChainMonitor>, Arc<test_utils::TestBroadcaster>, Arc<KeysManager>, Arc<test_utils::TestFeeEstimator>, Arc<test_utils::TestLogger>>| FilesystemPersister::persist_manager(data_dir.clone(), node);
+               let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger));
+               let scorer = Arc::new(Mutex::new(test_utils::TestScorer::default()));
+               let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2)));
+               let event_handler = Arc::clone(&invoice_payer);
+               let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
+               assert!(bg_processor.stop().is_ok());
+       }
 }
index 06ca52f70cd50e786b45e0b56b4142f003f993f5..e9d5f52816e4e56cb7a0bc4142ca82d4964b5647 100644 (file)
@@ -1,6 +1,6 @@
 [package]
 name = "lightning-block-sync"
-version = "0.0.102"
+version = "0.0.103"
 authors = ["Jeffrey Czyz", "Matt Corallo"]
 license = "MIT OR Apache-2.0"
 repository = "http://github.com/rust-bitcoin/rust-lightning"
@@ -15,7 +15,7 @@ rpc-client = [ "serde", "serde_json", "chunked_transfer" ]
 
 [dependencies]
 bitcoin = "0.27"
-lightning = { version = "0.0.102", path = "../lightning" }
+lightning = { version = "0.0.103", path = "../lightning" }
 tokio = { version = "1.0", features = [ "io-util", "net", "time" ], optional = true }
 serde = { version = "1.0", features = ["derive"], optional = true }
 serde_json = { version = "1.0", optional = true }
index 886fbc553acae0f97c772e2fbdf9d067146e668e..010becc27198333100fb740664f92cc258f64279 100644 (file)
@@ -1,7 +1,7 @@
 [package]
 name = "lightning-invoice"
 description = "Data structures to parse and serialize BOLT11 lightning invoices"
-version = "0.10.0"
+version = "0.11.0"
 authors = ["Sebastian Geisler <sgeisler@wh2.tu-dresden.de>"]
 documentation = "https://docs.rs/lightning-invoice/"
 license = "MIT OR Apache-2.0"
@@ -10,11 +10,11 @@ readme = "README.md"
 
 [dependencies]
 bech32 = "0.8"
-lightning = { version = "0.0.102", path = "../lightning" }
+lightning = { version = "0.0.103", path = "../lightning" }
 secp256k1 = { version = "0.20", features = ["recovery"] }
 num-traits = "0.2.8"
 bitcoin_hashes = "0.10"
 
 [dev-dependencies]
 hex = "0.3"
-lightning = { version = "0.0.102", path = "../lightning", features = ["_test_utils"] }
+lightning = { version = "0.0.103", path = "../lightning", features = ["_test_utils"] }
index 61e394da6dba4ef70979a5dae1d673ce2b483732..04ba08bd8552c34379576b59047de719ef200ae3 100644 (file)
 //!   * For parsing use `str::parse::<Invoice>(&self)` (see the docs of `impl FromStr for Invoice`)
 //!   * For constructing invoices use the `InvoiceBuilder`
 //!   * For serializing invoices use the `Display`/`ToString` traits
+pub mod payment;
 pub mod utils;
 
 extern crate bech32;
 extern crate bitcoin_hashes;
-extern crate lightning;
+#[macro_use] extern crate lightning;
 extern crate num_traits;
 extern crate secp256k1;
 
@@ -378,7 +379,8 @@ pub enum TaggedField {
 
 /// SHA-256 hash
 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
-pub struct Sha256(pub sha256::Hash);
+pub struct Sha256(/// (C-not exported) as the native hash types are not currently mapped
+       pub sha256::Hash);
 
 /// Description string
 ///
@@ -1187,6 +1189,19 @@ impl Invoice {
                        .unwrap_or(Duration::from_secs(DEFAULT_EXPIRY_TIME))
        }
 
+       /// Returns whether the invoice has expired.
+       pub fn is_expired(&self) -> bool {
+               Self::is_expired_from_epoch(self.timestamp(), self.expiry_time())
+       }
+
+       /// Returns whether the expiry time from the given epoch has passed.
+       pub(crate) fn is_expired_from_epoch(epoch: &SystemTime, expiry_time: Duration) -> bool {
+               match epoch.elapsed() {
+                       Ok(elapsed) => elapsed > expiry_time,
+                       Err(_) => false,
+               }
+       }
+
        /// Returns the invoice's `min_final_cltv_expiry` time, if present, otherwise
        /// [`DEFAULT_MIN_FINAL_CLTV_EXPIRY`].
        pub fn min_final_cltv_expiry(&self) -> u64 {
@@ -1208,10 +1223,10 @@ impl Invoice {
        }
 
        /// Returns a list of all routes included in the invoice as the underlying hints
-       pub fn route_hints(&self) -> Vec<&RouteHint> {
+       pub fn route_hints(&self) -> Vec<RouteHint> {
                find_all_extract!(
                        self.signed_invoice.known_tagged_fields(), TaggedField::PrivateRoute(ref x), x
-               ).map(|route| &**route).collect()
+               ).map(|route| (**route).clone()).collect()
        }
 
        /// Returns the currency for which the invoice was issued
@@ -1219,8 +1234,13 @@ impl Invoice {
                self.signed_invoice.currency()
        }
 
+       /// Returns the amount if specified in the invoice as millisatoshis.
+       pub fn amount_milli_satoshis(&self) -> Option<u64> {
+               self.signed_invoice.amount_pico_btc().map(|v| v / 10)
+       }
+
        /// Returns the amount if specified in the invoice as pico <currency>.
-       pub fn amount_pico_btc(&self) -> Option<u64> {
+       fn amount_pico_btc(&self) -> Option<u64> {
                self.signed_invoice.amount_pico_btc()
        }
 }
@@ -1867,6 +1887,7 @@ mod test {
                assert!(invoice.check_signature().is_ok());
                assert_eq!(invoice.tagged_fields().count(), 10);
 
+               assert_eq!(invoice.amount_milli_satoshis(), Some(123));
                assert_eq!(invoice.amount_pico_btc(), Some(1230));
                assert_eq!(invoice.currency(), Currency::BitcoinTestnet);
                assert_eq!(
@@ -1913,5 +1934,33 @@ mod test {
 
                assert_eq!(invoice.min_final_cltv_expiry(), DEFAULT_MIN_FINAL_CLTV_EXPIRY);
                assert_eq!(invoice.expiry_time(), Duration::from_secs(DEFAULT_EXPIRY_TIME));
+               assert!(!invoice.is_expired());
+       }
+
+       #[test]
+       fn test_expiration() {
+               use ::*;
+               use secp256k1::Secp256k1;
+               use secp256k1::key::SecretKey;
+
+               let timestamp = SystemTime::now()
+                       .checked_sub(Duration::from_secs(DEFAULT_EXPIRY_TIME * 2))
+                       .unwrap();
+               let signed_invoice = InvoiceBuilder::new(Currency::Bitcoin)
+                       .description("Test".into())
+                       .payment_hash(sha256::Hash::from_slice(&[0;32][..]).unwrap())
+                       .payment_secret(PaymentSecret([0; 32]))
+                       .timestamp(timestamp)
+                       .build_raw()
+                       .unwrap()
+                       .sign::<_, ()>(|hash| {
+                               let privkey = SecretKey::from_slice(&[41; 32]).unwrap();
+                               let secp_ctx = Secp256k1::new();
+                               Ok(secp_ctx.sign_recoverable(hash, &privkey))
+                       })
+                       .unwrap();
+               let invoice = Invoice::from_signed(signed_invoice).unwrap();
+
+               assert!(invoice.is_expired());
        }
 }
diff --git a/lightning-invoice/src/payment.rs b/lightning-invoice/src/payment.rs
new file mode 100644 (file)
index 0000000..075559b
--- /dev/null
@@ -0,0 +1,1335 @@
+// This file is Copyright its original authors, visible in version control
+// history.
+//
+// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
+// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
+// You may not use this file except in accordance with one or both of these
+// licenses.
+
+//! A module for paying Lightning invoices.
+//!
+//! Defines an [`InvoicePayer`] utility for paying invoices, parameterized by [`Payer`] and
+//! [`Router`] traits. Implementations of [`Payer`] provide the payer's node id, channels, and means
+//! to send a payment over a [`Route`]. Implementations of [`Router`] find a [`Route`] between payer
+//! and payee using information provided by the payer and from the payee's [`Invoice`].
+//!
+//! [`InvoicePayer`] is capable of retrying failed payments. It accomplishes this by implementing
+//! [`EventHandler`] which decorates a user-provided handler. It will intercept any
+//! [`Event::PaymentPathFailed`] events and retry the failed paths for a fixed number of total
+//! attempts or until retry is no longer possible. In such a situation, [`InvoicePayer`] will pass
+//! along the events to the user-provided handler.
+//!
+//! # Example
+//!
+//! ```
+//! # extern crate lightning;
+//! # extern crate lightning_invoice;
+//! # extern crate secp256k1;
+//! #
+//! # use lightning::ln::{PaymentHash, PaymentSecret};
+//! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
+//! # use lightning::ln::msgs::LightningError;
+//! # use lightning::routing;
+//! # use lightning::routing::network_graph::NodeId;
+//! # use lightning::routing::router::{Route, RouteHop, RouteParameters};
+//! # use lightning::util::events::{Event, EventHandler, EventsProvider};
+//! # use lightning::util::logger::{Logger, Record};
+//! # use lightning_invoice::Invoice;
+//! # use lightning_invoice::payment::{InvoicePayer, Payer, RetryAttempts, Router};
+//! # use secp256k1::key::PublicKey;
+//! # use std::cell::RefCell;
+//! # use std::ops::Deref;
+//! #
+//! # struct FakeEventProvider {}
+//! # impl EventsProvider for FakeEventProvider {
+//! #     fn process_pending_events<H: Deref>(&self, handler: H) where H::Target: EventHandler {}
+//! # }
+//! #
+//! # struct FakePayer {}
+//! # impl Payer for FakePayer {
+//! #     fn node_id(&self) -> PublicKey { unimplemented!() }
+//! #     fn first_hops(&self) -> Vec<ChannelDetails> { unimplemented!() }
+//! #     fn send_payment(
+//! #         &self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>
+//! #     ) -> Result<PaymentId, PaymentSendFailure> { unimplemented!() }
+//! #     fn retry_payment(
+//! #         &self, route: &Route, payment_id: PaymentId
+//! #     ) -> Result<(), PaymentSendFailure> { unimplemented!() }
+//! # }
+//! #
+//! # struct FakeRouter {};
+//! # impl<S: routing::Score> Router<S> for FakeRouter {
+//! #     fn find_route(
+//! #         &self, payer: &PublicKey, params: &RouteParameters,
+//! #         first_hops: Option<&[&ChannelDetails]>, scorer: &S
+//! #     ) -> Result<Route, LightningError> { unimplemented!() }
+//! # }
+//! #
+//! # struct FakeScorer {};
+//! # impl routing::Score for FakeScorer {
+//! #     fn channel_penalty_msat(
+//! #         &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+//! #     ) -> u64 { 0 }
+//! #     fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}
+//! # }
+//! #
+//! # struct FakeLogger {};
+//! # impl Logger for FakeLogger {
+//! #     fn log(&self, record: &Record) { unimplemented!() }
+//! # }
+//! #
+//! # fn main() {
+//! let event_handler = |event: &Event| {
+//!     match event {
+//!         Event::PaymentPathFailed { .. } => println!("payment failed after retries"),
+//!         Event::PaymentSent { .. } => println!("payment successful"),
+//!         _ => {},
+//!     }
+//! };
+//! # let payer = FakePayer {};
+//! # let router = FakeRouter {};
+//! # let scorer = RefCell::new(FakeScorer {});
+//! # let logger = FakeLogger {};
+//! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+//!
+//! let invoice = "...";
+//! let invoice = invoice.parse::<Invoice>().unwrap();
+//! invoice_payer.pay_invoice(&invoice).unwrap();
+//!
+//! # let event_provider = FakeEventProvider {};
+//! loop {
+//!     event_provider.process_pending_events(&invoice_payer);
+//! }
+//! # }
+//! ```
+//!
+//! # Note
+//!
+//! The [`Route`] is computed before each payment attempt. Any updates affecting path finding such
+//! as updates to the network graph or changes to channel scores should be applied prior to
+//! retries, typically by way of composing [`EventHandler`]s accordingly.
+
+use crate::Invoice;
+
+use bitcoin_hashes::Hash;
+
+use lightning::ln::{PaymentHash, PaymentSecret};
+use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
+use lightning::ln::msgs::LightningError;
+use lightning::routing;
+use lightning::routing::{LockableScore, Score};
+use lightning::routing::router::{Payee, Route, RouteParameters};
+use lightning::util::events::{Event, EventHandler};
+use lightning::util::logger::Logger;
+
+use secp256k1::key::PublicKey;
+
+use std::collections::hash_map::{self, HashMap};
+use std::ops::Deref;
+use std::sync::Mutex;
+use std::time::{Duration, SystemTime};
+
+/// A utility for paying [`Invoice]`s.
+pub struct InvoicePayer<P: Deref, R, S: Deref, L: Deref, E>
+where
+       P::Target: Payer,
+       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
+       S::Target: for <'a> routing::LockableScore<'a>,
+       L::Target: Logger,
+       E: EventHandler,
+{
+       payer: P,
+       router: R,
+       scorer: S,
+       logger: L,
+       event_handler: E,
+       payment_cache: Mutex<HashMap<PaymentHash, usize>>,
+       retry_attempts: RetryAttempts,
+}
+
+/// A trait defining behavior of an [`Invoice`] payer.
+pub trait Payer {
+       /// Returns the payer's node id.
+       fn node_id(&self) -> PublicKey;
+
+       /// Returns the payer's channels.
+       fn first_hops(&self) -> Vec<ChannelDetails>;
+
+       /// Sends a payment over the Lightning Network using the given [`Route`].
+       fn send_payment(
+               &self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>
+       ) -> Result<PaymentId, PaymentSendFailure>;
+
+       /// Retries a failed payment path for the [`PaymentId`] using the given [`Route`].
+       fn retry_payment(&self, route: &Route, payment_id: PaymentId) -> Result<(), PaymentSendFailure>;
+}
+
+/// A trait defining behavior for routing an [`Invoice`] payment.
+pub trait Router<S: routing::Score> {
+       /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values.
+       fn find_route(
+               &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+               scorer: &S
+       ) -> Result<Route, LightningError>;
+}
+
+/// Number of attempts to retry payment path failures for an [`Invoice`].
+///
+/// Note that this is the number of *path* failures, not full payment retries. For multi-path
+/// payments, if this is less than the total number of paths, we will never even retry all of the
+/// payment's paths.
+#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
+pub struct RetryAttempts(pub usize);
+
+/// An error that may occur when making a payment.
+#[derive(Clone, Debug)]
+pub enum PaymentError {
+       /// An error resulting from the provided [`Invoice`] or payment hash.
+       Invoice(&'static str),
+       /// An error occurring when finding a route.
+       Routing(LightningError),
+       /// An error occurring when sending a payment.
+       Sending(PaymentSendFailure),
+}
+
+impl<P: Deref, R, S: Deref, L: Deref, E> InvoicePayer<P, R, S, L, E>
+where
+       P::Target: Payer,
+       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
+       S::Target: for <'a> routing::LockableScore<'a>,
+       L::Target: Logger,
+       E: EventHandler,
+{
+       /// Creates an invoice payer that retries failed payment paths.
+       ///
+       /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once
+       /// `retry_attempts` has been exceeded for a given [`Invoice`].
+       pub fn new(
+               payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts
+       ) -> Self {
+               Self {
+                       payer,
+                       router,
+                       scorer,
+                       logger,
+                       event_handler,
+                       payment_cache: Mutex::new(HashMap::new()),
+                       retry_attempts,
+               }
+       }
+
+       /// Pays the given [`Invoice`], caching it for later use in case a retry is needed.
+       ///
+       /// You should ensure that the `invoice.payment_hash()` is unique and the same payment_hash has
+       /// never been paid before. Because [`InvoicePayer`] is stateless no effort is made to do so
+       /// for you.
+       pub fn pay_invoice(&self, invoice: &Invoice) -> Result<PaymentId, PaymentError> {
+               if invoice.amount_milli_satoshis().is_none() {
+                       Err(PaymentError::Invoice("amount missing"))
+               } else {
+                       self.pay_invoice_internal(invoice, None, 0)
+               }
+       }
+
+       /// Pays the given zero-value [`Invoice`] using the given amount, caching it for later use in
+       /// case a retry is needed.
+       ///
+       /// You should ensure that the `invoice.payment_hash()` is unique and the same payment_hash has
+       /// never been paid before. Because [`InvoicePayer`] is stateless no effort is made to do so
+       /// for you.
+       pub fn pay_zero_value_invoice(
+               &self, invoice: &Invoice, amount_msats: u64
+       ) -> Result<PaymentId, PaymentError> {
+               if invoice.amount_milli_satoshis().is_some() {
+                       Err(PaymentError::Invoice("amount unexpected"))
+               } else {
+                       self.pay_invoice_internal(invoice, Some(amount_msats), 0)
+               }
+       }
+
+       fn pay_invoice_internal(
+               &self, invoice: &Invoice, amount_msats: Option<u64>, retry_count: usize
+       ) -> Result<PaymentId, PaymentError> {
+               debug_assert!(invoice.amount_milli_satoshis().is_some() ^ amount_msats.is_some());
+               let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+               if invoice.is_expired() {
+                       log_trace!(self.logger, "Invoice expired prior to first send for payment {}", log_bytes!(payment_hash.0));
+                       return Err(PaymentError::Invoice("Invoice expired prior to send"));
+               }
+               let retry_data_payment_id = loop {
+                       let mut payment_cache = self.payment_cache.lock().unwrap();
+                       match payment_cache.entry(payment_hash) {
+                               hash_map::Entry::Vacant(entry) => {
+                                       let payer = self.payer.node_id();
+                                       let mut payee = Payee::from_node_id(invoice.recover_payee_pub_key())
+                                               .with_expiry_time(expiry_time_from_unix_epoch(&invoice).as_secs())
+                                               .with_route_hints(invoice.route_hints());
+                                       if let Some(features) = invoice.features() {
+                                               payee = payee.with_features(features.clone());
+                                       }
+                                       let params = RouteParameters {
+                                               payee,
+                                               final_value_msat: invoice.amount_milli_satoshis().or(amount_msats).unwrap(),
+                                               final_cltv_expiry_delta: invoice.min_final_cltv_expiry() as u32,
+                                       };
+                                       let first_hops = self.payer.first_hops();
+                                       let route = self.router.find_route(
+                                               &payer,
+                                               &params,
+                                               Some(&first_hops.iter().collect::<Vec<_>>()),
+                                               &self.scorer.lock(),
+                                       ).map_err(|e| PaymentError::Routing(e))?;
+
+                                       let payment_secret = Some(invoice.payment_secret().clone());
+                                       let payment_id = match self.payer.send_payment(&route, payment_hash, &payment_secret) {
+                                               Ok(payment_id) => payment_id,
+                                               Err(PaymentSendFailure::ParameterError(e)) =>
+                                                       return Err(PaymentError::Sending(PaymentSendFailure::ParameterError(e))),
+                                               Err(PaymentSendFailure::PathParameterError(e)) =>
+                                                       return Err(PaymentError::Sending(PaymentSendFailure::PathParameterError(e))),
+                                               Err(PaymentSendFailure::AllFailedRetrySafe(e)) => {
+                                                       if retry_count >= self.retry_attempts.0 {
+                                                               return Err(PaymentError::Sending(PaymentSendFailure::AllFailedRetrySafe(e)))
+                                                       }
+                                                       break None;
+                                               },
+                                               Err(PaymentSendFailure::PartialFailure { results: _, failed_paths_retry, payment_id }) => {
+                                                       if let Some(retry_data) = failed_paths_retry {
+                                                               entry.insert(retry_count);
+                                                               break Some((retry_data, payment_id));
+                                                       } else {
+                                                               // This may happen if we send a payment and some paths fail, but
+                                                               // only due to a temporary monitor failure or the like, implying
+                                                               // they're really in-flight, but we haven't sent the initial
+                                                               // HTLC-Add messages yet.
+                                                               payment_id
+                                                       }
+                                               },
+                                       };
+                                       entry.insert(retry_count);
+                                       return Ok(payment_id);
+                               },
+                               hash_map::Entry::Occupied(_) => return Err(PaymentError::Invoice("payment pending")),
+                       }
+               };
+               if let Some((retry_data, payment_id)) = retry_data_payment_id {
+                       // Some paths were sent, even if we failed to send the full MPP value our recipient may
+                       // misbehave and claim the funds, at which point we have to consider the payment sent,
+                       // so return `Ok()` here, ignoring any retry errors.
+                       let _ = self.retry_payment(payment_id, payment_hash, &retry_data);
+                       Ok(payment_id)
+               } else {
+                       self.pay_invoice_internal(invoice, amount_msats, retry_count + 1)
+               }
+       }
+
+       fn retry_payment(&self, payment_id: PaymentId, payment_hash: PaymentHash, params: &RouteParameters)
+       -> Result<(), ()> {
+               let route;
+               {
+                       let mut payment_cache = self.payment_cache.lock().unwrap();
+                       let entry = loop {
+                               let entry = payment_cache.entry(payment_hash);
+                               match entry {
+                                       hash_map::Entry::Occupied(_) => break entry,
+                                       hash_map::Entry::Vacant(entry) => entry.insert(0),
+                               };
+                       };
+                       if let hash_map::Entry::Occupied(mut entry) = entry {
+                               let max_payment_attempts = self.retry_attempts.0 + 1;
+                               let attempts = entry.get_mut();
+                               *attempts += 1;
+
+                               if *attempts >= max_payment_attempts {
+                                       log_trace!(self.logger, "Payment {} exceeded maximum attempts; not retrying (attempts: {})", log_bytes!(payment_hash.0), attempts);
+                                       return Err(());
+                               } else if has_expired(params) {
+                                       log_trace!(self.logger, "Invoice expired for payment {}; not retrying (attempts: {})", log_bytes!(payment_hash.0), attempts);
+                                       return Err(());
+                               }
+
+                               let payer = self.payer.node_id();
+                               let first_hops = self.payer.first_hops();
+                               route = self.router.find_route(&payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()), &self.scorer.lock());
+                               if route.is_err() {
+                                       log_trace!(self.logger, "Failed to find a route for payment {}; not retrying (attempts: {})", log_bytes!(payment_hash.0), attempts);
+                                       return Err(());
+                               }
+                       } else {
+                               unreachable!();
+                       }
+               }
+
+               let retry_res = self.payer.retry_payment(&route.unwrap(), payment_id);
+               match retry_res {
+                       Ok(()) => Ok(()),
+                       Err(PaymentSendFailure::ParameterError(_)) |
+                       Err(PaymentSendFailure::PathParameterError(_)) => {
+                               log_trace!(self.logger, "Failed to retry for payment {} due to bogus route/payment data, not retrying.", log_bytes!(payment_hash.0));
+                               return Err(());
+                       },
+                       Err(PaymentSendFailure::AllFailedRetrySafe(_)) => {
+                               self.retry_payment(payment_id, payment_hash, params)
+                       },
+                       Err(PaymentSendFailure::PartialFailure { results: _, failed_paths_retry, .. }) => {
+                               if let Some(retry) = failed_paths_retry {
+                                       self.retry_payment(payment_id, payment_hash, &retry)
+                               } else {
+                                       Ok(())
+                               }
+                       },
+               }
+       }
+
+       /// Removes the payment cached by the given payment hash.
+       ///
+       /// Should be called once a payment has failed or succeeded if not using [`InvoicePayer`] as an
+       /// [`EventHandler`]. Otherwise, calling this method is unnecessary.
+       pub fn remove_cached_payment(&self, payment_hash: &PaymentHash) {
+               self.payment_cache.lock().unwrap().remove(payment_hash);
+       }
+}
+
+fn expiry_time_from_unix_epoch(invoice: &Invoice) -> Duration {
+       invoice.timestamp().duration_since(SystemTime::UNIX_EPOCH).unwrap() + invoice.expiry_time()
+}
+
+fn has_expired(params: &RouteParameters) -> bool {
+       if let Some(expiry_time) = params.payee.expiry_time {
+               Invoice::is_expired_from_epoch(&SystemTime::UNIX_EPOCH, Duration::from_secs(expiry_time))
+       } else { false }
+}
+
+impl<P: Deref, R, S: Deref, L: Deref, E> EventHandler for InvoicePayer<P, R, S, L, E>
+where
+       P::Target: Payer,
+       R: for <'a> Router<<<S as Deref>::Target as routing::LockableScore<'a>>::Locked>,
+       S::Target: for <'a> routing::LockableScore<'a>,
+       L::Target: Logger,
+       E: EventHandler,
+{
+       fn handle_event(&self, event: &Event) {
+               match event {
+                       Event::PaymentPathFailed {
+                               all_paths_failed, payment_id, payment_hash, rejected_by_dest, path, short_channel_id, retry, ..
+                       } => {
+                               if let Some(short_channel_id) = short_channel_id {
+                                       let t = path.iter().collect::<Vec<_>>();
+                                       self.scorer.lock().payment_path_failed(&t, *short_channel_id);
+                               }
+
+                               if *rejected_by_dest {
+                                       log_trace!(self.logger, "Payment {} rejected by destination; not retrying", log_bytes!(payment_hash.0));
+                               } else if payment_id.is_none() {
+                                       log_trace!(self.logger, "Payment {} has no id; not retrying", log_bytes!(payment_hash.0));
+                               } else if let Some(params) = retry {
+                                       if self.retry_payment(payment_id.unwrap(), *payment_hash, params).is_ok() {
+                                               // We retried at least somewhat, don't provide the PaymentPathFailed event to the user.
+                                               return;
+                                       }
+                               } else {
+                                       log_trace!(self.logger, "Payment {} missing retry params; not retrying", log_bytes!(payment_hash.0));
+                               }
+                               if *all_paths_failed { self.payment_cache.lock().unwrap().remove(payment_hash); }
+                       },
+                       Event::PaymentSent { payment_hash, .. } => {
+                               let mut payment_cache = self.payment_cache.lock().unwrap();
+                               let attempts = payment_cache
+                                       .remove(payment_hash)
+                                       .map_or(1, |attempts| attempts + 1);
+                               log_trace!(self.logger, "Payment {} succeeded (attempts: {})", log_bytes!(payment_hash.0), attempts);
+                       },
+                       _ => {},
+               }
+
+               // Delegate to the decorated event handler unless the payment is retried.
+               self.event_handler.handle_event(event)
+       }
+}
+
+#[cfg(test)]
+mod tests {
+       use super::*;
+       use crate::{DEFAULT_EXPIRY_TIME, InvoiceBuilder, Currency};
+       use utils::create_invoice_from_channelmanager;
+       use bitcoin_hashes::sha256::Hash as Sha256;
+       use lightning::ln::PaymentPreimage;
+       use lightning::ln::features::{ChannelFeatures, NodeFeatures, InitFeatures};
+       use lightning::ln::functional_test_utils::*;
+       use lightning::ln::msgs::{ErrorAction, LightningError};
+       use lightning::routing::network_graph::NodeId;
+       use lightning::routing::router::{Payee, Route, RouteHop};
+       use lightning::util::test_utils::TestLogger;
+       use lightning::util::errors::APIError;
+       use lightning::util::events::{Event, MessageSendEventsProvider};
+       use secp256k1::{SecretKey, PublicKey, Secp256k1};
+       use std::cell::RefCell;
+       use std::collections::VecDeque;
+       use std::time::{SystemTime, Duration};
+
+       fn invoice(payment_preimage: PaymentPreimage) -> Invoice {
+               let payment_hash = Sha256::hash(&payment_preimage.0);
+               let private_key = SecretKey::from_slice(&[42; 32]).unwrap();
+               InvoiceBuilder::new(Currency::Bitcoin)
+                       .description("test".into())
+                       .payment_hash(payment_hash)
+                       .payment_secret(PaymentSecret([0; 32]))
+                       .current_timestamp()
+                       .min_final_cltv_expiry(144)
+                       .amount_milli_satoshis(128)
+                       .build_signed(|hash| {
+                               Secp256k1::new().sign_recoverable(hash, &private_key)
+                       })
+                       .unwrap()
+       }
+
+       fn zero_value_invoice(payment_preimage: PaymentPreimage) -> Invoice {
+               let payment_hash = Sha256::hash(&payment_preimage.0);
+               let private_key = SecretKey::from_slice(&[42; 32]).unwrap();
+               InvoiceBuilder::new(Currency::Bitcoin)
+                       .description("test".into())
+                       .payment_hash(payment_hash)
+                       .payment_secret(PaymentSecret([0; 32]))
+                       .current_timestamp()
+                       .min_final_cltv_expiry(144)
+                       .build_signed(|hash| {
+                               Secp256k1::new().sign_recoverable(hash, &private_key)
+                       })
+                       .unwrap()
+       }
+
+       fn expired_invoice(payment_preimage: PaymentPreimage) -> Invoice {
+               let payment_hash = Sha256::hash(&payment_preimage.0);
+               let private_key = SecretKey::from_slice(&[42; 32]).unwrap();
+               let timestamp = SystemTime::now()
+                       .checked_sub(Duration::from_secs(DEFAULT_EXPIRY_TIME * 2))
+                       .unwrap();
+               InvoiceBuilder::new(Currency::Bitcoin)
+                       .description("test".into())
+                       .payment_hash(payment_hash)
+                       .payment_secret(PaymentSecret([0; 32]))
+                       .timestamp(timestamp)
+                       .min_final_cltv_expiry(144)
+                       .amount_milli_satoshis(128)
+                       .build_signed(|hash| {
+                               Secp256k1::new().sign_recoverable(hash, &private_key)
+                       })
+                       .unwrap()
+       }
+
+       #[test]
+       fn pays_invoice_on_first_attempt() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
+
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               invoice_payer.handle_event(&Event::PaymentSent {
+                       payment_id, payment_preimage, payment_hash, fee_paid_msat: None
+               });
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 1);
+       }
+
+       #[test]
+       fn pays_invoice_on_retry() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new()
+                       .expect_value_msat(final_value_msat)
+                       .expect_value_msat(final_value_msat / 2);
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash,
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path: TestRouter::path_for_value(final_value_msat),
+                       short_channel_id: None,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), false);
+               assert_eq!(*payer.attempts.borrow(), 2);
+
+               invoice_payer.handle_event(&Event::PaymentSent {
+                       payment_id, payment_preimage, payment_hash, fee_paid_msat: None
+               });
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 2);
+       }
+
+       #[test]
+       fn retries_payment_path_for_unknown_payment() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_id = Some(PaymentId([1; 32]));
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash,
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path: TestRouter::path_for_value(final_value_msat),
+                       short_channel_id: None,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), false);
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), false);
+               assert_eq!(*payer.attempts.borrow(), 2);
+
+               invoice_payer.handle_event(&Event::PaymentSent {
+                       payment_id, payment_preimage, payment_hash, fee_paid_msat: None
+               });
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 2);
+       }
+
+       #[test]
+       fn fails_paying_invoice_after_max_retries() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new()
+                       .expect_value_msat(final_value_msat)
+                       .expect_value_msat(final_value_msat / 2)
+                       .expect_value_msat(final_value_msat / 2);
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash: PaymentHash(invoice.payment_hash().clone().into_inner()),
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: true,
+                       path: TestRouter::path_for_value(final_value_msat),
+                       short_channel_id: None,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), false);
+               assert_eq!(*payer.attempts.borrow(), 2);
+
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash: PaymentHash(invoice.payment_hash().clone().into_inner()),
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path: TestRouter::path_for_value(final_value_msat / 2),
+                       short_channel_id: None,
+                       retry: Some(RouteParameters {
+                               final_value_msat: final_value_msat / 2, ..TestRouter::retry_for_invoice(&invoice)
+                       }),
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), false);
+               assert_eq!(*payer.attempts.borrow(), 3);
+
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 3);
+       }
+
+       #[test]
+       fn fails_paying_invoice_with_missing_retry_params() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash: PaymentHash(invoice.payment_hash().clone().into_inner()),
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path: vec![],
+                       short_channel_id: None,
+                       retry: None,
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 1);
+       }
+
+       #[test]
+       fn fails_paying_invoice_after_expiration() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = expired_invoice(payment_preimage);
+               if let PaymentError::Invoice(msg) = invoice_payer.pay_invoice(&invoice).unwrap_err() {
+                       assert_eq!(msg, "Invoice expired prior to send");
+               } else { panic!("Expected Invoice Error"); }
+       }
+
+       #[test]
+       fn fails_retrying_invoice_after_expiration() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               let mut retry_data = TestRouter::retry_for_invoice(&invoice);
+               retry_data.payee.expiry_time = Some(SystemTime::now()
+                       .checked_sub(Duration::from_secs(2)).unwrap()
+                       .duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs());
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash: PaymentHash(invoice.payment_hash().clone().into_inner()),
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path: vec![],
+                       short_channel_id: None,
+                       retry: Some(retry_data),
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 1);
+       }
+
+       #[test]
+       fn fails_paying_invoice_after_retry_error() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+
+               let payer = TestPayer::new()
+                       .fails_on_attempt(2)
+                       .expect_value_msat(final_value_msat);
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash: PaymentHash(invoice.payment_hash().clone().into_inner()),
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path: TestRouter::path_for_value(final_value_msat / 2),
+                       short_channel_id: None,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 2);
+       }
+
+       #[test]
+       fn fails_paying_invoice_after_rejected_by_payee() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash: PaymentHash(invoice.payment_hash().clone().into_inner()),
+                       network_update: None,
+                       rejected_by_dest: true,
+                       all_paths_failed: false,
+                       path: vec![],
+                       short_channel_id: None,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 1);
+       }
+
+       #[test]
+       fn fails_repaying_invoice_with_pending_payment() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+
+               // Cannot repay an invoice pending payment.
+               match invoice_payer.pay_invoice(&invoice) {
+                       Err(PaymentError::Invoice("payment pending")) => {},
+                       Err(_) => panic!("unexpected error"),
+                       Ok(_) => panic!("expected invoice error"),
+               }
+
+               // Can repay an invoice once cleared from cache.
+               let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+               invoice_payer.remove_cached_payment(&payment_hash);
+               assert!(invoice_payer.pay_invoice(&invoice).is_ok());
+
+               // Cannot retry paying an invoice if cleared from cache.
+               invoice_payer.remove_cached_payment(&payment_hash);
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash,
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path: vec![],
+                       short_channel_id: None,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+               assert_eq!(*event_handled.borrow(), true);
+       }
+
+       #[test]
+       fn fails_paying_invoice_with_routing_errors() {
+               let payer = TestPayer::new();
+               let router = FailingRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               match invoice_payer.pay_invoice(&invoice) {
+                       Err(PaymentError::Routing(_)) => {},
+                       Err(_) => panic!("unexpected error"),
+                       Ok(_) => panic!("expected routing error"),
+               }
+       }
+
+       #[test]
+       fn fails_paying_invoice_with_sending_errors() {
+               let payer = TestPayer::new().fails_on_attempt(1);
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               match invoice_payer.pay_invoice(&invoice) {
+                       Err(PaymentError::Sending(_)) => {},
+                       Err(_) => panic!("unexpected error"),
+                       Ok(_) => panic!("expected sending error"),
+               }
+       }
+
+       #[test]
+       fn pays_zero_value_invoice_using_amount() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = zero_value_invoice(payment_preimage);
+               let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+               let final_value_msat = 100;
+
+               let payer = TestPayer::new().expect_value_msat(final_value_msat);
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
+
+               let payment_id =
+                       Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap());
+               assert_eq!(*payer.attempts.borrow(), 1);
+
+               invoice_payer.handle_event(&Event::PaymentSent {
+                       payment_id, payment_preimage, payment_hash, fee_paid_msat: None
+               });
+               assert_eq!(*event_handled.borrow(), true);
+               assert_eq!(*payer.attempts.borrow(), 1);
+       }
+
+       #[test]
+       fn fails_paying_zero_value_invoice_with_amount() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new());
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+
+               // Cannot repay an invoice pending payment.
+               match invoice_payer.pay_zero_value_invoice(&invoice, 100) {
+                       Err(PaymentError::Invoice("amount unexpected")) => {},
+                       Err(_) => panic!("unexpected error"),
+                       Ok(_) => panic!("expected invoice error"),
+               }
+       }
+
+       #[test]
+       fn scores_failed_channel() {
+               let event_handled = core::cell::RefCell::new(false);
+               let event_handler = |_: &_| { *event_handled.borrow_mut() = true; };
+
+               let payment_preimage = PaymentPreimage([1; 32]);
+               let invoice = invoice(payment_preimage);
+               let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
+               let final_value_msat = invoice.amount_milli_satoshis().unwrap();
+               let path = TestRouter::path_for_value(final_value_msat);
+               let short_channel_id = Some(path[0].short_channel_id);
+
+               // Expect that scorer is given short_channel_id upon handling the event.
+               let payer = TestPayer::new();
+               let router = TestRouter {};
+               let scorer = RefCell::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap()));
+               let logger = TestLogger::new();
+               let invoice_payer =
+                       InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
+
+               let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
+               let event = Event::PaymentPathFailed {
+                       payment_id,
+                       payment_hash,
+                       network_update: None,
+                       rejected_by_dest: false,
+                       all_paths_failed: false,
+                       path,
+                       short_channel_id,
+                       retry: Some(TestRouter::retry_for_invoice(&invoice)),
+               };
+               invoice_payer.handle_event(&event);
+       }
+
+       struct TestRouter;
+
+       impl TestRouter {
+               fn route_for_value(final_value_msat: u64) -> Route {
+                       Route {
+                               paths: vec![
+                                       vec![RouteHop {
+                                               pubkey: PublicKey::from_slice(&hex::decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619").unwrap()[..]).unwrap(),
+                                               channel_features: ChannelFeatures::empty(),
+                                               node_features: NodeFeatures::empty(),
+                                               short_channel_id: 0, fee_msat: final_value_msat / 2, cltv_expiry_delta: 144
+                                       }],
+                                       vec![RouteHop {
+                                               pubkey: PublicKey::from_slice(&hex::decode("0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c").unwrap()[..]).unwrap(),
+                                               channel_features: ChannelFeatures::empty(),
+                                               node_features: NodeFeatures::empty(),
+                                               short_channel_id: 1, fee_msat: final_value_msat / 2, cltv_expiry_delta: 144
+                                       }],
+                               ],
+                               payee: None,
+                       }
+               }
+
+               fn path_for_value(final_value_msat: u64) -> Vec<RouteHop> {
+                       TestRouter::route_for_value(final_value_msat).paths[0].clone()
+               }
+
+               fn retry_for_invoice(invoice: &Invoice) -> RouteParameters {
+                       let mut payee = Payee::from_node_id(invoice.recover_payee_pub_key())
+                               .with_expiry_time(expiry_time_from_unix_epoch(invoice).as_secs())
+                               .with_route_hints(invoice.route_hints());
+                       if let Some(features) = invoice.features() {
+                               payee = payee.with_features(features.clone());
+                       }
+                       let final_value_msat = invoice.amount_milli_satoshis().unwrap() / 2;
+                       RouteParameters {
+                               payee,
+                               final_value_msat,
+                               final_cltv_expiry_delta: invoice.min_final_cltv_expiry() as u32,
+                       }
+               }
+       }
+
+       impl<S: routing::Score> Router<S> for TestRouter {
+               fn find_route(
+                       &self,
+                       _payer: &PublicKey,
+                       params: &RouteParameters,
+                       _first_hops: Option<&[&ChannelDetails]>,
+                       _scorer: &S,
+               ) -> Result<Route, LightningError> {
+                       Ok(Route {
+                               payee: Some(params.payee.clone()), ..Self::route_for_value(params.final_value_msat)
+                       })
+               }
+       }
+
+       struct FailingRouter;
+
+       impl<S: routing::Score> Router<S> for FailingRouter {
+               fn find_route(
+                       &self,
+                       _payer: &PublicKey,
+                       _params: &RouteParameters,
+                       _first_hops: Option<&[&ChannelDetails]>,
+                       _scorer: &S,
+               ) -> Result<Route, LightningError> {
+                       Err(LightningError { err: String::new(), action: ErrorAction::IgnoreError })
+               }
+       }
+
+       struct TestScorer {
+               expectations: VecDeque<u64>,
+       }
+
+       impl TestScorer {
+               fn new() -> Self {
+                       Self {
+                               expectations: VecDeque::new(),
+                       }
+               }
+
+               fn expect_channel_failure(mut self, short_channel_id: u64) -> Self {
+                       self.expectations.push_back(short_channel_id);
+                       self
+               }
+       }
+
+       impl routing::Score for TestScorer {
+               fn channel_penalty_msat(
+                       &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+               ) -> u64 { 0 }
+
+               fn payment_path_failed(&mut self, _path: &[&RouteHop], short_channel_id: u64) {
+                       if let Some(expected_short_channel_id) = self.expectations.pop_front() {
+                               assert_eq!(short_channel_id, expected_short_channel_id);
+                       }
+               }
+       }
+
+       impl Drop for TestScorer {
+               fn drop(&mut self) {
+                       if std::thread::panicking() {
+                               return;
+                       }
+
+                       if !self.expectations.is_empty() {
+                               panic!("Unsatisfied channel failure expectations: {:?}", self.expectations);
+                       }
+               }
+       }
+
+       struct TestPayer {
+               expectations: core::cell::RefCell<VecDeque<u64>>,
+               attempts: core::cell::RefCell<usize>,
+               failing_on_attempt: Option<usize>,
+       }
+
+       impl TestPayer {
+               fn new() -> Self {
+                       Self {
+                               expectations: core::cell::RefCell::new(VecDeque::new()),
+                               attempts: core::cell::RefCell::new(0),
+                               failing_on_attempt: None,
+                       }
+               }
+
+               fn expect_value_msat(self, value_msat: u64) -> Self {
+                       self.expectations.borrow_mut().push_back(value_msat);
+                       self
+               }
+
+               fn fails_on_attempt(self, attempt: usize) -> Self {
+                       Self {
+                               expectations: core::cell::RefCell::new(self.expectations.borrow().clone()),
+                               attempts: core::cell::RefCell::new(0),
+                               failing_on_attempt: Some(attempt),
+                       }
+               }
+
+               fn check_attempts(&self) -> bool {
+                       let mut attempts = self.attempts.borrow_mut();
+                       *attempts += 1;
+                       match self.failing_on_attempt {
+                               None => true,
+                               Some(attempt) if attempt != *attempts => true,
+                               Some(_) => false,
+                       }
+               }
+
+               fn check_value_msats(&self, route: &Route) {
+                       let expected_value_msats = self.expectations.borrow_mut().pop_front();
+                       if let Some(expected_value_msats) = expected_value_msats {
+                               let actual_value_msats = route.get_total_amount();
+                               assert_eq!(actual_value_msats, expected_value_msats);
+                       }
+               }
+       }
+
+       impl Drop for TestPayer {
+               fn drop(&mut self) {
+                       if std::thread::panicking() {
+                               return;
+                       }
+
+                       if !self.expectations.borrow().is_empty() {
+                               panic!("Unsatisfied payment expectations: {:?}", self.expectations.borrow());
+                       }
+               }
+       }
+
+       impl Payer for TestPayer {
+               fn node_id(&self) -> PublicKey {
+                       let secp_ctx = Secp256k1::new();
+                       PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap())
+               }
+
+               fn first_hops(&self) -> Vec<ChannelDetails> {
+                       Vec::new()
+               }
+
+               fn send_payment(
+                       &self,
+                       route: &Route,
+                       _payment_hash: PaymentHash,
+                       _payment_secret: &Option<PaymentSecret>
+               ) -> Result<PaymentId, PaymentSendFailure> {
+                       if self.check_attempts() {
+                               self.check_value_msats(route);
+                               Ok(PaymentId([1; 32]))
+                       } else {
+                               Err(PaymentSendFailure::ParameterError(APIError::MonitorUpdateFailed))
+                       }
+               }
+
+               fn retry_payment(
+                       &self, route: &Route, _payment_id: PaymentId
+               ) -> Result<(), PaymentSendFailure> {
+                       if self.check_attempts() {
+                               self.check_value_msats(route);
+                               Ok(())
+                       } else {
+                               Err(PaymentSendFailure::ParameterError(APIError::MonitorUpdateFailed))
+                       }
+               }
+       }
+
+       // *** Full Featured Functional Tests with a Real ChannelManager ***
+       struct ManualRouter(RefCell<VecDeque<Result<Route, LightningError>>>);
+
+       impl<S: routing::Score> Router<S> for ManualRouter {
+               fn find_route(&self, _payer: &PublicKey, _params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>, _scorer: &S)
+               -> Result<Route, LightningError> {
+                       self.0.borrow_mut().pop_front().unwrap()
+               }
+       }
+       impl ManualRouter {
+               fn expect_find_route(&self, result: Result<Route, LightningError>) {
+                       self.0.borrow_mut().push_back(result);
+               }
+       }
+       impl Drop for ManualRouter {
+               fn drop(&mut self) {
+                       if std::thread::panicking() {
+                               return;
+                       }
+                       assert!(self.0.borrow_mut().is_empty());
+               }
+       }
+
+       #[test]
+       fn retry_multi_path_single_failed_payment() {
+               // Tests that we can/will retry after a single path of an MPP payment failed immediately
+               let chanmon_cfgs = create_chanmon_cfgs(2);
+               let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+               let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None, None]);
+               let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+               create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1_000_000, 0, InitFeatures::known(), InitFeatures::known());
+               create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1_000_000, 0, InitFeatures::known(), InitFeatures::known());
+               let chans = nodes[0].node.list_usable_channels();
+               let mut route = Route {
+                       paths: vec![
+                               vec![RouteHop {
+                                       pubkey: nodes[1].node.get_our_node_id(),
+                                       node_features: NodeFeatures::known(),
+                                       short_channel_id: chans[0].short_channel_id.unwrap(),
+                                       channel_features: ChannelFeatures::known(),
+                                       fee_msat: 10_000,
+                                       cltv_expiry_delta: 100,
+                               }],
+                               vec![RouteHop {
+                                       pubkey: nodes[1].node.get_our_node_id(),
+                                       node_features: NodeFeatures::known(),
+                                       short_channel_id: chans[1].short_channel_id.unwrap(),
+                                       channel_features: ChannelFeatures::known(),
+                                       fee_msat: 100_000_001, // Our default max-HTLC-value is 10% of the channel value, which this is one more than
+                                       cltv_expiry_delta: 100,
+                               }],
+                       ],
+                       payee: Some(Payee::from_node_id(nodes[1].node.get_our_node_id())),
+               };
+               let router = ManualRouter(RefCell::new(VecDeque::new()));
+               router.expect_find_route(Ok(route.clone()));
+               // On retry, split the payment across both channels.
+               route.paths[0][0].fee_msat = 50_000_001;
+               route.paths[1][0].fee_msat = 50_000_000;
+               router.expect_find_route(Ok(route.clone()));
+
+               let event_handler = |_: &_| { panic!(); };
+               let scorer = RefCell::new(TestScorer::new());
+               let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1));
+
+               assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager(
+                       &nodes[1].node, nodes[1].keys_manager, Currency::Bitcoin, Some(100_010_000), "Invoice".to_string()).unwrap())
+                       .is_ok());
+               let htlc_msgs = nodes[0].node.get_and_clear_pending_msg_events();
+               assert_eq!(htlc_msgs.len(), 2);
+               check_added_monitors!(nodes[0], 2);
+       }
+
+       #[test]
+       fn immediate_retry_on_failure() {
+               // Tests that we can/will retry immediately after a failure
+               let chanmon_cfgs = create_chanmon_cfgs(2);
+               let node_cfgs = create_node_cfgs(2, &chanmon_cfgs);
+               let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None, None]);
+               let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
+
+               create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1_000_000, 0, InitFeatures::known(), InitFeatures::known());
+               create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1_000_000, 0, InitFeatures::known(), InitFeatures::known());
+               let chans = nodes[0].node.list_usable_channels();
+               let mut route = Route {
+                       paths: vec![
+                               vec![RouteHop {
+                                       pubkey: nodes[1].node.get_our_node_id(),
+                                       node_features: NodeFeatures::known(),
+                                       short_channel_id: chans[0].short_channel_id.unwrap(),
+                                       channel_features: ChannelFeatures::known(),
+                                       fee_msat: 100_000_001, // Our default max-HTLC-value is 10% of the channel value, which this is one more than
+                                       cltv_expiry_delta: 100,
+                               }],
+                       ],
+                       payee: Some(Payee::from_node_id(nodes[1].node.get_our_node_id())),
+               };
+               let router = ManualRouter(RefCell::new(VecDeque::new()));
+               router.expect_find_route(Ok(route.clone()));
+               // On retry, split the payment across both channels.
+               route.paths.push(route.paths[0].clone());
+               route.paths[0][0].short_channel_id = chans[1].short_channel_id.unwrap();
+               route.paths[0][0].fee_msat = 50_000_000;
+               route.paths[1][0].fee_msat = 50_000_001;
+               router.expect_find_route(Ok(route.clone()));
+
+               let event_handler = |_: &_| { panic!(); };
+               let scorer = RefCell::new(TestScorer::new());
+               let invoice_payer = InvoicePayer::new(nodes[0].node, router, &scorer, nodes[0].logger, event_handler, RetryAttempts(1));
+
+               assert!(invoice_payer.pay_invoice(&create_invoice_from_channelmanager(
+                       &nodes[1].node, nodes[1].keys_manager, Currency::Bitcoin, Some(100_010_000), "Invoice".to_string()).unwrap())
+                       .is_ok());
+               let htlc_msgs = nodes[0].node.get_and_clear_pending_msg_events();
+               assert_eq!(htlc_msgs.len(), 2);
+               check_added_monitors!(nodes[0], 2);
+       }
+}
index 31c1ab2550e4a5a53b079355ecd90f74f111a0a0..35a74b6a5ac6a3bf6a41984babc8619551a19622 100644 (file)
@@ -1,14 +1,21 @@
 //! Convenient utilities to create an invoice.
+
 use {Currency, DEFAULT_EXPIRY_TIME, Invoice, InvoiceBuilder, SignOrCreationError, RawInvoice};
+use payment::{Payer, Router};
+
 use bech32::ToBase32;
 use bitcoin_hashes::Hash;
 use lightning::chain;
 use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
 use lightning::chain::keysinterface::{Sign, KeysInterface};
-use lightning::ln::channelmanager::{ChannelManager, MIN_FINAL_CLTV_EXPIRY};
-use lightning::routing::network_graph::RoutingFees;
-use lightning::routing::router::{RouteHint, RouteHintHop};
+use lightning::ln::{PaymentHash, PaymentSecret};
+use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, PaymentSendFailure, MIN_FINAL_CLTV_EXPIRY};
+use lightning::ln::msgs::LightningError;
+use lightning::routing;
+use lightning::routing::network_graph::{NetworkGraph, RoutingFees};
+use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route};
 use lightning::util::logger::Logger;
+use secp256k1::key::PublicKey;
 use std::convert::TryInto;
 use std::ops::Deref;
 
@@ -89,6 +96,58 @@ where
        }
 }
 
+/// A [`Router`] implemented using [`find_route`].
+pub struct DefaultRouter<G, L: Deref> where G: Deref<Target = NetworkGraph>, L::Target: Logger {
+       network_graph: G,
+       logger: L,
+}
+
+impl<G, L: Deref> DefaultRouter<G, L> where G: Deref<Target = NetworkGraph>, L::Target: Logger {
+       /// Creates a new router using the given [`NetworkGraph`] and  [`Logger`].
+       pub fn new(network_graph: G, logger: L) -> Self {
+               Self { network_graph, logger }
+       }
+}
+
+impl<G, L: Deref, S: routing::Score> Router<S> for DefaultRouter<G, L>
+where G: Deref<Target = NetworkGraph>, L::Target: Logger {
+       fn find_route(
+               &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+               scorer: &S
+       ) -> Result<Route, LightningError> {
+               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, scorer)
+       }
+}
+
+impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Payer for ChannelManager<Signer, M, T, K, F, L>
+where
+       M::Target: chain::Watch<Signer>,
+       T::Target: BroadcasterInterface,
+       K::Target: KeysInterface<Signer = Signer>,
+       F::Target: FeeEstimator,
+       L::Target: Logger,
+{
+       fn node_id(&self) -> PublicKey {
+               self.get_our_node_id()
+       }
+
+       fn first_hops(&self) -> Vec<ChannelDetails> {
+               self.list_usable_channels()
+       }
+
+       fn send_payment(
+               &self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>
+       ) -> Result<PaymentId, PaymentSendFailure> {
+               self.send_payment(route, payment_hash, payment_secret)
+       }
+
+       fn retry_payment(
+               &self, route: &Route, payment_id: PaymentId
+       ) -> Result<(), PaymentSendFailure> {
+               self.retry_payment(route, payment_id)
+       }
+}
+
 #[cfg(test)]
 mod test {
        use {Currency, Description, InvoiceDescription};
@@ -97,8 +156,7 @@ mod test {
        use lightning::ln::functional_test_utils::*;
        use lightning::ln::features::InitFeatures;
        use lightning::ln::msgs::ChannelMessageHandler;
-       use lightning::routing::router;
-       use lightning::routing::scorer::Scorer;
+       use lightning::routing::router::{Payee, RouteParameters, find_route};
        use lightning::util::events::MessageSendEventsProvider;
        use lightning::util::test_utils;
        #[test]
@@ -113,23 +171,21 @@ mod test {
                assert_eq!(invoice.min_final_cltv_expiry(), MIN_FINAL_CLTV_EXPIRY as u64);
                assert_eq!(invoice.description(), InvoiceDescription::Direct(&Description("test".to_string())));
 
-               let amt_msat = invoice.amount_pico_btc().unwrap() / 10;
+               let payee = Payee::from_node_id(invoice.recover_payee_pub_key())
+                       .with_features(invoice.features().unwrap().clone())
+                       .with_route_hints(invoice.route_hints());
+               let params = RouteParameters {
+                       payee,
+                       final_value_msat: invoice.amount_milli_satoshis().unwrap(),
+                       final_cltv_expiry_delta: invoice.min_final_cltv_expiry() as u32,
+               };
                let first_hops = nodes[0].node.list_usable_channels();
-               let last_hops = invoice.route_hints();
-               let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
+               let network_graph = node_cfgs[0].network_graph;
                let logger = test_utils::TestLogger::new();
-               let scorer = Scorer::new(0);
-               let route = router::get_route(
-                       &nodes[0].node.get_our_node_id(),
-                       network_graph,
-                       &invoice.recover_payee_pub_key(),
-                       Some(invoice.features().unwrap().clone()),
-                       Some(&first_hops.iter().collect::<Vec<_>>()),
-                       &last_hops,
-                       amt_msat,
-                       invoice.min_final_cltv_expiry() as u32,
-                       &logger,
-                       &scorer,
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let route = find_route(
+                       &nodes[0].node.get_our_node_id(), &params, network_graph,
+                       Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer,
                ).unwrap();
 
                let payment_event = {
index 3d9a31f182b83d3ed9b0f6ed5d6223290acaf016..ad3dc008e6a69c1874e532a6b311ca202dbb8a1c 100644 (file)
@@ -1,6 +1,6 @@
 [package]
 name = "lightning-net-tokio"
-version = "0.0.102"
+version = "0.0.103"
 authors = ["Matt Corallo"]
 license = "MIT OR Apache-2.0"
 repository = "https://github.com/rust-bitcoin/rust-lightning/"
@@ -12,7 +12,7 @@ edition = "2018"
 
 [dependencies]
 bitcoin = "0.27"
-lightning = { version = "0.0.102", path = "../lightning" }
+lightning = { version = "0.0.103", path = "../lightning" }
 tokio = { version = "1.0", features = [ "io-util", "macros", "rt", "sync", "net", "time" ] }
 
 [dev-dependencies]
index 9e072dce78f89a44046e6530580256e349392c0c..ef8631059c986fb3d11468e0098391318492a964 100644 (file)
@@ -1,6 +1,6 @@
 [package]
 name = "lightning-persister"
-version = "0.0.102"
+version = "0.0.103"
 authors = ["Valentine Wallace", "Matt Corallo"]
 license = "MIT OR Apache-2.0"
 repository = "https://github.com/rust-bitcoin/rust-lightning/"
@@ -13,11 +13,11 @@ unstable = ["lightning/unstable"]
 
 [dependencies]
 bitcoin = "0.27"
-lightning = { version = "0.0.102", path = "../lightning" }
+lightning = { version = "0.0.103", path = "../lightning" }
 libc = "0.2"
 
 [target.'cfg(windows)'.dependencies]
 winapi = { version = "0.3", features = ["winbase"] }
 
 [dev-dependencies]
-lightning = { version = "0.0.102", path = "../lightning", features = ["_test_utils"] }
+lightning = { version = "0.0.103", path = "../lightning", features = ["_test_utils"] }
index 271a3d662abcf459a5c977cfc9b03bfd19cc1ead..8202219bf23c9f8a2bb45b467b0f3b5ea4b73f9b 100644 (file)
@@ -1,6 +1,6 @@
 [package]
 name = "lightning"
-version = "0.0.102"
+version = "0.0.103"
 authors = ["Matt Corallo"]
 license = "MIT OR Apache-2.0"
 repository = "https://github.com/rust-bitcoin/rust-lightning/"
index 0b1c16d35779948316b733ac1f431f12286533f3..0ae3ff29d87ed82e5c26beecc225c25401658fad 100644 (file)
@@ -304,7 +304,7 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) {
                                let events_3 = nodes[0].node.get_and_clear_pending_events();
                                assert_eq!(events_3.len(), 1);
                                match events_3[0] {
-                                       Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+                                       Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                                                assert_eq!(*payment_preimage, payment_preimage_1);
                                                assert_eq!(*payment_hash, payment_hash_1);
                                        },
@@ -397,7 +397,7 @@ fn do_test_monitor_temporary_update_fail(disconnect_count: usize) {
                        let events_3 = nodes[0].node.get_and_clear_pending_events();
                        assert_eq!(events_3.len(), 1);
                        match events_3[0] {
-                               Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+                               Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                                        assert_eq!(*payment_preimage, payment_preimage_1);
                                        assert_eq!(*payment_hash, payment_hash_1);
                                },
@@ -1399,7 +1399,7 @@ fn claim_while_disconnected_monitor_update_fail() {
        let events = nodes[0].node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
        match events[0] {
-               Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+               Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                        assert_eq!(*payment_preimage, payment_preimage_1);
                        assert_eq!(*payment_hash, payment_hash_1);
                },
@@ -1806,7 +1806,7 @@ fn monitor_update_claim_fail_no_response() {
        let events = nodes[0].node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
        match events[0] {
-               Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+               Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                        assert_eq!(*payment_preimage, payment_preimage_1);
                        assert_eq!(*payment_hash, payment_hash_1);
                },
@@ -1950,7 +1950,7 @@ fn test_path_paused_mpp() {
        // Now check that we get the right return value, indicating that the first path succeeded but
        // the second got a MonitorUpdateFailed err. This implies PaymentSendFailure::PartialFailure as
        // some paths succeeded, preventing retry.
-       if let Err(PaymentSendFailure::PartialFailure(results)) = nodes[0].node.send_payment(&route, payment_hash, &Some(payment_secret)) {
+       if let Err(PaymentSendFailure::PartialFailure { results, ..}) = nodes[0].node.send_payment(&route, payment_hash, &Some(payment_secret)) {
                assert_eq!(results.len(), 2);
                if let Ok(()) = results[0] {} else { panic!(); }
                if let Err(APIError::MonitorUpdateFailed) = results[1] {} else { panic!(); }
index 68bcf89985199954cf28bf284b2de2421e6d2de0..e39478bd740c063ce1f29a7dc853238ec75fcf48 100644 (file)
@@ -5810,6 +5810,7 @@ mod tests {
                                first_hop_htlc_msat: 548,
                                payment_id: PaymentId([42; 32]),
                                payment_secret: None,
+                               payee: None,
                        }
                });
 
index 9e4708dcfd6a6dbeb4e676a5e7f495d09cac7c97..a5f84abfdb3e2865505db43149fbdebde96ee71c 100644 (file)
@@ -45,7 +45,7 @@ use chain::transaction::{OutPoint, TransactionData};
 use ln::{PaymentHash, PaymentPreimage, PaymentSecret};
 use ln::channel::{Channel, ChannelError, ChannelUpdateStatus, UpdateFulfillCommitFetch};
 use ln::features::{InitFeatures, NodeFeatures};
-use routing::router::{Route, RouteHop};
+use routing::router::{Payee, Route, RouteHop, RoutePath, RouteParameters};
 use ln::msgs;
 use ln::msgs::NetAddress;
 use ln::onion_utils;
@@ -173,6 +173,7 @@ struct ClaimableHTLC {
 }
 
 /// A payment identifier used to uniquely identify a payment to LDK.
+/// (C-not exported) as we just use [u8; 32] directly
 #[derive(Hash, Copy, Clone, PartialEq, Eq, Debug)]
 pub struct PaymentId(pub [u8; 32]);
 
@@ -201,6 +202,7 @@ pub(crate) enum HTLCSource {
                first_hop_htlc_msat: u64,
                payment_id: PaymentId,
                payment_secret: Option<PaymentSecret>,
+               payee: Option<Payee>,
        },
 }
 #[allow(clippy::derive_hash_xor_eq)] // Our Hash is faithful to the data, we just don't have SecretKey::hash
@@ -211,13 +213,14 @@ impl core::hash::Hash for HTLCSource {
                                0u8.hash(hasher);
                                prev_hop_data.hash(hasher);
                        },
-                       HTLCSource::OutboundRoute { path, session_priv, payment_id, payment_secret, first_hop_htlc_msat } => {
+                       HTLCSource::OutboundRoute { path, session_priv, payment_id, payment_secret, first_hop_htlc_msat, payee } => {
                                1u8.hash(hasher);
                                path.hash(hasher);
                                session_priv[..].hash(hasher);
                                payment_id.hash(hasher);
                                payment_secret.hash(hasher);
                                first_hop_htlc_msat.hash(hasher);
+                               payee.hash(hasher);
                        },
                }
        }
@@ -231,6 +234,7 @@ impl HTLCSource {
                        first_hop_htlc_msat: 0,
                        payment_id: PaymentId([2; 32]),
                        payment_secret: None,
+                       payee: None,
                }
        }
 }
@@ -433,6 +437,8 @@ pub(crate) enum PendingOutboundPayment {
                payment_hash: PaymentHash,
                payment_secret: Option<PaymentSecret>,
                pending_amt_msat: u64,
+               /// Used to track the fee paid. Only present if the payment was serialized on 0.0.103+.
+               pending_fee_msat: Option<u64>,
                /// The total payment amount across all paths, used to verify that a retry is not overpaying.
                total_msat: u64,
                /// Our best known block height at the time this payment was initiated.
@@ -459,6 +465,12 @@ impl PendingOutboundPayment {
                        _ => false,
                }
        }
+       fn get_pending_fee_msat(&self) -> Option<u64> {
+               match self {
+                       PendingOutboundPayment::Retryable { pending_fee_msat, .. } => pending_fee_msat.clone(),
+                       _ => None,
+               }
+       }
 
        fn mark_fulfilled(&mut self) {
                let mut session_privs = HashSet::new();
@@ -471,8 +483,8 @@ impl PendingOutboundPayment {
                *self = PendingOutboundPayment::Fulfilled { session_privs };
        }
 
-       /// panics if part_amt_msat is None and !self.is_fulfilled
-       fn remove(&mut self, session_priv: &[u8; 32], part_amt_msat: Option<u64>) -> bool {
+       /// panics if path is None and !self.is_fulfilled
+       fn remove(&mut self, session_priv: &[u8; 32], path: Option<&Vec<RouteHop>>) -> bool {
                let remove_res = match self {
                        PendingOutboundPayment::Legacy { session_privs } |
                        PendingOutboundPayment::Retryable { session_privs, .. } |
@@ -481,14 +493,19 @@ impl PendingOutboundPayment {
                        }
                };
                if remove_res {
-                       if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, .. } = self {
-                               *pending_amt_msat -= part_amt_msat.expect("We must only not provide an amount if the payment was already fulfilled");
+                       if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, ref mut pending_fee_msat, .. } = self {
+                               let path = path.expect("Fulfilling a payment should always come with a path");
+                               let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
+                               *pending_amt_msat -= path_last_hop.fee_msat;
+                               if let Some(fee_msat) = pending_fee_msat.as_mut() {
+                                       *fee_msat -= path.get_path_fees();
+                               }
                        }
                }
                remove_res
        }
 
-       fn insert(&mut self, session_priv: [u8; 32], part_amt_msat: u64) -> bool {
+       fn insert(&mut self, session_priv: [u8; 32], path: &Vec<RouteHop>) -> bool {
                let insert_res = match self {
                        PendingOutboundPayment::Legacy { session_privs } |
                        PendingOutboundPayment::Retryable { session_privs, .. } => {
@@ -497,8 +514,12 @@ impl PendingOutboundPayment {
                        PendingOutboundPayment::Fulfilled { .. } => false
                };
                if insert_res {
-                       if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, .. } = self {
-                               *pending_amt_msat += part_amt_msat;
+                       if let PendingOutboundPayment::Retryable { ref mut pending_amt_msat, ref mut pending_fee_msat, .. } = self {
+                               let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
+                               *pending_amt_msat += path_last_hop.fee_msat;
+                               if let Some(fee_msat) = pending_fee_msat.as_mut() {
+                                       *fee_msat += path.get_path_fees();
+                               }
                        }
                }
                insert_res
@@ -925,7 +946,16 @@ pub enum PaymentSendFailure {
        /// as they will result in over-/re-payment. These HTLCs all either successfully sent (in the
        /// case of Ok(())) or will send once channel_monitor_updated is called on the next-hop channel
        /// with the latest update_id.
-       PartialFailure(Vec<Result<(), APIError>>),
+       PartialFailure {
+               /// The errors themselves, in the same order as the route hops.
+               results: Vec<Result<(), APIError>>,
+               /// If some paths failed without irrevocably committing to the new HTLC(s), this will
+               /// contain a [`RouteParameters`] object which can be used to calculate a new route that
+               /// will pay all remaining unpaid balance.
+               failed_paths_retry: Option<RouteParameters>,
+               /// The payment id for the payment, which is now at least partially pending.
+               payment_id: PaymentId,
+       },
 }
 
 macro_rules! handle_error {
@@ -1111,7 +1141,7 @@ macro_rules! handle_monitor_err {
                res
        } };
        ($self: ident, $err: expr, $channel_state: expr, $entry: expr, $action_type: path, $resend_raa: expr, $resend_commitment: expr, $failed_forwards: expr, $failed_fails: expr) => {
-               handle_monitor_err!($self, $err, $channel_state, $entry, $action_type, $resend_raa, $resend_commitment, $failed_forwards, $failed_fails, Vec::new());
+               handle_monitor_err!($self, $err, $channel_state, $entry, $action_type, $resend_raa, $resend_commitment, $failed_forwards, $failed_fails, Vec::new())
        }
 }
 
@@ -2021,7 +2051,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        }
 
        // Only public for testing, this should otherwise never be called direcly
-       pub(crate) fn send_payment_along_path(&self, path: &Vec<RouteHop>, payment_hash: &PaymentHash, payment_secret: &Option<PaymentSecret>, total_value: u64, cur_height: u32, payment_id: PaymentId, keysend_preimage: &Option<PaymentPreimage>) -> Result<(), APIError> {
+       pub(crate) fn send_payment_along_path(&self, path: &Vec<RouteHop>, payee: &Option<Payee>, payment_hash: &PaymentHash, payment_secret: &Option<PaymentSecret>, total_value: u64, cur_height: u32, payment_id: PaymentId, keysend_preimage: &Option<PaymentPreimage>) -> Result<(), APIError> {
                log_trace!(self.logger, "Attempting to send payment for path with next hop {}", path.first().unwrap().short_channel_id);
                let prng_seed = self.keys_manager.get_secure_random_bytes();
                let session_priv_bytes = self.keys_manager.get_secure_random_bytes();
@@ -2071,18 +2101,20 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                        first_hop_htlc_msat: htlc_msat,
                                                        payment_id,
                                                        payment_secret: payment_secret.clone(),
+                                                       payee: payee.clone(),
                                                }, onion_packet, &self.logger),
                                        channel_state, chan);
 
                                        let payment = payment_entry.or_insert_with(|| PendingOutboundPayment::Retryable {
                                                session_privs: HashSet::new(),
                                                pending_amt_msat: 0,
+                                               pending_fee_msat: Some(0),
                                                payment_hash: *payment_hash,
                                                payment_secret: *payment_secret,
                                                starting_block_height: self.best_block.read().unwrap().height(),
                                                total_msat: total_value,
                                        });
-                                       assert!(payment.insert(session_priv_bytes, path.last().unwrap().fee_msat));
+                                       assert!(payment.insert(session_priv_bytes, path));
 
                                        send_res
                                } {
@@ -2209,11 +2241,13 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let cur_height = self.best_block.read().unwrap().height() + 1;
                let mut results = Vec::new();
                for path in route.paths.iter() {
-                       results.push(self.send_payment_along_path(&path, &payment_hash, payment_secret, total_value, cur_height, payment_id, &keysend_preimage));
+                       results.push(self.send_payment_along_path(&path, &route.payee, &payment_hash, payment_secret, total_value, cur_height, payment_id, &keysend_preimage));
                }
                let mut has_ok = false;
                let mut has_err = false;
-               for res in results.iter() {
+               let mut pending_amt_unsent = 0;
+               let mut max_unsent_cltv_delta = 0;
+               for (res, path) in results.iter().zip(route.paths.iter()) {
                        if res.is_ok() { has_ok = true; }
                        if res.is_err() { has_err = true; }
                        if let &Err(APIError::MonitorUpdateFailed) = res {
@@ -2221,11 +2255,25 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                // PartialFailure.
                                has_err = true;
                                has_ok = true;
-                               break;
+                       } else if res.is_err() {
+                               pending_amt_unsent += path.last().unwrap().fee_msat;
+                               max_unsent_cltv_delta = cmp::max(max_unsent_cltv_delta, path.last().unwrap().cltv_expiry_delta);
                        }
                }
                if has_err && has_ok {
-                       Err(PaymentSendFailure::PartialFailure(results))
+                       Err(PaymentSendFailure::PartialFailure {
+                               results,
+                               payment_id,
+                               failed_paths_retry: if pending_amt_unsent != 0 {
+                                       if let Some(payee) = &route.payee {
+                                               Some(RouteParameters {
+                                                       payee: payee.clone(),
+                                                       final_value_msat: pending_amt_unsent,
+                                                       final_cltv_expiry_delta: max_unsent_cltv_delta,
+                                               })
+                                       } else { None }
+                               } else { None },
+                       })
                } else if has_err {
                        Err(PaymentSendFailure::AllFailedRetrySafe(results.drain(..).map(|r| r.unwrap_err()).collect()))
                } else {
@@ -3098,22 +3146,30 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        self.fail_htlc_backwards_internal(channel_state,
                                                htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data});
                                },
-                               HTLCSource::OutboundRoute { session_priv, payment_id, path, .. } => {
+                               HTLCSource::OutboundRoute { session_priv, payment_id, path, payee, .. } => {
                                        let mut session_priv_bytes = [0; 32];
                                        session_priv_bytes.copy_from_slice(&session_priv[..]);
                                        let mut outbounds = self.pending_outbound_payments.lock().unwrap();
                                        if let hash_map::Entry::Occupied(mut payment) = outbounds.entry(payment_id) {
-                                               if payment.get_mut().remove(&session_priv_bytes, Some(path.last().unwrap().fee_msat)) &&
-                                                       !payment.get().is_fulfilled()
-                                               {
+                                               if payment.get_mut().remove(&session_priv_bytes, Some(&path)) && !payment.get().is_fulfilled() {
+                                                       let retry = if let Some(payee_data) = payee {
+                                                               let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
+                                                               Some(RouteParameters {
+                                                                       payee: payee_data,
+                                                                       final_value_msat: path_last_hop.fee_msat,
+                                                                       final_cltv_expiry_delta: path_last_hop.cltv_expiry_delta,
+                                                               })
+                                                       } else { None };
                                                        self.pending_events.lock().unwrap().push(
                                                                events::Event::PaymentPathFailed {
+                                                                       payment_id: Some(payment_id),
                                                                        payment_hash,
                                                                        rejected_by_dest: false,
                                                                        network_update: None,
                                                                        all_paths_failed: payment.get().remaining_parts() == 0,
                                                                        path: path.clone(),
                                                                        short_channel_id: None,
+                                                                       retry,
                                                                        #[cfg(test)]
                                                                        error_code: None,
                                                                        #[cfg(test)]
@@ -3145,13 +3201,13 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                // from block_connected which may run during initialization prior to the chain_monitor
                // being fully configured. See the docs for `ChannelManagerReadArgs` for more.
                match source {
-                       HTLCSource::OutboundRoute { ref path, session_priv, payment_id, .. } => {
+                       HTLCSource::OutboundRoute { ref path, session_priv, payment_id, ref payee, .. } => {
                                let mut session_priv_bytes = [0; 32];
                                session_priv_bytes.copy_from_slice(&session_priv[..]);
                                let mut outbounds = self.pending_outbound_payments.lock().unwrap();
                                let mut all_paths_failed = false;
                                if let hash_map::Entry::Occupied(mut payment) = outbounds.entry(payment_id) {
-                                       if !payment.get_mut().remove(&session_priv_bytes, Some(path.last().unwrap().fee_msat)) {
+                                       if !payment.get_mut().remove(&session_priv_bytes, Some(&path)) {
                                                log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                                return;
                                        }
@@ -3166,8 +3222,16 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                        return;
                                }
-                               log_trace!(self.logger, "Failing outbound payment HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                mem::drop(channel_state_lock);
+                               let retry = if let Some(payee_data) = payee {
+                                       let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
+                                       Some(RouteParameters {
+                                               payee: payee_data.clone(),
+                                               final_value_msat: path_last_hop.fee_msat,
+                                               final_cltv_expiry_delta: path_last_hop.cltv_expiry_delta,
+                                       })
+                               } else { None };
+                               log_trace!(self.logger, "Failing outbound payment HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                match &onion_error {
                                        &HTLCFailReason::LightningError { ref err } => {
 #[cfg(test)]
@@ -3179,12 +3243,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                // next-hop is needlessly blaming us!
                                                self.pending_events.lock().unwrap().push(
                                                        events::Event::PaymentPathFailed {
+                                                               payment_id: Some(payment_id),
                                                                payment_hash: payment_hash.clone(),
                                                                rejected_by_dest: !payment_retryable,
                                                                network_update,
                                                                all_paths_failed,
                                                                path: path.clone(),
                                                                short_channel_id,
+                                                               retry,
 #[cfg(test)]
                                                                error_code: onion_error_code,
 #[cfg(test)]
@@ -3207,12 +3273,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                // channel here as we apparently can't relay through them anyway.
                                                self.pending_events.lock().unwrap().push(
                                                        events::Event::PaymentPathFailed {
+                                                               payment_id: Some(payment_id),
                                                                payment_hash: payment_hash.clone(),
                                                                rejected_by_dest: path.len() == 1,
                                                                network_update: None,
                                                                all_paths_failed,
                                                                path: path.clone(),
                                                                short_channel_id: Some(path.first().unwrap().short_channel_id),
+                                                               retry,
 #[cfg(test)]
                                                                error_code: Some(*failure_code),
 #[cfg(test)]
@@ -3431,8 +3499,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                let mut session_priv_bytes = [0; 32];
                                session_priv_bytes.copy_from_slice(&session_priv[..]);
                                let mut outbounds = self.pending_outbound_payments.lock().unwrap();
-                               let found_payment = if let hash_map::Entry::Occupied(mut payment) = outbounds.entry(payment_id) {
+                               if let hash_map::Entry::Occupied(mut payment) = outbounds.entry(payment_id) {
                                        let found_payment = !payment.get().is_fulfilled();
+                                       let fee_paid_msat = payment.get().get_pending_fee_msat();
                                        payment.get_mut().mark_fulfilled();
                                        if from_onchain {
                                                // We currently immediately remove HTLCs which were fulfilled on-chain.
@@ -3441,21 +3510,22 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                // restart.
                                                // TODO: We should have a second monitor event that informs us of payments
                                                // irrevocably fulfilled.
-                                               payment.get_mut().remove(&session_priv_bytes, Some(path.last().unwrap().fee_msat));
+                                               payment.get_mut().remove(&session_priv_bytes, Some(&path));
                                                if payment.get().remaining_parts() == 0 {
                                                        payment.remove();
                                                }
                                        }
-                                       found_payment
-                               } else { false };
-                               if found_payment {
-                                       let payment_hash = PaymentHash(Sha256::hash(&payment_preimage.0).into_inner());
-                                       self.pending_events.lock().unwrap().push(
-                                               events::Event::PaymentSent {
-                                                       payment_preimage,
-                                                       payment_hash: payment_hash
-                                               }
-                                       );
+                                       if found_payment {
+                                               let payment_hash = PaymentHash(Sha256::hash(&payment_preimage.0).into_inner());
+                                               self.pending_events.lock().unwrap().push(
+                                                       events::Event::PaymentSent {
+                                                               payment_id: Some(payment_id),
+                                                               payment_preimage,
+                                                               payment_hash: payment_hash,
+                                                               fee_paid_msat,
+                                                       }
+                                               );
+                                       }
                                } else {
                                        log_trace!(self.logger, "Received duplicative fulfill for HTLC with payment_preimage {}", log_bytes!(payment_preimage.0));
                                }
@@ -5375,12 +5445,14 @@ impl Readable for HTLCSource {
                                let mut path = Some(Vec::new());
                                let mut payment_id = None;
                                let mut payment_secret = None;
+                               let mut payee = None;
                                read_tlv_fields!(reader, {
                                        (0, session_priv, required),
                                        (1, payment_id, option),
                                        (2, first_hop_htlc_msat, required),
                                        (3, payment_secret, option),
                                        (4, path, vec_type),
+                                       (5, payee, option),
                                });
                                if payment_id.is_none() {
                                        // For backwards compat, if there was no payment_id written, use the session_priv bytes
@@ -5393,6 +5465,7 @@ impl Readable for HTLCSource {
                                        path: path.unwrap(),
                                        payment_id: payment_id.unwrap(),
                                        payment_secret,
+                                       payee,
                                })
                        }
                        1 => Ok(HTLCSource::PreviousHopData(Readable::read(reader)?)),
@@ -5404,7 +5477,7 @@ impl Readable for HTLCSource {
 impl Writeable for HTLCSource {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::io::Error> {
                match self {
-                       HTLCSource::OutboundRoute { ref session_priv, ref first_hop_htlc_msat, ref path, payment_id, payment_secret } => {
+                       HTLCSource::OutboundRoute { ref session_priv, ref first_hop_htlc_msat, ref path, payment_id, payment_secret, payee } => {
                                0u8.write(writer)?;
                                let payment_id_opt = Some(payment_id);
                                write_tlv_fields!(writer, {
@@ -5413,6 +5486,7 @@ impl Writeable for HTLCSource {
                                        (2, first_hop_htlc_msat, required),
                                        (3, payment_secret, option),
                                        (4, path, vec_type),
+                                       (5, payee, option),
                                 });
                        }
                        HTLCSource::PreviousHopData(ref field) => {
@@ -5464,6 +5538,7 @@ impl_writeable_tlv_based_enum_upgradable!(PendingOutboundPayment,
        },
        (2, Retryable) => {
                (0, session_privs, required),
+               (1, pending_fee_msat, option),
                (2, payment_hash, required),
                (4, payment_secret, option),
                (6, total_msat, required),
@@ -5920,16 +5995,18 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                                                        session_priv_bytes[..].copy_from_slice(&session_priv[..]);
                                                        match pending_outbound_payments.as_mut().unwrap().entry(payment_id) {
                                                                hash_map::Entry::Occupied(mut entry) => {
-                                                                       let newly_added = entry.get_mut().insert(session_priv_bytes, path_amt);
+                                                                       let newly_added = entry.get_mut().insert(session_priv_bytes, &path);
                                                                        log_info!(args.logger, "{} a pending payment path for {} msat for session priv {} on an existing pending payment with payment hash {}",
                                                                                if newly_added { "Added" } else { "Had" }, path_amt, log_bytes!(session_priv_bytes), log_bytes!(htlc.payment_hash.0));
                                                                },
                                                                hash_map::Entry::Vacant(entry) => {
+                                                                       let path_fee = path.get_path_fees();
                                                                        entry.insert(PendingOutboundPayment::Retryable {
                                                                                session_privs: [session_priv_bytes].iter().map(|a| *a).collect(),
                                                                                payment_hash: htlc.payment_hash,
                                                                                payment_secret,
                                                                                pending_amt_msat: path_amt,
+                                                                               pending_fee_msat: Some(path_fee),
                                                                                total_msat: path_amt,
                                                                                starting_block_height: best_block_height,
                                                                        });
@@ -6005,12 +6082,11 @@ mod tests {
        use core::time::Duration;
        use ln::{PaymentPreimage, PaymentHash, PaymentSecret};
        use ln::channelmanager::{PaymentId, PaymentSendFailure};
-       use ln::features::{InitFeatures, InvoiceFeatures};
+       use ln::features::InitFeatures;
        use ln::functional_test_utils::*;
        use ln::msgs;
        use ln::msgs::ChannelMessageHandler;
-       use routing::router::{get_keysend_route, get_route};
-       use routing::scorer::Scorer;
+       use routing::router::{Payee, RouteParameters, find_route};
        use util::errors::APIError;
        use util::events::{Event, MessageSendEvent, MessageSendEventsProvider};
        use util::test_utils;
@@ -6157,7 +6233,7 @@ mod tests {
                // Use the utility function send_payment_along_path to send the payment with MPP data which
                // indicates there are more HTLCs coming.
                let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match.
-               nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None).unwrap();
+               nodes[0].node.send_payment_along_path(&route.paths[0], &route.payee, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
@@ -6187,7 +6263,7 @@ mod tests {
                expect_payment_failed!(nodes[0], our_payment_hash, true);
 
                // Send the second half of the original MPP payment.
-               nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None).unwrap();
+               nodes[0].node.send_payment_along_path(&route.paths[0], &route.payee, &our_payment_hash, &Some(payment_secret), 200_000, cur_height, payment_id, &None).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
@@ -6229,7 +6305,8 @@ mod tests {
                // further events will be generated for subsequence path successes.
                let events = nodes[0].node.get_and_clear_pending_events();
                match events[0] {
-                       Event::PaymentSent { payment_preimage: ref preimage, payment_hash: ref hash } => {
+                       Event::PaymentSent { payment_id: ref id, payment_preimage: ref preimage, payment_hash: ref hash, .. } => {
+                               assert_eq!(Some(payment_id), *id);
                                assert_eq!(payment_preimage, *preimage);
                                assert_eq!(our_payment_hash, *hash);
                        },
@@ -6248,15 +6325,22 @@ mod tests {
                let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]);
                let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
                create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
-               let logger = test_utils::TestLogger::new();
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // To start (1), send a regular payment but don't claim it.
                let expected_route = [&nodes[1]];
                let (payment_preimage, payment_hash, _) = route_payment(&nodes[0], &expected_route, 100_000);
 
                // Next, attempt a keysend payment and make sure it fails.
-               let route = get_route(&nodes[0].node.get_our_node_id(), &nodes[0].net_graph_msg_handler.network_graph, &expected_route.last().unwrap().node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), 100_000, TEST_FINAL_CLTV, &logger, &scorer).unwrap();
+               let params = RouteParameters {
+                       payee: Payee::for_keysend(expected_route.last().unwrap().node.get_our_node_id()),
+                       final_value_msat: 100_000,
+                       final_cltv_expiry_delta: TEST_FINAL_CLTV,
+               };
+               let route = find_route(
+                       &nodes[0].node.get_our_node_id(), &params, nodes[0].network_graph, None,
+                       nodes[0].logger, &scorer
+               ).unwrap();
                nodes[0].node.send_spontaneous_payment(&route, Some(payment_preimage)).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
@@ -6284,7 +6368,10 @@ mod tests {
 
                // To start (2), send a keysend payment but don't claim it.
                let payment_preimage = PaymentPreimage([42; 32]);
-               let route = get_route(&nodes[0].node.get_our_node_id(), &nodes[0].net_graph_msg_handler.network_graph, &expected_route.last().unwrap().node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), 100_000, TEST_FINAL_CLTV, &logger, &scorer).unwrap();
+               let route = find_route(
+                       &nodes[0].node.get_our_node_id(), &params, nodes[0].network_graph, None,
+                       nodes[0].logger, &scorer
+               ).unwrap();
                let (payment_hash, _) = nodes[0].node.send_spontaneous_payment(&route, Some(payment_preimage)).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
@@ -6336,12 +6423,18 @@ mod tests {
                nodes[1].node.peer_connected(&payer_pubkey, &msgs::Init { features: InitFeatures::known() });
 
                let _chan = create_chan_between_nodes(&nodes[0], &nodes[1], InitFeatures::known(), InitFeatures::known());
-               let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
+               let params = RouteParameters {
+                       payee: Payee::for_keysend(payee_pubkey),
+                       final_value_msat: 10000,
+                       final_cltv_expiry_delta: 40,
+               };
+               let network_graph = nodes[0].network_graph;
                let first_hops = nodes[0].node.list_usable_channels();
-               let scorer = Scorer::new(0);
-               let route = get_keysend_route(&payer_pubkey, network_graph, &payee_pubkey,
-                                  Some(&first_hops.iter().collect::<Vec<_>>()), &vec![], 10000, 40,
-                                  nodes[0].logger, &scorer).unwrap();
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let route = find_route(
+                       &payer_pubkey, &params, network_graph, Some(&first_hops.iter().collect::<Vec<_>>()),
+                       nodes[0].logger, &scorer
+               ).unwrap();
 
                let test_preimage = PaymentPreimage([42; 32]);
                let mismatch_payment_hash = PaymentHash([43; 32]);
@@ -6373,12 +6466,18 @@ mod tests {
                nodes[1].node.peer_connected(&payer_pubkey, &msgs::Init { features: InitFeatures::known() });
 
                let _chan = create_chan_between_nodes(&nodes[0], &nodes[1], InitFeatures::known(), InitFeatures::known());
-               let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
+               let params = RouteParameters {
+                       payee: Payee::for_keysend(payee_pubkey),
+                       final_value_msat: 10000,
+                       final_cltv_expiry_delta: 40,
+               };
+               let network_graph = nodes[0].network_graph;
                let first_hops = nodes[0].node.list_usable_channels();
-               let scorer = Scorer::new(0);
-               let route = get_keysend_route(&payer_pubkey, network_graph, &payee_pubkey,
-                                  Some(&first_hops.iter().collect::<Vec<_>>()), &vec![], 10000, 40,
-                                  nodes[0].logger, &scorer).unwrap();
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let route = find_route(
+                       &payer_pubkey, &params, network_graph, Some(&first_hops.iter().collect::<Vec<_>>()),
+                       nodes[0].logger, &scorer
+               ).unwrap();
 
                let test_preimage = PaymentPreimage([42; 32]);
                let test_secret = PaymentSecret([43; 32]);
@@ -6438,7 +6537,7 @@ pub mod bench {
        use ln::functional_test_utils::*;
        use ln::msgs::{ChannelMessageHandler, Init};
        use routing::network_graph::NetworkGraph;
-       use routing::router::get_route;
+       use routing::router::{Payee, get_route};
        use routing::scorer::Scorer;
        use util::test_utils;
        use util::config::UserConfig;
@@ -6547,9 +6646,11 @@ pub mod bench {
                macro_rules! send_payment {
                        ($node_a: expr, $node_b: expr) => {
                                let usable_channels = $node_a.list_usable_channels();
-                               let scorer = Scorer::new(0);
-                               let route = get_route(&$node_a.get_our_node_id(), &dummy_graph, &$node_b.get_our_node_id(), Some(InvoiceFeatures::known()),
-                                       Some(&usable_channels.iter().map(|r| r).collect::<Vec<_>>()), &[], 10_000, TEST_FINAL_CLTV, &logger_a, &scorer).unwrap();
+                               let payee = Payee::from_node_id($node_b.get_our_node_id())
+                                       .with_features(InvoiceFeatures::known());
+                               let scorer = Scorer::with_fixed_penalty(0);
+                               let route = get_route(&$node_a.get_our_node_id(), &payee, &dummy_graph,
+                                       Some(&usable_channels.iter().map(|r| r).collect::<Vec<_>>()), 10_000, TEST_FINAL_CLTV, &logger_a, &scorer).unwrap();
 
                                let mut payment_preimage = PaymentPreimage([0; 32]);
                                payment_preimage.0[0..8].copy_from_slice(&payment_count.to_le_bytes());
index 46a296001ef2411a45407fe03cc8e04d786a4312..32ba9de758632036942318c756f64708793e1b1a 100644 (file)
@@ -468,11 +468,11 @@ impl InvoiceFeatures {
        /// Getting a route for a keysend payment to a private node requires providing the payee's
        /// features (since they were not announced in a node announcement). However, keysend payments
        /// don't have an invoice to pull the payee's features from, so this method is provided for use in
-       /// [`get_keysend_route`], thus omitting the need for payers to manually construct an
-       /// `InvoiceFeatures` for [`get_route`].
+       /// [`Payee::for_keysend`], thus omitting the need for payers to manually construct an
+       /// `InvoiceFeatures` for [`find_route`].
        ///
-       /// [`get_keysend_route`]: crate::routing::router::get_keysend_route
-       /// [`get_route`]: crate::routing::router::get_route
+       /// [`Payee::for_keysend`]: crate::routing::router::Payee::for_keysend
+       /// [`find_route`]: crate::routing::router::find_route
        pub(crate) fn for_keysend() -> InvoiceFeatures {
                InvoiceFeatures::empty().set_variable_length_onion_optional()
        }
index fc57bafdd03c47d10a6e0bfa2f376fe27cae8c48..e6b2467077422a974eff3657c1ae6a2d0c615f97 100644 (file)
@@ -16,8 +16,7 @@ use chain::transaction::OutPoint;
 use ln::{PaymentPreimage, PaymentHash, PaymentSecret};
 use ln::channelmanager::{ChainParameters, ChannelManager, ChannelManagerReadArgs, RAACommitmentOrder, PaymentSendFailure, PaymentId};
 use routing::network_graph::{NetGraphMsgHandler, NetworkGraph};
-use routing::router::{Route, get_route};
-use routing::scorer::Scorer;
+use routing::router::{Payee, Route, get_route};
 use ln::features::{InitFeatures, InvoiceFeatures};
 use ln::msgs;
 use ln::msgs::{ChannelMessageHandler,RoutingMessageHandler};
@@ -190,6 +189,7 @@ pub struct TestChanMonCfg {
        pub persister: test_utils::TestPersister,
        pub logger: test_utils::TestLogger,
        pub keys_manager: test_utils::TestKeysInterface,
+       pub network_graph: NetworkGraph,
 }
 
 pub struct NodeCfg<'a> {
@@ -199,6 +199,7 @@ pub struct NodeCfg<'a> {
        pub chain_monitor: test_utils::TestChainMonitor<'a>,
        pub keys_manager: &'a test_utils::TestKeysInterface,
        pub logger: &'a test_utils::TestLogger,
+       pub network_graph: &'a NetworkGraph,
        pub node_seed: [u8; 32],
        pub features: InitFeatures,
 }
@@ -209,7 +210,8 @@ pub struct Node<'a, 'b: 'a, 'c: 'b> {
        pub chain_monitor: &'b test_utils::TestChainMonitor<'c>,
        pub keys_manager: &'b test_utils::TestKeysInterface,
        pub node: &'a ChannelManager<EnforcingSigner, &'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'c test_utils::TestLogger>,
-       pub net_graph_msg_handler: NetGraphMsgHandler<&'c test_utils::TestChainSource, &'c test_utils::TestLogger>,
+       pub network_graph: &'c NetworkGraph,
+       pub net_graph_msg_handler: NetGraphMsgHandler<&'c NetworkGraph, &'c test_utils::TestChainSource, &'c test_utils::TestLogger>,
        pub node_seed: [u8; 32],
        pub network_payment_count: Rc<RefCell<u8>>,
        pub network_chan_count: Rc<RefCell<u32>>,
@@ -240,12 +242,11 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                        // Check that if we serialize the Router, we can deserialize it again.
                        {
                                let mut w = test_utils::TestVecWriter(Vec::new());
-                               let network_graph_ser = &self.net_graph_msg_handler.network_graph;
-                               network_graph_ser.write(&mut w).unwrap();
+                               self.network_graph.write(&mut w).unwrap();
                                let network_graph_deser = <NetworkGraph>::read(&mut io::Cursor::new(&w.0)).unwrap();
-                               assert!(network_graph_deser == self.net_graph_msg_handler.network_graph);
+                               assert!(network_graph_deser == *self.network_graph);
                                let net_graph_msg_handler = NetGraphMsgHandler::new(
-                                       network_graph_deser, Some(self.chain_source), self.logger
+                                       &network_graph_deser, Some(self.chain_source), self.logger
                                );
                                let mut chan_progress = 0;
                                loop {
@@ -482,9 +483,9 @@ macro_rules! unwrap_send_err {
                                        _ => panic!(),
                                }
                        },
-                       &Err(PaymentSendFailure::PartialFailure(ref fails)) if !$all_failed => {
-                               assert_eq!(fails.len(), 1);
-                               match fails[0] {
+                       &Err(PaymentSendFailure::PartialFailure { ref results, .. }) if !$all_failed => {
+                               assert_eq!(results.len(), 1);
+                               match results[0] {
                                        Err($type) => { $check },
                                        _ => panic!(),
                                }
@@ -1011,13 +1012,14 @@ macro_rules! get_route_and_payment_hash {
        }};
        ($send_node: expr, $recv_node: expr, $last_hops: expr, $recv_value: expr, $cltv: expr) => {{
                let (payment_preimage, payment_hash, payment_secret) = get_payment_preimage_hash!($recv_node, Some($recv_value));
-               let net_graph_msg_handler = &$send_node.net_graph_msg_handler;
-               let scorer = ::routing::scorer::Scorer::new(0);
+               let payee = $crate::routing::router::Payee::from_node_id($recv_node.node.get_our_node_id())
+                       .with_features($crate::ln::features::InvoiceFeatures::known())
+                       .with_route_hints($last_hops);
+               let scorer = ::util::test_utils::TestScorer::with_fixed_penalty(0);
                let route = ::routing::router::get_route(
-                       &$send_node.node.get_our_node_id(), &net_graph_msg_handler.network_graph,
-                       &$recv_node.node.get_our_node_id(), Some(::ln::features::InvoiceFeatures::known()),
+                       &$send_node.node.get_our_node_id(), &payee, $send_node.network_graph,
                        Some(&$send_node.node.list_usable_channels().iter().collect::<Vec<_>>()),
-                       &$last_hops, $recv_value, $cltv, $send_node.logger, &scorer
+                       $recv_value, $cltv, $send_node.logger, &scorer
                ).unwrap();
                (route, payment_hash, payment_preimage, payment_secret)
        }}
@@ -1079,13 +1081,20 @@ macro_rules! expect_payment_received {
 
 macro_rules! expect_payment_sent {
        ($node: expr, $expected_payment_preimage: expr) => {
+               expect_payment_sent!($node, $expected_payment_preimage, None::<u64>);
+       };
+       ($node: expr, $expected_payment_preimage: expr, $expected_fee_msat_opt: expr) => {
                let events = $node.node.get_and_clear_pending_events();
                let expected_payment_hash = PaymentHash(Sha256::hash(&$expected_payment_preimage.0).into_inner());
                assert_eq!(events.len(), 1);
                match events[0] {
-                       Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+                       Event::PaymentSent { payment_id: _, ref payment_preimage, ref payment_hash, ref fee_paid_msat } => {
                                assert_eq!($expected_payment_preimage, *payment_preimage);
                                assert_eq!(expected_payment_hash, *payment_hash);
+                               assert!(fee_paid_msat.is_some());
+                               if $expected_fee_msat_opt.is_some() {
+                                       assert_eq!(*fee_paid_msat, $expected_fee_msat_opt);
+                               }
                        },
                        _ => panic!("Unexpected event"),
                }
@@ -1112,9 +1121,12 @@ macro_rules! expect_payment_failed_with_update {
                let events = $node.node.get_and_clear_pending_events();
                assert_eq!(events.len(), 1);
                match events[0] {
-                       Event::PaymentPathFailed { ref payment_hash, rejected_by_dest, ref network_update, ref error_code, ref error_data, .. } => {
+                       Event::PaymentPathFailed { ref payment_hash, rejected_by_dest, ref network_update, ref error_code, ref error_data, ref path, ref retry, .. } => {
                                assert_eq!(*payment_hash, $expected_payment_hash, "unexpected payment_hash");
                                assert_eq!(rejected_by_dest, $rejected_by_dest, "unexpected rejected_by_dest value");
+                               assert!(retry.is_some(), "expected retry.is_some()");
+                               assert_eq!(retry.as_ref().unwrap().final_value_msat, path.last().unwrap().fee_msat, "Retry amount should match last hop in path");
+                               assert_eq!(retry.as_ref().unwrap().payee.pubkey, path.last().unwrap().pubkey, "Retry payee node_id should match last hop in path");
                                assert!(error_code.is_some(), "expected error_code.is_some() = true");
                                assert!(error_data.is_some(), "expected error_data.is_some() = true");
                                match network_update {
@@ -1141,9 +1153,12 @@ macro_rules! expect_payment_failed {
                let events = $node.node.get_and_clear_pending_events();
                assert_eq!(events.len(), 1);
                match events[0] {
-                       Event::PaymentPathFailed { ref payment_hash, rejected_by_dest, network_update: _, ref error_code, ref error_data, .. } => {
+                       Event::PaymentPathFailed { ref payment_hash, rejected_by_dest, network_update: _, ref error_code, ref error_data, ref path, ref retry, .. } => {
                                assert_eq!(*payment_hash, $expected_payment_hash, "unexpected payment_hash");
                                assert_eq!(rejected_by_dest, $rejected_by_dest, "unexpected rejected_by_dest value");
+                               assert!(retry.is_some(), "expected retry.is_some()");
+                               assert_eq!(retry.as_ref().unwrap().final_value_msat, path.last().unwrap().fee_msat, "Retry amount should match last hop in path");
+                               assert_eq!(retry.as_ref().unwrap().payee.pubkey, path.last().unwrap().pubkey, "Retry payee node_id should match last hop in path");
                                assert!(error_code.is_some(), "expected error_code.is_some() = true");
                                assert!(error_data.is_some(), "expected error_data.is_some() = true");
                                $(
@@ -1229,13 +1244,15 @@ pub fn send_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, route: Route
        (our_payment_preimage, our_payment_hash, our_payment_secret, payment_id)
 }
 
-pub fn claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_preimage: PaymentPreimage) {
+pub fn do_claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_preimage: PaymentPreimage) -> u64 {
        for path in expected_paths.iter() {
                assert_eq!(path.last().unwrap().node.get_our_node_id(), expected_paths[0].last().unwrap().node.get_our_node_id());
        }
        assert!(expected_paths[0].last().unwrap().node.claim_funds(our_payment_preimage));
        check_added_monitors!(expected_paths[0].last().unwrap(), expected_paths.len());
 
+       let mut expected_total_fee_msat = 0;
+
        macro_rules! msgs_from_ev {
                ($ev: expr) => {
                        match $ev {
@@ -1278,6 +1295,7 @@ pub fn claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, exp
                                        $node.node.handle_update_fulfill_htlc(&$prev_node.node.get_our_node_id(), &next_msgs.as_ref().unwrap().0);
                                        let fee = $node.node.channel_state.lock().unwrap().by_id.get(&next_msgs.as_ref().unwrap().0.channel_id).unwrap().config.forwarding_fee_base_msat;
                                        expect_payment_forwarded!($node, Some(fee as u64), false);
+                                       expected_total_fee_msat += fee as u64;
                                        check_added_monitors!($node, 1);
                                        let new_next_msgs = if $new_msgs {
                                                let events = $node.node.get_and_clear_pending_msg_events();
@@ -1316,8 +1334,12 @@ pub fn claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, exp
                        last_update_fulfill_dance!(origin_node, expected_route.first().unwrap());
                }
        }
+       expected_total_fee_msat
+}
+pub fn claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_preimage: PaymentPreimage) {
+       let expected_total_fee_msat = do_claim_payment_along_route(origin_node, expected_paths, skip_last, our_payment_preimage);
        if !skip_last {
-               expect_payment_sent!(origin_node, our_payment_preimage);
+               expect_payment_sent!(origin_node, our_payment_preimage, Some(expected_total_fee_msat));
        }
 }
 
@@ -1328,11 +1350,12 @@ pub fn claim_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route:
 pub const TEST_FINAL_CLTV: u32 = 70;
 
 pub fn route_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) -> (PaymentPreimage, PaymentHash, PaymentSecret) {
-       let net_graph_msg_handler = &origin_node.net_graph_msg_handler;
-       let scorer = Scorer::new(0);
-       let route = get_route(&origin_node.node.get_our_node_id(), &net_graph_msg_handler.network_graph,
-               &expected_route.last().unwrap().node.get_our_node_id(), Some(InvoiceFeatures::known()),
-               Some(&origin_node.node.list_usable_channels().iter().collect::<Vec<_>>()), &[],
+       let payee = Payee::from_node_id(expected_route.last().unwrap().node.get_our_node_id())
+               .with_features(InvoiceFeatures::known());
+       let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+       let route = get_route(
+               &origin_node.node.get_our_node_id(), &payee, &origin_node.network_graph,
+               Some(&origin_node.node.list_usable_channels().iter().collect::<Vec<_>>()),
                recv_value, TEST_FINAL_CLTV, origin_node.logger, &scorer).unwrap();
        assert_eq!(route.paths.len(), 1);
        assert_eq!(route.paths[0].len(), expected_route.len());
@@ -1345,9 +1368,10 @@ pub fn route_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route:
 }
 
 pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64)  {
-       let net_graph_msg_handler = &origin_node.net_graph_msg_handler;
-       let scorer = Scorer::new(0);
-       let route = get_route(&origin_node.node.get_our_node_id(), &net_graph_msg_handler.network_graph, &expected_route.last().unwrap().node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), recv_value, TEST_FINAL_CLTV, origin_node.logger, &scorer).unwrap();
+       let payee = Payee::from_node_id(expected_route.last().unwrap().node.get_our_node_id())
+               .with_features(InvoiceFeatures::known());
+       let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+       let route = get_route(&origin_node.node.get_our_node_id(), &payee, origin_node.network_graph, None, recv_value, TEST_FINAL_CLTV, origin_node.logger, &scorer).unwrap();
        assert_eq!(route.paths.len(), 1);
        assert_eq!(route.paths[0].len(), expected_route.len());
        for (node, hop) in expected_route.iter().zip(route.paths[0].iter()) {
@@ -1473,8 +1497,9 @@ pub fn create_chanmon_cfgs(node_count: usize) -> Vec<TestChanMonCfg> {
                let persister = test_utils::TestPersister::new();
                let seed = [i as u8; 32];
                let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
+               let network_graph = NetworkGraph::new(chain_source.genesis_hash);
 
-               chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger, persister, keys_manager });
+               chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger, persister, keys_manager, network_graph });
        }
 
        chan_mon_cfgs
@@ -1495,6 +1520,7 @@ pub fn create_node_cfgs<'a>(node_count: usize, chanmon_cfgs: &'a Vec<TestChanMon
                        keys_manager: &chanmon_cfgs[i].keys_manager,
                        node_seed: seed,
                        features: InitFeatures::known(),
+                       network_graph: &chanmon_cfgs[i].network_graph,
                });
        }
 
@@ -1540,15 +1566,15 @@ pub fn create_network<'a, 'b: 'a, 'c: 'b>(node_count: usize, cfgs: &'b Vec<NodeC
        let connect_style = Rc::new(RefCell::new(ConnectStyle::FullBlockViaListen));
 
        for i in 0..node_count {
-               let network_graph = NetworkGraph::new(cfgs[i].chain_source.genesis_hash);
-               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, cfgs[i].logger);
-               nodes.push(Node{ chain_source: cfgs[i].chain_source,
-                                tx_broadcaster: cfgs[i].tx_broadcaster, chain_monitor: &cfgs[i].chain_monitor,
-                                keys_manager: &cfgs[i].keys_manager, node: &chan_mgrs[i], net_graph_msg_handler,
-                                node_seed: cfgs[i].node_seed, network_chan_count: chan_count.clone(),
-                                network_payment_count: payment_count.clone(), logger: cfgs[i].logger,
-                                blocks: Arc::clone(&cfgs[i].tx_broadcaster.blocks),
-                                connect_style: Rc::clone(&connect_style),
+               let net_graph_msg_handler = NetGraphMsgHandler::new(cfgs[i].network_graph, None, cfgs[i].logger);
+               nodes.push(Node{
+                       chain_source: cfgs[i].chain_source, tx_broadcaster: cfgs[i].tx_broadcaster,
+                       chain_monitor: &cfgs[i].chain_monitor, keys_manager: &cfgs[i].keys_manager,
+                       node: &chan_mgrs[i], network_graph: &cfgs[i].network_graph, net_graph_msg_handler,
+                       node_seed: cfgs[i].node_seed, network_chan_count: chan_count.clone(),
+                       network_payment_count: payment_count.clone(), logger: cfgs[i].logger,
+                       blocks: Arc::clone(&cfgs[i].tx_broadcaster.blocks),
+                       connect_style: Rc::clone(&connect_style),
                })
        }
 
index 4e2e15502093155927084c8554cd6fe39afbe337..5e87aaff3c40a93c67ddcf012a67098a33d427b7 100644 (file)
@@ -24,8 +24,7 @@ use ln::channel::{Channel, ChannelError};
 use ln::{chan_utils, onion_utils};
 use ln::chan_utils::HTLC_SUCCESS_TX_WEIGHT;
 use routing::network_graph::{NetworkUpdate, RoutingFees};
-use routing::router::{Route, RouteHop, RouteHint, RouteHintHop, get_route, get_keysend_route};
-use routing::scorer::Scorer;
+use routing::router::{Payee, Route, RouteHop, RouteHint, RouteHintHop, RouteParameters, find_route, get_route};
 use ln::features::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
 use ln::msgs;
 use ln::msgs::{ChannelMessageHandler, RoutingMessageHandler, ErrorAction};
@@ -919,7 +918,7 @@ fn fake_network_test() {
        });
        hops[1].fee_msat = chan_4.1.contents.fee_base_msat as u64 + chan_4.1.contents.fee_proportional_millionths as u64 * hops[2].fee_msat as u64 / 1000000;
        hops[0].fee_msat = chan_3.0.contents.fee_base_msat as u64 + chan_3.0.contents.fee_proportional_millionths as u64 * hops[1].fee_msat as u64 / 1000000;
-       let payment_preimage_1 = send_along_route(&nodes[1], Route { paths: vec![hops] }, &vec!(&nodes[2], &nodes[3], &nodes[1])[..], 1000000).0;
+       let payment_preimage_1 = send_along_route(&nodes[1], Route { paths: vec![hops], payee: None }, &vec!(&nodes[2], &nodes[3], &nodes[1])[..], 1000000).0;
 
        let mut hops = Vec::with_capacity(3);
        hops.push(RouteHop {
@@ -948,7 +947,7 @@ fn fake_network_test() {
        });
        hops[1].fee_msat = chan_2.1.contents.fee_base_msat as u64 + chan_2.1.contents.fee_proportional_millionths as u64 * hops[2].fee_msat as u64 / 1000000;
        hops[0].fee_msat = chan_3.1.contents.fee_base_msat as u64 + chan_3.1.contents.fee_proportional_millionths as u64 * hops[1].fee_msat as u64 / 1000000;
-       let payment_hash_2 = send_along_route(&nodes[1], Route { paths: vec![hops] }, &vec!(&nodes[3], &nodes[2], &nodes[1])[..], 1000000).1;
+       let payment_hash_2 = send_along_route(&nodes[1], Route { paths: vec![hops], payee: None }, &vec!(&nodes[3], &nodes[2], &nodes[1])[..], 1000000).1;
 
        // Claim the rebalances...
        fail_payment(&nodes[1], &vec!(&nodes[3], &nodes[2], &nodes[1])[..], payment_hash_2);
@@ -2659,7 +2658,7 @@ fn test_htlc_on_chain_success() {
        let mut first_claimed = false;
        for event in events {
                match event {
-                       Event::PaymentSent { payment_preimage, payment_hash } => {
+                       Event::PaymentSent { payment_preimage, payment_hash, .. } => {
                                if payment_preimage == our_payment_preimage && payment_hash == payment_hash_1 {
                                        assert!(!first_claimed);
                                        first_claimed = true;
@@ -3350,7 +3349,7 @@ fn test_simple_peer_disconnect() {
                let events = nodes[0].node.get_and_clear_pending_events();
                assert_eq!(events.len(), 2);
                match events[0] {
-                       Event::PaymentSent { payment_preimage, payment_hash } => {
+                       Event::PaymentSent { payment_preimage, payment_hash, .. } => {
                                assert_eq!(payment_preimage, payment_preimage_3);
                                assert_eq!(payment_hash, payment_hash_3);
                        },
@@ -3514,7 +3513,7 @@ fn do_test_drop_messages_peer_disconnect(messages_delivered: u8, simulate_broken
                let events_4 = nodes[0].node.get_and_clear_pending_events();
                assert_eq!(events_4.len(), 1);
                match events_4[0] {
-                       Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+                       Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                                assert_eq!(payment_preimage_1, *payment_preimage);
                                assert_eq!(payment_hash_1, *payment_hash);
                        },
@@ -3555,7 +3554,7 @@ fn do_test_drop_messages_peer_disconnect(messages_delivered: u8, simulate_broken
                        let events_4 = nodes[0].node.get_and_clear_pending_events();
                        assert_eq!(events_4.len(), 1);
                        match events_4[0] {
-                               Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+                               Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                                        assert_eq!(payment_preimage_1, *payment_preimage);
                                        assert_eq!(payment_hash_1, *payment_hash);
                                },
@@ -3790,7 +3789,7 @@ fn test_drop_messages_peer_disconnect_dual_htlc() {
                        let events_3 = nodes[0].node.get_and_clear_pending_events();
                        assert_eq!(events_3.len(), 1);
                        match events_3[0] {
-                               Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+                               Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                                        assert_eq!(*payment_preimage, payment_preimage_1);
                                        assert_eq!(*payment_hash, payment_hash_1);
                                },
@@ -3912,7 +3911,7 @@ fn do_test_htlc_timeout(send_partial_mpp: bool) {
                // indicates there are more HTLCs coming.
                let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match.
                let payment_id = PaymentId([42; 32]);
-               nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200000, cur_height, payment_id, &None).unwrap();
+               nodes[0].node.send_payment_along_path(&route.paths[0], &route.payee, &our_payment_hash, &Some(payment_secret), 200000, cur_height, payment_id, &None).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
@@ -5059,7 +5058,7 @@ fn test_duplicate_payment_hash_one_failure_one_success() {
 
        let events = nodes[0].node.get_and_clear_pending_events();
        match events[0] {
-               Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+               Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                        assert_eq!(*payment_preimage, our_payment_preimage);
                        assert_eq!(*payment_hash, duplicate_payment_hash);
                }
@@ -5454,7 +5453,7 @@ fn test_key_derivation_params() {
        let seed = [42; 32];
        let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
        let chain_monitor = test_utils::TestChainMonitor::new(Some(&chanmon_cfgs[0].chain_source), &chanmon_cfgs[0].tx_broadcaster, &chanmon_cfgs[0].logger, &chanmon_cfgs[0].fee_estimator, &chanmon_cfgs[0].persister, &keys_manager);
-       let node = NodeCfg { chain_source: &chanmon_cfgs[0].chain_source, logger: &chanmon_cfgs[0].logger, tx_broadcaster: &chanmon_cfgs[0].tx_broadcaster, fee_estimator: &chanmon_cfgs[0].fee_estimator, chain_monitor, keys_manager: &keys_manager, node_seed: seed, features: InitFeatures::known() };
+       let node = NodeCfg { chain_source: &chanmon_cfgs[0].chain_source, logger: &chanmon_cfgs[0].logger, tx_broadcaster: &chanmon_cfgs[0].tx_broadcaster, fee_estimator: &chanmon_cfgs[0].fee_estimator, chain_monitor, keys_manager: &keys_manager, network_graph: &chanmon_cfgs[0].network_graph, node_seed: seed, features: InitFeatures::known() };
        let mut node_cfgs = create_node_cfgs(3, &chanmon_cfgs);
        node_cfgs.remove(0);
        node_cfgs.insert(0, node);
@@ -5572,7 +5571,7 @@ fn do_htlc_claim_local_commitment_only(use_dust: bool) {
        let events = nodes[0].node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
        match events[0] {
-               Event::PaymentSent { payment_preimage, payment_hash } => {
+               Event::PaymentSent { payment_preimage, payment_hash, .. } => {
                        assert_eq!(payment_preimage, our_payment_preimage);
                        assert_eq!(payment_hash, our_payment_hash);
                },
@@ -5844,7 +5843,7 @@ fn test_fail_holding_cell_htlc_upon_free() {
        let (route, our_payment_hash, _, our_payment_secret) = get_route_and_payment_hash!(nodes[0], nodes[1], max_can_send);
 
        // Send a payment which passes reserve checks but gets stuck in the holding cell.
-       nodes[0].node.send_payment(&route, our_payment_hash, &Some(our_payment_secret)).unwrap();
+       let our_payment_id = nodes[0].node.send_payment(&route, our_payment_hash, &Some(our_payment_secret)).unwrap();
        chan_stat = get_channel_value_stat!(nodes[0], chan.2);
        assert_eq!(chan_stat.holding_cell_outbound_amount_msat, max_can_send);
 
@@ -5869,7 +5868,8 @@ fn test_fail_holding_cell_htlc_upon_free() {
        let events = nodes[0].node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
        match &events[0] {
-               &Event::PaymentPathFailed { ref payment_hash, ref rejected_by_dest, ref network_update, ref all_paths_failed, path: _, ref short_channel_id, ref error_code, ref error_data } => {
+               &Event::PaymentPathFailed { ref payment_id, ref payment_hash, ref rejected_by_dest, ref network_update, ref all_paths_failed, ref short_channel_id, ref error_code, ref error_data, .. } => {
+                       assert_eq!(our_payment_id, *payment_id.as_ref().unwrap());
                        assert_eq!(our_payment_hash.clone(), *payment_hash);
                        assert_eq!(*rejected_by_dest, false);
                        assert_eq!(*all_paths_failed, true);
@@ -5927,7 +5927,7 @@ fn test_free_and_fail_holding_cell_htlcs() {
        nodes[0].node.send_payment(&route_1, payment_hash_1, &Some(payment_secret_1)).unwrap();
        chan_stat = get_channel_value_stat!(nodes[0], chan.2);
        assert_eq!(chan_stat.holding_cell_outbound_amount_msat, amt_1);
-       nodes[0].node.send_payment(&route_2, payment_hash_2, &Some(payment_secret_2)).unwrap();
+       let payment_id_2 = nodes[0].node.send_payment(&route_2, payment_hash_2, &Some(payment_secret_2)).unwrap();
        chan_stat = get_channel_value_stat!(nodes[0], chan.2);
        assert_eq!(chan_stat.holding_cell_outbound_amount_msat, amt_1 + amt_2);
 
@@ -5953,7 +5953,8 @@ fn test_free_and_fail_holding_cell_htlcs() {
        let events = nodes[0].node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
        match &events[0] {
-               &Event::PaymentPathFailed { ref payment_hash, ref rejected_by_dest, ref network_update, ref all_paths_failed, path: _, ref short_channel_id, ref error_code, ref error_data } => {
+               &Event::PaymentPathFailed { ref payment_id, ref payment_hash, ref rejected_by_dest, ref network_update, ref all_paths_failed, ref short_channel_id, ref error_code, ref error_data, .. } => {
+                       assert_eq!(payment_id_2, *payment_id.as_ref().unwrap());
                        assert_eq!(payment_hash_2.clone(), *payment_hash);
                        assert_eq!(*rejected_by_dest, false);
                        assert_eq!(*all_paths_failed, true);
@@ -6000,7 +6001,7 @@ fn test_free_and_fail_holding_cell_htlcs() {
        let events = nodes[0].node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
        match events[0] {
-               Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+               Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                        assert_eq!(*payment_preimage, payment_preimage_1);
                        assert_eq!(*payment_hash, payment_hash_1);
                }
@@ -7159,8 +7160,9 @@ fn test_check_htlc_underpaying() {
        // Create some initial channels
        create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
 
-       let scorer = Scorer::new(0);
-       let route = get_route(&nodes[0].node.get_our_node_id(), &nodes[0].net_graph_msg_handler.network_graph, &nodes[1].node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), 10_000, TEST_FINAL_CLTV, nodes[0].logger, &scorer).unwrap();
+       let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+       let payee = Payee::from_node_id(nodes[1].node.get_our_node_id()).with_features(InvoiceFeatures::known());
+       let route = get_route(&nodes[0].node.get_our_node_id(), &payee, nodes[0].network_graph, None, 10_000, TEST_FINAL_CLTV, nodes[0].logger, &scorer).unwrap();
        let (_, our_payment_hash, _) = get_payment_preimage_hash!(nodes[0]);
        let our_payment_secret = nodes[1].node.create_inbound_payment_for_hash(our_payment_hash, Some(100_000), 7200, 0).unwrap();
        nodes[0].node.send_payment(&route, our_payment_hash, &Some(our_payment_secret)).unwrap();
@@ -7348,7 +7350,7 @@ fn test_priv_forwarding_rejection() {
                htlc_minimum_msat: None,
                htlc_maximum_msat: None,
        }]);
-       let last_hops = vec![&route_hint];
+       let last_hops = vec![route_hint];
        let (route, our_payment_hash, our_payment_preimage, our_payment_secret) = get_route_and_payment_hash!(nodes[0], nodes[2], last_hops, 10_000, TEST_FINAL_CLTV);
 
        nodes[0].node.send_payment(&route, our_payment_hash, &Some(our_payment_secret)).unwrap();
@@ -7557,12 +7559,14 @@ fn test_bump_penalty_txn_on_revoked_htlcs() {
 
        let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1000000, 59000000, InitFeatures::known(), InitFeatures::known());
        // Lock HTLC in both directions (using a slightly lower CLTV delay to provide timely RBF bumps)
-       let scorer = Scorer::new(0);
-       let route = get_route(&nodes[0].node.get_our_node_id(), &nodes[0].net_graph_msg_handler.network_graph,
-               &nodes[1].node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), 3_000_000, 50, nodes[0].logger, &scorer).unwrap();
+       let payee = Payee::from_node_id(nodes[1].node.get_our_node_id()).with_features(InvoiceFeatures::known());
+       let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+       let route = get_route(&nodes[0].node.get_our_node_id(), &payee, &nodes[0].network_graph, None,
+               3_000_000, 50, nodes[0].logger, &scorer).unwrap();
        let payment_preimage = send_along_route(&nodes[0], route, &[&nodes[1]], 3_000_000).0;
-       let route = get_route(&nodes[1].node.get_our_node_id(), &nodes[1].net_graph_msg_handler.network_graph,
-               &nodes[0].node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), 3_000_000, 50, nodes[0].logger, &scorer).unwrap();
+       let payee = Payee::from_node_id(nodes[0].node.get_our_node_id()).with_features(InvoiceFeatures::known());
+       let route = get_route(&nodes[1].node.get_our_node_id(), &payee, nodes[1].network_graph, None,
+               3_000_000, 50, nodes[0].logger, &scorer).unwrap();
        send_along_route(&nodes[1], route, &[&nodes[0]], 3_000_000);
 
        let revoked_local_txn = get_local_commitment_txn!(nodes[1], chan.2);
@@ -9048,13 +9052,16 @@ fn test_keysend_payments_to_public_node() {
        let nodes = create_network(2, &node_cfgs, &node_chanmgrs);
 
        let _chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 100000, 10001, InitFeatures::known(), InitFeatures::known());
-       let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
+       let network_graph = nodes[0].network_graph;
        let payer_pubkey = nodes[0].node.get_our_node_id();
        let payee_pubkey = nodes[1].node.get_our_node_id();
-       let scorer = Scorer::new(0);
-       let route = get_keysend_route(
-               &payer_pubkey, &network_graph, &payee_pubkey, None, &vec![], 10000, 40, nodes[0].logger, &scorer
-       ).unwrap();
+       let params = RouteParameters {
+               payee: Payee::for_keysend(payee_pubkey),
+               final_value_msat: 10000,
+               final_cltv_expiry_delta: 40,
+       };
+       let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+       let route = find_route(&payer_pubkey, &params, network_graph, None, nodes[0].logger, &scorer).unwrap();
 
        let test_preimage = PaymentPreimage([42; 32]);
        let (payment_hash, _) = nodes[0].node.send_spontaneous_payment(&route, Some(test_preimage)).unwrap();
@@ -9080,12 +9087,17 @@ fn test_keysend_payments_to_private_node() {
        nodes[1].node.peer_connected(&payer_pubkey, &msgs::Init { features: InitFeatures::known() });
 
        let _chan = create_chan_between_nodes(&nodes[0], &nodes[1], InitFeatures::known(), InitFeatures::known());
-       let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
+       let params = RouteParameters {
+               payee: Payee::for_keysend(payee_pubkey),
+               final_value_msat: 10000,
+               final_cltv_expiry_delta: 40,
+       };
+       let network_graph = nodes[0].network_graph;
        let first_hops = nodes[0].node.list_usable_channels();
-       let scorer = Scorer::new(0);
-       let route = get_keysend_route(
-               &payer_pubkey, &network_graph, &payee_pubkey, Some(&first_hops.iter().collect::<Vec<_>>()),
-               &vec![], 10000, 40, nodes[0].logger, &scorer
+       let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+       let route = find_route(
+               &payer_pubkey, &params, network_graph, Some(&first_hops.iter().collect::<Vec<_>>()),
+               nodes[0].logger, &scorer
        ).unwrap();
 
        let test_preimage = PaymentPreimage([42; 32]);
index c6f55e384d8432c6d5bfb0958662aa62b5d3b884..f3681946d194f614583c64488bb809103fccc9ec 100644 (file)
@@ -162,7 +162,7 @@ fn run_onion_failure_test_with_fail_intercept<F1,F2,F3>(_name: &str, test_case:
 
        let events = nodes[0].node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
-       if let &Event::PaymentPathFailed { payment_hash: _, ref rejected_by_dest, ref network_update, ref all_paths_failed, path: _, ref short_channel_id, ref error_code, error_data: _ } = &events[0] {
+       if let &Event::PaymentPathFailed { ref rejected_by_dest, ref network_update, ref all_paths_failed, ref short_channel_id, ref error_code, .. } = &events[0] {
                assert_eq!(*rejected_by_dest, !expected_retryable);
                assert_eq!(*all_paths_failed, true);
                assert_eq!(*error_code, expected_error_code);
index 20ff0c8344d4836738b7a63bad1c41db4fc3e36b..c1767c025a3531fe61e523ea39a7250eef5f9732 100644 (file)
@@ -555,6 +555,7 @@ mod tests {
                                                short_channel_id: 0, fee_msat: 0, cltv_expiry_delta: 0 // Test vectors are garbage and not generateble from a RouteHop, we fill in payloads manually
                                        },
                        ]],
+                       payee: None,
                };
 
                let session_priv = SecretKey::from_slice(&hex::decode("4141414141414141414141414141414141414141414141414141414141414141").unwrap()[..]).unwrap();
index 1b9836621e7761a3a0316c4f441ac4d2e826abbc..f685c375f38913c6fbedc9ce40a064d7e7373598 100644 (file)
@@ -296,7 +296,7 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        let mut nodes = create_network(3, &node_cfgs, &node_chanmgrs);
 
        let (_, _, chan_id, funding_tx) = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
-       create_announced_chan_between_nodes(&nodes, 1, 2, InitFeatures::known(), InitFeatures::known());
+       let (_, _, chan_id_2, _) = create_announced_chan_between_nodes(&nodes, 1, 2, InitFeatures::known(), InitFeatures::known());
 
        // Serialize the ChannelManager prior to sending payments
        let nodes_0_serialized = nodes[0].node.encode();
@@ -445,7 +445,13 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        // Finally, retry the payment (which was reloaded from the ChannelMonitor when nodes[0] was
        // reloaded) via a route over the new channel, which work without issue and eventually be
        // received and claimed at the recipient just like any other payment.
-       let (new_route, _, _, _) = get_route_and_payment_hash!(nodes[0], nodes[2], 1_000_000);
+       let (mut new_route, _, _, _) = get_route_and_payment_hash!(nodes[0], nodes[2], 1_000_000);
+
+       // Update the fee on the middle hop to ensure PaymentSent events have the correct (retried) fee
+       // and not the original fee. We also update node[1]'s relevant config as
+       // do_claim_payment_along_route expects us to never overpay.
+       nodes[1].node.channel_state.lock().unwrap().by_id.get_mut(&chan_id_2).unwrap().config.forwarding_fee_base_msat += 100_000;
+       new_route.paths[0][0].fee_msat += 100_000;
 
        assert!(nodes[0].node.retry_payment(&new_route, payment_id_1).is_err()); // Shouldn't be allowed to retry a fulfilled payment
        nodes[0].node.retry_payment(&new_route, payment_id).unwrap();
@@ -453,7 +459,8 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) {
        let mut events = nodes[0].node.get_and_clear_pending_msg_events();
        assert_eq!(events.len(), 1);
        pass_along_path(&nodes[0], &[&nodes[1], &nodes[2]], 1_000_000, payment_hash, Some(payment_secret), events.pop().unwrap(), true, None);
-       claim_payment_along_route(&nodes[0], &[&[&nodes[1], &nodes[2]]], false, payment_preimage);
+       do_claim_payment_along_route(&nodes[0], &[&[&nodes[1], &nodes[2]]], false, payment_preimage);
+       expect_payment_sent!(nodes[0], payment_preimage, Some(new_route.paths[0][0].fee_msat));
 }
 
 #[test]
index ebb0322f810a901c9e46a3e78dadce6b54e99611..f5d106333ba1da96c553f3018bd2f7e994cec7b8 100644 (file)
@@ -27,7 +27,7 @@ use ln::wire;
 use util::atomic_counter::AtomicCounter;
 use util::events::{MessageSendEvent, MessageSendEventsProvider};
 use util::logger::Logger;
-use routing::network_graph::NetGraphMsgHandler;
+use routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
 
 use prelude::*;
 use io;
@@ -284,6 +284,10 @@ enum InitSyncTracker{
        NodesSyncing(PublicKey),
 }
 
+/// The ratio between buffer sizes at which we stop sending initial sync messages vs when we stop
+/// forwarding gossip messages to peers altogether.
+const FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO: usize = 2;
+
 /// When the outbound buffer has this many messages, we'll stop reading bytes from the peer until
 /// we have fewer than this many messages in the outbound buffer again.
 /// We also use this as the target number of outbound gossip messages to keep in the write buffer,
@@ -291,7 +295,29 @@ enum InitSyncTracker{
 const OUTBOUND_BUFFER_LIMIT_READ_PAUSE: usize = 10;
 /// When the outbound buffer has this many messages, we'll simply skip relaying gossip messages to
 /// the peer.
-const OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP: usize = 20;
+const OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP: usize = OUTBOUND_BUFFER_LIMIT_READ_PAUSE * FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO;
+
+/// If we've sent a ping, and are still awaiting a response, we may need to churn our way through
+/// the socket receive buffer before receiving the ping.
+///
+/// On a fairly old Arm64 board, with Linux defaults, this can take as long as 20 seconds, not
+/// including any network delays, outbound traffic, or the same for messages from other peers.
+///
+/// Thus, to avoid needlessly disconnecting a peer, we allow a peer to take this many timer ticks
+/// per connected peer to respond to a ping, as long as they send us at least one message during
+/// each tick, ensuring we aren't actually just disconnected.
+/// With a timer tick interval of five seconds, this translates to about 30 seconds per connected
+/// peer.
+///
+/// When we improve parallelism somewhat we should reduce this to e.g. this many timer ticks per
+/// two connected peers, assuming most LDK-running systems have at least two cores.
+const MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER: i8 = 6;
+
+/// This is the minimum number of messages we expect a peer to be able to handle within one timer
+/// tick. Once we have sent this many messages since the last ping, we send a ping right away to
+/// ensures we don't just fill up our send buffer and leave the peer with too many messages to
+/// process before the next ping.
+const BUFFER_DRAIN_MSGS_PER_TICK: usize = 32;
 
 struct Peer {
        channel_encryptor: PeerChannelEncryptor,
@@ -308,7 +334,9 @@ struct Peer {
 
        sync_status: InitSyncTracker,
 
-       awaiting_pong: bool,
+       msgs_sent_since_pong: usize,
+       awaiting_pong_timer_tick_intervals: i8,
+       received_message_since_timer_tick: bool,
 }
 
 impl Peer {
@@ -347,7 +375,7 @@ struct PeerHolder<Descriptor: SocketDescriptor> {
 /// lifetimes). Other times you can afford a reference, which is more efficient, in which case
 /// SimpleRefPeerManager is the more appropriate type. Defining these type aliases prevents
 /// issues such as overly long function definitions.
-pub type SimpleArcPeerManager<SD, M, T, F, C, L> = PeerManager<SD, Arc<SimpleArcChannelManager<M, T, F, L>>, Arc<NetGraphMsgHandler<Arc<C>, Arc<L>>>, Arc<L>, Arc<IgnoringMessageHandler>>;
+pub type SimpleArcPeerManager<SD, M, T, F, C, L> = PeerManager<SD, Arc<SimpleArcChannelManager<M, T, F, L>>, Arc<NetGraphMsgHandler<Arc<NetworkGraph>, Arc<C>, Arc<L>>>, Arc<L>, Arc<IgnoringMessageHandler>>;
 
 /// SimpleRefPeerManager is a type alias for a PeerManager reference, and is the reference
 /// counterpart to the SimpleArcPeerManager type alias. Use this type by default when you don't
@@ -355,7 +383,7 @@ pub type SimpleArcPeerManager<SD, M, T, F, C, L> = PeerManager<SD, Arc<SimpleArc
 /// usage of lightning-net-tokio (since tokio::spawn requires parameters with static lifetimes).
 /// But if this is not necessary, using a reference is more efficient. Defining these type aliases
 /// helps with issues such as long function definitions.
-pub type SimpleRefPeerManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, SD, M, T, F, C, L> = PeerManager<SD, SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L>, &'e NetGraphMsgHandler<&'g C, &'f L>, &'f L, IgnoringMessageHandler>;
+pub type SimpleRefPeerManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, SD, M, T, F, C, L> = PeerManager<SD, SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L>, &'e NetGraphMsgHandler<&'g NetworkGraph, &'h C, &'f L>, &'f L, IgnoringMessageHandler>;
 
 /// A PeerManager manages a set of peers, described by their [`SocketDescriptor`] and marshalls
 /// socket events into messages which it passes on to its [`MessageHandler`].
@@ -455,6 +483,17 @@ impl<Descriptor: SocketDescriptor, RM: Deref, L: Deref> PeerManager<Descriptor,
        }
 }
 
+/// A simple wrapper that optionally prints " from <pubkey>" for an optional pubkey.
+/// This works around `format!()` taking a reference to each argument, preventing
+/// `if let Some(node_id) = peer.their_node_id { format!(.., node_id) } else { .. }` from compiling
+/// due to lifetime errors.
+struct OptionalFromDebugger<'a>(&'a Option<PublicKey>);
+impl core::fmt::Display for OptionalFromDebugger<'_> {
+       fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
+               if let Some(node_id) = self.0 { write!(f, " from {}", log_pubkey!(node_id)) } else { Ok(()) }
+       }
+}
+
 impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> PeerManager<Descriptor, CM, RM, L, CMH> where
                CM::Target: ChannelMessageHandler,
                RM::Target: RoutingMessageHandler,
@@ -534,7 +573,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
 
                        sync_status: InitSyncTracker::NoSyncRequested,
 
-                       awaiting_pong: false,
+                       msgs_sent_since_pong: 0,
+                       awaiting_pong_timer_tick_intervals: 0,
+                       received_message_since_timer_tick: false,
                }).is_some() {
                        panic!("PeerManager driver duplicated descriptors!");
                };
@@ -572,7 +613,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
 
                        sync_status: InitSyncTracker::NoSyncRequested,
 
-                       awaiting_pong: false,
+                       msgs_sent_since_pong: 0,
+                       awaiting_pong_timer_tick_intervals: 0,
+                       received_message_since_timer_tick: false,
                }).is_some() {
                        panic!("PeerManager driver duplicated descriptors!");
                };
@@ -581,7 +624,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
 
        fn do_attempt_write_data(&self, descriptor: &mut Descriptor, peer: &mut Peer) {
                while !peer.awaiting_write_event {
-                       if peer.pending_outbound_buffer.len() < OUTBOUND_BUFFER_LIMIT_READ_PAUSE {
+                       if peer.pending_outbound_buffer.len() < OUTBOUND_BUFFER_LIMIT_READ_PAUSE && peer.msgs_sent_since_pong < BUFFER_DRAIN_MSGS_PER_TICK {
                                match peer.sync_status {
                                        InitSyncTracker::NoSyncRequested => {},
                                        InitSyncTracker::ChannelsSyncing(c) if c < 0xffff_ffff_ffff_ffff => {
@@ -626,6 +669,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                        },
                                }
                        }
+                       if peer.msgs_sent_since_pong >= BUFFER_DRAIN_MSGS_PER_TICK {
+                               self.maybe_send_extra_ping(peer);
+                       }
 
                        if {
                                let next_buff = match peer.pending_outbound_buffer.front() {
@@ -701,14 +747,18 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                }
        }
 
-       /// Append a message to a peer's pending outbound/write buffer, and update the map of peers needing sends accordingly.
+       /// Append a message to a peer's pending outbound/write buffer
+       fn enqueue_encoded_message(&self, peer: &mut Peer, encoded_message: &Vec<u8>) {
+               peer.msgs_sent_since_pong += 1;
+               peer.pending_outbound_buffer.push_back(peer.channel_encryptor.encrypt_message(&encoded_message[..]));
+       }
+
+       /// Append a message to a peer's pending outbound/write buffer
        fn enqueue_message<M: wire::Type>(&self, peer: &mut Peer, message: &M) {
                let mut buffer = VecWriter(Vec::with_capacity(2048));
                wire::write(message, &mut buffer).unwrap(); // crash if the write failed
-               let encoded_message = buffer.0;
-
                log_trace!(self.logger, "Enqueueing message {:?} to {}", message, log_pubkey!(peer.their_node_id.unwrap()));
-               peer.pending_outbound_buffer.push_back(peer.channel_encryptor.encrypt_message(&encoded_message[..]));
+               self.enqueue_encoded_message(peer, &buffer.0);
        }
 
        fn do_read_event(&self, peer_descriptor: &mut Descriptor, data: &[u8]) -> Result<bool, PeerHandleError> {
@@ -748,19 +798,19 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                                                                        match e.action {
                                                                                                msgs::ErrorAction::DisconnectPeer { msg: _ } => {
                                                                                                        //TODO: Try to push msg
-                                                                                                       log_debug!(self.logger, "Error handling message; disconnecting peer with: {}", e.err);
+                                                                                                       log_debug!(self.logger, "Error handling message{}; disconnecting peer with: {}", OptionalFromDebugger(&peer.their_node_id), e.err);
                                                                                                        return Err(PeerHandleError{ no_connection_possible: false });
                                                                                                },
                                                                                                msgs::ErrorAction::IgnoreAndLog(level) => {
-                                                                                                       log_given_level!(self.logger, level, "Error handling message; ignoring: {}", e.err);
+                                                                                                       log_given_level!(self.logger, level, "Error handling message{}; ignoring: {}", OptionalFromDebugger(&peer.their_node_id), e.err);
                                                                                                        continue
                                                                                                },
                                                                                                msgs::ErrorAction::IgnoreError => {
-                                                                                                       log_debug!(self.logger, "Error handling message; ignoring: {}", e.err);
+                                                                                                       log_debug!(self.logger, "Error handling message{}; ignoring: {}", OptionalFromDebugger(&peer.their_node_id), e.err);
                                                                                                        continue;
                                                                                                },
                                                                                                msgs::ErrorAction::SendErrorMessage { msg } => {
-                                                                                                       log_debug!(self.logger, "Error handling message; sending error message with: {}", e.err);
+                                                                                                       log_debug!(self.logger, "Error handling message{}; sending error message with: {}", OptionalFromDebugger(&peer.their_node_id), e.err);
                                                                                                        self.enqueue_message(peer, &msg);
                                                                                                        continue;
                                                                                                },
@@ -804,6 +854,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                                                        let features = InitFeatures::known();
                                                                        let resp = msgs::Init { features };
                                                                        self.enqueue_message(peer, &resp);
+                                                                       peer.awaiting_pong_timer_tick_intervals = 0;
                                                                },
                                                                NextNoiseStep::ActThree => {
                                                                        let their_node_id = try_potential_handleerror!(peer.channel_encryptor.process_act_three(&peer.pending_read_buffer[..]));
@@ -814,6 +865,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                                                        let features = InitFeatures::known();
                                                                        let resp = msgs::Init { features };
                                                                        self.enqueue_message(peer, &resp);
+                                                                       peer.awaiting_pong_timer_tick_intervals = 0;
                                                                },
                                                                NextNoiseStep::NoiseComplete => {
                                                                        if peer.pending_read_is_header {
@@ -902,6 +954,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                message: wire::Message<<<CMH as core::ops::Deref>::Target as wire::CustomMessageReader>::CustomMessage>
        ) -> Result<Option<wire::Message<<<CMH as core::ops::Deref>::Target as wire::CustomMessageReader>::CustomMessage>>, MessageHandlingError> {
                log_trace!(self.logger, "Received message {:?} from {}", message, log_pubkey!(peer.their_node_id.unwrap()));
+               peer.received_message_since_timer_tick = true;
 
                // Need an Init as first message
                if let wire::Message::Init(_) = message {
@@ -923,7 +976,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                        return Err(PeerHandleError{ no_connection_possible: false }.into());
                                }
 
-                               log_info!(self.logger, "Received peer Init message: {}", msg.features);
+                               log_info!(self.logger, "Received peer Init message from {}: {}", log_pubkey!(peer.their_node_id.unwrap()), msg.features);
 
                                if msg.features.initial_routing_sync() {
                                        peer.sync_status = InitSyncTracker::ChannelsSyncing(0);
@@ -965,7 +1018,8 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                }
                        },
                        wire::Message::Pong(_msg) => {
-                               peer.awaiting_pong = false;
+                               peer.awaiting_pong_timer_tick_intervals = 0;
+                               peer.msgs_sent_since_pong = 0;
                        },
 
                        // Channel messages:
@@ -1086,7 +1140,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                                        !peer.should_forward_channel_announcement(msg.contents.short_channel_id) {
                                                continue
                                        }
-                                       if peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP {
+                                       if peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP
+                                               || peer.msgs_sent_since_pong > BUFFER_DRAIN_MSGS_PER_TICK * FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO
+                                       {
                                                log_trace!(self.logger, "Skipping broadcast message to {:?} as its outbound buffer is full", peer.their_node_id);
                                                continue;
                                        }
@@ -1097,7 +1153,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                        if except_node.is_some() && peer.their_node_id.as_ref() == except_node {
                                                continue;
                                        }
-                                       peer.pending_outbound_buffer.push_back(peer.channel_encryptor.encrypt_message(&encoded_msg[..]));
+                                       self.enqueue_encoded_message(peer, &encoded_msg);
                                }
                        },
                        wire::Message::NodeAnnouncement(ref msg) => {
@@ -1109,7 +1165,9 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                                        !peer.should_forward_node_announcement(msg.contents.node_id) {
                                                continue
                                        }
-                                       if peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP {
+                                       if peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP
+                                               || peer.msgs_sent_since_pong > BUFFER_DRAIN_MSGS_PER_TICK * FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO
+                                       {
                                                log_trace!(self.logger, "Skipping broadcast message to {:?} as its outbound buffer is full", peer.their_node_id);
                                                continue;
                                        }
@@ -1119,7 +1177,7 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                        if except_node.is_some() && peer.their_node_id.as_ref() == except_node {
                                                continue;
                                        }
-                                       peer.pending_outbound_buffer.push_back(peer.channel_encryptor.encrypt_message(&encoded_msg[..]));
+                                       self.enqueue_encoded_message(peer, &encoded_msg);
                                }
                        },
                        wire::Message::ChannelUpdate(ref msg) => {
@@ -1131,14 +1189,16 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                                        !peer.should_forward_channel_announcement(msg.contents.short_channel_id)  {
                                                continue
                                        }
-                                       if peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP {
+                                       if peer.pending_outbound_buffer.len() > OUTBOUND_BUFFER_LIMIT_DROP_GOSSIP
+                                               || peer.msgs_sent_since_pong > BUFFER_DRAIN_MSGS_PER_TICK * FORWARD_INIT_SYNC_BUFFER_LIMIT_RATIO
+                                       {
                                                log_trace!(self.logger, "Skipping broadcast message to {:?} as its outbound buffer is full", peer.their_node_id);
                                                continue;
                                        }
                                        if except_node.is_some() && peer.their_node_id.as_ref() == except_node {
                                                continue;
                                        }
-                                       peer.pending_outbound_buffer.push_back(peer.channel_encryptor.encrypt_message(&encoded_msg[..]));
+                                       self.enqueue_encoded_message(peer, &encoded_msg);
                                }
                        },
                        _ => debug_assert!(false, "We shouldn't attempt to forward anything but gossip messages"),
@@ -1414,6 +1474,37 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                }
        }
 
+       /// Disconnects all currently-connected peers. This is useful on platforms where there may be
+       /// an indication that TCP sockets have stalled even if we weren't around to time them out
+       /// using regular ping/pongs.
+       pub fn disconnect_all_peers(&self) {
+               let mut peers_lock = self.peers.lock().unwrap();
+               let peers = &mut *peers_lock;
+               for (mut descriptor, peer) in peers.peers.drain() {
+                       if let Some(node_id) = peer.their_node_id {
+                               log_trace!(self.logger, "Disconnecting peer with id {} due to client request to disconnect all peers", node_id);
+                               peers.node_id_to_descriptor.remove(&node_id);
+                               self.message_handler.chan_handler.peer_disconnected(&node_id, false);
+                       }
+                       descriptor.disconnect_socket();
+               }
+               debug_assert!(peers.node_id_to_descriptor.is_empty());
+       }
+
+       /// This is called when we're blocked on sending additional gossip messages until we receive a
+       /// pong. If we aren't waiting on a pong, we take this opportunity to send a ping (setting
+       /// `awaiting_pong_timer_tick_intervals` to a special flag value to indicate this).
+       fn maybe_send_extra_ping(&self, peer: &mut Peer) {
+               if peer.awaiting_pong_timer_tick_intervals == 0 {
+                       peer.awaiting_pong_timer_tick_intervals = -1;
+                       let ping = msgs::Ping {
+                               ponglen: 0,
+                               byteslen: 64,
+                       };
+                       self.enqueue_message(peer, &ping);
+               }
+       }
+
        /// Send pings to each peer and disconnect those which did not respond to the last round of
        /// pings.
        ///
@@ -1432,9 +1523,35 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                        let node_id_to_descriptor = &mut peers.node_id_to_descriptor;
                        let peers = &mut peers.peers;
                        let mut descriptors_needing_disconnect = Vec::new();
+                       let peer_count = peers.len();
 
                        peers.retain(|descriptor, peer| {
-                               if peer.awaiting_pong {
+                               let mut do_disconnect_peer = false;
+                               if !peer.channel_encryptor.is_ready_for_encryption() || peer.their_node_id.is_none() {
+                                       // The peer needs to complete its handshake before we can exchange messages. We
+                                       // give peers one timer tick to complete handshake, reusing
+                                       // `awaiting_pong_timer_tick_intervals` to track number of timer ticks taken
+                                       // for handshake completion.
+                                       if peer.awaiting_pong_timer_tick_intervals != 0 {
+                                               do_disconnect_peer = true;
+                                       } else {
+                                               peer.awaiting_pong_timer_tick_intervals = 1;
+                                               return true;
+                                       }
+                               }
+
+                               if peer.awaiting_pong_timer_tick_intervals == -1 {
+                                       // Magic value set in `maybe_send_extra_ping`.
+                                       peer.awaiting_pong_timer_tick_intervals = 1;
+                                       peer.received_message_since_timer_tick = false;
+                                       return true;
+                               }
+
+                               if do_disconnect_peer
+                                       || (peer.awaiting_pong_timer_tick_intervals > 0 && !peer.received_message_since_timer_tick)
+                                       || peer.awaiting_pong_timer_tick_intervals as u64 >
+                                               MAX_BUFFER_DRAIN_TICK_INTERVALS_PER_PEER as u64 * peer_count as u64
+                               {
                                        descriptors_needing_disconnect.push(descriptor.clone());
                                        match peer.their_node_id {
                                                Some(node_id) => {
@@ -1442,30 +1559,25 @@ impl<Descriptor: SocketDescriptor, CM: Deref, RM: Deref, L: Deref, CMH: Deref> P
                                                        node_id_to_descriptor.remove(&node_id);
                                                        self.message_handler.chan_handler.peer_disconnected(&node_id, false);
                                                }
-                                               None => {
-                                                       // This can't actually happen as we should have hit
-                                                       // is_ready_for_encryption() previously on this same peer.
-                                                       unreachable!();
-                                               },
+                                               None => {},
                                        }
                                        return false;
                                }
+                               peer.received_message_since_timer_tick = false;
 
-                               if !peer.channel_encryptor.is_ready_for_encryption() {
-                                       // The peer needs to complete its handshake before we can exchange messages
+                               if peer.awaiting_pong_timer_tick_intervals > 0 {
+                                       peer.awaiting_pong_timer_tick_intervals += 1;
                                        return true;
                                }
 
+                               peer.awaiting_pong_timer_tick_intervals = 1;
                                let ping = msgs::Ping {
                                        ponglen: 0,
                                        byteslen: 64,
                                };
                                self.enqueue_message(peer, &ping);
+                               self.do_attempt_write_data(&mut (descriptor.clone()), &mut *peer);
 
-                               let mut descriptor_clone = descriptor.clone();
-                               self.do_attempt_write_data(&mut descriptor_clone, peer);
-
-                               peer.awaiting_pong = true;
                                true
                        });
 
@@ -1624,11 +1736,23 @@ mod tests {
                // than can fit into a peer's buffer).
                let (mut fd_a, mut fd_b) = establish_connection(&peers[0], &peers[1]);
 
-               // Make each peer to read the messages that the other peer just wrote to them.
-               peers[0].process_events();
-               peers[1].read_event(&mut fd_b, &fd_a.outbound_data.lock().unwrap().split_off(0)).unwrap();
-               peers[1].process_events();
-               peers[0].read_event(&mut fd_a, &fd_b.outbound_data.lock().unwrap().split_off(0)).unwrap();
+               // Make each peer to read the messages that the other peer just wrote to them. Note that
+               // due to the max-messagse-before-ping limits this may take a few iterations to complete.
+               for _ in 0..150/super::BUFFER_DRAIN_MSGS_PER_TICK + 1 {
+                       peers[0].process_events();
+                       let b_read_data = fd_a.outbound_data.lock().unwrap().split_off(0);
+                       assert!(!b_read_data.is_empty());
+
+                       peers[1].read_event(&mut fd_b, &b_read_data).unwrap();
+                       peers[1].process_events();
+
+                       let a_read_data = fd_b.outbound_data.lock().unwrap().split_off(0);
+                       assert!(!a_read_data.is_empty());
+                       peers[0].read_event(&mut fd_a, &a_read_data).unwrap();
+
+                       peers[1].process_events();
+                       assert_eq!(fd_b.outbound_data.lock().unwrap().len(), 0, "Until B receives data, it shouldn't send more messages");
+               }
 
                // Check that each peer has received the expected number of channel updates and channel
                // announcements.
@@ -1637,4 +1761,37 @@ mod tests {
                assert_eq!(cfgs[1].routing_handler.chan_upds_recvd.load(Ordering::Acquire), 100);
                assert_eq!(cfgs[1].routing_handler.chan_anns_recvd.load(Ordering::Acquire), 50);
        }
+
+       #[test]
+       fn test_handshake_timeout() {
+               // Tests that we time out a peer still waiting on handshake completion after a full timer
+               // tick.
+               let cfgs = create_peermgr_cfgs(2);
+               cfgs[0].routing_handler.request_full_sync.store(true, Ordering::Release);
+               cfgs[1].routing_handler.request_full_sync.store(true, Ordering::Release);
+               let peers = create_network(2, &cfgs);
+
+               let secp_ctx = Secp256k1::new();
+               let a_id = PublicKey::from_secret_key(&secp_ctx, &peers[0].our_node_secret);
+               let mut fd_a = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let mut fd_b = FileDescriptor { fd: 1, outbound_data: Arc::new(Mutex::new(Vec::new())) };
+               let initial_data = peers[1].new_outbound_connection(a_id, fd_b.clone()).unwrap();
+               peers[0].new_inbound_connection(fd_a.clone()).unwrap();
+
+               // If we get a single timer tick before completion, that's fine
+               assert_eq!(peers[0].peers.lock().unwrap().peers.len(), 1);
+               peers[0].timer_tick_occurred();
+               assert_eq!(peers[0].peers.lock().unwrap().peers.len(), 1);
+
+               assert_eq!(peers[0].read_event(&mut fd_a, &initial_data).unwrap(), false);
+               peers[0].process_events();
+               assert_eq!(peers[1].read_event(&mut fd_b, &fd_a.outbound_data.lock().unwrap().split_off(0)).unwrap(), false);
+               peers[1].process_events();
+
+               // ...but if we get a second timer tick, we should disconnect the peer
+               peers[0].timer_tick_occurred();
+               assert_eq!(peers[0].peers.lock().unwrap().peers.len(), 0);
+
+               assert!(peers[0].read_event(&mut fd_a, &fd_b.outbound_data.lock().unwrap().split_off(0)).is_err());
+       }
 }
index e82cdae60f2b31617f801221163b4fb581c51a63..fcec4703e44f76306b0889d138c71ebf0b28e1a7 100644 (file)
@@ -13,9 +13,8 @@ use chain::keysinterface::KeysInterface;
 use chain::transaction::OutPoint;
 use ln::{PaymentPreimage, PaymentHash};
 use ln::channelmanager::PaymentSendFailure;
-use routing::router::get_route;
+use routing::router::{Payee, get_route};
 use routing::network_graph::NetworkUpdate;
-use routing::scorer::Scorer;
 use ln::features::{InitFeatures, InvoiceFeatures};
 use ln::msgs;
 use ln::msgs::{ChannelMessageHandler, ErrorAction};
@@ -82,7 +81,7 @@ fn updates_shutdown_wait() {
        let chan_1 = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
        let chan_2 = create_announced_chan_between_nodes(&nodes, 1, 2, InitFeatures::known(), InitFeatures::known());
        let logger = test_utils::TestLogger::new();
-       let scorer = Scorer::new(0);
+       let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
        let (our_payment_preimage, our_payment_hash, _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 100000);
 
@@ -97,10 +96,10 @@ fn updates_shutdown_wait() {
 
        let (_, payment_hash, payment_secret) = get_payment_preimage_hash!(nodes[0]);
 
-       let net_graph_msg_handler0 = &nodes[0].net_graph_msg_handler;
-       let net_graph_msg_handler1 = &nodes[1].net_graph_msg_handler;
-       let route_1 = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler0.network_graph, &nodes[1].node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &[], 100000, TEST_FINAL_CLTV, &logger, &scorer).unwrap();
-       let route_2 = get_route(&nodes[1].node.get_our_node_id(), &net_graph_msg_handler1.network_graph, &nodes[0].node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &[], 100000, TEST_FINAL_CLTV, &logger, &scorer).unwrap();
+       let payee_1 = Payee::from_node_id(nodes[1].node.get_our_node_id()).with_features(InvoiceFeatures::known());
+       let route_1 = get_route(&nodes[0].node.get_our_node_id(), &payee_1, nodes[0].network_graph, None, 100000, TEST_FINAL_CLTV, &logger, &scorer).unwrap();
+       let payee_2 = Payee::from_node_id(nodes[0].node.get_our_node_id()).with_features(InvoiceFeatures::known());
+       let route_2 = get_route(&nodes[1].node.get_our_node_id(), &payee_2, nodes[1].network_graph, None, 100000, TEST_FINAL_CLTV, &logger, &scorer).unwrap();
        unwrap_send_err!(nodes[0].node.send_payment(&route_1, payment_hash, &Some(payment_secret)), true, APIError::ChannelUnavailable {..}, {});
        unwrap_send_err!(nodes[1].node.send_payment(&route_2, payment_hash, &Some(payment_secret)), true, APIError::ChannelUnavailable {..}, {});
 
@@ -129,7 +128,7 @@ fn updates_shutdown_wait() {
        let events = nodes[0].node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
        match events[0] {
-               Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+               Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                        assert_eq!(our_payment_preimage, *payment_preimage);
                        assert_eq!(our_payment_hash, *payment_hash);
                },
@@ -307,7 +306,7 @@ fn do_test_shutdown_rebroadcast(recv_count: u8) {
        let events = nodes[0].node.get_and_clear_pending_events();
        assert_eq!(events.len(), 1);
        match events[0] {
-               Event::PaymentSent { ref payment_preimage, ref payment_hash } => {
+               Event::PaymentSent { ref payment_preimage, ref payment_hash, .. } => {
                        assert_eq!(our_payment_preimage, *payment_preimage);
                        assert_eq!(our_payment_hash, *payment_hash);
                },
index 51ffd91b50483d58bc2016440bdb98d9f82886e5..3a48ffe93ddf42c77fb3b3c9b17d727dba14c8f0 100644 (file)
@@ -14,6 +14,11 @@ pub mod router;
 pub mod scorer;
 
 use routing::network_graph::NodeId;
+use routing::router::RouteHop;
+
+use core::cell::{RefCell, RefMut};
+use core::ops::DerefMut;
+use sync::{Mutex, MutexGuard};
 
 /// An interface used to score payment channels for path finding.
 ///
@@ -22,4 +27,49 @@ pub trait Score {
        /// Returns the fee in msats willing to be paid to avoid routing through the given channel
        /// in the direction from `source` to `target`.
        fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64;
+
+       /// Handles updating channel penalties after failing to route through a channel.
+       fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64);
+}
+
+/// A scorer that is accessed under a lock.
+///
+/// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while
+/// having shared ownership of a scorer but without requiring internal locking in [`Score`]
+/// implementations. Internal locking would be detrimental to route finding performance and could
+/// result in [`Score::channel_penalty_msat`] returning a different value for the same channel.
+///
+/// [`find_route`]: crate::routing::router::find_route
+pub trait LockableScore<'a> {
+       /// The locked [`Score`] type.
+       type Locked: 'a + Score;
+
+       /// Returns the locked scorer.
+       fn lock(&'a self) -> Self::Locked;
+}
+
+impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
+       type Locked = MutexGuard<'a, T>;
+
+       fn lock(&'a self) -> MutexGuard<'a, T> {
+               Mutex::lock(self).unwrap()
+       }
+}
+
+impl<'a, T: 'a + Score> LockableScore<'a> for RefCell<T> {
+       type Locked = RefMut<'a, T>;
+
+       fn lock(&'a self) -> RefMut<'a, T> {
+               self.borrow_mut()
+       }
+}
+
+impl<S: Score, T: DerefMut<Target=S>> Score for T {
+       fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 {
+               self.deref().channel_penalty_msat(short_channel_id, source, target)
+       }
+
+       fn payment_path_failed(&mut self, path: &[&RouteHop], short_channel_id: u64) {
+               self.deref_mut().payment_path_failed(path, short_channel_id)
+       }
 }
index f65b8fa657a5fc2e77da24a33578ecf51089e8c4..851e01b5088e6988eae597982c91fa4c865b3172 100644 (file)
@@ -186,7 +186,7 @@ impl_writeable_tlv_based_enum_upgradable!(NetworkUpdate,
        },
 );
 
-impl<C: Deref, L: Deref> EventHandler for NetGraphMsgHandler<C, L>
+impl<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref> EventHandler for NetGraphMsgHandler<G, C, L>
 where C::Target: chain::Access, L::Target: Logger {
        fn handle_event(&self, event: &Event) {
                if let Event::PaymentPathFailed { payment_hash: _, rejected_by_dest: _, network_update, .. } = event {
@@ -205,19 +205,18 @@ where C::Target: chain::Access, L::Target: Logger {
 ///
 /// Serves as an [`EventHandler`] for applying updates from [`Event::PaymentPathFailed`] to the
 /// [`NetworkGraph`].
-pub struct NetGraphMsgHandler<C: Deref, L: Deref>
+pub struct NetGraphMsgHandler<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref>
 where C::Target: chain::Access, L::Target: Logger
 {
        secp_ctx: Secp256k1<secp256k1::VerifyOnly>,
-       /// Representation of the payment channel network
-       pub network_graph: NetworkGraph,
+       network_graph: G,
        chain_access: Option<C>,
        full_syncs_requested: AtomicUsize,
        pending_events: Mutex<Vec<MessageSendEvent>>,
        logger: L,
 }
 
-impl<C: Deref, L: Deref> NetGraphMsgHandler<C, L>
+impl<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref> NetGraphMsgHandler<G, C, L>
 where C::Target: chain::Access, L::Target: Logger
 {
        /// Creates a new tracker of the actual state of the network of channels and nodes,
@@ -225,7 +224,7 @@ where C::Target: chain::Access, L::Target: Logger
        /// Chain monitor is used to make sure announced channels exist on-chain,
        /// channel data is correct, and that the announcement is signed with
        /// channel owners' keys.
-       pub fn new(network_graph: NetworkGraph, chain_access: Option<C>, logger: L) -> Self {
+       pub fn new(network_graph: G, chain_access: Option<C>, logger: L) -> Self {
                NetGraphMsgHandler {
                        secp_ctx: Secp256k1::verification_only(),
                        network_graph,
@@ -288,7 +287,7 @@ macro_rules! secp_verify_sig {
        };
 }
 
-impl<C: Deref, L: Deref> RoutingMessageHandler for NetGraphMsgHandler<C, L>
+impl<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref> RoutingMessageHandler for NetGraphMsgHandler<G, C, L>
 where C::Target: chain::Access, L::Target: Logger
 {
        fn handle_node_announcement(&self, msg: &msgs::NodeAnnouncement) -> Result<bool, LightningError> {
@@ -554,7 +553,7 @@ where C::Target: chain::Access, L::Target: Logger
        }
 }
 
-impl<C: Deref, L: Deref> MessageSendEventsProvider for NetGraphMsgHandler<C, L>
+impl<G: Deref<Target=NetworkGraph>, C: Deref, L: Deref> MessageSendEventsProvider for NetGraphMsgHandler<G, C, L>
 where
        C::Target: chain::Access,
        L::Target: Logger,
@@ -1255,18 +1254,24 @@ mod tests {
        use prelude::*;
        use sync::Arc;
 
-       fn create_net_graph_msg_handler() -> (Secp256k1<All>, NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>) {
+       fn create_network_graph() -> NetworkGraph {
+               let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
+               NetworkGraph::new(genesis_hash)
+       }
+
+       fn create_net_graph_msg_handler(network_graph: &NetworkGraph) -> (
+               Secp256k1<All>, NetGraphMsgHandler<&NetworkGraph, Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>
+       ) {
                let secp_ctx = Secp256k1::new();
                let logger = Arc::new(test_utils::TestLogger::new());
-               let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
-               let network_graph = NetworkGraph::new(genesis_hash);
                let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger));
                (secp_ctx, net_graph_msg_handler)
        }
 
        #[test]
        fn request_full_sync_finite_times() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_id = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode("0202020202020202020202020202020202020202020202020202020202020202").unwrap()[..]).unwrap());
 
                assert!(net_graph_msg_handler.should_request_full_sync(&node_id));
@@ -1279,7 +1284,8 @@ mod tests {
 
        #[test]
        fn handling_node_announcements() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
 
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
@@ -1422,15 +1428,14 @@ mod tests {
 
                // Test if the UTXO lookups were not supported
                let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               let mut net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger));
+               let mut net_graph_msg_handler = NetGraphMsgHandler::new(&network_graph, None, Arc::clone(&logger));
                match net_graph_msg_handler.handle_channel_announcement(&valid_announcement) {
                        Ok(res) => assert!(res),
                        _ => panic!()
                };
 
                {
-                       let network = &net_graph_msg_handler.network_graph;
-                       match network.read_only().channels().get(&unsigned_announcement.short_channel_id) {
+                       match network_graph.read_only().channels().get(&unsigned_announcement.short_channel_id) {
                                None => panic!(),
                                Some(_) => ()
                        };
@@ -1447,7 +1452,7 @@ mod tests {
                let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
                *chain_source.utxo_ret.lock().unwrap() = Err(chain::AccessError::UnknownTx);
                let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, Some(chain_source.clone()), Arc::clone(&logger));
+               net_graph_msg_handler = NetGraphMsgHandler::new(&network_graph, Some(chain_source.clone()), Arc::clone(&logger));
                unsigned_announcement.short_channel_id += 1;
 
                msghash = hash_to_message!(&Sha256dHash::hash(&unsigned_announcement.encode()[..])[..]);
@@ -1482,8 +1487,7 @@ mod tests {
                };
 
                {
-                       let network = &net_graph_msg_handler.network_graph;
-                       match network.read_only().channels().get(&unsigned_announcement.short_channel_id) {
+                       match network_graph.read_only().channels().get(&unsigned_announcement.short_channel_id) {
                                None => panic!(),
                                Some(_) => ()
                        };
@@ -1513,8 +1517,7 @@ mod tests {
                        _ => panic!()
                };
                {
-                       let network = &net_graph_msg_handler.network_graph;
-                       match network.read_only().channels().get(&unsigned_announcement.short_channel_id) {
+                       match network_graph.read_only().channels().get(&unsigned_announcement.short_channel_id) {
                                Some(channel_entry) => {
                                        assert_eq!(channel_entry.features, ChannelFeatures::empty());
                                },
@@ -1572,7 +1575,7 @@ mod tests {
                let logger: Arc<Logger> = Arc::new(test_utils::TestLogger::new());
                let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
                let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, Some(chain_source.clone()), Arc::clone(&logger));
+               let net_graph_msg_handler = NetGraphMsgHandler::new(&network_graph, Some(chain_source.clone()), Arc::clone(&logger));
 
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
@@ -1644,8 +1647,7 @@ mod tests {
                };
 
                {
-                       let network = &net_graph_msg_handler.network_graph;
-                       match network.read_only().channels().get(&short_channel_id) {
+                       match network_graph.read_only().channels().get(&short_channel_id) {
                                None => panic!(),
                                Some(channel_info) => {
                                        assert_eq!(channel_info.one_to_two.as_ref().unwrap().cltv_expiry_delta, 144);
@@ -1741,7 +1743,7 @@ mod tests {
                let chain_source = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
                let genesis_hash = genesis_block(Network::Testnet).header.block_hash();
                let network_graph = NetworkGraph::new(genesis_hash);
-               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, Some(chain_source.clone()), &logger);
+               let net_graph_msg_handler = NetGraphMsgHandler::new(&network_graph, Some(chain_source.clone()), &logger);
                let secp_ctx = Secp256k1::new();
 
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
@@ -1753,7 +1755,6 @@ mod tests {
 
                let short_channel_id = 0;
                let chain_hash = genesis_block(Network::Testnet).header.block_hash();
-               let network_graph = &net_graph_msg_handler.network_graph;
 
                {
                        // There is no nodes in the table at the beginning.
@@ -1806,6 +1807,7 @@ mod tests {
                        assert!(network_graph.read_only().channels().get(&short_channel_id).unwrap().one_to_two.is_none());
 
                        net_graph_msg_handler.handle_event(&Event::PaymentPathFailed {
+                               payment_id: None,
                                payment_hash: PaymentHash([0; 32]),
                                rejected_by_dest: false,
                                all_paths_failed: true,
@@ -1814,6 +1816,7 @@ mod tests {
                                        msg: valid_channel_update,
                                }),
                                short_channel_id: None,
+                               retry: None,
                                error_code: None,
                                error_data: None,
                        });
@@ -1831,6 +1834,7 @@ mod tests {
                        };
 
                        net_graph_msg_handler.handle_event(&Event::PaymentPathFailed {
+                               payment_id: None,
                                payment_hash: PaymentHash([0; 32]),
                                rejected_by_dest: false,
                                all_paths_failed: true,
@@ -1840,6 +1844,7 @@ mod tests {
                                        is_permanent: false,
                                }),
                                short_channel_id: None,
+                               retry: None,
                                error_code: None,
                                error_data: None,
                        });
@@ -1855,6 +1860,7 @@ mod tests {
                // Permanent closing deletes a channel
                {
                        net_graph_msg_handler.handle_event(&Event::PaymentPathFailed {
+                               payment_id: None,
                                payment_hash: PaymentHash([0; 32]),
                                rejected_by_dest: false,
                                all_paths_failed: true,
@@ -1864,6 +1870,7 @@ mod tests {
                                        is_permanent: true,
                                }),
                                short_channel_id: None,
+                               retry: None,
                                error_code: None,
                                error_data: None,
                        });
@@ -1877,7 +1884,8 @@ mod tests {
 
        #[test]
        fn getting_next_channel_announcements() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_1_privkey);
@@ -2011,7 +2019,8 @@ mod tests {
 
        #[test]
        fn getting_next_node_announcements() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_1_privkey);
@@ -2128,7 +2137,8 @@ mod tests {
 
        #[test]
        fn network_graph_serialization() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
 
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_2_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
@@ -2185,17 +2195,17 @@ mod tests {
                        Err(_) => panic!()
                };
 
-               let network = &net_graph_msg_handler.network_graph;
                let mut w = test_utils::TestVecWriter(Vec::new());
-               assert!(!network.read_only().nodes().is_empty());
-               assert!(!network.read_only().channels().is_empty());
-               network.write(&mut w).unwrap();
-               assert!(<NetworkGraph>::read(&mut io::Cursor::new(&w.0)).unwrap() == *network);
+               assert!(!network_graph.read_only().nodes().is_empty());
+               assert!(!network_graph.read_only().channels().is_empty());
+               network_graph.write(&mut w).unwrap();
+               assert!(<NetworkGraph>::read(&mut io::Cursor::new(&w.0)).unwrap() == network_graph);
        }
 
        #[test]
        fn calling_sync_routing_table() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_privkey_1 = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_privkey_1);
 
@@ -2232,7 +2242,8 @@ mod tests {
                // The initial implementation allows syncing with the first 5 peers after
                // which should_request_full_sync will return false
                {
-                       let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+                       let network_graph = create_network_graph();
+                       let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                        let init_msg = Init { features: InitFeatures::known() };
                        for n in 1..7 {
                                let node_privkey = &SecretKey::from_slice(&[n; 32]).unwrap();
@@ -2251,7 +2262,8 @@ mod tests {
 
        #[test]
        fn handling_reply_channel_range() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_privkey_1 = &SecretKey::from_slice(&[42; 32]).unwrap();
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_privkey_1);
 
@@ -2299,7 +2311,8 @@ mod tests {
 
        #[test]
        fn handling_reply_short_channel_ids() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
                let node_id = PublicKey::from_secret_key(&secp_ctx, node_privkey);
 
@@ -2328,7 +2341,8 @@ mod tests {
 
        #[test]
        fn handling_query_channel_range() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
 
                let chain_hash = genesis_block(Network::Testnet).header.block_hash();
                let node_1_privkey = &SecretKey::from_slice(&[42; 32]).unwrap();
@@ -2590,7 +2604,7 @@ mod tests {
        }
 
        fn do_handling_query_channel_range(
-               net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
+               net_graph_msg_handler: &NetGraphMsgHandler<&NetworkGraph, Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
                test_node_id: &PublicKey,
                msg: QueryChannelRange,
                expected_ok: bool,
@@ -2639,7 +2653,8 @@ mod tests {
 
        #[test]
        fn handling_query_short_channel_ids() {
-               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler();
+               let network_graph = create_network_graph();
+               let (secp_ctx, net_graph_msg_handler) = create_net_graph_msg_handler(&network_graph);
                let node_privkey = &SecretKey::from_slice(&[41; 32]).unwrap();
                let node_id = PublicKey::from_secret_key(&secp_ctx, node_privkey);
 
index b617eebd42d834996279c9346e07a743bee7c1de..974ae74e4960f0c636670f721106aa4cb8d87d53 100644 (file)
@@ -70,19 +70,34 @@ pub struct Route {
        /// given path is variable, keeping the length of any path to less than 20 should currently
        /// ensure it is viable.
        pub paths: Vec<Vec<RouteHop>>,
+       /// The `payee` parameter passed to [`find_route`].
+       /// This is used by `ChannelManager` to track information which may be required for retries,
+       /// provided back to you via [`Event::PaymentPathFailed`].
+       ///
+       /// [`Event::PaymentPathFailed`]: crate::util::events::Event::PaymentPathFailed
+       pub payee: Option<Payee>,
+}
+
+pub(crate) trait RoutePath {
+       /// Gets the fees for a given path, excluding any excess paid to the recipient.
+       fn get_path_fees(&self) -> u64;
+}
+impl RoutePath for Vec<RouteHop> {
+       fn get_path_fees(&self) -> u64 {
+               // Do not count last hop of each path since that's the full value of the payment
+               self.split_last().map(|(_, path_prefix)| path_prefix).unwrap_or(&[])
+                       .iter().map(|hop| &hop.fee_msat)
+                       .sum()
+       }
 }
 
 impl Route {
        /// Returns the total amount of fees paid on this [`Route`].
        ///
        /// This doesn't include any extra payment made to the recipient, which can happen in excess of
-       /// the amount passed to [`get_route`]'s `final_value_msat`.
+       /// the amount passed to [`find_route`]'s `params.final_value_msat`.
        pub fn get_total_fees(&self) -> u64 {
-               // Do not count last hop of each path since that's the full value of the payment
-               return self.paths.iter()
-                       .flat_map(|path| path.split_last().map(|(_, path_prefix)| path_prefix).unwrap_or(&[]))
-                       .map(|hop| &hop.fee_msat)
-                       .sum();
+               self.paths.iter().map(|path| path.get_path_fees()).sum()
        }
 
        /// Returns the total amount paid on this [`Route`], excluding the fees.
@@ -106,7 +121,9 @@ impl Writeable for Route {
                                hop.write(writer)?;
                        }
                }
-               write_tlv_fields!(writer, {});
+               write_tlv_fields!(writer, {
+                       (1, self.payee, option),
+               });
                Ok(())
        }
 }
@@ -124,8 +141,101 @@ impl Readable for Route {
                        }
                        paths.push(hops);
                }
-               read_tlv_fields!(reader, {});
-               Ok(Route { paths })
+               let mut payee = None;
+               read_tlv_fields!(reader, {
+                       (1, payee, option),
+               });
+               Ok(Route { paths, payee })
+       }
+}
+
+/// Parameters needed to find a [`Route`] for paying a [`Payee`].
+///
+/// Passed to [`find_route`] and also provided in [`Event::PaymentPathFailed`] for retrying a failed
+/// payment path.
+///
+/// [`Event::PaymentPathFailed`]: crate::util::events::Event::PaymentPathFailed
+#[derive(Clone, Debug)]
+pub struct RouteParameters {
+       /// The recipient of the failed payment path.
+       pub payee: Payee,
+
+       /// The amount in msats sent on the failed payment path.
+       pub final_value_msat: u64,
+
+       /// The CLTV on the final hop of the failed payment path.
+       pub final_cltv_expiry_delta: u32,
+}
+
+impl_writeable_tlv_based!(RouteParameters, {
+       (0, payee, required),
+       (2, final_value_msat, required),
+       (4, final_cltv_expiry_delta, required),
+});
+
+/// The recipient of a payment.
+#[derive(Clone, Debug, Hash, PartialEq, Eq)]
+pub struct Payee {
+       /// The node id of the payee.
+       pub pubkey: PublicKey,
+
+       /// Features supported by the payee.
+       ///
+       /// May be set from the payee's invoice or via [`for_keysend`]. May be `None` if the invoice
+       /// does not contain any features.
+       ///
+       /// [`for_keysend`]: Self::for_keysend
+       pub features: Option<InvoiceFeatures>,
+
+       /// Hints for routing to the payee, containing channels connecting the payee to public nodes.
+       pub route_hints: Vec<RouteHint>,
+
+       /// Expiration of a payment to the payee, in seconds relative to the UNIX epoch.
+       pub expiry_time: Option<u64>,
+}
+
+impl_writeable_tlv_based!(Payee, {
+       (0, pubkey, required),
+       (2, features, option),
+       (4, route_hints, vec_type),
+       (6, expiry_time, option),
+});
+
+impl Payee {
+       /// Creates a payee with the node id of the given `pubkey`.
+       pub fn from_node_id(pubkey: PublicKey) -> Self {
+               Self {
+                       pubkey,
+                       features: None,
+                       route_hints: vec![],
+                       expiry_time: None,
+               }
+       }
+
+       /// Creates a payee with the node id of the given `pubkey` to use for keysend payments.
+       pub fn for_keysend(pubkey: PublicKey) -> Self {
+               Self::from_node_id(pubkey).with_features(InvoiceFeatures::for_keysend())
+       }
+
+       /// Includes the payee's features.
+       ///
+       /// (C-not exported) since bindings don't support move semantics
+       pub fn with_features(self, features: InvoiceFeatures) -> Self {
+               Self { features: Some(features), ..self }
+       }
+
+       /// Includes hints for routing to the payee.
+       ///
+       /// (C-not exported) since bindings don't support move semantics
+       pub fn with_route_hints(self, route_hints: Vec<RouteHint>) -> Self {
+               Self { route_hints, ..self }
+       }
+
+       /// Includes a payment expiration in seconds relative to the UNIX epoch.
+       ///
+       /// (C-not exported) since bindings don't support move semantics
+       pub fn with_expiry_time(self, expiry_time: u64) -> Self {
+               Self { expiry_time: Some(expiry_time), ..self }
        }
 }
 
@@ -133,6 +243,28 @@ impl Readable for Route {
 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
 pub struct RouteHint(pub Vec<RouteHintHop>);
 
+
+impl Writeable for RouteHint {
+       fn write<W: ::util::ser::Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
+               (self.0.len() as u64).write(writer)?;
+               for hop in self.0.iter() {
+                       hop.write(writer)?;
+               }
+               Ok(())
+       }
+}
+
+impl Readable for RouteHint {
+       fn read<R: io::Read>(reader: &mut R) -> Result<Self, DecodeError> {
+               let hop_count: u64 = Readable::read(reader)?;
+               let mut hops = Vec::with_capacity(cmp::min(hop_count, 16) as usize);
+               for _ in 0..hop_count {
+                       hops.push(Readable::read(reader)?);
+               }
+               Ok(Self(hops))
+       }
+}
+
 /// A channel descriptor for a hop along a payment path.
 #[derive(Clone, Debug, Hash, Eq, PartialEq)]
 pub struct RouteHintHop {
@@ -150,6 +282,15 @@ pub struct RouteHintHop {
        pub htlc_maximum_msat: Option<u64>,
 }
 
+impl_writeable_tlv_based!(RouteHintHop, {
+       (0, src_node_id, required),
+       (1, htlc_minimum_msat, option),
+       (2, short_channel_id, required),
+       (3, htlc_maximum_msat, option),
+       (4, fees, required),
+       (6, cltv_expiry_delta, required),
+});
+
 #[derive(Eq, PartialEq)]
 struct RouteGraphNode {
        node_id: NodeId,
@@ -360,48 +501,52 @@ fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option<u64> {
        }
 }
 
-/// Gets a keysend route from us (payer) to the given target node (payee). This is needed because
-/// keysend payments do not have an invoice from which to pull the payee's supported features, which
-/// makes it tricky to otherwise supply the `payee_features` parameter of `get_route`.
-pub fn get_keysend_route<L: Deref, S: routing::Score>(
-       our_node_pubkey: &PublicKey, network: &NetworkGraph, payee: &PublicKey,
-       first_hops: Option<&[&ChannelDetails]>, last_hops: &[&RouteHint], final_value_msat: u64,
-       final_cltv: u32, logger: L, scorer: &S
-) -> Result<Route, LightningError>
-where L::Target: Logger {
-       let invoice_features = InvoiceFeatures::for_keysend();
-       get_route(
-               our_node_pubkey, network, payee, Some(invoice_features), first_hops, last_hops,
-               final_value_msat, final_cltv, logger, scorer
-       )
-}
-
-/// Gets a route from us (payer) to the given target node (payee).
+/// Finds a route from us (payer) to the given target node (payee).
 ///
-/// If the payee provided features in their invoice, they should be provided via payee_features.
+/// If the payee provided features in their invoice, they should be provided via `params.payee`.
 /// Without this, MPP will only be used if the payee's features are available in the network graph.
 ///
-/// Private routing paths between a public node and the target may be included in `last_hops`.
-/// Currently, only the last hop in each path is considered.
+/// Private routing paths between a public node and the target may be included in `params.payee`.
+///
+/// If some channels aren't announced, it may be useful to fill in `first_hops` with the results
+/// from [`ChannelManager::list_usable_channels`]. If it is filled in, the view of our local
+/// channels from [`NetworkGraph`] will be ignored, and only those in `first_hops` will be used.
+///
+/// The fees on channels from us to the next hop are ignored as they are assumed to all be equal.
+/// However, the enabled/disabled bit on such channels as well as the `htlc_minimum_msat` /
+/// `htlc_maximum_msat` *are* checked as they may change based on the receiving node.
+///
+/// # Note
+///
+/// May be used to re-compute a [`Route`] when handling a [`Event::PaymentPathFailed`]. Any
+/// adjustments to the [`NetworkGraph`] and channel scores should be made prior to calling this
+/// function.
 ///
-/// If some channels aren't announced, it may be useful to fill in a first_hops with the
-/// results from a local ChannelManager::list_usable_channels() call. If it is filled in, our
-/// view of our local channels (from net_graph_msg_handler) will be ignored, and only those
-/// in first_hops will be used.
+/// # Panics
 ///
-/// Panics if first_hops contains channels without short_channel_ids
-/// (ChannelManager::list_usable_channels will never include such channels).
+/// Panics if first_hops contains channels without short_channel_ids;
+/// [`ChannelManager::list_usable_channels`] will never include such channels.
 ///
-/// The fees on channels from us to next-hops are ignored (as they are assumed to all be
-/// equal), however the enabled/disabled bit on such channels as well as the
-/// htlc_minimum_msat/htlc_maximum_msat *are* checked as they may change based on the receiving node.
-pub fn get_route<L: Deref, S: routing::Score>(
-       our_node_pubkey: &PublicKey, network: &NetworkGraph, payee: &PublicKey,
-       payee_features: Option<InvoiceFeatures>, first_hops: Option<&[&ChannelDetails]>,
-       last_hops: &[&RouteHint], final_value_msat: u64, final_cltv: u32, logger: L, scorer: &S
+/// [`ChannelManager::list_usable_channels`]: crate::ln::channelmanager::ChannelManager::list_usable_channels
+/// [`Event::PaymentPathFailed`]: crate::util::events::Event::PaymentPathFailed
+pub fn find_route<L: Deref, S: routing::Score>(
+       our_node_pubkey: &PublicKey, params: &RouteParameters, network: &NetworkGraph,
+       first_hops: Option<&[&ChannelDetails]>, logger: L, scorer: &S
 ) -> Result<Route, LightningError>
 where L::Target: Logger {
-       let payee_node_id = NodeId::from_pubkey(&payee);
+       get_route(
+               our_node_pubkey, &params.payee, network, first_hops, params.final_value_msat,
+               params.final_cltv_expiry_delta, logger, scorer
+       )
+}
+
+pub(crate) fn get_route<L: Deref, S: routing::Score>(
+       our_node_pubkey: &PublicKey, payee: &Payee, network: &NetworkGraph,
+       first_hops: Option<&[&ChannelDetails]>, final_value_msat: u64, final_cltv_expiry_delta: u32,
+       logger: L, scorer: &S
+) -> Result<Route, LightningError>
+where L::Target: Logger {
+       let payee_node_id = NodeId::from_pubkey(&payee.pubkey);
        let our_node_id = NodeId::from_pubkey(&our_node_pubkey);
 
        if payee_node_id == our_node_id {
@@ -416,10 +561,10 @@ where L::Target: Logger {
                return Err(LightningError{err: "Cannot send a payment of 0 msat".to_owned(), action: ErrorAction::IgnoreError});
        }
 
-       for route in last_hops.iter() {
+       for route in payee.route_hints.iter() {
                for hop in &route.0 {
-                       if hop.src_node_id == *payee {
-                               return Err(LightningError{err: "Last hop cannot have a payee as a source.".to_owned(), action: ErrorAction::IgnoreError});
+                       if hop.src_node_id == payee.pubkey {
+                               return Err(LightningError{err: "Route hint cannot have the payee as the source.".to_owned(), action: ErrorAction::IgnoreError});
                        }
                }
        }
@@ -500,15 +645,15 @@ where L::Target: Logger {
        // Allow MPP only if we have a features set from somewhere that indicates the payee supports
        // it. If the payee supports it they're supposed to include it in the invoice, so that should
        // work reliably.
-       let allow_mpp = if let Some(features) = &payee_features {
+       let allow_mpp = if let Some(features) = &payee.features {
                features.supports_basic_mpp()
        } else if let Some(node) = network_nodes.get(&payee_node_id) {
                if let Some(node_info) = node.announcement_info.as_ref() {
                        node_info.features.supports_basic_mpp()
                } else { false }
        } else { false };
-       log_trace!(logger, "Searching for a route from payer {} to payee {} {} MPP", our_node_pubkey, payee,
-               if allow_mpp { "with" } else { "without" });
+       log_trace!(logger, "Searching for a route from payer {} to payee {} {} MPP", our_node_pubkey,
+               payee.pubkey, if allow_mpp { "with" } else { "without" });
 
        // Step (1).
        // Prepare the data we'll use for payee-to-payer search by
@@ -567,7 +712,7 @@ where L::Target: Logger {
        // - when we want to stop looking for new paths.
        let mut already_collected_value_msat = 0;
 
-       log_trace!(logger, "Building path from {} (payee) to {} (us/payer) for value {} msat.", payee, our_node_pubkey, final_value_msat);
+       log_trace!(logger, "Building path from {} (payee) to {} (us/payer) for value {} msat.", payee.pubkey, our_node_pubkey, final_value_msat);
 
        macro_rules! add_entry {
                // Adds entry which goes from $src_node_id to $dest_node_id
@@ -937,7 +1082,7 @@ where L::Target: Logger {
                // If a caller provided us with last hops, add them to routing targets. Since this happens
                // earlier than general path finding, they will be somewhat prioritized, although currently
                // it matters only if the fees are exactly the same.
-               for route in last_hops.iter().filter(|route| !route.0.is_empty()) {
+               for route in payee.route_hints.iter().filter(|route| !route.0.is_empty()) {
                        let first_hop_in_route = &(route.0)[0];
                        let have_hop_src_in_graph =
                                // Only add the hops in this route to our candidate set if either
@@ -949,7 +1094,7 @@ where L::Target: Logger {
                                // We start building the path from reverse, i.e., from payee
                                // to the first RouteHintHop in the path.
                                let hop_iter = route.0.iter().rev();
-                               let prev_hop_iter = core::iter::once(payee).chain(
+                               let prev_hop_iter = core::iter::once(&payee.pubkey).chain(
                                        route.0.iter().skip(1).rev().map(|hop| &hop.src_node_id));
                                let mut hop_used = true;
                                let mut aggregate_next_hops_fee_msat: u64 = 0;
@@ -1108,7 +1253,7 @@ where L::Target: Logger {
                                }
                                ordered_hops.last_mut().unwrap().0.fee_msat = value_contribution_msat;
                                ordered_hops.last_mut().unwrap().0.hop_use_fee_msat = 0;
-                               ordered_hops.last_mut().unwrap().0.cltv_expiry_delta = final_cltv;
+                               ordered_hops.last_mut().unwrap().0.cltv_expiry_delta = final_cltv_expiry_delta;
 
                                log_trace!(logger, "Found a path back to us from the target with {} hops contributing up to {} msat: {:?}",
                                        ordered_hops.len(), value_contribution_msat, ordered_hops);
@@ -1309,7 +1454,7 @@ where L::Target: Logger {
                }).collect());
        }
 
-       if let Some(features) = &payee_features {
+       if let Some(features) = &payee.features {
                for path in selected_paths.iter_mut() {
                        if let Ok(route_hop) = path.last_mut().unwrap() {
                                route_hop.node_features = features.to_context();
@@ -1317,8 +1462,11 @@ where L::Target: Logger {
                }
        }
 
-       let route = Route { paths: selected_paths.into_iter().map(|path| path.into_iter().collect()).collect::<Result<Vec<_>, _>>()? };
-       log_info!(logger, "Got route to {}: {}", payee, log_route!(route));
+       let route = Route {
+               paths: selected_paths.into_iter().map(|path| path.into_iter().collect()).collect::<Result<Vec<_>, _>>()?,
+               payee: Some(payee.clone()),
+       };
+       log_info!(logger, "Got route to {}: {}", payee.pubkey, log_route!(route));
        Ok(route)
 }
 
@@ -1326,8 +1474,7 @@ where L::Target: Logger {
 mod tests {
        use routing;
        use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
-       use routing::router::{get_route, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees};
-       use routing::scorer::Scorer;
+       use routing::router::{get_route, Payee, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees};
        use chain::transaction::OutPoint;
        use ln::features::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
        use ln::msgs::{ErrorAction, LightningError, OptionalField, UnsignedChannelAnnouncement, ChannelAnnouncement, RoutingMessageHandler,
@@ -1378,7 +1525,7 @@ mod tests {
 
        // Using the same keys for LN and BTC ids
        fn add_channel(
-               net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
+               net_graph_msg_handler: &NetGraphMsgHandler<Arc<NetworkGraph>, Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
                secp_ctx: &Secp256k1<All>, node_1_privkey: &SecretKey, node_2_privkey: &SecretKey, features: ChannelFeatures, short_channel_id: u64
        ) {
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_1_privkey);
@@ -1410,7 +1557,7 @@ mod tests {
        }
 
        fn update_channel(
-               net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
+               net_graph_msg_handler: &NetGraphMsgHandler<Arc<NetworkGraph>, Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
                secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey, update: UnsignedChannelUpdate
        ) {
                let msghash = hash_to_message!(&Sha256dHash::hash(&update.encode()[..])[..]);
@@ -1426,7 +1573,7 @@ mod tests {
        }
 
        fn add_or_update_node(
-               net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
+               net_graph_msg_handler: &NetGraphMsgHandler<Arc<NetworkGraph>, Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
                secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey, features: NodeFeatures, timestamp: u32
        ) {
                let node_id = PublicKey::from_secret_key(&secp_ctx, node_privkey);
@@ -1480,12 +1627,18 @@ mod tests {
                }
        }
 
-       fn build_graph() -> (Secp256k1<All>, NetGraphMsgHandler<sync::Arc<test_utils::TestChainSource>, sync::Arc<crate::util::test_utils::TestLogger>>, sync::Arc<test_utils::TestChainSource>, sync::Arc<test_utils::TestLogger>) {
+       fn build_graph() -> (
+               Secp256k1<All>,
+               sync::Arc<NetworkGraph>,
+               NetGraphMsgHandler<sync::Arc<NetworkGraph>, sync::Arc<test_utils::TestChainSource>, sync::Arc<crate::util::test_utils::TestLogger>>,
+               sync::Arc<test_utils::TestChainSource>,
+               sync::Arc<test_utils::TestLogger>,
+       ) {
                let secp_ctx = Secp256k1::new();
                let logger = Arc::new(test_utils::TestLogger::new());
                let chain_monitor = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
-               let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger));
+               let network_graph = Arc::new(NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()));
+               let net_graph_msg_handler = NetGraphMsgHandler::new(Arc::clone(&network_graph), None, Arc::clone(&logger));
                // Build network from our_id to node6:
                //
                //        -1(1)2-  node0  -1(3)2-
@@ -1783,22 +1936,23 @@ mod tests {
 
                add_or_update_node(&net_graph_msg_handler, &secp_ctx, &privkeys[5], NodeFeatures::from_le_bytes(id_to_feature_flags(6)), 0);
 
-               (secp_ctx, net_graph_msg_handler, chain_monitor, logger)
+               (secp_ctx, network_graph, net_graph_msg_handler, chain_monitor, logger)
        }
 
        #[test]
        fn simple_route_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let payee = Payee::from_node_id(nodes[2]);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // Simple route to 2 via 1
 
-               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, None, &Vec::new(), 0, 42, Arc::clone(&logger), &scorer) {
+               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payee, &network_graph, None, 0, 42, Arc::clone(&logger), &scorer) {
                        assert_eq!(err, "Cannot send a payment of 0 msat");
                } else { panic!(); }
 
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, None, &Vec::new(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 2);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[1]);
@@ -1818,27 +1972,29 @@ mod tests {
 
        #[test]
        fn invalid_first_hop_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let payee = Payee::from_node_id(nodes[2]);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // Simple route to 2 via 1
 
                let our_chans = vec![get_channel_details(Some(2), our_id, InitFeatures::from_le_bytes(vec![0b11]), 100000)];
 
-               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, Some(&our_chans.iter().collect::<Vec<_>>()), &Vec::new(), 100, 42, Arc::clone(&logger), &scorer) {
+               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payee, &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), 100, 42, Arc::clone(&logger), &scorer) {
                        assert_eq!(err, "First hop cannot have our_node_pubkey as a destination.");
                } else { panic!(); }
 
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, None, &Vec::new(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 2);
        }
 
        #[test]
        fn htlc_minimum_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let payee = Payee::from_node_id(nodes[2]);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // Simple route to 2 via 1
 
@@ -1935,7 +2091,7 @@ mod tests {
                });
 
                // Not possible to send 199_999_999, because the minimum on channel=2 is 200_000_000.
-               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, None, &Vec::new(), 199_999_999, 42, Arc::clone(&logger), &scorer) {
+               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payee, &network_graph, None, 199_999_999, 42, Arc::clone(&logger), &scorer) {
                        assert_eq!(err, "Failed to find a path to the given destination");
                } else { panic!(); }
 
@@ -1954,15 +2110,16 @@ mod tests {
                });
 
                // A payment above the minimum should pass
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, None, &Vec::new(), 199_999_999, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 199_999_999, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 2);
        }
 
        #[test]
        fn htlc_minimum_overpay_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let payee = Payee::from_node_id(nodes[2]).with_features(InvoiceFeatures::known());
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // A route to node#2 via two paths.
                // One path allows transferring 35-40 sats, another one also allows 35-40 sats.
@@ -2032,8 +2189,7 @@ mod tests {
                        excess_data: Vec::new()
                });
 
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                       Some(InvoiceFeatures::known()), None, &Vec::new(), 60_000, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 60_000, 42, Arc::clone(&logger), &scorer).unwrap();
                // Overpay fees to hit htlc_minimum_msat.
                let overpaid_fees = route.paths[0][0].fee_msat + route.paths[1][0].fee_msat;
                // TODO: this could be better balanced to overpay 10k and not 15k.
@@ -2078,16 +2234,14 @@ mod tests {
                        excess_data: Vec::new()
                });
 
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                       Some(InvoiceFeatures::known()), None, &Vec::new(), 60_000, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 60_000, 42, Arc::clone(&logger), &scorer).unwrap();
                // Fine to overpay for htlc_minimum_msat if it allows us to save fee.
                assert_eq!(route.paths.len(), 1);
                assert_eq!(route.paths[0][0].short_channel_id, 12);
                let fees = route.paths[0][0].fee_msat;
                assert_eq!(fees, 5_000);
 
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                       Some(InvoiceFeatures::known()), None, &Vec::new(), 50_000, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 50_000, 42, Arc::clone(&logger), &scorer).unwrap();
                // Not fine to overpay for htlc_minimum_msat if it requires paying more than fee on
                // the other channel.
                assert_eq!(route.paths.len(), 1);
@@ -2098,9 +2252,10 @@ mod tests {
 
        #[test]
        fn disable_channels_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let payee = Payee::from_node_id(nodes[2]);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // // Disable channels 4 and 12 by flags=2
                update_channel(&net_graph_msg_handler, &secp_ctx, &privkeys[1], UnsignedChannelUpdate {
@@ -2129,13 +2284,13 @@ mod tests {
                });
 
                // If all the channels require some features we don't understand, route should fail
-               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, None, &Vec::new(), 100, 42, Arc::clone(&logger), &scorer) {
+               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer) {
                        assert_eq!(err, "Failed to find a path to the given destination");
                } else { panic!(); }
 
                // If we specify a channel to node7, that overrides our local channel view and that gets used
                let our_chans = vec![get_channel_details(Some(42), nodes[7].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)];
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, Some(&our_chans.iter().collect::<Vec<_>>()),  &Vec::new(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 2);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[7]);
@@ -2155,9 +2310,10 @@ mod tests {
 
        #[test]
        fn disable_node_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let payee = Payee::from_node_id(nodes[2]);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // Disable nodes 1, 2, and 8 by requiring unknown feature bits
                let unknown_features = NodeFeatures::known().set_unknown_feature_required();
@@ -2166,13 +2322,13 @@ mod tests {
                add_or_update_node(&net_graph_msg_handler, &secp_ctx, &privkeys[7], unknown_features.clone(), 1);
 
                // If all nodes require some features we don't understand, route should fail
-               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, None, &Vec::new(), 100, 42, Arc::clone(&logger), &scorer) {
+               if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer) {
                        assert_eq!(err, "Failed to find a path to the given destination");
                } else { panic!(); }
 
                // If we specify a channel to node7, that overrides our local channel view and that gets used
                let our_chans = vec![get_channel_details(Some(42), nodes[7].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)];
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, Some(&our_chans.iter().collect::<Vec<_>>()), &Vec::new(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 2);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[7]);
@@ -2196,12 +2352,13 @@ mod tests {
 
        #[test]
        fn our_chans_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // Route to 1 via 2 and 3 because our channel to 1 is disabled
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[0], None, None, &Vec::new(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let payee = Payee::from_node_id(nodes[0]);
+               let route = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 3);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[1]);
@@ -2226,8 +2383,9 @@ mod tests {
                assert_eq!(route.paths[0][2].channel_features.le_flags(), &id_to_feature_flags(3));
 
                // If we specify a channel to node7, that overrides our local channel view and that gets used
+               let payee = Payee::from_node_id(nodes[2]);
                let our_chans = vec![get_channel_details(Some(42), nodes[7].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)];
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, Some(&our_chans.iter().collect::<Vec<_>>()), &Vec::new(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 2);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[7]);
@@ -2323,9 +2481,9 @@ mod tests {
 
        #[test]
        fn partial_route_hint_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // Simple test across 2, 3, 5, and 4 via a last_hop channel
                // Tests the behaviour when the RouteHint contains a suboptimal hop.
@@ -2347,12 +2505,14 @@ mod tests {
                let mut invalid_last_hops = last_hops_multi_private_channels(&nodes);
                invalid_last_hops.push(invalid_last_hop);
                {
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &invalid_last_hops.iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer) {
-                               assert_eq!(err, "Last hop cannot have a payee as a source.");
+                       let payee = Payee::from_node_id(nodes[6]).with_route_hints(invalid_last_hops);
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer) {
+                               assert_eq!(err, "Route hint cannot have the payee as the source.");
                        } else { panic!(); }
                }
 
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &last_hops_multi_private_channels(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let payee = Payee::from_node_id(nodes[6]).with_route_hints(last_hops_multi_private_channels(&nodes));
+               let route = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 5);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[1]);
@@ -2419,13 +2579,14 @@ mod tests {
 
        #[test]
        fn ignores_empty_last_hops_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let payee = Payee::from_node_id(nodes[6]).with_route_hints(empty_last_hop(&nodes));
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // Test handling of an empty RouteHint passed in Invoice.
 
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &empty_last_hop(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 5);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[1]);
@@ -2500,9 +2661,10 @@ mod tests {
 
        #[test]
        fn multi_hint_last_hops_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let payee = Payee::from_node_id(nodes[6]).with_route_hints(multi_hint_last_hops(&nodes));
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
                // Test through channels 2, 3, 5, 8.
                // Test shows that multiple hop hints are considered.
 
@@ -2532,7 +2694,7 @@ mod tests {
                        excess_data: Vec::new()
                });
 
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &multi_hint_last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 4);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[1]);
@@ -2605,13 +2767,14 @@ mod tests {
 
        #[test]
        fn last_hops_with_public_channel_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let payee = Payee::from_node_id(nodes[6]).with_route_hints(last_hops_with_public_channel(&nodes));
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
                // This test shows that public routes can be present in the invoice
                // which would be handled in the same manner.
 
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &last_hops_with_public_channel(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 5);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[1]);
@@ -2654,14 +2817,15 @@ mod tests {
 
        #[test]
        fn our_chans_last_hop_connect_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // Simple test with outbound channel to 4 to test that last_hops and first_hops connect
                let our_chans = vec![get_channel_details(Some(42), nodes[3].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)];
                let mut last_hops = last_hops(&nodes);
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, Some(&our_chans.iter().collect::<Vec<_>>()), &last_hops.iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let payee = Payee::from_node_id(nodes[6]).with_route_hints(last_hops.clone());
+               let route = get_route(&our_id, &payee, &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 2);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[3]);
@@ -2681,7 +2845,8 @@ mod tests {
                last_hops[0].0[0].fees.base_msat = 1000;
 
                // Revert to via 6 as the fee on 8 goes up
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &last_hops.iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
+               let payee = Payee::from_node_id(nodes[6]).with_route_hints(last_hops);
+               let route = get_route(&our_id, &payee, &network_graph, None, 100, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 4);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[1]);
@@ -2715,7 +2880,7 @@ mod tests {
                assert_eq!(route.paths[0][3].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
 
                // ...but still use 8 for larger payments as 6 has a variable feerate
-               let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &last_hops.iter().collect::<Vec<_>>(), 2000, 42, Arc::clone(&logger), &scorer).unwrap();
+               let route = get_route(&our_id, &payee, &network_graph, None, 2000, 42, Arc::clone(&logger), &scorer).unwrap();
                assert_eq!(route.paths[0].len(), 5);
 
                assert_eq!(route.paths[0][0].pubkey, nodes[1]);
@@ -2773,9 +2938,10 @@ mod tests {
                        htlc_minimum_msat: None,
                        htlc_maximum_msat: last_hop_htlc_max,
                }]);
+               let payee = Payee::from_node_id(target_node_id).with_route_hints(vec![last_hops]);
                let our_chans = vec![get_channel_details(Some(42), middle_node_id, InitFeatures::from_le_bytes(vec![0b11]), outbound_capacity_msat)];
-               let scorer = Scorer::new(0);
-               get_route(&source_node_id, &NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()), &target_node_id, None, Some(&our_chans.iter().collect::<Vec<_>>()), &vec![&last_hops], route_val, 42, &test_utils::TestLogger::new(), &scorer)
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               get_route(&source_node_id, &payee, &NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()), Some(&our_chans.iter().collect::<Vec<_>>()), route_val, 42, &test_utils::TestLogger::new(), &scorer)
        }
 
        #[test]
@@ -2826,9 +2992,10 @@ mod tests {
        fn available_amount_while_routing_test() {
                // Tests whether we choose the correct available channel amount while routing.
 
-               let (secp_ctx, mut net_graph_msg_handler, chain_monitor, logger) = build_graph();
+               let (secp_ctx, network_graph, mut net_graph_msg_handler, chain_monitor, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[2]).with_features(InvoiceFeatures::known());
 
                // We will use a simple single-path route from
                // our node to node2 via node0: channels {1, 3}.
@@ -2891,16 +3058,15 @@ mod tests {
 
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                                       Some(InvoiceFeatures::known()), None, &Vec::new(), 250_000_001, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, None, 250_000_001, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
 
                {
                        // Now, attempt to route an exact amount we have should be fine.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 250_000_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 250_000_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
@@ -2928,16 +3094,15 @@ mod tests {
 
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                                       Some(InvoiceFeatures::known()), Some(&our_chans.iter().collect::<Vec<_>>()), &Vec::new(), 200_000_001, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), 200_000_001, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
 
                {
                        // Now, attempt to route an exact amount we have should be fine.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                               Some(InvoiceFeatures::known()), Some(&our_chans.iter().collect::<Vec<_>>()), &Vec::new(), 200_000_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), 200_000_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
@@ -2976,16 +3141,15 @@ mod tests {
 
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                                       Some(InvoiceFeatures::known()), None, &Vec::new(), 15_001, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, None, 15_001, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
 
                {
                        // Now, attempt to route an exact amount we have should be fine.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 15_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 15_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
@@ -3047,16 +3211,15 @@ mod tests {
 
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                                       Some(InvoiceFeatures::known()), None, &Vec::new(), 15_001, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, None, 15_001, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
 
                {
                        // Now, attempt to route an exact amount we have should be fine.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 15_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 15_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
@@ -3080,16 +3243,15 @@ mod tests {
 
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                                       Some(InvoiceFeatures::known()), None, &Vec::new(), 10_001, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, None, 10_001, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
 
                {
                        // Now, attempt to route an exact amount we have should be fine.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 10_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 10_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        let path = route.paths.last().unwrap();
                        assert_eq!(path.len(), 2);
@@ -3102,9 +3264,10 @@ mod tests {
        fn available_liquidity_last_hop_test() {
                // Check that available liquidity properly limits the path even when only
                // one of the latter hops is limited.
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[3]).with_features(InvoiceFeatures::known());
 
                // Path via {node7, node2, node4} is channels {12, 13, 6, 11}.
                // {12, 13, 11} have the capacities of 100, {6} has a capacity of 50.
@@ -3189,16 +3352,15 @@ mod tests {
                });
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[3],
-                                       Some(InvoiceFeatures::known()), None, &Vec::new(), 60_000, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, None, 60_000, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
 
                {
                        // Now, attempt to route 49 sats (just a bit below the capacity).
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[3],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 49_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 49_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        let mut total_amount_paid_msat = 0;
                        for path in &route.paths {
@@ -3211,8 +3373,7 @@ mod tests {
 
                {
                        // Attempt to route an exact amount is also fine
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[3],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 50_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 50_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        let mut total_amount_paid_msat = 0;
                        for path in &route.paths {
@@ -3226,9 +3387,10 @@ mod tests {
 
        #[test]
        fn ignore_fee_first_hop_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[2]);
 
                // Path via node0 is channels {1, 3}. Limit them to 100 and 50 sats (total limit 50).
                update_channel(&net_graph_msg_handler, &secp_ctx, &our_privkey, UnsignedChannelUpdate {
@@ -3257,7 +3419,7 @@ mod tests {
                });
 
                {
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, None, &Vec::new(), 50_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 50_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        let mut total_amount_paid_msat = 0;
                        for path in &route.paths {
@@ -3271,9 +3433,10 @@ mod tests {
 
        #[test]
        fn simple_mpp_route_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[2]).with_features(InvoiceFeatures::known());
 
                // We need a route consisting of 3 paths:
                // From our node to node2 via node0, node7, node1 (three paths one hop each).
@@ -3365,8 +3528,8 @@ mod tests {
 
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph,
-                                       &nodes[2], Some(InvoiceFeatures::known()), None, &Vec::new(), 300_000, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, None, 300_000, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
@@ -3374,8 +3537,7 @@ mod tests {
                {
                        // Now, attempt to route 250 sats (just a bit below the capacity).
                        // Our algorithm should provide us with these 3 paths.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 250_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 250_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 3);
                        let mut total_amount_paid_msat = 0;
                        for path in &route.paths {
@@ -3388,8 +3550,7 @@ mod tests {
 
                {
                        // Attempt to route an exact amount is also fine
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 290_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 290_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 3);
                        let mut total_amount_paid_msat = 0;
                        for path in &route.paths {
@@ -3403,9 +3564,10 @@ mod tests {
 
        #[test]
        fn long_mpp_route_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[3]).with_features(InvoiceFeatures::known());
 
                // We need a route consisting of 3 paths:
                // From our node to node3 via {node0, node2}, {node7, node2, node4} and {node7, node2}.
@@ -3540,8 +3702,8 @@ mod tests {
 
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[3],
-                                       Some(InvoiceFeatures::known()), None, &Vec::new(), 350_000, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, None, 350_000, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
@@ -3549,8 +3711,7 @@ mod tests {
                {
                        // Now, attempt to route 300 sats (exact amount we can route).
                        // Our algorithm should provide us with these 3 paths, 100 sats each.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[3],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 300_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 300_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 3);
 
                        let mut total_amount_paid_msat = 0;
@@ -3565,9 +3726,10 @@ mod tests {
 
        #[test]
        fn mpp_cheaper_route_test() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[3]).with_features(InvoiceFeatures::known());
 
                // This test checks that if we have two cheaper paths and one more expensive path,
                // so that liquidity-wise any 2 of 3 combination is sufficient,
@@ -3707,8 +3869,7 @@ mod tests {
                {
                        // Now, attempt to route 180 sats.
                        // Our algorithm should provide us with these 2 paths.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[3],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 180_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 180_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 2);
 
                        let mut total_value_transferred_msat = 0;
@@ -3732,9 +3893,10 @@ mod tests {
                // This test makes sure that MPP algorithm properly takes into account
                // fees charged on the channels, by making the fees impactful:
                // if the fee is not properly accounted for, the behavior is different.
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[3]).with_features(InvoiceFeatures::known());
 
                // We need a route consisting of 2 paths:
                // From our node to node3 via {node0, node2} and {node7, node2, node4}.
@@ -3874,16 +4036,15 @@ mod tests {
 
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[3],
-                                       Some(InvoiceFeatures::known()), None, &Vec::new(), 210_000, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, None, 210_000, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
 
                {
                        // Now, attempt to route 200 sats (exact amount we can route).
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[3],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 200_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 200_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 2);
 
                        let mut total_amount_paid_msat = 0;
@@ -3901,9 +4062,10 @@ mod tests {
        fn drop_lowest_channel_mpp_route_test() {
                // This test checks that low-capacity channel is dropped when after
                // path finding we realize that we found more capacity than we need.
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[2]).with_features(InvoiceFeatures::known());
 
                // We need a route consisting of 3 paths:
                // From our node to node2 via node0, node7, node1 (three paths one hop each).
@@ -3994,8 +4156,8 @@ mod tests {
 
                {
                        // Attempt to route more than available results in a failure.
-                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                                       Some(InvoiceFeatures::known()), None, &Vec::new(), 150_000, 42, Arc::clone(&logger), &scorer) {
+                       if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route(
+                                       &our_id, &payee, &network_graph, None, 150_000, 42, Arc::clone(&logger), &scorer) {
                                assert_eq!(err, "Failed to find a sufficient route to the given destination");
                        } else { panic!(); }
                }
@@ -4003,8 +4165,7 @@ mod tests {
                {
                        // Now, attempt to route 125 sats (just a bit below the capacity of 3 channels).
                        // Our algorithm should provide us with these 3 paths.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 125_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 125_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 3);
                        let mut total_amount_paid_msat = 0;
                        for path in &route.paths {
@@ -4017,8 +4178,7 @@ mod tests {
 
                {
                        // Attempt to route without the last small cheap channel
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2],
-                               Some(InvoiceFeatures::known()), None, &Vec::new(), 90_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 90_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 2);
                        let mut total_amount_paid_msat = 0;
                        for path in &route.paths {
@@ -4058,10 +4218,11 @@ mod tests {
                // "previous hop" being set to node 3, creating a loop in the path.
                let secp_ctx = Secp256k1::new();
                let logger = Arc::new(test_utils::TestLogger::new());
-               let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger));
+               let network_graph = Arc::new(NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()));
+               let net_graph_msg_handler = NetGraphMsgHandler::new(Arc::clone(&network_graph), None, Arc::clone(&logger));
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[6]);
 
                add_channel(&net_graph_msg_handler, &secp_ctx, &our_privkey, &privkeys[1], ChannelFeatures::from_le_bytes(id_to_feature_flags(6)), 6);
                update_channel(&net_graph_msg_handler, &secp_ctx, &our_privkey, UnsignedChannelUpdate {
@@ -4154,7 +4315,7 @@ mod tests {
 
                {
                        // Now ensure the route flows simply over nodes 1 and 4 to 6.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &Vec::new(), 10_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 10_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        assert_eq!(route.paths[0].len(), 3);
 
@@ -4187,9 +4348,10 @@ mod tests {
                // Test that if, while walking the graph, we find a hop that has exactly enough liquidity
                // for us, including later hop fees, we take it. In the first version of our MPP algorithm
                // we calculated fees on a higher value, resulting in us ignoring such paths.
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, _, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[2]);
 
                // We modify the graph to set the htlc_maximum of channel 2 to below the value we wish to
                // send.
@@ -4222,7 +4384,7 @@ mod tests {
                {
                        // Now, attempt to route 90 sats, which is exactly 90 sats at the last hop, plus the
                        // 200% fee charged channel 13 in the 1-to-2 direction.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], None, None, &Vec::new(), 90_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 90_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        assert_eq!(route.paths[0].len(), 2);
 
@@ -4248,9 +4410,10 @@ mod tests {
                // htlc_maximum_msat, we don't end up undershooting a later htlc_minimum_msat. In the
                // initial version of MPP we'd accept such routes but reject them while recalculating fees,
                // resulting in us thinking there is no possible path, even if other paths exist.
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, net_graph_msg_handler, _, logger) = build_graph();
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[2]).with_features(InvoiceFeatures::known());
 
                // We modify the graph to set the htlc_minimum of channel 2 and 4 as needed - channel 2
                // gets an htlc_maximum_msat of 80_000 and channel 4 an htlc_minimum_msat of 90_000. We
@@ -4284,7 +4447,7 @@ mod tests {
                        // Now, attempt to route 90 sats, hitting the htlc_minimum on channel 4, but
                        // overshooting the htlc_maximum on channel 2. Thus, we should pick the (absurdly
                        // expensive) channels 12-13 path.
-                       let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[2], Some(InvoiceFeatures::known()), None, &Vec::new(), 90_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       let route = get_route(&our_id, &payee, &network_graph, None, 90_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        assert_eq!(route.paths[0].len(), 2);
 
@@ -4316,13 +4479,14 @@ mod tests {
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
                let logger = Arc::new(test_utils::TestLogger::new());
                let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
+               let payee = Payee::from_node_id(nodes[0]).with_features(InvoiceFeatures::known());
 
                {
-                       let route = get_route(&our_id, &network_graph, &nodes[0], Some(InvoiceFeatures::known()), Some(&[
+                       let route = get_route(&our_id, &payee, &network_graph, Some(&[
                                &get_channel_details(Some(3), nodes[0], InitFeatures::known(), 200_000),
                                &get_channel_details(Some(2), nodes[0], InitFeatures::known(), 10_000),
-                       ]), &[], 100_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       ]), 100_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 1);
                        assert_eq!(route.paths[0].len(), 1);
 
@@ -4331,10 +4495,10 @@ mod tests {
                        assert_eq!(route.paths[0][0].fee_msat, 100_000);
                }
                {
-                       let route = get_route(&our_id, &network_graph, &nodes[0], Some(InvoiceFeatures::known()), Some(&[
+                       let route = get_route(&our_id, &payee, &network_graph, Some(&[
                                &get_channel_details(Some(3), nodes[0], InitFeatures::known(), 50_000),
                                &get_channel_details(Some(2), nodes[0], InitFeatures::known(), 50_000),
-                       ]), &[], 100_000, 42, Arc::clone(&logger), &scorer).unwrap();
+                       ]), 100_000, 42, Arc::clone(&logger), &scorer).unwrap();
                        assert_eq!(route.paths.len(), 2);
                        assert_eq!(route.paths[0].len(), 1);
                        assert_eq!(route.paths[1].len(), 1);
@@ -4351,14 +4515,15 @@ mod tests {
 
        #[test]
        fn prefers_shorter_route_with_higher_fees() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
+               let payee = Payee::from_node_id(nodes[6]).with_route_hints(last_hops(&nodes));
 
                // Without penalizing each hop 100 msats, a longer path with lower fees is chosen.
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
                let route = get_route(
-                       &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
-                       &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
+                       &our_id, &payee, &network_graph, None, 100, 42,
+                       Arc::clone(&logger), &scorer
                ).unwrap();
                let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
 
@@ -4368,10 +4533,10 @@ mod tests {
 
                // Applying a 100 msat penalty to each hop results in taking channels 7 and 10 to nodes[6]
                // from nodes[2] rather than channel 6, 11, and 8, even though the longer path is cheaper.
-               let scorer = Scorer::new(100);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(100);
                let route = get_route(
-                       &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
-                       &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
+                       &our_id, &payee, &network_graph, None, 100, 42,
+                       Arc::clone(&logger), &scorer
                ).unwrap();
                let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
 
@@ -4388,6 +4553,8 @@ mod tests {
                fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 {
                        if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
                }
+
+               fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}
        }
 
        struct BadNodeScorer {
@@ -4398,18 +4565,21 @@ mod tests {
                fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 {
                        if *target == self.node_id { u64::max_value() } else { 0 }
                }
+
+               fn payment_path_failed(&mut self, _path: &[&RouteHop], _short_channel_id: u64) {}
        }
 
        #[test]
        fn avoids_routing_through_bad_channels_and_nodes() {
-               let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
+               let (secp_ctx, network_graph, _, _, logger) = build_graph();
                let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
+               let payee = Payee::from_node_id(nodes[6]).with_route_hints(last_hops(&nodes));
 
                // A path to nodes[6] exists when no penalties are applied to any channel.
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
                let route = get_route(
-                       &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
-                       &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
+                       &our_id, &payee, &network_graph, None, 100, 42,
+                       Arc::clone(&logger), &scorer
                ).unwrap();
                let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
 
@@ -4420,8 +4590,8 @@ mod tests {
                // A different path to nodes[6] exists if channel 6 cannot be routed over.
                let scorer = BadChannelScorer { short_channel_id: 6 };
                let route = get_route(
-                       &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
-                       &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
+                       &our_id, &payee, &network_graph, None, 100, 42,
+                       Arc::clone(&logger), &scorer
                ).unwrap();
                let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
 
@@ -4432,8 +4602,8 @@ mod tests {
                // A path to nodes[6] does not exist if nodes[2] cannot be routed through.
                let scorer = BadNodeScorer { node_id: NodeId::from_pubkey(&nodes[2]) };
                match get_route(
-                       &our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
-                       &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
+                       &our_id, &payee, &network_graph, None, 100, 42,
+                       Arc::clone(&logger), &scorer
                ) {
                        Err(LightningError { err, .. } ) => {
                                assert_eq!(err, "Failed to find a path to the given destination");
@@ -4462,6 +4632,7 @@ mod tests {
                                        short_channel_id: 0, fee_msat: 225, cltv_expiry_delta: 0
                                },
                        ]],
+                       payee: None,
                };
 
                assert_eq!(route.get_total_fees(), 250);
@@ -4494,6 +4665,7 @@ mod tests {
                                        short_channel_id: 0, fee_msat: 150, cltv_expiry_delta: 0
                                },
                        ]],
+                       payee: None,
                };
 
                assert_eq!(route.get_total_fees(), 200);
@@ -4505,7 +4677,7 @@ mod tests {
                // In an earlier version of `Route::get_total_fees` and `Route::get_total_amount`, they
                // would both panic if the route was completely empty. We test to ensure they return 0
                // here, even though its somewhat nonsensical as a route.
-               let route = Route { paths: Vec::new() };
+               let route = Route { paths: Vec::new(), payee: None };
 
                assert_eq!(route.get_total_fees(), 0);
                assert_eq!(route.get_total_amount(), 0);
@@ -4533,7 +4705,7 @@ mod tests {
                        },
                };
                let graph = NetworkGraph::read(&mut d).unwrap();
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut seed = random_init_seed() as usize;
@@ -4543,9 +4715,10 @@ mod tests {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
                                let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let payee = Payee::from_node_id(dst);
                                let amt = seed as u64 % 200_000_000;
-                               if get_route(src, &graph, dst, None, None, &[], amt, 42, &test_utils::TestLogger::new(), &scorer).is_ok() {
+                               if get_route(src, &payee, &graph, None, amt, 42, &test_utils::TestLogger::new(), &scorer).is_ok() {
                                        continue 'load_endpoints;
                                }
                        }
@@ -4563,7 +4736,7 @@ mod tests {
                        },
                };
                let graph = NetworkGraph::read(&mut d).unwrap();
-               let scorer = Scorer::new(0);
+               let scorer = test_utils::TestScorer::with_fixed_penalty(0);
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut seed = random_init_seed() as usize;
@@ -4573,9 +4746,10 @@ mod tests {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
                                let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed = seed.overflowing_mul(0xdeadbeef).0;
-                               let dst = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let payee = Payee::from_node_id(dst).with_features(InvoiceFeatures::known());
                                let amt = seed as u64 % 200_000_000;
-                               if get_route(src, &graph, dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &test_utils::TestLogger::new(), &scorer).is_ok() {
+                               if get_route(src, &payee, &graph, None, amt, 42, &test_utils::TestLogger::new(), &scorer).is_ok() {
                                        continue 'load_endpoints;
                                }
                        }
@@ -4628,7 +4802,7 @@ mod benches {
                let mut d = test_utils::get_route_file().unwrap();
                let graph = NetworkGraph::read(&mut d).unwrap();
                let nodes = graph.read_only().nodes().clone();
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut path_endpoints = Vec::new();
@@ -4639,8 +4813,9 @@ mod benches {
                                let src = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed *= 0xdeadbeef;
                                let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let payee = Payee::from_node_id(dst);
                                let amt = seed as u64 % 1_000_000;
-                               if get_route(&src, &graph, &dst, None, None, &[], amt, 42, &DummyLogger{}, &scorer).is_ok() {
+                               if get_route(&src, &payee, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok() {
                                        path_endpoints.push((src, dst, amt));
                                        continue 'load_endpoints;
                                }
@@ -4651,7 +4826,8 @@ mod benches {
                let mut idx = 0;
                bench.iter(|| {
                        let (src, dst, amt) = path_endpoints[idx % path_endpoints.len()];
-                       assert!(get_route(&src, &graph, &dst, None, None, &[], amt, 42, &DummyLogger{}, &scorer).is_ok());
+                       let payee = Payee::from_node_id(dst);
+                       assert!(get_route(&src, &payee, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok());
                        idx += 1;
                });
        }
@@ -4661,7 +4837,7 @@ mod benches {
                let mut d = test_utils::get_route_file().unwrap();
                let graph = NetworkGraph::read(&mut d).unwrap();
                let nodes = graph.read_only().nodes().clone();
-               let scorer = Scorer::new(0);
+               let scorer = Scorer::with_fixed_penalty(0);
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut path_endpoints = Vec::new();
@@ -4672,8 +4848,9 @@ mod benches {
                                let src = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
                                seed *= 0xdeadbeef;
                                let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap();
+                               let payee = Payee::from_node_id(dst).with_features(InvoiceFeatures::known());
                                let amt = seed as u64 % 1_000_000;
-                               if get_route(&src, &graph, &dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &DummyLogger{}, &scorer).is_ok() {
+                               if get_route(&src, &payee, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok() {
                                        path_endpoints.push((src, dst, amt));
                                        continue 'load_endpoints;
                                }
@@ -4684,7 +4861,8 @@ mod benches {
                let mut idx = 0;
                bench.iter(|| {
                        let (src, dst, amt) = path_endpoints[idx % path_endpoints.len()];
-                       assert!(get_route(&src, &graph, &dst, Some(InvoiceFeatures::known()), None, &[], amt, 42, &DummyLogger{}, &scorer).is_ok());
+                       let payee = Payee::from_node_id(dst).with_features(InvoiceFeatures::known());
+                       assert!(get_route(&src, &payee, &graph, None, amt, 42, &DummyLogger{}, &scorer).is_ok());
                        idx += 1;
                });
        }
index 0f43c3d79283ae8e458b506a783aa72d03f90347..8b5fae70abdbffa34e6a365f312eb654f5165d84 100644 (file)
@@ -9,7 +9,7 @@
 
 //! Utilities for scoring payment channels.
 //!
-//! [`Scorer`] may be given to [`get_route`] to score payment channels during path finding when a
+//! [`Scorer`] may be given to [`find_route`] to score payment channels during path finding when a
 //! custom [`routing::Score`] implementation is not needed.
 //!
 //! # Example
@@ -18,8 +18,8 @@
 //! # extern crate secp256k1;
 //! #
 //! # use lightning::routing::network_graph::NetworkGraph;
-//! # use lightning::routing::router::get_route;
-//! # use lightning::routing::scorer::Scorer;
+//! # use lightning::routing::router::{RouteParameters, find_route};
+//! # use lightning::routing::scorer::{Scorer, ScoringParameters};
 //! # use lightning::util::logger::{Logger, Record};
 //! # use secp256k1::key::PublicKey;
 //! #
 //! # impl Logger for FakeLogger {
 //! #     fn log(&self, record: &Record) { unimplemented!() }
 //! # }
-//! # fn find_scored_route(payer: PublicKey, payee: PublicKey, network_graph: NetworkGraph) {
+//! # fn find_scored_route(payer: PublicKey, params: RouteParameters, network_graph: NetworkGraph) {
 //! # let logger = FakeLogger {};
 //! #
-//! // Use the default channel penalty.
+//! // Use the default channel penalties.
 //! let scorer = Scorer::default();
 //!
-//! // Or use a custom channel penalty.
-//! let scorer = Scorer::new(1_000);
+//! // Or use custom channel penalties.
+//! let scorer = Scorer::new(ScoringParameters {
+//!     base_penalty_msat: 1000,
+//!     failure_penalty_msat: 2 * 1024 * 1000,
+//!     ..ScoringParameters::default()
+//! });
 //!
-//! let route = get_route(&payer, &network_graph, &payee, None, None, &vec![], 1_000, 42, &logger, &scorer);
+//! let route = find_route(&payer, &params, &network_graph, None, &logger, &scorer);
 //! # }
 //! ```
 //!
-//! [`get_route`]: crate::routing::router::get_route
+//! # Note
+//!
+//! If persisting [`Scorer`], it must be restored using the same [`Time`] parameterization. Using a
+//! different type results in undefined behavior. Specifically, persisting when built with feature
+//! `no-std` and restoring without it, or vice versa, uses different types and thus is undefined.
+//!
+//! [`find_route`]: crate::routing::router::find_route
 
 use routing;
 
+use ln::msgs::DecodeError;
 use routing::network_graph::NodeId;
+use routing::router::RouteHop;
+use util::ser::{Readable, Writeable, Writer};
+
+use prelude::*;
+use core::ops::Sub;
+use core::time::Duration;
+use io::{self, Read};
 
 /// [`routing::Score`] implementation that provides reasonable default behavior.
 ///
 /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
-/// slightly higher fees are available.
+/// slightly higher fees are available. Will further penalize channels that fail to relay payments.
 ///
 /// See [module-level documentation] for usage.
 ///
 /// [module-level documentation]: crate::routing::scorer
-pub struct Scorer {
-       base_penalty_msat: u64,
+pub type Scorer = ScorerUsingTime::<DefaultTime>;
+
+/// Time used by [`Scorer`].
+#[cfg(not(feature = "no-std"))]
+pub type DefaultTime = std::time::Instant;
+
+/// Time used by [`Scorer`].
+#[cfg(feature = "no-std")]
+pub type DefaultTime = Eternity;
+
+/// [`routing::Score`] implementation parameterized by [`Time`].
+///
+/// See [`Scorer`] for details.
+///
+/// # Note
+///
+/// Mixing [`Time`] types between serialization and deserialization results in undefined behavior.
+pub struct ScorerUsingTime<T: Time> {
+       params: ScoringParameters,
+       // TODO: Remove entries of closed channels.
+       channel_failures: HashMap<u64, ChannelFailure<T>>,
+}
+
+/// Parameters for configuring [`Scorer`].
+pub struct ScoringParameters {
+       /// A fixed penalty in msats to apply to each channel.
+       pub base_penalty_msat: u64,
+
+       /// A penalty in msats to apply to a channel upon failing to relay a payment.
+       ///
+       /// This accumulates for each failure but may be reduced over time based on
+       /// [`failure_penalty_half_life`].
+       ///
+       /// [`failure_penalty_half_life`]: Self::failure_penalty_half_life
+       pub failure_penalty_msat: u64,
+
+       /// The time required to elapse before any accumulated [`failure_penalty_msat`] penalties are
+       /// cut in half.
+       ///
+       /// # Note
+       ///
+       /// When time is an [`Eternity`], as is default when enabling feature `no-std`, it will never
+       /// elapse. Therefore, this penalty will never decay.
+       ///
+       /// [`failure_penalty_msat`]: Self::failure_penalty_msat
+       pub failure_penalty_half_life: Duration,
+}
+
+impl_writeable_tlv_based!(ScoringParameters, {
+       (0, base_penalty_msat, required),
+       (2, failure_penalty_msat, required),
+       (4, failure_penalty_half_life, required),
+});
+
+/// Accounting for penalties against a channel for failing to relay any payments.
+///
+/// Penalties decay over time, though accumulate as more failures occur.
+struct ChannelFailure<T: Time> {
+       /// Accumulated penalty in msats for the channel as of `last_failed`.
+       undecayed_penalty_msat: u64,
+
+       /// Last time the channel failed. Used to decay `undecayed_penalty_msat`.
+       last_failed: T,
 }
 
-impl Scorer {
-       /// Creates a new scorer using `base_penalty_msat` as the channel penalty.
-       pub fn new(base_penalty_msat: u64) -> Self {
-               Self { base_penalty_msat }
+/// A measurement of time.
+pub trait Time: Sub<Duration, Output = Self> where Self: Sized {
+       /// Returns an instance corresponding to the current moment.
+       fn now() -> Self;
+
+       /// Returns the amount of time elapsed since `self` was created.
+       fn elapsed(&self) -> Duration;
+
+       /// Returns the amount of time passed since the beginning of [`Time`].
+       ///
+       /// Used during (de-)serialization.
+       fn duration_since_epoch() -> Duration;
+}
+
+impl<T: Time> ScorerUsingTime<T> {
+       /// Creates a new scorer using the given scoring parameters.
+       pub fn new(params: ScoringParameters) -> Self {
+               Self {
+                       params,
+                       channel_failures: HashMap::new(),
+               }
+       }
+
+       /// Creates a new scorer using `penalty_msat` as a fixed channel penalty.
+       #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))]
+       pub fn with_fixed_penalty(penalty_msat: u64) -> Self {
+               Self::new(ScoringParameters {
+                       base_penalty_msat: penalty_msat,
+                       failure_penalty_msat: 0,
+                       failure_penalty_half_life: Duration::from_secs(0),
+               })
+       }
+}
+
+impl<T: Time> ChannelFailure<T> {
+       fn new(failure_penalty_msat: u64) -> Self {
+               Self {
+                       undecayed_penalty_msat: failure_penalty_msat,
+                       last_failed: T::now(),
+               }
+       }
+
+       fn add_penalty(&mut self, failure_penalty_msat: u64, half_life: Duration) {
+               self.undecayed_penalty_msat = self.decayed_penalty_msat(half_life) + failure_penalty_msat;
+               self.last_failed = T::now();
+       }
+
+       fn decayed_penalty_msat(&self, half_life: Duration) -> u64 {
+               let decays = self.last_failed.elapsed().as_secs().checked_div(half_life.as_secs());
+               match decays {
+                       Some(decays) => self.undecayed_penalty_msat >> decays,
+                       None => 0,
+               }
+       }
+}
+
+impl<T: Time> Default for ScorerUsingTime<T> {
+       fn default() -> Self {
+               Self::new(ScoringParameters::default())
        }
 }
 
-impl Default for Scorer {
-       /// Creates a new scorer using 500 msat as the channel penalty.
+impl Default for ScoringParameters {
        fn default() -> Self {
-               Scorer::new(500)
+               Self {
+                       base_penalty_msat: 500,
+                       failure_penalty_msat: 1024 * 1000,
+                       failure_penalty_half_life: Duration::from_secs(3600),
+               }
        }
 }
 
-impl routing::Score for Scorer {
+impl<T: Time> routing::Score for ScorerUsingTime<T> {
        fn channel_penalty_msat(
-               &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
+               &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId
        ) -> u64 {
-               self.base_penalty_msat
+               let failure_penalty_msat = self.channel_failures
+                       .get(&short_channel_id)
+                       .map_or(0, |value| value.decayed_penalty_msat(self.params.failure_penalty_half_life));
+
+               self.params.base_penalty_msat + failure_penalty_msat
+       }
+
+       fn payment_path_failed(&mut self, _path: &[&RouteHop], short_channel_id: u64) {
+               let failure_penalty_msat = self.params.failure_penalty_msat;
+               let half_life = self.params.failure_penalty_half_life;
+               self.channel_failures
+                       .entry(short_channel_id)
+                       .and_modify(|failure| failure.add_penalty(failure_penalty_msat, half_life))
+                       .or_insert_with(|| ChannelFailure::new(failure_penalty_msat));
+       }
+}
+
+#[cfg(not(feature = "no-std"))]
+impl Time for std::time::Instant {
+       fn now() -> Self {
+               std::time::Instant::now()
+       }
+
+       fn duration_since_epoch() -> Duration {
+               use std::time::SystemTime;
+               SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap()
+       }
+
+       fn elapsed(&self) -> Duration {
+               std::time::Instant::elapsed(self)
+       }
+}
+
+/// A state in which time has no meaning.
+pub struct Eternity;
+
+impl Time for Eternity {
+       fn now() -> Self {
+               Self
+       }
+
+       fn duration_since_epoch() -> Duration {
+               Duration::from_secs(0)
+       }
+
+       fn elapsed(&self) -> Duration {
+               Duration::from_secs(0)
+       }
+}
+
+impl Sub<Duration> for Eternity {
+       type Output = Self;
+
+       fn sub(self, _other: Duration) -> Self {
+               self
+       }
+}
+
+impl<T: Time> Writeable for ScorerUsingTime<T> {
+       #[inline]
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.params.write(w)?;
+               self.channel_failures.write(w)?;
+               write_tlv_fields!(w, {});
+               Ok(())
+       }
+}
+
+impl<T: Time> Readable for ScorerUsingTime<T> {
+       #[inline]
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let res = Ok(Self {
+                       params: Readable::read(r)?,
+                       channel_failures: Readable::read(r)?,
+               });
+               read_tlv_fields!(r, {});
+               res
+       }
+}
+
+impl<T: Time> Writeable for ChannelFailure<T> {
+       #[inline]
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               let duration_since_epoch = T::duration_since_epoch() - self.last_failed.elapsed();
+               write_tlv_fields!(w, {
+                       (0, self.undecayed_penalty_msat, required),
+                       (2, duration_since_epoch, required),
+               });
+               Ok(())
+       }
+}
+
+impl<T: Time> Readable for ChannelFailure<T> {
+       #[inline]
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let mut undecayed_penalty_msat = 0;
+               let mut duration_since_epoch = Duration::from_secs(0);
+               read_tlv_fields!(r, {
+                       (0, undecayed_penalty_msat, required),
+                       (2, duration_since_epoch, required),
+               });
+               Ok(Self {
+                       undecayed_penalty_msat,
+                       last_failed: T::now() - (T::duration_since_epoch() - duration_since_epoch),
+               })
        }
 }
index 7ea369f551408bed9533c520a6adc8232ebf7b44..e0f0a63dbb341d8507a80ecd07b4a94f3cb91699 100644 (file)
 //! few other things.
 
 use chain::keysinterface::SpendableOutputDescriptor;
+use ln::channelmanager::PaymentId;
 use ln::msgs;
 use ln::msgs::DecodeError;
 use ln::{PaymentPreimage, PaymentHash, PaymentSecret};
 use routing::network_graph::NetworkUpdate;
 use util::ser::{BigSize, FixedLengthReader, Writeable, Writer, MaybeReadable, Readable, VecReadWrapper, VecWriteWrapper};
-use routing::router::RouteHop;
+use routing::router::{RouteHop, RouteParameters};
 
+use bitcoin::Transaction;
 use bitcoin::blockdata::script::Script;
 use bitcoin::hashes::Hash;
 use bitcoin::hashes::sha256::Hash as Sha256;
@@ -31,7 +33,7 @@ use io;
 use prelude::*;
 use core::time::Duration;
 use core::ops::Deref;
-use bitcoin::Transaction;
+use sync::Arc;
 
 /// Some information provided on receipt of payment depends on whether the payment received is a
 /// spontaneous payment or a "conventional" lightning payment that's paying an invoice.
@@ -178,6 +180,12 @@ pub enum Event {
        /// Note for MPP payments: in rare cases, this event may be preceded by a `PaymentPathFailed`
        /// event. In this situation, you SHOULD treat this payment as having succeeded.
        PaymentSent {
+               /// The id returned by [`ChannelManager::send_payment`] and used with
+               /// [`ChannelManager::retry_payment`].
+               ///
+               /// [`ChannelManager::send_payment`]: crate::ln::channelmanager::ChannelManager::send_payment
+               /// [`ChannelManager::retry_payment`]: crate::ln::channelmanager::ChannelManager::retry_payment
+               payment_id: Option<PaymentId>,
                /// The preimage to the hash given to ChannelManager::send_payment.
                /// Note that this serves as a payment receipt, if you wish to have such a thing, you must
                /// store it somehow!
@@ -186,10 +194,26 @@ pub enum Event {
                ///
                /// [`ChannelManager::send_payment`]: crate::ln::channelmanager::ChannelManager::send_payment
                payment_hash: PaymentHash,
+               /// The total fee which was spent at intermediate hops in this payment, across all paths.
+               ///
+               /// Note that, like [`Route::get_total_fees`] this does *not* include any potential
+               /// overpayment to the recipient node.
+               ///
+               /// If the recipient or an intermediate node misbehaves and gives us free money, this may
+               /// overstate the amount paid, though this is unlikely.
+               ///
+               /// [`Route::get_total_fees`]: crate::routing::router::Route::get_total_fees
+               fee_paid_msat: Option<u64>,
        },
        /// Indicates an outbound payment we made failed. Probably some intermediary node dropped
        /// something. You may wish to retry with a different route.
        PaymentPathFailed {
+               /// The id returned by [`ChannelManager::send_payment`] and used with
+               /// [`ChannelManager::retry_payment`].
+               ///
+               /// [`ChannelManager::send_payment`]: crate::ln::channelmanager::ChannelManager::send_payment
+               /// [`ChannelManager::retry_payment`]: crate::ln::channelmanager::ChannelManager::retry_payment
+               payment_id: Option<PaymentId>,
                /// The hash which was given to ChannelManager::send_payment.
                payment_hash: PaymentHash,
                /// Indicates the payment was rejected for some reason by the recipient. This implies that
@@ -216,6 +240,13 @@ pub enum Event {
                /// If this is `Some`, then the corresponding channel should be avoided when the payment is
                /// retried. May be `None` for older [`Event`] serializations.
                short_channel_id: Option<u64>,
+               /// Parameters needed to compute a new [`Route`] when retrying the failed payment path.
+               ///
+               /// See [`find_route`] for details.
+               ///
+               /// [`Route`]: crate::routing::router::Route
+               /// [`find_route`]: crate::routing::router::find_route
+               retry: Option<RouteParameters>,
 #[cfg(test)]
                error_code: Option<u16>,
 #[cfg(test)]
@@ -315,15 +346,18 @@ impl Writeable for Event {
                                        (8, payment_preimage, option),
                                });
                        },
-                       &Event::PaymentSent { ref payment_preimage, ref payment_hash} => {
+                       &Event::PaymentSent { ref payment_id, ref payment_preimage, ref payment_hash, ref fee_paid_msat } => {
                                2u8.write(writer)?;
                                write_tlv_fields!(writer, {
                                        (0, payment_preimage, required),
                                        (1, payment_hash, required),
+                                       (3, payment_id, option),
+                                       (5, fee_paid_msat, option),
                                });
                        },
-                       &Event::PaymentPathFailed { ref payment_hash, ref rejected_by_dest, ref network_update,
-                                                   ref all_paths_failed, ref path, ref short_channel_id,
+                       &Event::PaymentPathFailed {
+                               ref payment_id, ref payment_hash, ref rejected_by_dest, ref network_update,
+                               ref all_paths_failed, ref path, ref short_channel_id, ref retry,
                                #[cfg(test)]
                                ref error_code,
                                #[cfg(test)]
@@ -341,6 +375,8 @@ impl Writeable for Event {
                                        (3, all_paths_failed, required),
                                        (5, path, vec_type),
                                        (7, short_channel_id, option),
+                                       (9, retry, option),
+                                       (11, payment_id, option),
                                });
                        },
                        &Event::PendingHTLCsForwardable { time_forwardable: _ } => {
@@ -426,16 +462,22 @@ impl MaybeReadable for Event {
                                let f = || {
                                        let mut payment_preimage = PaymentPreimage([0; 32]);
                                        let mut payment_hash = None;
+                                       let mut payment_id = None;
+                                       let mut fee_paid_msat = None;
                                        read_tlv_fields!(reader, {
                                                (0, payment_preimage, required),
                                                (1, payment_hash, option),
+                                               (3, payment_id, option),
+                                               (5, fee_paid_msat, option),
                                        });
                                        if payment_hash.is_none() {
                                                payment_hash = Some(PaymentHash(Sha256::hash(&payment_preimage.0[..]).into_inner()));
                                        }
                                        Ok(Some(Event::PaymentSent {
+                                               payment_id,
                                                payment_preimage,
                                                payment_hash: payment_hash.unwrap(),
+                                               fee_paid_msat,
                                        }))
                                };
                                f()
@@ -452,21 +494,27 @@ impl MaybeReadable for Event {
                                        let mut all_paths_failed = Some(true);
                                        let mut path: Option<Vec<RouteHop>> = Some(vec![]);
                                        let mut short_channel_id = None;
+                                       let mut retry = None;
+                                       let mut payment_id = None;
                                        read_tlv_fields!(reader, {
                                                (0, payment_hash, required),
                                                (1, network_update, ignorable),
                                                (2, rejected_by_dest, required),
                                                (3, all_paths_failed, option),
                                                (5, path, vec_type),
-                                               (7, short_channel_id, ignorable),
+                                               (7, short_channel_id, option),
+                                               (9, retry, option),
+                                               (11, payment_id, option),
                                        });
                                        Ok(Some(Event::PaymentPathFailed {
+                                               payment_id,
                                                payment_hash,
                                                rejected_by_dest,
                                                network_update,
                                                all_paths_failed: all_paths_failed.unwrap(),
                                                path: path.unwrap(),
                                                short_channel_id,
+                                               retry,
                                                #[cfg(test)]
                                                error_code,
                                                #[cfg(test)]
@@ -748,3 +796,9 @@ impl<F> EventHandler for F where F: Fn(&Event) {
                self(event)
        }
 }
+
+impl<T: EventHandler> EventHandler for Arc<T> {
+       fn handle_event(&self, event: &Event) {
+               self.deref().handle_event(event)
+       }
+}
index b34d1f0b782d8618756edc28e3d4a4791be0779f..2717effb209d94750b1965a28c64a856647f3d80 100644 (file)
@@ -120,6 +120,19 @@ pub trait Logger {
        fn log(&self, record: &Record);
 }
 
+/// Wrapper for logging byte slices in hex format.
+/// (C-not exported) as fmt can't be used in C
+#[doc(hidden)]
+pub struct DebugBytes<'a>(pub &'a [u8]);
+impl<'a> core::fmt::Display for DebugBytes<'a> {
+       fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {
+               for i in self.0 {
+                       write!(f, "{:02x}", i)?;
+               }
+               Ok(())
+       }
+}
+
 #[cfg(test)]
 mod tests {
        use util::logger::{Logger, Level};
index d7849d77ae0b970796208ca7a52f726671047c26..7477526133e7b99394e699e7193dda3c7a21bbba 100644 (file)
@@ -16,6 +16,7 @@ use bitcoin::secp256k1::key::PublicKey;
 
 use routing::router::Route;
 use ln::chan_utils::HTLCType;
+use util::logger::DebugBytes;
 
 pub(crate) struct DebugPubKey<'a>(pub &'a PublicKey);
 impl<'a> core::fmt::Display for DebugPubKey<'a> {
@@ -32,18 +33,11 @@ macro_rules! log_pubkey {
        }
 }
 
-pub(crate) struct DebugBytes<'a>(pub &'a [u8]);
-impl<'a> core::fmt::Display for DebugBytes<'a> {
-       fn fmt(&self, f: &mut core::fmt::Formatter) -> Result<(), core::fmt::Error> {
-               for i in self.0 {
-                       write!(f, "{:02x}", i)?;
-               }
-               Ok(())
-       }
-}
+/// Logs a byte slice in hex format.
+#[macro_export]
 macro_rules! log_bytes {
        ($obj: expr) => {
-               ::util::macro_logger::DebugBytes(&$obj)
+               $crate::util::logger::DebugBytes(&$obj)
        }
 }
 
@@ -157,6 +151,7 @@ macro_rules! log_spendable {
 
 /// Create a new Record and log it. You probably don't want to use this macro directly,
 /// but it needs to be exported so `log_trace` etc can use it in external crates.
+#[doc(hidden)]
 #[macro_export]
 macro_rules! log_internal {
        ($logger: expr, $lvl:expr, $($arg:tt)+) => (
@@ -218,6 +213,6 @@ macro_rules! log_debug {
 #[macro_export]
 macro_rules! log_trace {
        ($logger: expr, $($arg:tt)*) => (
-               log_given_level!($logger, $crate::util::logger::Level::Trace, $($arg)*);
+               log_given_level!($logger, $crate::util::logger::Level::Trace, $($arg)*)
        )
 }
index ba8a9fd0d3551efe8062e907f4d52d5f2b635d51..8fb72a8f9fe643f56cfc259e5b7214b56a4fb761 100644 (file)
@@ -27,6 +27,7 @@ use bitcoin::consensus::Encodable;
 use bitcoin::hashes::sha256d::Hash as Sha256dHash;
 use bitcoin::hash_types::{Txid, BlockHash};
 use core::marker::Sized;
+use core::time::Duration;
 use ln::msgs::DecodeError;
 use ln::{PaymentPreimage, PaymentHash, PaymentSecret};
 
@@ -911,3 +912,19 @@ impl Readable for String {
                Ok(ret)
        }
 }
+
+impl Writeable for Duration {
+       #[inline]
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.as_secs().write(w)?;
+               self.subsec_nanos().write(w)
+       }
+}
+impl Readable for Duration {
+       #[inline]
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let secs = Readable::read(r)?;
+               let nanos = Readable::read(r)?;
+               Ok(Duration::new(secs, nanos))
+       }
+}
index 0827a023aedfbe141ca6df425f8c3f11e447b17f..165d1f1edbacea61438ef3426828a1195336a39e 100644 (file)
@@ -310,7 +310,7 @@ macro_rules! write_ver_prefix {
 /// correctly.
 macro_rules! write_tlv_fields {
        ($stream: expr, {$(($type: expr, $field: expr, $fieldty: tt)),* $(,)*}) => {
-               encode_varint_length_prefixed_tlv!($stream, {$(($type, $field, $fieldty)),*});
+               encode_varint_length_prefixed_tlv!($stream, {$(($type, $field, $fieldty)),*})
        }
 }
 
index 5a125929f21a8f6cd6d369d78eaf302055755bc5..4734a1cb459da27f08101be0ef5e3dd2b46e55bd 100644 (file)
@@ -21,6 +21,7 @@ use ln::features::{ChannelFeatures, InitFeatures};
 use ln::msgs;
 use ln::msgs::OptionalField;
 use ln::script::ShutdownScript;
+use routing::scorer::{Eternity, ScorerUsingTime};
 use util::enforcing_trait_impls::{EnforcingSigner, EnforcementState};
 use util::events;
 use util::logger::{Logger, Level, Record};
@@ -320,7 +321,6 @@ fn get_dummy_channel_update(short_chan_id: u64) -> msgs::ChannelUpdate {
 pub struct TestRoutingMessageHandler {
        pub chan_upds_recvd: AtomicUsize,
        pub chan_anns_recvd: AtomicUsize,
-       pub chan_anns_sent: AtomicUsize,
        pub request_full_sync: AtomicBool,
 }
 
@@ -329,7 +329,6 @@ impl TestRoutingMessageHandler {
                TestRoutingMessageHandler {
                        chan_upds_recvd: AtomicUsize::new(0),
                        chan_anns_recvd: AtomicUsize::new(0),
-                       chan_anns_sent: AtomicUsize::new(0),
                        request_full_sync: AtomicBool::new(false),
                }
        }
@@ -348,8 +347,8 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
        }
        fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(msgs::ChannelAnnouncement, Option<msgs::ChannelUpdate>, Option<msgs::ChannelUpdate>)> {
                let mut chan_anns = Vec::new();
-               const TOTAL_UPDS: u64 = 100;
-               let end: u64 = cmp::min(starting_point + batch_amount as u64, TOTAL_UPDS - self.chan_anns_sent.load(Ordering::Acquire) as u64);
+               const TOTAL_UPDS: u64 = 50;
+               let end: u64 = cmp::min(starting_point + batch_amount as u64, TOTAL_UPDS);
                for i in starting_point..end {
                        let chan_upd_1 = get_dummy_channel_update(i);
                        let chan_upd_2 = get_dummy_channel_update(i);
@@ -358,7 +357,6 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
                        chan_anns.push((chan_ann, Some(chan_upd_1), Some(chan_upd_2)));
                }
 
-               self.chan_anns_sent.fetch_add(chan_anns.len(), Ordering::AcqRel);
                chan_anns
        }
 
@@ -693,3 +691,6 @@ impl core::fmt::Debug for OnRegisterOutput {
                        .finish()
        }
 }
+
+/// A scorer useful in testing, when the passage of time isn't a concern.
+pub type TestScorer = ScorerUsingTime<Eternity>;