Add a random per-path CLTV offset for privacy.
[rust-lightning] / lightning-invoice / src / utils.rs
index 78a092e24cc53de30e1f2cd92740625053375a05..c2bd4b49c6b2091e3c381a8a63017c3741fc9774 100644 (file)
@@ -4,7 +4,7 @@ use {CreationError, Currency, DEFAULT_EXPIRY_TIME, Invoice, InvoiceBuilder, Sign
 use payment::{Payer, Router};
 
 use bech32::ToBase32;
-use bitcoin_hashes::Hash;
+use bitcoin_hashes::{Hash, sha256};
 use crate::prelude::*;
 use lightning::chain;
 use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
@@ -22,6 +22,7 @@ use secp256k1::key::PublicKey;
 use core::convert::TryInto;
 use core::ops::Deref;
 use core::time::Duration;
+use sync::Mutex;
 
 #[cfg(feature = "std")]
 /// Utility to create an invoice that can be paid to one of multiple nodes, or a "phantom invoice."
@@ -224,12 +225,15 @@ where
 pub struct DefaultRouter<G: Deref<Target = NetworkGraph>, L: Deref> where L::Target: Logger {
        network_graph: G,
        logger: L,
+       random_seed_bytes: Mutex<[u8; 32]>,
 }
 
 impl<G: Deref<Target = NetworkGraph>, L: Deref> DefaultRouter<G, L> where L::Target: Logger {
-       /// Creates a new router using the given [`NetworkGraph`] and  [`Logger`].
-       pub fn new(network_graph: G, logger: L) -> Self {
-               Self { network_graph, logger }
+       /// Creates a new router using the given [`NetworkGraph`], a [`Logger`], and a randomness source
+       /// `random_seed_bytes`.
+       pub fn new(network_graph: G, logger: L, random_seed_bytes: [u8; 32]) -> Self {
+               let random_seed_bytes = Mutex::new(random_seed_bytes);
+               Self { network_graph, logger, random_seed_bytes }
        }
 }
 
@@ -239,7 +243,12 @@ where L::Target: Logger {
                &self, payer: &PublicKey, params: &RouteParameters, _payment_hash: &PaymentHash,
                first_hops: Option<&[&ChannelDetails]>, scorer: &S
        ) -> Result<Route, LightningError> {
-               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, scorer)
+               let random_seed_bytes = {
+                       let mut locked_random_seed_bytes = self.random_seed_bytes.lock().unwrap();
+                       *locked_random_seed_bytes = sha256::Hash::hash(&*locked_random_seed_bytes).into_inner();
+                       *locked_random_seed_bytes
+               };
+               find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, scorer, &random_seed_bytes)
        }
 }
 
@@ -299,6 +308,7 @@ mod test {
        use lightning::util::enforcing_trait_impls::EnforcingSigner;
        use lightning::util::events::{MessageSendEvent, MessageSendEventsProvider, Event};
        use lightning::util::test_utils;
+       use lightning::chain::keysinterface::KeysInterface;
        use utils::create_invoice_from_channelmanager_and_duration_since_epoch;
 
        #[test]
@@ -327,9 +337,10 @@ mod test {
                let network_graph = node_cfgs[0].network_graph;
                let logger = test_utils::TestLogger::new();
                let scorer = test_utils::TestScorer::with_penalty(0);
+               let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
                let route = find_route(
                        &nodes[0].node.get_our_node_id(), &route_params, network_graph,
-                       Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer,
+                       Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer, &random_seed_bytes
                ).unwrap();
 
                let payment_event = {
@@ -415,9 +426,10 @@ mod test {
                let network_graph = node_cfgs[0].network_graph;
                let logger = test_utils::TestLogger::new();
                let scorer = test_utils::TestScorer::with_penalty(0);
+               let random_seed_bytes = chanmon_cfgs[1].keys_manager.get_secure_random_bytes();
                let route = find_route(
                        &nodes[0].node.get_our_node_id(), &params, network_graph,
-                       Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer,
+                       Some(&first_hops.iter().collect::<Vec<_>>()), &logger, &scorer, &random_seed_bytes
                ).unwrap();
                let (payment_event, fwd_idx) = {
                        let mut payment_hash = PaymentHash([0; 32]);