Merge pull request #790 from bmancini55/sync_complete
[rust-lightning] / lightning / src / routing / router.rs
index 1960c7b4e52cde08a5a59c49dd7942d6a11635d0..32084fad581ad55e232d997584251af168a853bc 100644 (file)
@@ -45,6 +45,7 @@ pub struct RouteHop {
        pub cltv_expiry_delta: u32,
 }
 
+/// (C-not exported)
 impl Writeable for Vec<RouteHop> {
        fn write<W: ::util::ser::Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
                (self.len() as u8).write(writer)?;
@@ -60,6 +61,7 @@ impl Writeable for Vec<RouteHop> {
        }
 }
 
+/// (C-not exported)
 impl Readable for Vec<RouteHop> {
        fn read<R: ::std::io::Read>(reader: &mut R) -> Result<Vec<RouteHop>, DecodeError> {
                let hops_count: u8 = Readable::read(reader)?;
@@ -113,6 +115,7 @@ impl Readable for Route {
 }
 
 /// A channel descriptor which provides a last-hop route to get_route
+#[derive(Clone)]
 pub struct RouteHint {
        /// The node_id of the non-target end of the route
        pub src_node_id: PublicKey,
@@ -236,13 +239,12 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, targ
                                        let mut total_fee = $starting_fee_msat as u64;
                                        let hm_entry = dist.entry(&$src_node_id);
                                        let old_entry = hm_entry.or_insert_with(|| {
-                                               let node = network.get_nodes().get(&$src_node_id).unwrap();
                                                let mut fee_base_msat = u32::max_value();
                                                let mut fee_proportional_millionths = u32::max_value();
-                                               if let Some(fees) = node.lowest_inbound_channel_fees {
+                                               if let Some(fees) = network.get_nodes().get(&$src_node_id).and_then(|node| node.lowest_inbound_channel_fees) {
                                                        fee_base_msat = fees.base_msat;
                                                        fee_proportional_millionths = fees.proportional_millionths;
-                                               };
+                                               }
                                                (u64::max_value(),
                                                        fee_base_msat,
                                                        fee_proportional_millionths,
@@ -341,21 +343,26 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, targ
        }
 
        for hop in last_hops.iter() {
-               if first_hops.is_none() || hop.src_node_id != *our_node_id { // first_hop overrules last_hops
-                       if network.get_nodes().get(&hop.src_node_id).is_some() {
-                               if first_hops.is_some() {
-                                       if let Some(&(ref first_hop, ref features)) = first_hop_targets.get(&hop.src_node_id) {
-                                               // Currently there are no channel-context features defined, so we are a
-                                               // bit lazy here. In the future, we should pull them out via our
-                                               // ChannelManager, but there's no reason to waste the space until we
-                                               // need them.
-                                               add_entry!(first_hop, *our_node_id , hop.src_node_id, dummy_directional_info, features.to_context(), 0);
-                                       }
-                               }
-                               // BOLT 11 doesn't allow inclusion of features for the last hop hints, which
-                               // really sucks, cause we're gonna need that eventually.
-                               add_entry!(hop.short_channel_id, hop.src_node_id, target, hop, ChannelFeatures::empty(), 0);
-                       }
+               let have_hop_src_in_graph =
+                       if let Some(&(ref first_hop, ref features)) = first_hop_targets.get(&hop.src_node_id) {
+                               // If this hop connects to a node with which we have a direct channel, ignore the
+                               // network graph and add both the hop and our direct channel to the candidate set:
+                               //
+                               // Currently there are no channel-context features defined, so we are a
+                               // bit lazy here. In the future, we should pull them out via our
+                               // ChannelManager, but there's no reason to waste the space until we
+                               // need them.
+                               add_entry!(first_hop, *our_node_id , hop.src_node_id, dummy_directional_info, features.to_context(), 0);
+                               true
+                       } else {
+                               // In any other case, only add the hop if the source is in the regular network
+                               // graph:
+                               network.get_nodes().get(&hop.src_node_id).is_some()
+                       };
+               if have_hop_src_in_graph {
+                       // BOLT 11 doesn't allow inclusion of features for the last hop hints, which
+                       // really sucks, cause we're gonna need that eventually.
+                       add_entry!(hop.short_channel_id, hop.src_node_id, target, hop, ChannelFeatures::empty(), 0);
                }
        }
 
@@ -410,9 +417,8 @@ pub fn get_route<L: Deref>(our_node_id: &PublicKey, network: &NetworkGraph, targ
 
 #[cfg(test)]
 mod tests {
-       use chain::chaininterface;
        use routing::router::{get_route, RouteHint, RoutingFees};
-       use routing::network_graph::NetGraphMsgHandler;
+       use routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
        use ln::features::{ChannelFeatures, InitFeatures, NodeFeatures};
        use ln::msgs::{ErrorAction, LightningError, OptionalField, UnsignedChannelAnnouncement, ChannelAnnouncement, RoutingMessageHandler,
           NodeAnnouncement, UnsignedNodeAnnouncement, ChannelUpdate, UnsignedChannelUpdate};
@@ -433,7 +439,7 @@ mod tests {
        use std::sync::Arc;
 
        // Using the same keys for LN and BTC ids
-       fn add_channel(net_graph_msg_handler: &NetGraphMsgHandler<Arc<chaininterface::ChainWatchInterfaceUtil>, Arc<test_utils::TestLogger>>, secp_ctx: &Secp256k1<All>, node_1_privkey: &SecretKey,
+       fn add_channel(net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>, secp_ctx: &Secp256k1<All>, node_1_privkey: &SecretKey,
           node_2_privkey: &SecretKey, features: ChannelFeatures, short_channel_id: u64) {
                let node_id_1 = PublicKey::from_secret_key(&secp_ctx, node_1_privkey);
                let node_id_2 = PublicKey::from_secret_key(&secp_ctx, node_2_privkey);
@@ -463,7 +469,7 @@ mod tests {
                };
        }
 
-       fn update_channel(net_graph_msg_handler: &NetGraphMsgHandler<Arc<chaininterface::ChainWatchInterfaceUtil>, Arc<test_utils::TestLogger>>, secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey, update: UnsignedChannelUpdate) {
+       fn update_channel(net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>, secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey, update: UnsignedChannelUpdate) {
                let msghash = hash_to_message!(&Sha256dHash::hash(&update.encode()[..])[..]);
                let valid_channel_update = ChannelUpdate {
                        signature: secp_ctx.sign(&msghash, node_privkey),
@@ -478,7 +484,7 @@ mod tests {
        }
 
 
-       fn add_or_update_node(net_graph_msg_handler: &NetGraphMsgHandler<Arc<chaininterface::ChainWatchInterfaceUtil>, Arc<test_utils::TestLogger>>, secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey,
+       fn add_or_update_node(net_graph_msg_handler: &NetGraphMsgHandler<Arc<test_utils::TestChainSource>, Arc<test_utils::TestLogger>>, secp_ctx: &Secp256k1<All>, node_privkey: &SecretKey,
           features: NodeFeatures, timestamp: u32) {
                let node_id = PublicKey::from_secret_key(&secp_ctx, node_privkey);
                let unsigned_announcement = UnsignedNodeAnnouncement {
@@ -531,11 +537,10 @@ mod tests {
                }
        }
 
-       fn build_graph() -> (Secp256k1<All>, NetGraphMsgHandler<std::sync::Arc<crate::chain::chaininterface::ChainWatchInterfaceUtil>, std::sync::Arc<crate::util::test_utils::TestLogger>>, std::sync::Arc<test_utils::TestLogger>) {
+       fn build_graph() -> (Secp256k1<All>, NetGraphMsgHandler<std::sync::Arc<crate::util::test_utils::TestChainSource>, std::sync::Arc<crate::util::test_utils::TestLogger>>, std::sync::Arc<test_utils::TestLogger>) {
                let secp_ctx = Secp256k1::new();
                let logger = Arc::new(test_utils::TestLogger::new());
-               let chain_monitor = Arc::new(chaininterface::ChainWatchInterfaceUtil::new(Network::Testnet));
-               let net_graph_msg_handler = NetGraphMsgHandler::new(chain_monitor, Arc::clone(&logger));
+               let net_graph_msg_handler = NetGraphMsgHandler::new(genesis_block(Network::Testnet).header.block_hash(), None, Arc::clone(&logger));
                // Build network from our_id to node7:
                //
                //        -1(1)2-  node0  -1(3)2-
@@ -1223,4 +1228,54 @@ mod tests {
                assert_eq!(route.paths[0][4].node_features.le_flags(), &Vec::<u8>::new()); // We dont pass flags in from invoices yet
                assert_eq!(route.paths[0][4].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
        }
+
+       #[test]
+       fn unannounced_path_test() {
+               // We should be able to send a payment to a destination without any help of a routing graph
+               // if we have a channel with a common counterparty that appears in the first and last hop
+               // hints.
+               let source_node_id = PublicKey::from_secret_key(&Secp256k1::new(), &SecretKey::from_slice(&hex::decode(format!("{:02}", 41).repeat(32)).unwrap()[..]).unwrap());
+               let middle_node_id = PublicKey::from_secret_key(&Secp256k1::new(), &SecretKey::from_slice(&hex::decode(format!("{:02}", 42).repeat(32)).unwrap()[..]).unwrap());
+               let target_node_id = PublicKey::from_secret_key(&Secp256k1::new(), &SecretKey::from_slice(&hex::decode(format!("{:02}", 43).repeat(32)).unwrap()[..]).unwrap());
+
+               // If we specify a channel to a middle hop, that overrides our local channel view and that gets used
+               let last_hops = vec![RouteHint {
+                       src_node_id: middle_node_id,
+                       short_channel_id: 8,
+                       fees: RoutingFees {
+                               base_msat: 1000,
+                               proportional_millionths: 0,
+                       },
+                       cltv_expiry_delta: (8 << 8) | 1,
+                       htlc_minimum_msat: 0,
+               }];
+               let our_chans = vec![channelmanager::ChannelDetails {
+                       channel_id: [0; 32],
+                       short_channel_id: Some(42),
+                       remote_network_id: middle_node_id,
+                       counterparty_features: InitFeatures::from_le_bytes(vec![0b11]),
+                       channel_value_satoshis: 100000,
+                       user_id: 0,
+                       outbound_capacity_msat: 100000,
+                       inbound_capacity_msat: 100000,
+                       is_live: true,
+               }];
+               let route = get_route(&source_node_id, &NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()), &target_node_id, Some(&our_chans.iter().collect::<Vec<_>>()), &last_hops.iter().collect::<Vec<_>>(), 100, 42, Arc::new(test_utils::TestLogger::new())).unwrap();
+
+               assert_eq!(route.paths[0].len(), 2);
+
+               assert_eq!(route.paths[0][0].pubkey, middle_node_id);
+               assert_eq!(route.paths[0][0].short_channel_id, 42);
+               assert_eq!(route.paths[0][0].fee_msat, 1000);
+               assert_eq!(route.paths[0][0].cltv_expiry_delta, (8 << 8) | 1);
+               assert_eq!(route.paths[0][0].node_features.le_flags(), &[0b11]);
+               assert_eq!(route.paths[0][0].channel_features.le_flags(), &[0; 0]); // We can't learn any flags from invoices, sadly
+
+               assert_eq!(route.paths[0][1].pubkey, target_node_id);
+               assert_eq!(route.paths[0][1].short_channel_id, 8);
+               assert_eq!(route.paths[0][1].fee_msat, 100);
+               assert_eq!(route.paths[0][1].cltv_expiry_delta, 42);
+               assert_eq!(route.paths[0][1].node_features.le_flags(), &[0; 0]); // We dont pass flags in from invoices yet
+               assert_eq!(route.paths[0][1].channel_features.le_flags(), &[0; 0]); // We can't learn any flags from invoices, sadly
+       }
 }