Split claim and fail payment functions to be able to skip one hop
[rust-lightning] / src / ln / channelmanager.rs
index 4f0d0efd087d48d0e19ca76a04fc9c7cf7fc1e8d..31a0491f75cfd3982a94855fb73ac4ad2b7550af 100644 (file)
@@ -2741,7 +2741,7 @@ mod tests {
                (our_payment_preimage, our_payment_hash)
        }
 
-       fn claim_payment(origin_node: &Node, expected_route: &[&Node], our_payment_preimage: [u8; 32]) {
+       fn claim_payment_along_route(origin_node: &Node, expected_route: &[&Node], skip_last: bool, our_payment_preimage: [u8; 32]) {
                assert!(expected_route.last().unwrap().node.claim_funds(our_payment_preimage));
                {
                        let mut added_monitors = expected_route.last().unwrap().chan_monitor.added_monitors.lock().unwrap();
@@ -2770,40 +2770,51 @@ mod tests {
 
                let mut expected_next_node = expected_route.last().unwrap().node.get_our_node_id();
                let mut prev_node = expected_route.last().unwrap();
-               for node in expected_route.iter().rev() {
+               for (idx, node) in expected_route.iter().rev().enumerate() {
                        assert_eq!(expected_next_node, node.node.get_our_node_id());
                        if next_msgs.is_some() {
                                update_fulfill_dance!(node, prev_node, false);
                        }
 
                        let events = node.node.get_and_clear_pending_events();
+                       if !skip_last || idx != expected_route.len() - 1 {
+                               assert_eq!(events.len(), 1);
+                               match events[0] {
+                                       Event::UpdateHTLCs { ref node_id, updates: msgs::CommitmentUpdate { ref update_add_htlcs, ref update_fulfill_htlcs, ref update_fail_htlcs, ref update_fail_malformed_htlcs, ref commitment_signed } } => {
+                                               assert!(update_add_htlcs.is_empty());
+                                               assert_eq!(update_fulfill_htlcs.len(), 1);
+                                               assert!(update_fail_htlcs.is_empty());
+                                               assert!(update_fail_malformed_htlcs.is_empty());
+                                               expected_next_node = node_id.clone();
+                                               next_msgs = Some((update_fulfill_htlcs[0].clone(), commitment_signed.clone()));
+                                       },
+                                       _ => panic!("Unexpected event"),
+                               }
+                       } else {
+                               assert!(events.is_empty());
+                       }
+                       if !skip_last && idx == expected_route.len() - 1 {
+                               assert_eq!(expected_next_node, origin_node.node.get_our_node_id());
+                       }
+
+                       prev_node = node;
+               }
+
+               if !skip_last {
+                       update_fulfill_dance!(origin_node, expected_route.first().unwrap(), true);
+                       let events = origin_node.node.get_and_clear_pending_events();
                        assert_eq!(events.len(), 1);
                        match events[0] {
-                               Event::UpdateHTLCs { ref node_id, updates: msgs::CommitmentUpdate { ref update_add_htlcs, ref update_fulfill_htlcs, ref update_fail_htlcs, ref update_fail_malformed_htlcs, ref commitment_signed } } => {
-                                       assert!(update_add_htlcs.is_empty());
-                                       assert_eq!(update_fulfill_htlcs.len(), 1);
-                                       assert!(update_fail_htlcs.is_empty());
-                                       assert!(update_fail_malformed_htlcs.is_empty());
-                                       expected_next_node = node_id.clone();
-                                       next_msgs = Some((update_fulfill_htlcs[0].clone(), commitment_signed.clone()));
+                               Event::PaymentSent { payment_preimage } => {
+                                       assert_eq!(payment_preimage, our_payment_preimage);
                                },
                                _ => panic!("Unexpected event"),
-                       };
-
-                       prev_node = node;
+                       }
                }
+       }
 
-               assert_eq!(expected_next_node, origin_node.node.get_our_node_id());
-               update_fulfill_dance!(origin_node, expected_route.first().unwrap(), true);
-
-               let events = origin_node.node.get_and_clear_pending_events();
-               assert_eq!(events.len(), 1);
-               match events[0] {
-                       Event::PaymentSent { payment_preimage } => {
-                               assert_eq!(payment_preimage, our_payment_preimage);
-                       },
-                       _ => panic!("Unexpected event"),
-               }
+       fn claim_payment(origin_node: &Node, expected_route: &[&Node], our_payment_preimage: [u8; 32]) {
+               claim_payment_along_route(origin_node, expected_route, false, our_payment_preimage);
        }
 
        const TEST_FINAL_CLTV: u32 = 32;
@@ -2847,7 +2858,7 @@ mod tests {
                claim_payment(&origin, expected_route, our_payment_preimage);
        }
 
-       fn fail_payment(origin_node: &Node, expected_route: &[&Node], our_payment_hash: [u8; 32]) {
+       fn fail_payment_along_route(origin_node: &Node, expected_route: &[&Node], skip_last: bool, our_payment_hash: [u8; 32]) {
                assert!(expected_route.last().unwrap().node.fail_htlc_backwards(&our_payment_hash));
                {
                        let mut added_monitors = expected_route.last().unwrap().chan_monitor.added_monitors.lock().unwrap();
@@ -2867,42 +2878,57 @@ mod tests {
 
                let mut expected_next_node = expected_route.last().unwrap().node.get_our_node_id();
                let mut prev_node = expected_route.last().unwrap();
-               for node in expected_route.iter().rev() {
+               for (idx, node) in expected_route.iter().rev().enumerate() {
                        assert_eq!(expected_next_node, node.node.get_our_node_id());
                        if next_msgs.is_some() {
-                               update_fail_dance!(node, prev_node, false);
+                               // We may be the "last node" for the purpose of the commitment dance if we're
+                               // skipping the last node (implying it is disconnected) and we're the
+                               // second-to-last node!
+                               update_fail_dance!(node, prev_node, skip_last && idx == expected_route.len() - 1);
                        }
 
                        let events = node.node.get_and_clear_pending_events();
-                       assert_eq!(events.len(), 1);
-                       match events[0] {
-                               Event::UpdateHTLCs { ref node_id, updates: msgs::CommitmentUpdate { ref update_add_htlcs, ref update_fulfill_htlcs, ref update_fail_htlcs, ref update_fail_malformed_htlcs, ref commitment_signed } } => {
-                                       assert!(update_add_htlcs.is_empty());
-                                       assert!(update_fulfill_htlcs.is_empty());
-                                       assert_eq!(update_fail_htlcs.len(), 1);
-                                       assert!(update_fail_malformed_htlcs.is_empty());
-                                       expected_next_node = node_id.clone();
-                                       next_msgs = Some((update_fail_htlcs[0].clone(), commitment_signed.clone()));
-                               },
-                               _ => panic!("Unexpected event"),
-                       };
+                       if !skip_last || idx != expected_route.len() - 1 {
+                               assert_eq!(events.len(), 1);
+                               match events[0] {
+                                       Event::UpdateHTLCs { ref node_id, updates: msgs::CommitmentUpdate { ref update_add_htlcs, ref update_fulfill_htlcs, ref update_fail_htlcs, ref update_fail_malformed_htlcs, ref commitment_signed } } => {
+                                               assert!(update_add_htlcs.is_empty());
+                                               assert!(update_fulfill_htlcs.is_empty());
+                                               assert_eq!(update_fail_htlcs.len(), 1);
+                                               assert!(update_fail_malformed_htlcs.is_empty());
+                                               expected_next_node = node_id.clone();
+                                               next_msgs = Some((update_fail_htlcs[0].clone(), commitment_signed.clone()));
+                                       },
+                                       _ => panic!("Unexpected event"),
+                               }
+                       } else {
+                               assert!(events.is_empty());
+                       }
+                       if !skip_last && idx == expected_route.len() - 1 {
+                               assert_eq!(expected_next_node, origin_node.node.get_our_node_id());
+                       }
 
                        prev_node = node;
                }
 
-               assert_eq!(expected_next_node, origin_node.node.get_our_node_id());
-               update_fail_dance!(origin_node, expected_route.first().unwrap(), true);
+               if !skip_last {
+                       update_fail_dance!(origin_node, expected_route.first().unwrap(), true);
 
-               let events = origin_node.node.get_and_clear_pending_events();
-               assert_eq!(events.len(), 1);
-               match events[0] {
-                       Event::PaymentFailed { payment_hash } => {
-                               assert_eq!(payment_hash, our_payment_hash);
-                       },
-                       _ => panic!("Unexpected event"),
+                       let events = origin_node.node.get_and_clear_pending_events();
+                       assert_eq!(events.len(), 1);
+                       match events[0] {
+                               Event::PaymentFailed { payment_hash } => {
+                                       assert_eq!(payment_hash, our_payment_hash);
+                               },
+                               _ => panic!("Unexpected event"),
+                       }
                }
        }
 
+       fn fail_payment(origin_node: &Node, expected_route: &[&Node], our_payment_hash: [u8; 32]) {
+               fail_payment_along_route(origin_node, expected_route, false, our_payment_hash);
+       }
+
        fn create_network(node_count: usize) -> Vec<Node> {
                let mut nodes = Vec::new();
                let mut rng = thread_rng();