Merge pull request #1061 from sr-gi/add-transaction-convert
[rust-lightning] / lightning / src / routing / router.rs
index 6d6d603c56a377383a64f268b44634c190f4fdf9..e60fd1ee1b3f63b72c733491496cda6926028be3 100644 (file)
@@ -71,6 +71,27 @@ pub struct Route {
        pub paths: Vec<Vec<RouteHop>>,
 }
 
+impl Route {
+       /// Returns the total amount of fees paid on this [`Route`].
+       ///
+       /// This doesn't include any extra payment made to the recipient, which can happen in excess of
+       /// the amount passed to [`get_route`]'s `final_value_msat`.
+       pub fn get_total_fees(&self) -> u64 {
+               // Do not count last hop of each path since that's the full value of the payment
+               return self.paths.iter()
+                       .flat_map(|path| path.split_last().map(|(_, path_prefix)| path_prefix).unwrap_or(&[]))
+                       .map(|hop| &hop.fee_msat)
+                       .sum();
+       }
+
+       /// Returns the total amount paid on this [`Route`], excluding the fees.
+       pub fn get_total_amount(&self) -> u64 {
+               return self.paths.iter()
+                       .map(|path| path.split_last().map(|(hop, _)| hop.fee_msat).unwrap_or(0))
+                       .sum();
+       }
+}
+
 const SERIALIZATION_VERSION: u8 = 1;
 const MIN_SERIALIZATION_VERSION: u8 = 1;
 
@@ -443,8 +464,9 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
        // to use as the A* heuristic beyond just the cost to get one node further than the current
        // one.
 
-       let network_channels = network.get_channels();
-       let network_nodes = network.get_nodes();
+       let network_graph = network.read_only();
+       let network_channels = network_graph.channels();
+       let network_nodes = network_graph.nodes();
        let dummy_directional_info = DummyDirectionalChannelInfo { // used for first_hops routes
                cltv_expiry_delta: 0,
                htlc_minimum_msat: 0,
@@ -1244,7 +1266,7 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, paye
 
 #[cfg(test)]
 mod tests {
-       use routing::router::{get_route, Route, RouteHint, RouteHintHop, RoutingFees};
+       use routing::router::{get_route, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees};
        use routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
        use chain::transaction::OutPoint;
        use ln::features::{ChannelFeatures, InitFeatures, InvoiceFeatures, NodeFeatures};
@@ -1295,8 +1317,10 @@ mod tests {
        }
 
        // Using the same keys for LN and BTC ids
-       fn add_channel(net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>, secp_ctx: &Secp256k1<All>, node_1_privkey: &SecretKey,
-          node_2_privkey: &SecretKey, features: ChannelFeatures, short_channel_id: u64) {
+       fn add_channel(
+               net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
+               secp_ctx: &Secp256k1<All>, node_1_privkey: &SecretKey, node_2_privkey: &SecretKey, features: ChannelFeatures, short_channel_id: u64
+       ) {
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_1_privkey);
                let node_id_2 = PublicKey::from_secret_key(&secp_ctx, node_2_privkey);
 
@@ -1325,7 +1349,10 @@ mod tests {
                };
        }
 
-       fn update_channel(net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>, secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey, update: UnsignedChannelUpdate) {
+       fn update_channel(
+               net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
+               secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey, update: UnsignedChannelUpdate
+       ) {
                let msghash = hash_to_message!(&Sha256dHash::hash(&update.encode()[..])[..]);
                let valid_channel_update = ChannelUpdate {
                        signature: secp_ctx.sign(&msghash, node_privkey),
@@ -1338,8 +1365,10 @@ mod tests {
                };
        }
 
-       fn add_or_update_node(net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>, secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey,
-          features: NodeFeatures, timestamp: u32) {
+       fn add_or_update_node(
+               net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>,
+               secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey, features: NodeFeatures, timestamp: u32
+       ) {
                let node_id = PublicKey::from_secret_key(&secp_ctx, node_privkey);
                let unsigned_announcement = UnsignedNodeAnnouncement {
                        features,
@@ -1395,7 +1424,8 @@ mod tests {
                let secp_ctx = Secp256k1::new();
                let logger = Arc::new(test_utils::TestLogger::new());
                let chain_monitor = Arc::new(test_utils::TestChainSource::new(Network::Testnet));
-               let net_graph_msg_handler = NetGraphMsgHandler::new(genesis_block(Network::Testnet).header.block_hash(), None, Arc::clone(&logger));
+               let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
+               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger));
                // Build network from our_id to node6:
                //
                //        -1(1)2-  node0  -1(3)2-
@@ -3782,6 +3812,7 @@ mod tests {
                                total_amount_paid_msat += path.last().unwrap().fee_msat;
                        }
                        assert_eq!(total_amount_paid_msat, 200_000);
+                       assert_eq!(route.get_total_fees(), 150_000);
                }
 
        }
@@ -3946,7 +3977,8 @@ mod tests {
                // "previous hop" being set to node 3, creating a loop in the path.
                let secp_ctx = Secp256k1::new();
                let logger = Arc::new(test_utils::TestLogger::new());
-               let net_graph_msg_handler = NetGraphMsgHandler::new(genesis_block(Network::Testnet).header.block_hash(), None, Arc::clone(&logger));
+               let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash());
+               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger));
                let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx);
 
                add_channel(&net_graph_msg_handler, &secp_ctx, &our_privkey, &privkeys[1], ChannelFeatures::from_le_bytes(id_to_feature_flags(6)), 6);
@@ -4188,6 +4220,75 @@ mod tests {
                }
        }
 
+       #[test]
+       fn total_fees_single_path() {
+               let route = Route {
+                       paths: vec![vec![
+                               RouteHop {
+                                       pubkey: PublicKey::from_slice(&hex::decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619").unwrap()[..]).unwrap(),
+                                       channel_features: ChannelFeatures::empty(), node_features: NodeFeatures::empty(),
+                                       short_channel_id: 0, fee_msat: 100, cltv_expiry_delta: 0
+                               },
+                               RouteHop {
+                                       pubkey: PublicKey::from_slice(&hex::decode("0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c").unwrap()[..]).unwrap(),
+                                       channel_features: ChannelFeatures::empty(), node_features: NodeFeatures::empty(),
+                                       short_channel_id: 0, fee_msat: 150, cltv_expiry_delta: 0
+                               },
+                               RouteHop {
+                                       pubkey: PublicKey::from_slice(&hex::decode("027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007").unwrap()[..]).unwrap(),
+                                       channel_features: ChannelFeatures::empty(), node_features: NodeFeatures::empty(),
+                                       short_channel_id: 0, fee_msat: 225, cltv_expiry_delta: 0
+                               },
+                       ]],
+               };
+
+               assert_eq!(route.get_total_fees(), 250);
+               assert_eq!(route.get_total_amount(), 225);
+       }
+
+       #[test]
+       fn total_fees_multi_path() {
+               let route = Route {
+                       paths: vec![vec![
+                               RouteHop {
+                                       pubkey: PublicKey::from_slice(&hex::decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619").unwrap()[..]).unwrap(),
+                                       channel_features: ChannelFeatures::empty(), node_features: NodeFeatures::empty(),
+                                       short_channel_id: 0, fee_msat: 100, cltv_expiry_delta: 0
+                               },
+                               RouteHop {
+                                       pubkey: PublicKey::from_slice(&hex::decode("0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c").unwrap()[..]).unwrap(),
+                                       channel_features: ChannelFeatures::empty(), node_features: NodeFeatures::empty(),
+                                       short_channel_id: 0, fee_msat: 150, cltv_expiry_delta: 0
+                               },
+                       ],vec![
+                               RouteHop {
+                                       pubkey: PublicKey::from_slice(&hex::decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619").unwrap()[..]).unwrap(),
+                                       channel_features: ChannelFeatures::empty(), node_features: NodeFeatures::empty(),
+                                       short_channel_id: 0, fee_msat: 100, cltv_expiry_delta: 0
+                               },
+                               RouteHop {
+                                       pubkey: PublicKey::from_slice(&hex::decode("0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c").unwrap()[..]).unwrap(),
+                                       channel_features: ChannelFeatures::empty(), node_features: NodeFeatures::empty(),
+                                       short_channel_id: 0, fee_msat: 150, cltv_expiry_delta: 0
+                               },
+                       ]],
+               };
+
+               assert_eq!(route.get_total_fees(), 200);
+               assert_eq!(route.get_total_amount(), 300);
+       }
+
+       #[test]
+       fn total_empty_route_no_panic() {
+               // In an earlier version of `Route::get_total_fees` and `Route::get_total_amount`, they
+               // would both panic if the route was completely empty. We test to ensure they return 0
+               // here, even though its somewhat nonsensical as a route.
+               let route = Route { paths: Vec::new() };
+
+               assert_eq!(route.get_total_fees(), 0);
+               assert_eq!(route.get_total_amount(), 0);
+       }
+
        #[cfg(not(feature = "no-std"))]
        pub(super) fn random_init_seed() -> u64 {
                // Because the default HashMap in std pulls OS randomness, we can use it as a (bad) RNG.
@@ -4213,7 +4314,7 @@ mod tests {
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut seed = random_init_seed() as usize;
-               let nodes = graph.get_nodes();
+               let nodes = graph.read_only().nodes().clone();
                'load_endpoints: for _ in 0..10 {
                        loop {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
@@ -4242,7 +4343,7 @@ mod tests {
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut seed = random_init_seed() as usize;
-               let nodes = graph.get_nodes();
+               let nodes = graph.read_only().nodes().clone();
                'load_endpoints: for _ in 0..10 {
                        loop {
                                seed = seed.overflowing_mul(0xdeadbeef).0;
@@ -4301,7 +4402,7 @@ mod benches {
        fn generate_routes(bench: &mut Bencher) {
                let mut d = test_utils::get_route_file().unwrap();
                let graph = NetworkGraph::read(&mut d).unwrap();
-               let nodes = graph.get_nodes();
+               let nodes = graph.read_only().nodes().clone();
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut path_endpoints = Vec::new();
@@ -4333,7 +4434,7 @@ mod benches {
        fn generate_mpp_routes(bench: &mut Bencher) {
                let mut d = test_utils::get_route_file().unwrap();
                let graph = NetworkGraph::read(&mut d).unwrap();
-               let nodes = graph.get_nodes();
+               let nodes = graph.read_only().nodes().clone();
 
                // First, get 100 (source, destination) pairs for which route-getting actually succeeds...
                let mut path_endpoints = Vec::new();