Store channels per peer
[rust-lightning] / lightning / src / ln / functional_test_utils.rs
index a7e1f8cfd63a7e6f4e58c272a746b288abe19492..4430128ea533b0d940573bd76e25fe8b6b7e199c 100644 (file)
@@ -560,20 +560,22 @@ macro_rules! get_htlc_update_msgs {
 
 #[cfg(test)]
 macro_rules! get_channel_ref {
-       ($node: expr, $lock: ident, $channel_id: expr) => {
+       ($node: expr, $counterparty_node: expr, $per_peer_state_lock: ident, $peer_state_lock: ident, $channel_id: expr) => {
                {
-                       $lock = $node.node.channel_state.lock().unwrap();
-                       $lock.by_id.get_mut(&$channel_id).unwrap()
+                       $per_peer_state_lock = $node.node.per_peer_state.read().unwrap();
+                       $peer_state_lock = $per_peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap();
+                       $peer_state_lock.channel_by_id.get_mut(&$channel_id).unwrap()
                }
        }
 }
 
 #[cfg(test)]
 macro_rules! get_feerate {
-       ($node: expr, $channel_id: expr) => {
+       ($node: expr, $counterparty_node: expr, $channel_id: expr) => {
                {
-                       let mut lock;
-                       let chan = get_channel_ref!($node, lock, $channel_id);
+                       let mut per_peer_state_lock;
+                       let mut peer_state_lock;
+                       let chan = get_channel_ref!($node, $counterparty_node, per_peer_state_lock, peer_state_lock, $channel_id);
                        chan.get_feerate()
                }
        }
@@ -581,10 +583,11 @@ macro_rules! get_feerate {
 
 #[cfg(test)]
 macro_rules! get_opt_anchors {
-       ($node: expr, $channel_id: expr) => {
+       ($node: expr, $counterparty_node: expr, $channel_id: expr) => {
                {
-                       let mut lock;
-                       let chan = get_channel_ref!($node, lock, $channel_id);
+                       let mut per_peer_state_lock;
+                       let mut peer_state_lock;
+                       let chan = get_channel_ref!($node, $counterparty_node, per_peer_state_lock, peer_state_lock, $channel_id);
                        chan.opt_anchors()
                }
        }
@@ -1896,9 +1899,10 @@ pub fn do_claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>,
                                {
                                        $node.node.handle_update_fulfill_htlc(&$prev_node.node.get_our_node_id(), &next_msgs.as_ref().unwrap().0);
                                        let fee = {
-                                               let channel_state = $node.node.channel_state.lock().unwrap();
-                                               let channel = channel_state
-                                                       .by_id.get(&next_msgs.as_ref().unwrap().0.channel_id).unwrap();
+                                               let per_peer_state = $node.node.per_peer_state.read().unwrap();
+                                               let peer_state = per_peer_state.get(&$prev_node.node.get_our_node_id())
+                                                       .unwrap().lock().unwrap();
+                                               let channel = peer_state.channel_by_id.get(&next_msgs.as_ref().unwrap().0.channel_id).unwrap();
                                                if let Some(prev_config) = channel.prev_config() {
                                                        prev_config.forwarding_fee_base_msat
                                                } else {
@@ -2405,9 +2409,10 @@ pub fn get_announce_close_broadcast_events<'a, 'b, 'c>(nodes: &Vec<Node<'a, 'b,
 
 #[cfg(test)]
 macro_rules! get_channel_value_stat {
-       ($node: expr, $channel_id: expr) => {{
-               let chan_lock = $node.node.channel_state.lock().unwrap();
-               let chan = chan_lock.by_id.get(&$channel_id).unwrap();
+       ($node: expr, $counterparty_node: expr, $channel_id: expr) => {{
+               let peer_state_lock = $node.node.per_peer_state.read().unwrap();
+               let chan_lock = peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap();
+               let chan = chan_lock.channel_by_id.get(&$channel_id).unwrap();
                chan.get_value_stat()
        }}
 }