Split up generic parameters that used to comprise KeysInterface.
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index e813d012fec0085d30f0daa25f529d10a8a76d08..fabc7a5ee80900da21552cb2375b30bf7da2dcb4 100644 (file)
@@ -317,7 +317,7 @@ pub struct Node<'a, 'b: 'a, 'c: 'b> {
        pub router: &'b test_utils::TestRouter<'c>,
        pub chain_monitor: &'b test_utils::TestChainMonitor<'c>,
        pub keys_manager: &'b test_utils::TestKeysInterface,
-       pub node: &'a ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'b test_utils::TestRouter<'c>, &'c test_utils::TestLogger>,
+       pub node: &'a ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'b test_utils::TestRouter<'c>, &'c test_utils::TestLogger>,
        pub network_graph: &'a NetworkGraph<&'c test_utils::TestLogger>,
        pub gossip_sync: P2PGossipSync<&'b NetworkGraph<&'c test_utils::TestLogger>, &'c test_utils::TestChainSource, &'c test_utils::TestLogger>,
        pub node_seed: [u8; 32],
@@ -398,7 +398,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                        let mut w = test_utils::TestVecWriter(Vec::new());
                                        self.chain_monitor.chain_monitor.get_monitor(outpoint).unwrap().write(&mut w).unwrap();
                                        let (_, deserialized_monitor) = <(BlockHash, ChannelMonitor<EnforcingSigner>)>::read(
-                                               &mut io::Cursor::new(&w.0), self.keys_manager).unwrap();
+                                               &mut io::Cursor::new(&w.0), (self.keys_manager, self.keys_manager)).unwrap();
                                        deserialized_monitors.push(deserialized_monitor);
                                }
                        }
@@ -418,9 +418,11 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
 
                                let mut w = test_utils::TestVecWriter(Vec::new());
                                self.node.write(&mut w).unwrap();
-                               <(BlockHash, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>)>::read(&mut io::Cursor::new(w.0), ChannelManagerReadArgs {
+                               <(BlockHash, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>)>::read(&mut io::Cursor::new(w.0), ChannelManagerReadArgs {
                                        default_config: *self.node.get_current_default_configuration(),
-                                       keys_manager: self.keys_manager,
+                                       entropy_source: self.keys_manager,
+                                       node_signer: self.keys_manager,
+                                       signer_provider: self.keys_manager,
                                        fee_estimator: &test_utils::TestFeeEstimator { sat_per_kw: Mutex::new(253) },
                                        router: &test_utils::TestRouter::new(Arc::new(network_graph)),
                                        chain_monitor: self.chain_monitor,
@@ -558,6 +560,84 @@ macro_rules! get_htlc_update_msgs {
        }
 }
 
+/// Fetches the first `msg_event` to the passed `node_id` in the passed `msg_events` vec.
+/// Returns the `msg_event`, along with an updated `msg_events` vec with the message removed.
+///
+/// Note that even though `BroadcastChannelAnnouncement` and `BroadcastChannelUpdate`
+/// `msg_events` are stored under specific peers, this function does not fetch such `msg_events` as
+/// such messages are intended to all peers.
+pub fn remove_first_msg_event_to_node(msg_node_id: &PublicKey, msg_events: &Vec<MessageSendEvent>) -> (MessageSendEvent, Vec<MessageSendEvent>) {
+       let ev_index = msg_events.iter().position(|e| { match e {
+               MessageSendEvent::SendAcceptChannel { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendOpenChannel { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendFundingCreated { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendFundingSigned { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendChannelReady { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendAnnouncementSignatures { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::UpdateHTLCs { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendRevokeAndACK { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendClosingSigned { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendShutdown { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendChannelReestablish { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendChannelAnnouncement { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::BroadcastChannelAnnouncement { .. } => {
+                       false
+               },
+               MessageSendEvent::BroadcastChannelUpdate { .. } => {
+                       false
+               },
+               MessageSendEvent::SendChannelUpdate { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::HandleError { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendChannelRangeQuery { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendShortIdsQuery { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendReplyChannelRange { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+               MessageSendEvent::SendGossipTimestampFilter { node_id, .. } => {
+                       node_id == msg_node_id
+               },
+       }});
+       if ev_index.is_some() {
+               let mut updated_msg_events = msg_events.to_vec();
+               let ev = updated_msg_events.remove(ev_index.unwrap());
+               (ev, updated_msg_events)
+       } else {
+               panic!("Couldn't find any MessageSendEvent to the node!")
+       }
+}
+
 #[cfg(test)]
 macro_rules! get_channel_ref {
        ($node: expr, $counterparty_node: expr, $per_peer_state_lock: ident, $peer_state_lock: ident, $channel_id: expr) => {
@@ -662,12 +742,12 @@ macro_rules! check_added_monitors {
        }
 }
 
-pub fn _reload_node<'a, 'b, 'c, 'd>(node: &'a Node<'b, 'c, 'd>, default_config: UserConfig, chanman_encoded: &[u8], monitors_encoded: &[&[u8]]) -> ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'b test_utils::TestRouter<'c>, &'c test_utils::TestLogger> {
+pub fn _reload_node<'a, 'b, 'c, 'd>(node: &'a Node<'b, 'c, 'd>, default_config: UserConfig, chanman_encoded: &[u8], monitors_encoded: &[&[u8]]) -> ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'b test_utils::TestRouter<'c>, &'c test_utils::TestLogger> {
        let mut monitors_read = Vec::with_capacity(monitors_encoded.len());
        for encoded in monitors_encoded {
                let mut monitor_read = &encoded[..];
                let (_, monitor) = <(BlockHash, ChannelMonitor<EnforcingSigner>)>
-                       ::read(&mut monitor_read, node.keys_manager).unwrap();
+                       ::read(&mut monitor_read, (node.keys_manager, node.keys_manager)).unwrap();
                assert!(monitor_read.is_empty());
                monitors_read.push(monitor);
        }
@@ -678,9 +758,11 @@ pub fn _reload_node<'a, 'b, 'c, 'd>(node: &'a Node<'b, 'c, 'd>, default_config:
                for monitor in monitors_read.iter_mut() {
                        assert!(channel_monitors.insert(monitor.get_funding_txo().0, monitor).is_none());
                }
-               <(BlockHash, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>)>::read(&mut node_read, ChannelManagerReadArgs {
+               <(BlockHash, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>)>::read(&mut node_read, ChannelManagerReadArgs {
                        default_config,
-                       keys_manager: node.keys_manager,
+                       entropy_source: node.keys_manager,
+                       node_signer: node.keys_manager,
+                       signer_provider: node.keys_manager,
                        fee_estimator: node.fee_estimator,
                        router: node.router,
                        chain_monitor: node.chain_monitor,
@@ -1276,13 +1358,14 @@ macro_rules! commitment_signed_dance {
                        let (bs_revoke_and_ack, extra_msg_option) = {
                                let events = $node_b.node.get_and_clear_pending_msg_events();
                                assert!(events.len() <= 2);
-                               (match events[0] {
+                               let (node_a_event, events) = remove_first_msg_event_to_node(&$node_a.node.get_our_node_id(), &events);
+                               (match node_a_event {
                                        MessageSendEvent::SendRevokeAndACK { ref node_id, ref msg } => {
                                                assert_eq!(*node_id, $node_a.node.get_our_node_id());
                                                (*msg).clone()
                                        },
                                        _ => panic!("Unexpected event"),
-                               }, events.get(1).map(|e| e.clone()))
+                               }, events.get(0).map(|e| e.clone()))
                        };
                        check_added_monitors!($node_b, 1);
                        if $fail_backwards {
@@ -1323,7 +1406,6 @@ macro_rules! commitment_signed_dance {
                                expect_pending_htlcs_forwardable_and_htlc_handling_failed!($node_a, vec![$crate::util::events::HTLCDestination::NextHopChannel{ node_id: Some($node_b.node.get_our_node_id()), channel_id: $commitment_signed.channel_id }]);
                                check_added_monitors!($node_a, 1);
 
-                               let channel_state = $node_a.node.channel_state.lock().unwrap();
                                let node_a_per_peer_state = $node_a.node.per_peer_state.read().unwrap();
                                let mut number_of_msg_events = 0;
                                for (cp_id, peer_state_mutex) in node_a_per_peer_state.iter() {
@@ -1833,7 +1915,9 @@ pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path
 pub fn pass_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&[&Node<'a, 'b, 'c>]], recv_value: u64, our_payment_hash: PaymentHash, our_payment_secret: PaymentSecret) {
        let mut events = origin_node.node.get_and_clear_pending_msg_events();
        assert_eq!(events.len(), expected_route.len());
-       for (path_idx, (ev, expected_path)) in events.drain(..).zip(expected_route.iter()).enumerate() {
+       for (path_idx, expected_path) in expected_route.iter().enumerate() {
+               let (ev, updated_events) = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &events);
+               events = updated_events;
                // Once we've gotten through all the HTLCs, the last one should result in a
                // PaymentClaimable (but each previous one should not!), .
                let expect_payment = path_idx == expected_route.len() - 1;
@@ -1884,10 +1968,18 @@ pub fn do_claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>,
                }
        }
        let mut per_path_msgs: Vec<((msgs::UpdateFulfillHTLC, msgs::CommitmentSigned), PublicKey)> = Vec::with_capacity(expected_paths.len());
-       let events = expected_paths[0].last().unwrap().node.get_and_clear_pending_msg_events();
+       let mut events = expected_paths[0].last().unwrap().node.get_and_clear_pending_msg_events();
        assert_eq!(events.len(), expected_paths.len());
-       for ev in events.iter() {
-               per_path_msgs.push(msgs_from_ev!(ev));
+
+       if events.len() == 1 {
+               per_path_msgs.push(msgs_from_ev!(&events[0]));
+       } else {
+               for expected_path in expected_paths.iter() {
+                       // For MPP payments, we always want the message to the first node in the path.
+                       let (ev, updated_events) = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &events);
+                       per_path_msgs.push(msgs_from_ev!(&ev));
+                       events = updated_events;
+               }
        }
 
        for (expected_route, (path_msgs, next_hop)) in expected_paths.iter().zip(per_path_msgs.drain(..)) {
@@ -2217,7 +2309,7 @@ pub fn test_default_channel_config() -> UserConfig {
        default_config
 }
 
-pub fn create_node_chanmgrs<'a, 'b>(node_count: usize, cfgs: &'a Vec<NodeCfg<'b>>, node_config: &[Option<UserConfig>]) -> Vec<ChannelManager<&'a TestChainMonitor<'b>, &'b test_utils::TestBroadcaster, &'a test_utils::TestKeysInterface, &'b test_utils::TestFeeEstimator, &'a test_utils::TestRouter<'b>, &'b test_utils::TestLogger>> {
+pub fn create_node_chanmgrs<'a, 'b>(node_count: usize, cfgs: &'a Vec<NodeCfg<'b>>, node_config: &[Option<UserConfig>]) -> Vec<ChannelManager<&'a TestChainMonitor<'b>, &'b test_utils::TestBroadcaster, &'a test_utils::TestKeysInterface, &'a test_utils::TestKeysInterface, &'a test_utils::TestKeysInterface, &'b test_utils::TestFeeEstimator, &'a test_utils::TestRouter<'b>, &'b test_utils::TestLogger>> {
        let mut chanmgrs = Vec::new();
        for i in 0..node_count {
                let network = Network::Testnet;
@@ -2226,14 +2318,14 @@ pub fn create_node_chanmgrs<'a, 'b>(node_count: usize, cfgs: &'a Vec<NodeCfg<'b>
                        best_block: BestBlock::from_genesis(network),
                };
                let node = ChannelManager::new(cfgs[i].fee_estimator, &cfgs[i].chain_monitor, cfgs[i].tx_broadcaster, &cfgs[i].router, cfgs[i].logger, cfgs[i].keys_manager,
-                       if node_config[i].is_some() { node_config[i].clone().unwrap() } else { test_default_channel_config() }, params);
+                       cfgs[i].keys_manager, cfgs[i].keys_manager, if node_config[i].is_some() { node_config[i].clone().unwrap() } else { test_default_channel_config() }, params);
                chanmgrs.push(node);
        }
 
        chanmgrs
 }
 
-pub fn create_network<'a, 'b: 'a, 'c: 'b>(node_count: usize, cfgs: &'b Vec<NodeCfg<'c>>, chan_mgrs: &'a Vec<ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'c test_utils::TestRouter, &'c test_utils::TestLogger>>) -> Vec<Node<'a, 'b, 'c>> {
+pub fn create_network<'a, 'b: 'a, 'c: 'b>(node_count: usize, cfgs: &'b Vec<NodeCfg<'c>>, chan_mgrs: &'a Vec<ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'c test_utils::TestRouter, &'c test_utils::TestLogger>>) -> Vec<Node<'a, 'b, 'c>> {
        let mut nodes = Vec::new();
        let chan_count = Rc::new(RefCell::new(0));
        let payment_count = Rc::new(RefCell::new(0));