Avoid cross-test statics in ChannelManager network tests
authorMatt Corallo <git@bluematt.me>
Sun, 9 Sep 2018 01:02:42 +0000 (21:02 -0400)
committerMatt Corallo <git@bluematt.me>
Wed, 12 Sep 2018 15:15:51 +0000 (11:15 -0400)
src/ln/channelmanager.rs

index 8baca601933e57ff5eeb835778b8d69c864f4c3b..d4548ce621a516fca6ba5fdcb1b5a87213373f16 100644 (file)
@@ -2212,8 +2212,10 @@ mod tests {
 
        use rand::{thread_rng,Rng};
 
+       use std::cell::RefCell;
        use std::collections::HashMap;
        use std::default::Default;
+       use std::rc::Rc;
        use std::sync::{Arc, Mutex};
        use std::time::Instant;
        use std::mem;
@@ -2384,9 +2386,10 @@ mod tests {
                chan_monitor: Arc<test_utils::TestChannelMonitor>,
                node: Arc<ChannelManager>,
                router: Router,
+               network_payment_count: Rc<RefCell<u8>>,
+               network_chan_count: Rc<RefCell<u32>>,
        }
 
-       static mut CHAN_COUNT: u32 = 0;
        fn create_chan_between_nodes(node_a: &Node, node_b: &Node) -> (msgs::ChannelAnnouncement, msgs::ChannelUpdate, msgs::ChannelUpdate, [u8; 32], Transaction) {
                node_a.node.create_channel(node_b.node.get_our_node_id(), 100000, 10001, 42).unwrap();
 
@@ -2402,7 +2405,7 @@ mod tests {
 
                node_a.node.handle_accept_channel(&node_b.node.get_our_node_id(), &accept_chan).unwrap();
 
-               let chan_id = unsafe { CHAN_COUNT };
+               let chan_id = *node_a.network_chan_count.borrow();
                let tx;
                let funding_output;
 
@@ -2508,9 +2511,7 @@ mod tests {
                        _ => panic!("Unexpected event"),
                };
 
-               unsafe {
-                       CHAN_COUNT += 1;
-               }
+               *node_a.network_chan_count.borrow_mut() += 1;
 
                ((*announcement).clone(), (*as_update).clone(), (*bs_update).clone(), channel_id, tx)
        }
@@ -2612,10 +2613,9 @@ mod tests {
                }
        }
 
-       static mut PAYMENT_COUNT: u8 = 0;
        fn send_along_route(origin_node: &Node, route: Route, expected_route: &[&Node], recv_value: u64) -> ([u8; 32], [u8; 32]) {
-               let our_payment_preimage = unsafe { [PAYMENT_COUNT; 32] };
-               unsafe { PAYMENT_COUNT += 1 };
+               let our_payment_preimage = [*origin_node.network_payment_count.borrow(); 32];
+               *origin_node.network_payment_count.borrow_mut() += 1;
                let our_payment_hash = {
                        let mut sha = Sha256::new();
                        sha.input(&our_payment_preimage[..]);
@@ -2807,8 +2807,8 @@ mod tests {
                        assert_eq!(hop.pubkey, node.node.get_our_node_id());
                }
 
-               let our_payment_preimage = unsafe { [PAYMENT_COUNT; 32] };
-               unsafe { PAYMENT_COUNT += 1 };
+               let our_payment_preimage = [*origin_node.network_payment_count.borrow(); 32];
+               *origin_node.network_payment_count.borrow_mut() += 1;
                let our_payment_hash = {
                        let mut sha = Sha256::new();
                        sha.input(&our_payment_preimage[..]);
@@ -2919,6 +2919,9 @@ mod tests {
                let secp_ctx = Secp256k1::new();
                let logger: Arc<Logger> = Arc::new(test_utils::TestLogger::new());
 
+               let chan_count = Rc::new(RefCell::new(0));
+               let payment_count = Rc::new(RefCell::new(0));
+
                for _ in 0..node_count {
                        let feeest = Arc::new(test_utils::TestFeeEstimator { sat_per_kw: 253 });
                        let chain_monitor = Arc::new(chaininterface::ChainWatchInterfaceUtil::new(Network::Testnet, Arc::clone(&logger)));
@@ -2931,7 +2934,10 @@ mod tests {
                        };
                        let node = ChannelManager::new(node_id.clone(), 0, true, Network::Testnet, feeest.clone(), chan_monitor.clone(), chain_monitor.clone(), tx_broadcaster.clone(), Arc::clone(&logger)).unwrap();
                        let router = Router::new(PublicKey::from_secret_key(&secp_ctx, &node_id), chain_monitor.clone(), Arc::clone(&logger));
-                       nodes.push(Node { chain_monitor, tx_broadcaster, chan_monitor, node, router });
+                       nodes.push(Node { chain_monitor, tx_broadcaster, chan_monitor, node, router,
+                               network_payment_count: payment_count.clone(),
+                               network_chan_count: chan_count.clone(),
+                       });
                }
 
                nodes