Merge pull request #1149 from jkczyz/2021-11-network-graph
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Mon, 1 Nov 2021 22:19:08 +0000 (22:19 +0000)
committerGitHub <noreply@github.com>
Mon, 1 Nov 2021 22:19:08 +0000 (22:19 +0000)
Shared ownership of NetworkGraph

1  2 
lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs

index cdde8bd85325f9a68f5ed1ed4510c891428421c7,017591d06b83839815e960a2d398e8a5c37104bd..f7441e0f1da89ee1194e899a9e1f03b6d87f4cbf
@@@ -173,7 -173,6 +173,7 @@@ struct ClaimableHTLC 
  }
  
  /// A payment identifier used to uniquely identify a payment to LDK.
 +/// (C-not exported) as we just use [u8; 32] directly
  #[derive(Hash, Copy, Clone, PartialEq, Eq, Debug)]
  pub struct PaymentId(pub [u8; 32]);
  
@@@ -946,16 -945,7 +946,16 @@@ pub enum PaymentSendFailure 
        /// as they will result in over-/re-payment. These HTLCs all either successfully sent (in the
        /// case of Ok(())) or will send once channel_monitor_updated is called on the next-hop channel
        /// with the latest update_id.
 -      PartialFailure(Vec<Result<(), APIError>>),
 +      PartialFailure {
 +              /// The errors themselves, in the same order as the route hops.
 +              results: Vec<Result<(), APIError>>,
 +              /// If some paths failed without irrevocably committing to the new HTLC(s), this will
 +              /// contain a [`RouteParameters`] object which can be used to calculate a new route that
 +              /// will pay all remaining unpaid balance.
 +              failed_paths_retry: Option<RouteParameters>,
 +              /// The payment id for the payment, which is now at least partially pending.
 +              payment_id: PaymentId,
 +      },
  }
  
  macro_rules! handle_error {
@@@ -2245,9 -2235,7 +2245,9 @@@ impl<Signer: Sign, M: Deref, T: Deref, 
                }
                let mut has_ok = false;
                let mut has_err = false;
 -              for res in results.iter() {
 +              let mut pending_amt_unsent = 0;
 +              let mut max_unsent_cltv_delta = 0;
 +              for (res, path) in results.iter().zip(route.paths.iter()) {
                        if res.is_ok() { has_ok = true; }
                        if res.is_err() { has_err = true; }
                        if let &Err(APIError::MonitorUpdateFailed) = res {
                                // PartialFailure.
                                has_err = true;
                                has_ok = true;
 -                              break;
 +                      } else if res.is_err() {
 +                              pending_amt_unsent += path.last().unwrap().fee_msat;
 +                              max_unsent_cltv_delta = cmp::max(max_unsent_cltv_delta, path.last().unwrap().cltv_expiry_delta);
                        }
                }
                if has_err && has_ok {
 -                      Err(PaymentSendFailure::PartialFailure(results))
 +                      Err(PaymentSendFailure::PartialFailure {
 +                              results,
 +                              payment_id,
 +                              failed_paths_retry: if pending_amt_unsent != 0 {
 +                                      if let Some(payee) = &route.payee {
 +                                              Some(RouteParameters {
 +                                                      payee: payee.clone(),
 +                                                      final_value_msat: pending_amt_unsent,
 +                                                      final_cltv_expiry_delta: max_unsent_cltv_delta,
 +                                              })
 +                                      } else { None }
 +                              } else { None },
 +                      })
                } else if has_err {
                        Err(PaymentSendFailure::AllFailedRetrySafe(results.drain(..).map(|r| r.unwrap_err()).collect()))
                } else {
@@@ -6338,8 -6312,8 +6338,8 @@@ mod tests 
                        final_cltv_expiry_delta: TEST_FINAL_CLTV,
                };
                let route = find_route(
-                       &nodes[0].node.get_our_node_id(), &params,
-                       &nodes[0].net_graph_msg_handler.network_graph, None, nodes[0].logger, &scorer
+                       &nodes[0].node.get_our_node_id(), &params, nodes[0].network_graph, None,
+                       nodes[0].logger, &scorer
                ).unwrap();
                nodes[0].node.send_spontaneous_payment(&route, Some(payment_preimage)).unwrap();
                check_added_monitors!(nodes[0], 1);
                // To start (2), send a keysend payment but don't claim it.
                let payment_preimage = PaymentPreimage([42; 32]);
                let route = find_route(
-                       &nodes[0].node.get_our_node_id(), &params,
-                       &nodes[0].net_graph_msg_handler.network_graph, None, nodes[0].logger, &scorer
+                       &nodes[0].node.get_our_node_id(), &params, nodes[0].network_graph, None,
+                       nodes[0].logger, &scorer
                ).unwrap();
                let (payment_hash, _) = nodes[0].node.send_spontaneous_payment(&route, Some(payment_preimage)).unwrap();
                check_added_monitors!(nodes[0], 1);
                        final_value_msat: 10000,
                        final_cltv_expiry_delta: 40,
                };
-               let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
+               let network_graph = nodes[0].network_graph;
                let first_hops = nodes[0].node.list_usable_channels();
                let scorer = Scorer::with_fixed_penalty(0);
                let route = find_route(
                        final_value_msat: 10000,
                        final_cltv_expiry_delta: 40,
                };
-               let network_graph = &nodes[0].net_graph_msg_handler.network_graph;
+               let network_graph = nodes[0].network_graph;
                let first_hops = nodes[0].node.list_usable_channels();
                let scorer = Scorer::with_fixed_penalty(0);
                let route = find_route(
index 14bc6beec5fb14a2d42981f2ce06a936b1dea9b2,139173d258bd083e6fac7059ac52a4e7119958ff..f605a63abd462fbe3ec3a14265d96d6ad5592542
@@@ -190,6 -190,7 +190,7 @@@ pub struct TestChanMonCfg 
        pub persister: test_utils::TestPersister,
        pub logger: test_utils::TestLogger,
        pub keys_manager: test_utils::TestKeysInterface,
+       pub network_graph: NetworkGraph,
  }
  
  pub struct NodeCfg<'a> {
        pub chain_monitor: test_utils::TestChainMonitor<'a>,
        pub keys_manager: &'a test_utils::TestKeysInterface,
        pub logger: &'a test_utils::TestLogger,
+       pub network_graph: &'a NetworkGraph,
        pub node_seed: [u8; 32],
        pub features: InitFeatures,
  }
@@@ -209,7 -211,8 +211,8 @@@ pub struct Node<'a, 'b: 'a, 'c: 'b> 
        pub chain_monitor: &'b test_utils::TestChainMonitor<'c>,
        pub keys_manager: &'b test_utils::TestKeysInterface,
        pub node: &'a ChannelManager<EnforcingSigner, &'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'c test_utils::TestLogger>,
-       pub net_graph_msg_handler: NetGraphMsgHandler<&'c test_utils::TestChainSource, &'c test_utils::TestLogger>,
+       pub network_graph: &'c NetworkGraph,
+       pub net_graph_msg_handler: NetGraphMsgHandler<&'c NetworkGraph, &'c test_utils::TestChainSource, &'c test_utils::TestLogger>,
        pub node_seed: [u8; 32],
        pub network_payment_count: Rc<RefCell<u8>>,
        pub network_chan_count: Rc<RefCell<u32>>,
@@@ -240,12 -243,11 +243,11 @@@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 
                        // Check that if we serialize the Router, we can deserialize it again.
                        {
                                let mut w = test_utils::TestVecWriter(Vec::new());
-                               let network_graph_ser = &self.net_graph_msg_handler.network_graph;
-                               network_graph_ser.write(&mut w).unwrap();
+                               self.network_graph.write(&mut w).unwrap();
                                let network_graph_deser = <NetworkGraph>::read(&mut io::Cursor::new(&w.0)).unwrap();
-                               assert!(network_graph_deser == self.net_graph_msg_handler.network_graph);
+                               assert!(network_graph_deser == *self.network_graph);
                                let net_graph_msg_handler = NetGraphMsgHandler::new(
-                                       network_graph_deser, Some(self.chain_source), self.logger
+                                       &network_graph_deser, Some(self.chain_source), self.logger
                                );
                                let mut chan_progress = 0;
                                loop {
@@@ -482,9 -484,9 +484,9 @@@ macro_rules! unwrap_send_err 
                                        _ => panic!(),
                                }
                        },
 -                      &Err(PaymentSendFailure::PartialFailure(ref fails)) if !$all_failed => {
 -                              assert_eq!(fails.len(), 1);
 -                              match fails[0] {
 +                      &Err(PaymentSendFailure::PartialFailure { ref results, .. }) if !$all_failed => {
 +                              assert_eq!(results.len(), 1);
 +                              match results[0] {
                                        Err($type) => { $check },
                                        _ => panic!(),
                                }
@@@ -1014,10 -1016,9 +1016,9 @@@ macro_rules! get_route_and_payment_has
                let payee = $crate::routing::router::Payee::new($recv_node.node.get_our_node_id())
                        .with_features($crate::ln::features::InvoiceFeatures::known())
                        .with_route_hints($last_hops);
-               let net_graph_msg_handler = &$send_node.net_graph_msg_handler;
                let scorer = ::routing::scorer::Scorer::with_fixed_penalty(0);
                let route = ::routing::router::get_route(
-                       &$send_node.node.get_our_node_id(), &payee, &net_graph_msg_handler.network_graph,
+                       &$send_node.node.get_our_node_id(), &payee, $send_node.network_graph,
                        Some(&$send_node.node.list_usable_channels().iter().collect::<Vec<_>>()),
                        $recv_value, $cltv, $send_node.logger, &scorer
                ).unwrap();
@@@ -1352,10 -1353,9 +1353,9 @@@ pub const TEST_FINAL_CLTV: u32 = 70
  pub fn route_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64) -> (PaymentPreimage, PaymentHash, PaymentSecret) {
        let payee = Payee::new(expected_route.last().unwrap().node.get_our_node_id())
                .with_features(InvoiceFeatures::known());
-       let net_graph_msg_handler = &origin_node.net_graph_msg_handler;
        let scorer = Scorer::with_fixed_penalty(0);
        let route = get_route(
-               &origin_node.node.get_our_node_id(), &payee, &net_graph_msg_handler.network_graph,
+               &origin_node.node.get_our_node_id(), &payee, &origin_node.network_graph,
                Some(&origin_node.node.list_usable_channels().iter().collect::<Vec<_>>()),
                recv_value, TEST_FINAL_CLTV, origin_node.logger, &scorer).unwrap();
        assert_eq!(route.paths.len(), 1);
  pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], recv_value: u64)  {
        let payee = Payee::new(expected_route.last().unwrap().node.get_our_node_id())
                .with_features(InvoiceFeatures::known());
-       let net_graph_msg_handler = &origin_node.net_graph_msg_handler;
        let scorer = Scorer::with_fixed_penalty(0);
-       let route = get_route(&origin_node.node.get_our_node_id(), &payee, &net_graph_msg_handler.network_graph, None, recv_value, TEST_FINAL_CLTV, origin_node.logger, &scorer).unwrap();
+       let route = get_route(&origin_node.node.get_our_node_id(), &payee, origin_node.network_graph, None, recv_value, TEST_FINAL_CLTV, origin_node.logger, &scorer).unwrap();
        assert_eq!(route.paths.len(), 1);
        assert_eq!(route.paths[0].len(), expected_route.len());
        for (node, hop) in expected_route.iter().zip(route.paths[0].iter()) {
@@@ -1499,8 -1498,9 +1498,9 @@@ pub fn create_chanmon_cfgs(node_count: 
                let persister = test_utils::TestPersister::new();
                let seed = [i as u8; 32];
                let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
+               let network_graph = NetworkGraph::new(chain_source.genesis_hash);
  
-               chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger, persister, keys_manager });
+               chan_mon_cfgs.push(TestChanMonCfg{ tx_broadcaster, fee_estimator, chain_source, logger, persister, keys_manager, network_graph });
        }
  
        chan_mon_cfgs
@@@ -1521,6 -1521,7 +1521,7 @@@ pub fn create_node_cfgs<'a>(node_count
                        keys_manager: &chanmon_cfgs[i].keys_manager,
                        node_seed: seed,
                        features: InitFeatures::known(),
+                       network_graph: &chanmon_cfgs[i].network_graph,
                });
        }
  
@@@ -1566,15 -1567,15 +1567,15 @@@ pub fn create_network<'a, 'b: 'a, 'c: '
        let connect_style = Rc::new(RefCell::new(ConnectStyle::FullBlockViaListen));
  
        for i in 0..node_count {
-               let network_graph = NetworkGraph::new(cfgs[i].chain_source.genesis_hash);
-               let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, cfgs[i].logger);
-               nodes.push(Node{ chain_source: cfgs[i].chain_source,
-                                tx_broadcaster: cfgs[i].tx_broadcaster, chain_monitor: &cfgs[i].chain_monitor,
-                                keys_manager: &cfgs[i].keys_manager, node: &chan_mgrs[i], net_graph_msg_handler,
-                                node_seed: cfgs[i].node_seed, network_chan_count: chan_count.clone(),
-                                network_payment_count: payment_count.clone(), logger: cfgs[i].logger,
-                                blocks: Arc::clone(&cfgs[i].tx_broadcaster.blocks),
-                                connect_style: Rc::clone(&connect_style),
+               let net_graph_msg_handler = NetGraphMsgHandler::new(cfgs[i].network_graph, None, cfgs[i].logger);
+               nodes.push(Node{
+                       chain_source: cfgs[i].chain_source, tx_broadcaster: cfgs[i].tx_broadcaster,
+                       chain_monitor: &cfgs[i].chain_monitor, keys_manager: &cfgs[i].keys_manager,
+                       node: &chan_mgrs[i], network_graph: &cfgs[i].network_graph, net_graph_msg_handler,
+                       node_seed: cfgs[i].node_seed, network_chan_count: chan_count.clone(),
+                       network_payment_count: payment_count.clone(), logger: cfgs[i].logger,
+                       blocks: Arc::clone(&cfgs[i].tx_broadcaster.blocks),
+                       connect_style: Rc::clone(&connect_style),
                })
        }