Support receiving MPP keysend
[rust-lightning] / lightning / src / ln / channelmanager.rs
index f3ab0ac5ab358670109dd009a7f8ae61d1ea6d79..702ae86fbb3300a305f209cb465857e5b4e19d9a 100644 (file)
@@ -3505,7 +3505,7 @@ where
                                                                                panic!("short_channel_id == 0 should imply any pending_forward entries are of type Receive");
                                                                        }
                                                                };
-                                                               let mut claimable_htlc = ClaimableHTLC {
+                                                               let claimable_htlc = ClaimableHTLC {
                                                                        prev_hop: HTLCPreviousHopData {
                                                                                short_channel_id: prev_short_channel_id,
                                                                                outpoint: prev_funding_outpoint,
@@ -3555,13 +3555,11 @@ where
                                                                }
 
                                                                macro_rules! check_total_value {
-                                                                       ($payment_data: expr, $payment_preimage: expr) => {{
+                                                                       ($purpose: expr) => {{
                                                                                let mut payment_claimable_generated = false;
-                                                                               let purpose = || {
-                                                                                       events::PaymentPurpose::InvoicePayment {
-                                                                                               payment_preimage: $payment_preimage,
-                                                                                               payment_secret: $payment_data.payment_secret,
-                                                                                       }
+                                                                               let is_keysend = match $purpose {
+                                                                                       events::PaymentPurpose::SpontaneousPayment(_) => true,
+                                                                                       events::PaymentPurpose::InvoicePayment { .. } => false,
                                                                                };
                                                                                let mut claimable_payments = self.claimable_payments.lock().unwrap();
                                                                                if claimable_payments.pending_claiming_payments.contains_key(&payment_hash) {
@@ -3573,9 +3571,18 @@ where
                                                                                        .or_insert_with(|| {
                                                                                                committed_to_claimable = true;
                                                                                                ClaimablePayment {
-                                                                                                       purpose: purpose(), htlcs: Vec::new(), onion_fields: None,
+                                                                                                       purpose: $purpose.clone(), htlcs: Vec::new(), onion_fields: None,
                                                                                                }
                                                                                        });
+                                                                               if $purpose != claimable_payment.purpose {
+                                                                                       let log_keysend = |keysend| if keysend { "keysend" } else { "non-keysend" };
+                                                                                       log_trace!(self.logger, "Failing new {} HTLC with payment_hash {} as we already had an existing {} HTLC with the same payment hash", log_keysend(is_keysend), log_bytes!(payment_hash.0), log_keysend(!is_keysend));
+                                                                                       fail_htlc!(claimable_htlc, payment_hash);
+                                                                               }
+                                                                               if !self.default_configuration.accept_mpp_keysend && is_keysend && !claimable_payment.htlcs.is_empty() {
+                                                                                       log_trace!(self.logger, "Failing new keysend HTLC with payment_hash {} as we already had an existing keysend HTLC with the same payment hash and our config states we don't accept MPP keysend", log_bytes!(payment_hash.0));
+                                                                                       fail_htlc!(claimable_htlc, payment_hash);
+                                                                               }
                                                                                if let Some(earlier_fields) = &mut claimable_payment.onion_fields {
                                                                                        if earlier_fields.check_merge(&mut onion_fields).is_err() {
                                                                                                fail_htlc!(claimable_htlc, payment_hash);
@@ -3584,38 +3591,27 @@ where
                                                                                        claimable_payment.onion_fields = Some(onion_fields);
                                                                                }
                                                                                let ref mut htlcs = &mut claimable_payment.htlcs;
-                                                                               if htlcs.len() == 1 {
-                                                                                       if let OnionPayload::Spontaneous(_) = htlcs[0].onion_payload {
-                                                                                               log_trace!(self.logger, "Failing new HTLC with payment_hash {} as we already had an existing keysend HTLC with the same payment hash", log_bytes!(payment_hash.0));
-                                                                                               fail_htlc!(claimable_htlc, payment_hash);
-                                                                                       }
-                                                                               }
                                                                                let mut total_value = claimable_htlc.sender_intended_value;
                                                                                let mut earliest_expiry = claimable_htlc.cltv_expiry;
                                                                                for htlc in htlcs.iter() {
                                                                                        total_value += htlc.sender_intended_value;
                                                                                        earliest_expiry = cmp::min(earliest_expiry, htlc.cltv_expiry);
-                                                                                       match &htlc.onion_payload {
-                                                                                               OnionPayload::Invoice { .. } => {
-                                                                                                       if htlc.total_msat != $payment_data.total_msat {
-                                                                                                               log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the HTLCs had inconsistent total values (eg {} and {})",
-                                                                                                                       log_bytes!(payment_hash.0), $payment_data.total_msat, htlc.total_msat);
-                                                                                                               total_value = msgs::MAX_VALUE_MSAT;
-                                                                                                       }
-                                                                                                       if total_value >= msgs::MAX_VALUE_MSAT { break; }
-                                                                                               },
-                                                                                               _ => unreachable!(),
+                                                                                       if htlc.total_msat != claimable_htlc.total_msat {
+                                                                                               log_trace!(self.logger, "Failing HTLCs with payment_hash {} as the HTLCs had inconsistent total values (eg {} and {})",
+                                                                                                       log_bytes!(payment_hash.0), claimable_htlc.total_msat, htlc.total_msat);
+                                                                                               total_value = msgs::MAX_VALUE_MSAT;
                                                                                        }
+                                                                                       if total_value >= msgs::MAX_VALUE_MSAT { break; }
                                                                                }
                                                                                // The condition determining whether an MPP is complete must
                                                                                // match exactly the condition used in `timer_tick_occurred`
                                                                                if total_value >= msgs::MAX_VALUE_MSAT {
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
-                                                                               } else if total_value - claimable_htlc.sender_intended_value >= $payment_data.total_msat {
+                                                                               } else if total_value - claimable_htlc.sender_intended_value >= claimable_htlc.total_msat {
                                                                                        log_trace!(self.logger, "Failing HTLC with payment_hash {} as payment is already claimable",
                                                                                                log_bytes!(payment_hash.0));
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
-                                                                               } else if total_value >= $payment_data.total_msat {
+                                                                               } else if total_value >= claimable_htlc.total_msat {
                                                                                        #[allow(unused_assignments)] {
                                                                                                committed_to_claimable = true;
                                                                                        }
@@ -3626,7 +3622,7 @@ where
                                                                                        new_events.push_back((events::Event::PaymentClaimable {
                                                                                                receiver_node_id: Some(receiver_node_id),
                                                                                                payment_hash,
-                                                                                               purpose: purpose(),
+                                                                                               purpose: $purpose,
                                                                                                amount_msat,
                                                                                                via_channel_id: Some(prev_channel_id),
                                                                                                via_user_channel_id: Some(prev_user_channel_id),
@@ -3674,49 +3670,23 @@ where
                                                                                                                fail_htlc!(claimable_htlc, payment_hash);
                                                                                                        }
                                                                                                }
-                                                                                               check_total_value!(payment_data, payment_preimage);
+                                                                                               let purpose = events::PaymentPurpose::InvoicePayment {
+                                                                                                       payment_preimage: payment_preimage.clone(),
+                                                                                                       payment_secret: payment_data.payment_secret,
+                                                                                               };
+                                                                                               check_total_value!(purpose);
                                                                                        },
                                                                                        OnionPayload::Spontaneous(preimage) => {
-                                                                                               let mut claimable_payments = self.claimable_payments.lock().unwrap();
-                                                                                               if claimable_payments.pending_claiming_payments.contains_key(&payment_hash) {
-                                                                                                       fail_htlc!(claimable_htlc, payment_hash);
-                                                                                               }
-                                                                                               match claimable_payments.claimable_payments.entry(payment_hash) {
-                                                                                                       hash_map::Entry::Vacant(e) => {
-                                                                                                               let amount_msat = claimable_htlc.value;
-                                                                                                               claimable_htlc.total_value_received = Some(amount_msat);
-                                                                                                               let claim_deadline = Some(claimable_htlc.cltv_expiry - HTLC_FAIL_BACK_BUFFER);
-                                                                                                               let purpose = events::PaymentPurpose::SpontaneousPayment(preimage);
-                                                                                                               e.insert(ClaimablePayment {
-                                                                                                                       purpose: purpose.clone(),
-                                                                                                                       onion_fields: Some(onion_fields.clone()),
-                                                                                                                       htlcs: vec![claimable_htlc],
-                                                                                                               });
-                                                                                                               let prev_channel_id = prev_funding_outpoint.to_channel_id();
-                                                                                                               new_events.push_back((events::Event::PaymentClaimable {
-                                                                                                                       receiver_node_id: Some(receiver_node_id),
-                                                                                                                       payment_hash,
-                                                                                                                       amount_msat,
-                                                                                                                       purpose,
-                                                                                                                       via_channel_id: Some(prev_channel_id),
-                                                                                                                       via_user_channel_id: Some(prev_user_channel_id),
-                                                                                                                       claim_deadline,
-                                                                                                                       onion_fields: Some(onion_fields),
-                                                                                                               }, None));
-                                                                                                       },
-                                                                                                       hash_map::Entry::Occupied(_) => {
-                                                                                                               log_trace!(self.logger, "Failing new keysend HTLC with payment_hash {} for a duplicative payment hash", log_bytes!(payment_hash.0));
-                                                                                                               fail_htlc!(claimable_htlc, payment_hash);
-                                                                                                       }
-                                                                                               }
+                                                                                               let purpose = events::PaymentPurpose::SpontaneousPayment(preimage);
+                                                                                               check_total_value!(purpose);
                                                                                        }
                                                                                }
                                                                        },
                                                                        hash_map::Entry::Occupied(inbound_payment) => {
-                                                                               if payment_data.is_none() {
+                                                                               if let OnionPayload::Spontaneous(_) = claimable_htlc.onion_payload {
                                                                                        log_trace!(self.logger, "Failing new keysend HTLC with payment_hash {} because we already have an inbound payment with the same payment hash", log_bytes!(payment_hash.0));
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
-                                                                               };
+                                                                               }
                                                                                let payment_data = payment_data.unwrap();
                                                                                if inbound_payment.get().payment_secret != payment_data.payment_secret {
                                                                                        log_trace!(self.logger, "Failing new HTLC with payment_hash {} as it didn't match our expected payment secret.", log_bytes!(payment_hash.0));
@@ -3726,7 +3696,11 @@ where
                                                                                                log_bytes!(payment_hash.0), payment_data.total_msat, inbound_payment.get().min_value_msat.unwrap());
                                                                                        fail_htlc!(claimable_htlc, payment_hash);
                                                                                } else {
-                                                                                       let payment_claimable_generated = check_total_value!(payment_data, inbound_payment.get().payment_preimage);
+                                                                                       let purpose = events::PaymentPurpose::InvoicePayment {
+                                                                                               payment_preimage: inbound_payment.get().payment_preimage,
+                                                                                               payment_secret: payment_data.payment_secret,
+                                                                                       };
+                                                                                       let payment_claimable_generated = check_total_value!(purpose);
                                                                                        if payment_claimable_generated {
                                                                                                inbound_payment.remove_entry();
                                                                                        }
@@ -4265,18 +4239,6 @@ where
                                break;
                        }
                        expected_amt_msat = htlc.total_value_received;
-
-                       if let OnionPayload::Spontaneous(_) = &htlc.onion_payload {
-                               // We don't currently support MPP for spontaneous payments, so just check
-                               // that there's one payment here and move on.
-                               if sources.len() != 1 {
-                                       log_error!(self.logger, "Somehow ended up with an MPP spontaneous payment - this should not be reachable!");
-                                       debug_assert!(false);
-                                       valid_mpp = false;
-                                       break;
-                               }
-                       }
-
                        claimable_amt_msat += htlc.value;
                }
                mem::drop(per_peer_state);