X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Fln%2Ffunctional_test_utils.rs;fp=lightning%2Fsrc%2Fln%2Ffunctional_test_utils.rs;h=bcfa571f9ac317f9eb4d47230fab4025fe3f4f45;hb=5221e4a861687751c92d79cf3a54bb9cc1f7aee2;hp=a7e1f8cfd63a7e6f4e58c272a746b288abe19492;hpb=fad52d8b98467e18e4112006cebdb1dec39d199a;p=rust-lightning diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index a7e1f8cf..bcfa571f 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -558,22 +558,102 @@ macro_rules! get_htlc_update_msgs { } } +/// Fetches the first `msg_event` to the passed `node_id` in the passed `msg_events` vec. +/// Returns the `msg_event`, along with an updated `msg_events` vec with the message removed. +/// +/// Note that even though `BroadcastChannelAnnouncement` and `BroadcastChannelUpdate` +/// `msg_events` are stored under specific peers, this function does not fetch such `msg_events` as +/// such messages are intended to all peers. +pub fn remove_first_msg_event_to_node(msg_node_id: &PublicKey, msg_events: &Vec) -> (MessageSendEvent, Vec) { + let ev_index = msg_events.iter().position(|e| { match e { + MessageSendEvent::SendAcceptChannel { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendOpenChannel { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendFundingCreated { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendFundingSigned { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendChannelReady { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendAnnouncementSignatures { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::UpdateHTLCs { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendRevokeAndACK { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendClosingSigned { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendShutdown { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendChannelReestablish { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendChannelAnnouncement { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::BroadcastChannelAnnouncement { .. } => { + false + }, + MessageSendEvent::BroadcastChannelUpdate { .. } => { + false + }, + MessageSendEvent::SendChannelUpdate { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::HandleError { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendChannelRangeQuery { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendShortIdsQuery { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendReplyChannelRange { node_id, .. } => { + node_id == msg_node_id + }, + MessageSendEvent::SendGossipTimestampFilter { node_id, .. } => { + node_id == msg_node_id + }, + }}); + if ev_index.is_some() { + let mut updated_msg_events = msg_events.to_vec(); + let ev = updated_msg_events.remove(ev_index.unwrap()); + (ev, updated_msg_events) + } else { + panic!("Couldn't find any MessageSendEvent to the node!") + } +} + #[cfg(test)] macro_rules! get_channel_ref { - ($node: expr, $lock: ident, $channel_id: expr) => { + ($node: expr, $counterparty_node: expr, $per_peer_state_lock: ident, $peer_state_lock: ident, $channel_id: expr) => { { - $lock = $node.node.channel_state.lock().unwrap(); - $lock.by_id.get_mut(&$channel_id).unwrap() + $per_peer_state_lock = $node.node.per_peer_state.read().unwrap(); + $peer_state_lock = $per_peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap(); + $peer_state_lock.channel_by_id.get_mut(&$channel_id).unwrap() } } } #[cfg(test)] macro_rules! get_feerate { - ($node: expr, $channel_id: expr) => { + ($node: expr, $counterparty_node: expr, $channel_id: expr) => { { - let mut lock; - let chan = get_channel_ref!($node, lock, $channel_id); + let mut per_peer_state_lock; + let mut peer_state_lock; + let chan = get_channel_ref!($node, $counterparty_node, per_peer_state_lock, peer_state_lock, $channel_id); chan.get_feerate() } } @@ -581,10 +661,11 @@ macro_rules! get_feerate { #[cfg(test)] macro_rules! get_opt_anchors { - ($node: expr, $channel_id: expr) => { + ($node: expr, $counterparty_node: expr, $channel_id: expr) => { { - let mut lock; - let chan = get_channel_ref!($node, lock, $channel_id); + let mut per_peer_state_lock; + let mut peer_state_lock; + let chan = get_channel_ref!($node, $counterparty_node, per_peer_state_lock, peer_state_lock, $channel_id); chan.opt_anchors() } } @@ -1273,13 +1354,14 @@ macro_rules! commitment_signed_dance { let (bs_revoke_and_ack, extra_msg_option) = { let events = $node_b.node.get_and_clear_pending_msg_events(); assert!(events.len() <= 2); - (match events[0] { + let (node_a_event, events) = remove_first_msg_event_to_node(&$node_a.node.get_our_node_id(), &events); + (match node_a_event { MessageSendEvent::SendRevokeAndACK { ref node_id, ref msg } => { assert_eq!(*node_id, $node_a.node.get_our_node_id()); (*msg).clone() }, _ => panic!("Unexpected event"), - }, events.get(1).map(|e| e.clone())) + }, events.get(0).map(|e| e.clone())) }; check_added_monitors!($node_b, 1); if $fail_backwards { @@ -1320,11 +1402,20 @@ macro_rules! commitment_signed_dance { expect_pending_htlcs_forwardable_and_htlc_handling_failed!($node_a, vec![$crate::util::events::HTLCDestination::NextHopChannel{ node_id: Some($node_b.node.get_our_node_id()), channel_id: $commitment_signed.channel_id }]); check_added_monitors!($node_a, 1); - let channel_state = $node_a.node.channel_state.lock().unwrap(); - assert_eq!(channel_state.pending_msg_events.len(), 1); - if let MessageSendEvent::UpdateHTLCs { ref node_id, .. } = channel_state.pending_msg_events[0] { - assert_ne!(*node_id, $node_b.node.get_our_node_id()); - } else { panic!("Unexpected event"); } + let node_a_per_peer_state = $node_a.node.per_peer_state.read().unwrap(); + let mut number_of_msg_events = 0; + for (cp_id, peer_state_mutex) in node_a_per_peer_state.iter() { + let peer_state = peer_state_mutex.lock().unwrap(); + let cp_pending_msg_events = &peer_state.pending_msg_events; + number_of_msg_events += cp_pending_msg_events.len(); + if cp_pending_msg_events.len() == 1 { + if let MessageSendEvent::UpdateHTLCs { .. } = cp_pending_msg_events[0] { + assert_ne!(*cp_id, $node_b.node.get_our_node_id()); + } else { panic!("Unexpected event"); } + } + } + // Expecting the failure backwards event to the previous hop (not `node_b`) + assert_eq!(number_of_msg_events, 1); } else { assert!($node_a.node.get_and_clear_pending_msg_events().is_empty()); } @@ -1820,7 +1911,9 @@ pub fn pass_along_path<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path pub fn pass_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&[&Node<'a, 'b, 'c>]], recv_value: u64, our_payment_hash: PaymentHash, our_payment_secret: PaymentSecret) { let mut events = origin_node.node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), expected_route.len()); - for (path_idx, (ev, expected_path)) in events.drain(..).zip(expected_route.iter()).enumerate() { + for (path_idx, expected_path) in expected_route.iter().enumerate() { + let (ev, updated_events) = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &events); + events = updated_events; // Once we've gotten through all the HTLCs, the last one should result in a // PaymentClaimable (but each previous one should not!), . let expect_payment = path_idx == expected_route.len() - 1; @@ -1871,10 +1964,18 @@ pub fn do_claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, } } let mut per_path_msgs: Vec<((msgs::UpdateFulfillHTLC, msgs::CommitmentSigned), PublicKey)> = Vec::with_capacity(expected_paths.len()); - let events = expected_paths[0].last().unwrap().node.get_and_clear_pending_msg_events(); + let mut events = expected_paths[0].last().unwrap().node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), expected_paths.len()); - for ev in events.iter() { - per_path_msgs.push(msgs_from_ev!(ev)); + + if events.len() == 1 { + per_path_msgs.push(msgs_from_ev!(&events[0])); + } else { + for expected_path in expected_paths.iter() { + // For MPP payments, we always want the message to the first node in the path. + let (ev, updated_events) = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &events); + per_path_msgs.push(msgs_from_ev!(&ev)); + events = updated_events; + } } for (expected_route, (path_msgs, next_hop)) in expected_paths.iter().zip(per_path_msgs.drain(..)) { @@ -1896,9 +1997,10 @@ pub fn do_claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, { $node.node.handle_update_fulfill_htlc(&$prev_node.node.get_our_node_id(), &next_msgs.as_ref().unwrap().0); let fee = { - let channel_state = $node.node.channel_state.lock().unwrap(); - let channel = channel_state - .by_id.get(&next_msgs.as_ref().unwrap().0.channel_id).unwrap(); + let per_peer_state = $node.node.per_peer_state.read().unwrap(); + let peer_state = per_peer_state.get(&$prev_node.node.get_our_node_id()) + .unwrap().lock().unwrap(); + let channel = peer_state.channel_by_id.get(&next_msgs.as_ref().unwrap().0.channel_id).unwrap(); if let Some(prev_config) = channel.prev_config() { prev_config.forwarding_fee_base_msat } else { @@ -2405,9 +2507,10 @@ pub fn get_announce_close_broadcast_events<'a, 'b, 'c>(nodes: &Vec {{ - let chan_lock = $node.node.channel_state.lock().unwrap(); - let chan = chan_lock.by_id.get(&$channel_id).unwrap(); + ($node: expr, $counterparty_node: expr, $channel_id: expr) => {{ + let peer_state_lock = $node.node.per_peer_state.read().unwrap(); + let chan_lock = peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap(); + let chan = chan_lock.channel_by_id.get(&$channel_id).unwrap(); chan.get_value_stat() }} }