Add test of an initial message other than `Init`
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index c2ce158f897b97e4b7573bf71ad9486ae760f5d6..9f888848fe91db7d2e158ca1bdf922af2dfa729d 100644 (file)
@@ -305,6 +305,7 @@ pub struct TestChanMonCfg {
        pub persister: test_utils::TestPersister,
        pub logger: test_utils::TestLogger,
        pub keys_manager: test_utils::TestKeysInterface,
+       pub scorer: Mutex<test_utils::TestScorer>,
 }
 
 pub struct NodeCfg<'a> {
@@ -427,6 +428,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                        channel_monitors.insert(monitor.get_funding_txo().0, monitor);
                                }
 
+                               let scorer = Mutex::new(test_utils::TestScorer::new());
                                let mut w = test_utils::TestVecWriter(Vec::new());
                                self.node.write(&mut w).unwrap();
                                <(BlockHash, ChannelManager<&test_utils::TestChainMonitor, &test_utils::TestBroadcaster, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestKeysInterface, &test_utils::TestFeeEstimator, &test_utils::TestRouter, &test_utils::TestLogger>)>::read(&mut io::Cursor::new(w.0), ChannelManagerReadArgs {
@@ -435,7 +437,7 @@ impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
                                        node_signer: self.keys_manager,
                                        signer_provider: self.keys_manager,
                                        fee_estimator: &test_utils::TestFeeEstimator { sat_per_kw: Mutex::new(253) },
-                                       router: &test_utils::TestRouter::new(Arc::new(network_graph)),
+                                       router: &test_utils::TestRouter::new(Arc::new(network_graph), &scorer),
                                        chain_monitor: self.chain_monitor,
                                        tx_broadcaster: &broadcaster,
                                        logger: &self.logger,
@@ -572,12 +574,12 @@ macro_rules! get_htlc_update_msgs {
 }
 
 /// Fetches the first `msg_event` to the passed `node_id` in the passed `msg_events` vec.
-/// Returns the `msg_event`, along with an updated `msg_events` vec with the message removed.
+/// Returns the `msg_event`.
 ///
 /// Note that even though `BroadcastChannelAnnouncement` and `BroadcastChannelUpdate`
 /// `msg_events` are stored under specific peers, this function does not fetch such `msg_events` as
 /// such messages are intended to all peers.
-pub fn remove_first_msg_event_to_node(msg_node_id: &PublicKey, msg_events: &Vec<MessageSendEvent>) -> (MessageSendEvent, Vec<MessageSendEvent>) {
+pub fn remove_first_msg_event_to_node(msg_node_id: &PublicKey, msg_events: &mut Vec<MessageSendEvent>) -> MessageSendEvent {
        let ev_index = msg_events.iter().position(|e| { match e {
                MessageSendEvent::SendAcceptChannel { node_id, .. } => {
                        node_id == msg_node_id
@@ -621,6 +623,9 @@ pub fn remove_first_msg_event_to_node(msg_node_id: &PublicKey, msg_events: &Vec<
                MessageSendEvent::BroadcastChannelUpdate { .. } => {
                        false
                },
+               MessageSendEvent::BroadcastNodeAnnouncement { .. } => {
+                       false
+               },
                MessageSendEvent::SendChannelUpdate { node_id, .. } => {
                        node_id == msg_node_id
                },
@@ -641,9 +646,7 @@ pub fn remove_first_msg_event_to_node(msg_node_id: &PublicKey, msg_events: &Vec<
                },
        }});
        if ev_index.is_some() {
-               let mut updated_msg_events = msg_events.to_vec();
-               let ev = updated_msg_events.remove(ev_index.unwrap());
-               (ev, updated_msg_events)
+               msg_events.remove(ev_index.unwrap())
        } else {
                panic!("Couldn't find any MessageSendEvent to the node!")
        }
@@ -1010,7 +1013,7 @@ pub fn create_chan_between_nodes_with_value_b<'a, 'b, 'c>(node_a: &Node<'a, 'b,
        assert_eq!(events_7.len(), 1);
        let (announcement, bs_update) = match events_7[0] {
                MessageSendEvent::BroadcastChannelAnnouncement { ref msg, ref update_msg } => {
-                       (msg, update_msg)
+                       (msg, update_msg.clone().unwrap())
                },
                _ => panic!("Unexpected event"),
        };
@@ -1021,6 +1024,7 @@ pub fn create_chan_between_nodes_with_value_b<'a, 'b, 'c>(node_a: &Node<'a, 'b,
        let as_update = match events_8[0] {
                MessageSendEvent::BroadcastChannelAnnouncement { ref msg, ref update_msg } => {
                        assert!(*announcement == *msg);
+                       let update_msg = update_msg.clone().unwrap();
                        assert_eq!(update_msg.contents.short_channel_id, announcement.contents.short_channel_id);
                        assert_eq!(update_msg.contents.short_channel_id, bs_update.contents.short_channel_id);
                        update_msg
@@ -1031,7 +1035,7 @@ pub fn create_chan_between_nodes_with_value_b<'a, 'b, 'c>(node_a: &Node<'a, 'b,
        *node_a.network_chan_count.borrow_mut() += 1;
 
        expect_channel_ready_event(&node_b, &node_a.node.get_our_node_id());
-       ((*announcement).clone(), (*as_update).clone(), (*bs_update).clone())
+       ((*announcement).clone(), as_update, bs_update)
 }
 
 pub fn create_announced_chan_between_nodes<'a, 'b, 'c, 'd>(nodes: &'a Vec<Node<'b, 'c, 'd>>, a: usize, b: usize) -> (msgs::ChannelUpdate, msgs::ChannelUpdate, [u8; 32], Transaction) {
@@ -1482,9 +1486,9 @@ pub fn do_main_commitment_signed_dance(node_a: &Node<'_, '_, '_>, node_b: &Node<
        check_added_monitors!(node_b, 1);
        node_b.node.handle_commitment_signed(&node_a.node.get_our_node_id(), &as_commitment_signed);
        let (bs_revoke_and_ack, extra_msg_option) = {
-               let events = node_b.node.get_and_clear_pending_msg_events();
+               let mut events = node_b.node.get_and_clear_pending_msg_events();
                assert!(events.len() <= 2);
-               let (node_a_event, events) = remove_first_msg_event_to_node(&node_a.node.get_our_node_id(), &events);
+               let node_a_event = remove_first_msg_event_to_node(&node_a.node.get_our_node_id(), &mut events);
                (match node_a_event {
                        MessageSendEvent::SendRevokeAndACK { ref node_id, ref msg } => {
                                assert_eq!(*node_id, node_a.node.get_our_node_id());
@@ -1566,7 +1570,7 @@ macro_rules! get_payment_preimage_hash {
 macro_rules! get_route {
        ($send_node: expr, $payment_params: expr, $recv_value: expr, $cltv: expr) => {{
                use $crate::chain::keysinterface::EntropySource;
-               let scorer = $crate::util::test_utils::TestScorer::with_penalty(0);
+               let scorer = $crate::util::test_utils::TestScorer::new();
                let keys_manager = $crate::util::test_utils::TestKeysInterface::new(&[0u8; 32], bitcoin::network::constants::Network::Testnet);
                let random_seed_bytes = keys_manager.get_secure_random_bytes();
                $crate::routing::router::get_route(
@@ -1939,8 +1943,7 @@ pub fn pass_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_rou
        let mut events = origin_node.node.get_and_clear_pending_msg_events();
        assert_eq!(events.len(), expected_route.len());
        for (path_idx, expected_path) in expected_route.iter().enumerate() {
-               let (ev, updated_events) = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &events);
-               events = updated_events;
+               let ev = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &mut events);
                // Once we've gotten through all the HTLCs, the last one should result in a
                // PaymentClaimable (but each previous one should not!), .
                let expect_payment = path_idx == expected_route.len() - 1;
@@ -1999,9 +2002,8 @@ pub fn do_claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>,
        } else {
                for expected_path in expected_paths.iter() {
                        // For MPP payments, we always want the message to the first node in the path.
-                       let (ev, updated_events) = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &events);
+                       let ev = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &mut events);
                        per_path_msgs.push(msgs_from_ev!(&ev));
-                       events = updated_events;
                }
        }
 
@@ -2120,7 +2122,7 @@ pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_rou
        let payment_params = PaymentParameters::from_node_id(expected_route.last().unwrap().node.get_our_node_id(), TEST_FINAL_CLTV)
                .with_features(expected_route.last().unwrap().node.invoice_features());
        let network_graph = origin_node.network_graph.read_only();
-       let scorer = test_utils::TestScorer::with_penalty(0);
+       let scorer = test_utils::TestScorer::new();
        let seed = [0u8; 32];
        let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
        let random_seed_bytes = keys_manager.get_secure_random_bytes();
@@ -2285,8 +2287,9 @@ pub fn create_chanmon_cfgs(node_count: usize) -> Vec<TestChanMonCfg> {
                let persister = test_utils::TestPersister::new();
                let seed = [i as u8; 32];
                let keys_manager = test_utils::TestKeysInterface::new(&seed, Network::Testnet);
+               let scorer = Mutex::new(test_utils::TestScorer::new());
 
-               chan_mon_cfgs.push(TestChanMonCfg { tx_broadcaster, fee_estimator, chain_source, logger, persister, keys_manager });
+               chan_mon_cfgs.push(TestChanMonCfg { tx_broadcaster, fee_estimator, chain_source, logger, persister, keys_manager, scorer });
        }
 
        chan_mon_cfgs
@@ -2304,7 +2307,7 @@ pub fn create_node_cfgs<'a>(node_count: usize, chanmon_cfgs: &'a Vec<TestChanMon
                        logger: &chanmon_cfgs[i].logger,
                        tx_broadcaster: &chanmon_cfgs[i].tx_broadcaster,
                        fee_estimator: &chanmon_cfgs[i].fee_estimator,
-                       router: test_utils::TestRouter::new(network_graph.clone()),
+                       router: test_utils::TestRouter::new(network_graph.clone(), &chanmon_cfgs[i].scorer),
                        chain_monitor,
                        keys_manager: &chanmon_cfgs[i].keys_manager,
                        node_seed: seed,