Merge pull request #1141 from TheBlueMatt/2021-10-payment-id-on-partial-failure
[rust-lightning] / lightning-invoice / src / utils.rs
index df2bbfd8f12459381d3296dc44499cffa2030d07..8da9994a3f7729ea1df9c407dd567fc811782d6a 100644 (file)
@@ -1,14 +1,21 @@
 //! Convenient utilities to create an invoice.
+
 use {Currency, DEFAULT_EXPIRY_TIME, Invoice, InvoiceBuilder, SignOrCreationError, RawInvoice};
+use payment::{Payer, Router};
+
 use bech32::ToBase32;
 use bitcoin_hashes::Hash;
 use lightning::chain;
 use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
 use lightning::chain::keysinterface::{Sign, KeysInterface};
-use lightning::ln::channelmanager::{ChannelManager, MIN_FINAL_CLTV_EXPIRY};
-use lightning::routing::network_graph::RoutingFees;
-use lightning::routing::router::{RouteHint, RouteHintHop};
+use lightning::ln::{PaymentHash, PaymentSecret};
+use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, PaymentSendFailure, MIN_FINAL_CLTV_EXPIRY};
+use lightning::ln::msgs::LightningError;
+use lightning::routing;
+use lightning::routing::network_graph::{NetworkGraph, RoutingFees};
+use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route};
 use lightning::util::logger::Logger;
+use secp256k1::key::PublicKey;
 use std::convert::TryInto;
 use std::ops::Deref;
 
@@ -89,6 +96,58 @@ where
        }
 }
 
+/// A [`Router`] implemented using [`find_route`].
+pub struct DefaultRouter<G, L: Deref> where G: Deref<Target = NetworkGraph>, L::Target: Logger {
+       network_graph: G,
+       logger: L,
+}
+
+impl<G, L: Deref> DefaultRouter<G, L> where G: Deref<Target = NetworkGraph>, L::Target: Logger {
+       /// Creates a new router using the given [`NetworkGraph`] and  [`Logger`].
+       pub fn new(network_graph: G, logger: L) -> Self {
+               Self { network_graph, logger }
+       }
+}
+
+impl<G, L: Deref, S: routing::Score> Router<S> for DefaultRouter<G, L>
+where G: Deref<Target = NetworkGraph>, L::Target: Logger {
+       fn find_route(
+               &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>,
+               scorer: &S
+       ) -> Result<Route, LightningError> {
+               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, scorer)
+       }
+}
+
+impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Payer for ChannelManager<Signer, M, T, K, F, L>
+where
+       M::Target: chain::Watch<Signer>,
+       T::Target: BroadcasterInterface,
+       K::Target: KeysInterface<Signer = Signer>,
+       F::Target: FeeEstimator,
+       L::Target: Logger,
+{
+       fn node_id(&self) -> PublicKey {
+               self.get_our_node_id()
+       }
+
+       fn first_hops(&self) -> Vec<ChannelDetails> {
+               self.list_usable_channels()
+       }
+
+       fn send_payment(
+               &self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>
+       ) -> Result<PaymentId, PaymentSendFailure> {
+               self.send_payment(route, payment_hash, payment_secret)
+       }
+
+       fn retry_payment(
+               &self, route: &Route, payment_id: PaymentId
+       ) -> Result<(), PaymentSendFailure> {
+               self.retry_payment(route, payment_id)
+       }
+}
+
 #[cfg(test)]
 mod test {
        use {Currency, Description, InvoiceDescription};
@@ -97,7 +156,8 @@ mod test {
        use lightning::ln::functional_test_utils::*;
        use lightning::ln::features::InitFeatures;
        use lightning::ln::msgs::ChannelMessageHandler;
-       use lightning::routing::router;
+       use lightning::routing::router::{Payee, RouteParameters, find_route};
+       use lightning::routing::scorer::Scorer;
        use lightning::util::events::MessageSendEventsProvider;
        use lightning::util::test_utils;
        #[test]
@@ -112,21 +172,21 @@ mod test {
                assert_eq!(invoice.min_final_cltv_expiry(), MIN_FINAL_CLTV_EXPIRY as u64);
                assert_eq!(invoice.description(), InvoiceDescription::Direct(&Description("test".to_string())));
 
-               let amt_msat = invoice.amount_pico_btc().unwrap() / 10;
+               let payee = Payee::new(invoice.recover_payee_pub_key())
+                       .with_features(invoice.features().unwrap().clone())
+                       .with_route_hints(invoice.route_hints());
+               let params = RouteParameters {
+                       payee,
+                       final_value_msat: invoice.amount_milli_satoshis().unwrap(),
+                       final_cltv_expiry_delta: invoice.min_final_cltv_expiry() as u32,
+               };
                let first_hops = nodes[0].node.list_usable_channels();
-               let last_hops = invoice.route_hints();
-               let network_graph = nodes[0].net_graph_msg_handler.network_graph.read().unwrap();
+               let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
                let logger = test_utils::TestLogger::new();
-               let route = router::get_route(
-                       &nodes[0].node.get_our_node_id(),
-                       &network_graph,
-                       &invoice.recover_payee_pub_key(),
-                       Some(invoice.features().unwrap().clone()),
-                       Some(&first_hops.iter().collect::<Vec<_>>()),
-                       &last_hops,
-                       amt_msat,
-                       invoice.min_final_cltv_expiry() as u32,
-                       &logger,
+               let scorer = Scorer::with_fixed_penalty(0);
+               let route = find_route(
+                       &nodes[0].node.get_our_node_id(), &params, network_graph,
+                       Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer,
                ).unwrap();
 
                let payment_event = {