}
}
+/// Fetches the first `msg_event` to the passed `node_id` in the passed `msg_events` vec.
+/// Returns the `msg_event`, along with an updated `msg_events` vec with the message removed.
+///
+/// Note that even though `BroadcastChannelAnnouncement` and `BroadcastChannelUpdate`
+/// `msg_events` are stored under specific peers, this function does not fetch such `msg_events` as
+/// such messages are intended to all peers.
+pub fn remove_first_msg_event_to_node(msg_node_id: &PublicKey, msg_events: &Vec<MessageSendEvent>) -> (MessageSendEvent, Vec<MessageSendEvent>) {
+ let ev_index = msg_events.iter().position(|e| { match e {
+ MessageSendEvent::SendAcceptChannel { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendOpenChannel { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendFundingCreated { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendFundingSigned { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendChannelReady { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendAnnouncementSignatures { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::UpdateHTLCs { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendRevokeAndACK { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendClosingSigned { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendShutdown { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendChannelReestablish { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendChannelAnnouncement { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::BroadcastChannelAnnouncement { .. } => {
+ false
+ },
+ MessageSendEvent::BroadcastChannelUpdate { .. } => {
+ false
+ },
+ MessageSendEvent::SendChannelUpdate { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::HandleError { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendChannelRangeQuery { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendShortIdsQuery { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendReplyChannelRange { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ MessageSendEvent::SendGossipTimestampFilter { node_id, .. } => {
+ node_id == msg_node_id
+ },
+ }});
+ if ev_index.is_some() {
+ let mut updated_msg_events = msg_events.to_vec();
+ let ev = updated_msg_events.remove(ev_index.unwrap());
+ (ev, updated_msg_events)
+ } else {
+ panic!("Couldn't find any MessageSendEvent to the node!")
+ }
+}
+
#[cfg(test)]
macro_rules! get_channel_ref {
- ($node: expr, $lock: ident, $channel_id: expr) => {
+ ($node: expr, $counterparty_node: expr, $per_peer_state_lock: ident, $peer_state_lock: ident, $channel_id: expr) => {
{
- $lock = $node.node.channel_state.lock().unwrap();
- $lock.by_id.get_mut(&$channel_id).unwrap()
+ $per_peer_state_lock = $node.node.per_peer_state.read().unwrap();
+ $peer_state_lock = $per_peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap();
+ $peer_state_lock.channel_by_id.get_mut(&$channel_id).unwrap()
}
}
}
#[cfg(test)]
macro_rules! get_feerate {
- ($node: expr, $channel_id: expr) => {
+ ($node: expr, $counterparty_node: expr, $channel_id: expr) => {
{
- let mut lock;
- let chan = get_channel_ref!($node, lock, $channel_id);
+ let mut per_peer_state_lock;
+ let mut peer_state_lock;
+ let chan = get_channel_ref!($node, $counterparty_node, per_peer_state_lock, peer_state_lock, $channel_id);
chan.get_feerate()
}
}
#[cfg(test)]
macro_rules! get_opt_anchors {
- ($node: expr, $channel_id: expr) => {
+ ($node: expr, $counterparty_node: expr, $channel_id: expr) => {
{
- let mut lock;
- let chan = get_channel_ref!($node, lock, $channel_id);
+ let mut per_peer_state_lock;
+ let mut peer_state_lock;
+ let chan = get_channel_ref!($node, $counterparty_node, per_peer_state_lock, peer_state_lock, $channel_id);
chan.opt_anchors()
}
}
let (bs_revoke_and_ack, extra_msg_option) = {
let events = $node_b.node.get_and_clear_pending_msg_events();
assert!(events.len() <= 2);
- (match events[0] {
+ let (node_a_event, events) = remove_first_msg_event_to_node(&$node_a.node.get_our_node_id(), &events);
+ (match node_a_event {
MessageSendEvent::SendRevokeAndACK { ref node_id, ref msg } => {
assert_eq!(*node_id, $node_a.node.get_our_node_id());
(*msg).clone()
},
_ => panic!("Unexpected event"),
- }, events.get(1).map(|e| e.clone()))
+ }, events.get(0).map(|e| e.clone()))
};
check_added_monitors!($node_b, 1);
if $fail_backwards {
expect_pending_htlcs_forwardable_and_htlc_handling_failed!($node_a, vec![$crate::util::events::HTLCDestination::NextHopChannel{ node_id: Some($node_b.node.get_our_node_id()), channel_id: $commitment_signed.channel_id }]);
check_added_monitors!($node_a, 1);
- let channel_state = $node_a.node.channel_state.lock().unwrap();
- assert_eq!(channel_state.pending_msg_events.len(), 1);
- if let MessageSendEvent::UpdateHTLCs { ref node_id, .. } = channel_state.pending_msg_events[0] {
- assert_ne!(*node_id, $node_b.node.get_our_node_id());
- } else { panic!("Unexpected event"); }
+ let node_a_per_peer_state = $node_a.node.per_peer_state.read().unwrap();
+ let mut number_of_msg_events = 0;
+ for (cp_id, peer_state_mutex) in node_a_per_peer_state.iter() {
+ let peer_state = peer_state_mutex.lock().unwrap();
+ let cp_pending_msg_events = &peer_state.pending_msg_events;
+ number_of_msg_events += cp_pending_msg_events.len();
+ if cp_pending_msg_events.len() == 1 {
+ if let MessageSendEvent::UpdateHTLCs { .. } = cp_pending_msg_events[0] {
+ assert_ne!(*cp_id, $node_b.node.get_our_node_id());
+ } else { panic!("Unexpected event"); }
+ }
+ }
+ // Expecting the failure backwards event to the previous hop (not `node_b`)
+ assert_eq!(number_of_msg_events, 1);
} else {
assert!($node_a.node.get_and_clear_pending_msg_events().is_empty());
}
pub fn pass_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&[&Node<'a, 'b, 'c>]], recv_value: u64, our_payment_hash: PaymentHash, our_payment_secret: PaymentSecret) {
let mut events = origin_node.node.get_and_clear_pending_msg_events();
assert_eq!(events.len(), expected_route.len());
- for (path_idx, (ev, expected_path)) in events.drain(..).zip(expected_route.iter()).enumerate() {
+ for (path_idx, expected_path) in expected_route.iter().enumerate() {
+ let (ev, updated_events) = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &events);
+ events = updated_events;
// Once we've gotten through all the HTLCs, the last one should result in a
// PaymentClaimable (but each previous one should not!), .
let expect_payment = path_idx == expected_route.len() - 1;
}
}
let mut per_path_msgs: Vec<((msgs::UpdateFulfillHTLC, msgs::CommitmentSigned), PublicKey)> = Vec::with_capacity(expected_paths.len());
- let events = expected_paths[0].last().unwrap().node.get_and_clear_pending_msg_events();
+ let mut events = expected_paths[0].last().unwrap().node.get_and_clear_pending_msg_events();
assert_eq!(events.len(), expected_paths.len());
- for ev in events.iter() {
- per_path_msgs.push(msgs_from_ev!(ev));
+
+ if events.len() == 1 {
+ per_path_msgs.push(msgs_from_ev!(&events[0]));
+ } else {
+ for expected_path in expected_paths.iter() {
+ // For MPP payments, we always want the message to the first node in the path.
+ let (ev, updated_events) = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &events);
+ per_path_msgs.push(msgs_from_ev!(&ev));
+ events = updated_events;
+ }
}
for (expected_route, (path_msgs, next_hop)) in expected_paths.iter().zip(per_path_msgs.drain(..)) {
{
$node.node.handle_update_fulfill_htlc(&$prev_node.node.get_our_node_id(), &next_msgs.as_ref().unwrap().0);
let fee = {
- let channel_state = $node.node.channel_state.lock().unwrap();
- let channel = channel_state
- .by_id.get(&next_msgs.as_ref().unwrap().0.channel_id).unwrap();
+ let per_peer_state = $node.node.per_peer_state.read().unwrap();
+ let peer_state = per_peer_state.get(&$prev_node.node.get_our_node_id())
+ .unwrap().lock().unwrap();
+ let channel = peer_state.channel_by_id.get(&next_msgs.as_ref().unwrap().0.channel_id).unwrap();
if let Some(prev_config) = channel.prev_config() {
prev_config.forwarding_fee_base_msat
} else {
#[cfg(test)]
macro_rules! get_channel_value_stat {
- ($node: expr, $channel_id: expr) => {{
- let chan_lock = $node.node.channel_state.lock().unwrap();
- let chan = chan_lock.by_id.get(&$channel_id).unwrap();
+ ($node: expr, $counterparty_node: expr, $channel_id: expr) => {{
+ let peer_state_lock = $node.node.per_peer_state.read().unwrap();
+ let chan_lock = peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap();
+ let chan = chan_lock.channel_by_id.get(&$channel_id).unwrap();
chan.get_value_stat()
}}
}