Allow overshooting `total_msat` for an MPP
[rust-lightning] / lightning / src / ln / functional_tests.rs
index 265bd49ac92ff785ae961872cd6866d0aaaa845e..25b0f792ccec9e91a21f01cd2740183c29f84e38 100644 (file)
@@ -7926,6 +7926,76 @@ fn test_can_not_accept_unknown_inbound_channel() {
        }
 }
 
+fn do_test_overshoot_mpp(msat_amounts: &[u64], total_msat: u64) {
+
+       let routing_node_count = msat_amounts.len();
+       let node_count = routing_node_count + 2;
+
+       let chanmon_cfgs = create_chanmon_cfgs(node_count);
+       let node_cfgs = create_node_cfgs(node_count, &chanmon_cfgs);
+       let node_chanmgrs = create_node_chanmgrs(node_count, &node_cfgs, &vec![None; node_count]);
+       let nodes = create_network(node_count, &node_cfgs, &node_chanmgrs);
+
+       let src_idx = 0;
+       let dst_idx = 1;
+
+       // Create channels for each amount
+       let mut expected_paths = Vec::with_capacity(routing_node_count);
+       let mut src_chan_ids = Vec::with_capacity(routing_node_count);
+       let mut dst_chan_ids = Vec::with_capacity(routing_node_count);
+       for i in 0..routing_node_count {
+               let routing_node = 2 + i;
+               let src_chan_id = create_announced_chan_between_nodes(&nodes, src_idx, routing_node).0.contents.short_channel_id;
+               src_chan_ids.push(src_chan_id);
+               let dst_chan_id = create_announced_chan_between_nodes(&nodes, routing_node, dst_idx).0.contents.short_channel_id;
+               dst_chan_ids.push(dst_chan_id);
+               let path = vec![&nodes[routing_node], &nodes[dst_idx]];
+               expected_paths.push(path);
+       }
+       let expected_paths: Vec<&[&Node]> = expected_paths.iter().map(|route| route.as_slice()).collect();
+
+       // Create a route for each amount
+       let example_amount = 100000;
+       let (mut route, our_payment_hash, our_payment_preimage, our_payment_secret) = get_route_and_payment_hash!(&nodes[src_idx], nodes[dst_idx], example_amount);
+       let sample_path = route.paths.pop().unwrap();
+       for i in 0..routing_node_count {
+               let routing_node = 2 + i;
+               let mut path = sample_path.clone();
+               path[0].pubkey = nodes[routing_node].node.get_our_node_id();
+               path[0].short_channel_id = src_chan_ids[i];
+               path[1].pubkey = nodes[dst_idx].node.get_our_node_id();
+               path[1].short_channel_id = dst_chan_ids[i];
+               path[1].fee_msat = msat_amounts[i];
+               route.paths.push(path);
+       }
+
+       // Send payment with manually set total_msat
+       let payment_id = PaymentId(nodes[src_idx].keys_manager.backing.get_secure_random_bytes());
+       let onion_session_privs = nodes[src_idx].node.test_add_new_pending_payment(our_payment_hash, Some(our_payment_secret), payment_id, &route).unwrap();
+       nodes[src_idx].node.test_send_payment_internal(&route, our_payment_hash, &Some(our_payment_secret), None, payment_id, Some(total_msat), onion_session_privs).unwrap();
+       check_added_monitors!(nodes[src_idx], expected_paths.len());
+
+       let mut events = nodes[src_idx].node.get_and_clear_pending_msg_events();
+       assert_eq!(events.len(), expected_paths.len());
+       let mut amount_received = 0;
+       for (path_idx, expected_path) in expected_paths.iter().enumerate() {
+               let ev = remove_first_msg_event_to_node(&expected_path[0].node.get_our_node_id(), &mut events);
+
+               let current_path_amount = msat_amounts[path_idx];
+               amount_received += current_path_amount;
+               let became_claimable_now = amount_received >= total_msat && amount_received - current_path_amount < total_msat;
+               pass_along_path(&nodes[src_idx], expected_path, amount_received, our_payment_hash.clone(), Some(our_payment_secret), ev, became_claimable_now, None);
+       }
+
+       claim_payment_along_route(&nodes[src_idx], &expected_paths, false, our_payment_preimage);
+}
+
+#[test]
+fn test_overshoot_mpp() {
+       do_test_overshoot_mpp(&[100_000, 101_000], 200_000);
+       do_test_overshoot_mpp(&[100_000, 10_000, 100_000], 200_000);
+}
+
 #[test]
 fn test_simple_mpp() {
        // Simple test of sending a multi-path payment.