From 6b3cc8bb4dd645f6afaa00afe6a4e02d8ff8e2db Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Sat, 8 Sep 2018 21:02:42 -0400 Subject: [PATCH] Avoid cross-test statics in ChannelManager network tests --- src/ln/channelmanager.rs | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/ln/channelmanager.rs b/src/ln/channelmanager.rs index 8baca601..d4548ce6 100644 --- a/src/ln/channelmanager.rs +++ b/src/ln/channelmanager.rs @@ -2212,8 +2212,10 @@ mod tests { use rand::{thread_rng,Rng}; + use std::cell::RefCell; use std::collections::HashMap; use std::default::Default; + use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::time::Instant; use std::mem; @@ -2384,9 +2386,10 @@ mod tests { chan_monitor: Arc, node: Arc, router: Router, + network_payment_count: Rc>, + network_chan_count: Rc>, } - static mut CHAN_COUNT: u32 = 0; fn create_chan_between_nodes(node_a: &Node, node_b: &Node) -> (msgs::ChannelAnnouncement, msgs::ChannelUpdate, msgs::ChannelUpdate, [u8; 32], Transaction) { node_a.node.create_channel(node_b.node.get_our_node_id(), 100000, 10001, 42).unwrap(); @@ -2402,7 +2405,7 @@ mod tests { node_a.node.handle_accept_channel(&node_b.node.get_our_node_id(), &accept_chan).unwrap(); - let chan_id = unsafe { CHAN_COUNT }; + let chan_id = *node_a.network_chan_count.borrow(); let tx; let funding_output; @@ -2508,9 +2511,7 @@ mod tests { _ => panic!("Unexpected event"), }; - unsafe { - CHAN_COUNT += 1; - } + *node_a.network_chan_count.borrow_mut() += 1; ((*announcement).clone(), (*as_update).clone(), (*bs_update).clone(), channel_id, tx) } @@ -2612,10 +2613,9 @@ mod tests { } } - static mut PAYMENT_COUNT: u8 = 0; fn send_along_route(origin_node: &Node, route: Route, expected_route: &[&Node], recv_value: u64) -> ([u8; 32], [u8; 32]) { - let our_payment_preimage = unsafe { [PAYMENT_COUNT; 32] }; - unsafe { PAYMENT_COUNT += 1 }; + let our_payment_preimage = [*origin_node.network_payment_count.borrow(); 32]; + *origin_node.network_payment_count.borrow_mut() += 1; let our_payment_hash = { let mut sha = Sha256::new(); sha.input(&our_payment_preimage[..]); @@ -2807,8 +2807,8 @@ mod tests { assert_eq!(hop.pubkey, node.node.get_our_node_id()); } - let our_payment_preimage = unsafe { [PAYMENT_COUNT; 32] }; - unsafe { PAYMENT_COUNT += 1 }; + let our_payment_preimage = [*origin_node.network_payment_count.borrow(); 32]; + *origin_node.network_payment_count.borrow_mut() += 1; let our_payment_hash = { let mut sha = Sha256::new(); sha.input(&our_payment_preimage[..]); @@ -2919,6 +2919,9 @@ mod tests { let secp_ctx = Secp256k1::new(); let logger: Arc = Arc::new(test_utils::TestLogger::new()); + let chan_count = Rc::new(RefCell::new(0)); + let payment_count = Rc::new(RefCell::new(0)); + for _ in 0..node_count { let feeest = Arc::new(test_utils::TestFeeEstimator { sat_per_kw: 253 }); let chain_monitor = Arc::new(chaininterface::ChainWatchInterfaceUtil::new(Network::Testnet, Arc::clone(&logger))); @@ -2931,7 +2934,10 @@ mod tests { }; let node = ChannelManager::new(node_id.clone(), 0, true, Network::Testnet, feeest.clone(), chan_monitor.clone(), chain_monitor.clone(), tx_broadcaster.clone(), Arc::clone(&logger)).unwrap(); let router = Router::new(PublicKey::from_secret_key(&secp_ctx, &node_id), chain_monitor.clone(), Arc::clone(&logger)); - nodes.push(Node { chain_monitor, tx_broadcaster, chan_monitor, node, router }); + nodes.push(Node { chain_monitor, tx_broadcaster, chan_monitor, node, router, + network_payment_count: payment_count.clone(), + network_chan_count: chan_count.clone(), + }); } nodes -- 2.30.2