Merge pull request #1053 from valentinewallace/2021-08-dedup-payment-sent
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 2af1d8086333ed420903e77903be0d5349991d7e..46cbaac32925f3ea47cfde796aa1d3510e1a0511 100644 (file)
@@ -172,6 +172,22 @@ struct ClaimableHTLC {
        onion_payload: OnionPayload,
 }
 
+/// A payment identifier used to correlate an MPP payment's per-path HTLC sources internally.
+#[derive(Hash, Copy, Clone, PartialEq, Eq, Debug)]
+pub(crate) struct MppId(pub [u8; 32]);
+
+impl Writeable for MppId {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
+               self.0.write(w)
+       }
+}
+
+impl Readable for MppId {
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let buf: [u8; 32] = Readable::read(r)?;
+               Ok(MppId(buf))
+       }
+}
 /// Tracks the inbound corresponding to an outbound HTLC
 #[derive(Clone, PartialEq)]
 pub(crate) enum HTLCSource {
@@ -182,6 +198,7 @@ pub(crate) enum HTLCSource {
                /// Technically we can recalculate this from the route, but we cache it here to avoid
                /// doing a double-pass on route when we get a failure back
                first_hop_htlc_msat: u64,
+               mpp_id: MppId,
        },
 }
 #[cfg(test)]
@@ -191,6 +208,7 @@ impl HTLCSource {
                        path: Vec::new(),
                        session_priv: SecretKey::from_slice(&[1; 32]).unwrap(),
                        first_hop_htlc_msat: 0,
+                       mpp_id: MppId([2; 32]),
                }
        }
 }
@@ -472,8 +490,11 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// which may generate a claim event, we may receive similar duplicate claim/fail MonitorEvents
        /// after reloading from disk while replaying blocks against ChannelMonitors.
        ///
+       /// Each payment has each of its MPP part's session_priv bytes in the HashSet of the map (even
+       /// payments over a single path).
+       ///
        /// Locked *after* channel_state.
-       pending_outbound_payments: Mutex<HashSet<[u8; 32]>>,
+       pending_outbound_payments: Mutex<HashMap<MppId, HashSet<[u8; 32]>>>,
 
        our_network_key: SecretKey,
        our_network_pubkey: PublicKey,
@@ -1150,7 +1171,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                pending_msg_events: Vec::new(),
                        }),
                        pending_inbound_payments: Mutex::new(HashMap::new()),
-                       pending_outbound_payments: Mutex::new(HashSet::new()),
+                       pending_outbound_payments: Mutex::new(HashMap::new()),
 
                        our_network_key: keys_manager.get_node_secret(),
                        our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &keys_manager.get_node_secret()),
@@ -1832,7 +1853,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        }
 
        // Only public for testing, this should otherwise never be called direcly
-       pub(crate) fn send_payment_along_path(&self, path: &Vec<RouteHop>, payment_hash: &PaymentHash, payment_secret: &Option<PaymentSecret>, total_value: u64, cur_height: u32, keysend_preimage: &Option<PaymentPreimage>) -> Result<(), APIError> {
+       pub(crate) fn send_payment_along_path(&self, path: &Vec<RouteHop>, payment_hash: &PaymentHash, payment_secret: &Option<PaymentSecret>, total_value: u64, cur_height: u32, mpp_id: MppId, keysend_preimage: &Option<PaymentPreimage>) -> Result<(), APIError> {
                log_trace!(self.logger, "Attempting to send payment for path with next hop {}", path.first().unwrap().short_channel_id);
                let prng_seed = self.keys_manager.get_secure_random_bytes();
                let session_priv_bytes = self.keys_manager.get_secure_random_bytes();
@@ -1847,7 +1868,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash);
 
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
-               assert!(self.pending_outbound_payments.lock().unwrap().insert(session_priv_bytes));
+               let mut pending_outbounds = self.pending_outbound_payments.lock().unwrap();
+               let sessions = pending_outbounds.entry(mpp_id).or_insert(HashSet::new());
+               assert!(sessions.insert(session_priv_bytes));
 
                let err: Result<(), _> = loop {
                        let mut channel_lock = self.channel_state.lock().unwrap();
@@ -1869,6 +1892,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                path: path.clone(),
                                                session_priv: session_priv.clone(),
                                                first_hop_htlc_msat: htlc_msat,
+                                               mpp_id,
                                        }, onion_packet, &self.logger), channel_state, chan)
                                } {
                                        Some((update_add, commitment_signed, monitor_update)) => {
@@ -1968,6 +1992,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let mut total_value = 0;
                let our_node_id = self.get_our_node_id();
                let mut path_errs = Vec::with_capacity(route.paths.len());
+               let mpp_id = MppId(self.keys_manager.get_secure_random_bytes());
                'path_check: for path in route.paths.iter() {
                        if path.len() < 1 || path.len() > 20 {
                                path_errs.push(Err(APIError::RouteError{err: "Path didn't go anywhere/had bogus size"}));
@@ -1989,7 +2014,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let cur_height = self.best_block.read().unwrap().height() + 1;
                let mut results = Vec::new();
                for path in route.paths.iter() {
-                       results.push(self.send_payment_along_path(&path, &payment_hash, payment_secret, total_value, cur_height, &keysend_preimage));
+                       results.push(self.send_payment_along_path(&path, &payment_hash, payment_secret, total_value, cur_height, mpp_id, &keysend_preimage));
                }
                let mut has_ok = false;
                let mut has_err = false;
@@ -2824,23 +2849,28 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        self.fail_htlc_backwards_internal(channel_state,
                                                htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data});
                                },
-                               HTLCSource::OutboundRoute { session_priv, .. } => {
-                                       if {
-                                               let mut session_priv_bytes = [0; 32];
-                                               session_priv_bytes.copy_from_slice(&session_priv[..]);
-                                               self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
-                                       } {
-                                               self.pending_events.lock().unwrap().push(
-                                                       events::Event::PaymentFailed {
-                                                               payment_hash,
-                                                               rejected_by_dest: false,
-                                                               network_update: None,
-#[cfg(test)]
-                                                               error_code: None,
-#[cfg(test)]
-                                                               error_data: None,
+                               HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => {
+                                       let mut session_priv_bytes = [0; 32];
+                                       session_priv_bytes.copy_from_slice(&session_priv[..]);
+                                       let mut outbounds = self.pending_outbound_payments.lock().unwrap();
+                                       if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(mpp_id) {
+                                               if sessions.get_mut().remove(&session_priv_bytes) {
+                                                       self.pending_events.lock().unwrap().push(
+                                                               events::Event::PaymentFailed {
+                                                                       payment_hash,
+                                                                       rejected_by_dest: false,
+                                                                       network_update: None,
+                                                                       all_paths_failed: sessions.get().len() == 0,
+                                                                       #[cfg(test)]
+                                                                       error_code: None,
+                                                                       #[cfg(test)]
+                                                                       error_data: None,
+                                                               }
+                                                       );
+                                                       if sessions.get().len() == 0 {
+                                                               sessions.remove();
                                                        }
-                                               )
+                                               }
                                        } else {
                                                log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                        }
@@ -2865,12 +2895,21 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                // from block_connected which may run during initialization prior to the chain_monitor
                // being fully configured. See the docs for `ChannelManagerReadArgs` for more.
                match source {
-                       HTLCSource::OutboundRoute { ref path, session_priv, .. } => {
-                               if {
-                                       let mut session_priv_bytes = [0; 32];
-                                       session_priv_bytes.copy_from_slice(&session_priv[..]);
-                                       !self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
-                               } {
+                       HTLCSource::OutboundRoute { ref path, session_priv, mpp_id, .. } => {
+                               let mut session_priv_bytes = [0; 32];
+                               session_priv_bytes.copy_from_slice(&session_priv[..]);
+                               let mut outbounds = self.pending_outbound_payments.lock().unwrap();
+                               let mut all_paths_failed = false;
+                               if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(mpp_id) {
+                                       if !sessions.get_mut().remove(&session_priv_bytes) {
+                                               log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
+                                               return;
+                                       }
+                                       if sessions.get().len() == 0 {
+                                               all_paths_failed = true;
+                                               sessions.remove();
+                                       }
+                               } else {
                                        log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
                                        return;
                                }
@@ -2890,6 +2929,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                payment_hash: payment_hash.clone(),
                                                                rejected_by_dest: !payment_retryable,
                                                                network_update,
+                                                               all_paths_failed,
 #[cfg(test)]
                                                                error_code: onion_error_code,
 #[cfg(test)]
@@ -2915,6 +2955,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                payment_hash: payment_hash.clone(),
                                                                rejected_by_dest: path.len() == 1,
                                                                network_update: None,
+                                                               all_paths_failed,
 #[cfg(test)]
                                                                error_code: Some(*failure_code),
 #[cfg(test)]
@@ -3111,17 +3152,18 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
        fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_preimage: PaymentPreimage, forwarded_htlc_value_msat: Option<u64>, from_onchain: bool) {
                match source {
-                       HTLCSource::OutboundRoute { session_priv, .. } => {
+                       HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => {
                                mem::drop(channel_state_lock);
-                               if {
-                                       let mut session_priv_bytes = [0; 32];
-                                       session_priv_bytes.copy_from_slice(&session_priv[..]);
-                                       self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes)
-                               } {
-                                       let mut pending_events = self.pending_events.lock().unwrap();
-                                       pending_events.push(events::Event::PaymentSent {
-                                               payment_preimage
-                                       });
+                               let mut session_priv_bytes = [0; 32];
+                               session_priv_bytes.copy_from_slice(&session_priv[..]);
+                               let mut outbounds = self.pending_outbound_payments.lock().unwrap();
+                               let found_payment = if let Some(mut sessions) = outbounds.remove(&mpp_id) {
+                                       sessions.remove(&session_priv_bytes)
+                               } else { false };
+                               if found_payment {
+                                       self.pending_events.lock().unwrap().push(
+                                               events::Event::PaymentSent { payment_preimage }
+                                       );
                                } else {
                                        log_trace!(self.logger, "Received duplicative fulfill for HTLC with payment_preimage {}", log_bytes!(payment_preimage.0));
                                }
@@ -4923,14 +4965,60 @@ impl Readable for ClaimableHTLC {
        }
 }
 
-impl_writeable_tlv_based_enum!(HTLCSource,
-       (0, OutboundRoute) => {
-               (0, session_priv, required),
-               (2, first_hop_htlc_msat, required),
-               (4, path, vec_type),
-       }, ;
-       (1, PreviousHopData)
-);
+impl Readable for HTLCSource {
+       fn read<R: Read>(reader: &mut R) -> Result<Self, DecodeError> {
+               let id: u8 = Readable::read(reader)?;
+               match id {
+                       0 => {
+                               let mut session_priv: ::util::ser::OptionDeserWrapper<SecretKey> = ::util::ser::OptionDeserWrapper(None);
+                               let mut first_hop_htlc_msat: u64 = 0;
+                               let mut path = Some(Vec::new());
+                               let mut mpp_id = None;
+                               read_tlv_fields!(reader, {
+                                       (0, session_priv, required),
+                                       (1, mpp_id, option),
+                                       (2, first_hop_htlc_msat, required),
+                                       (4, path, vec_type),
+                               });
+                               if mpp_id.is_none() {
+                                       // For backwards compat, if there was no mpp_id written, use the session_priv bytes
+                                       // instead.
+                                       mpp_id = Some(MppId(*session_priv.0.unwrap().as_ref()));
+                               }
+                               Ok(HTLCSource::OutboundRoute {
+                                       session_priv: session_priv.0.unwrap(),
+                                       first_hop_htlc_msat: first_hop_htlc_msat,
+                                       path: path.unwrap(),
+                                       mpp_id: mpp_id.unwrap(),
+                               })
+                       }
+                       1 => Ok(HTLCSource::PreviousHopData(Readable::read(reader)?)),
+                       _ => Err(DecodeError::UnknownRequiredFeature),
+               }
+       }
+}
+
+impl Writeable for HTLCSource {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::io::Error> {
+               match self {
+                       HTLCSource::OutboundRoute { ref session_priv, ref first_hop_htlc_msat, ref path, mpp_id } => {
+                               0u8.write(writer)?;
+                               let mpp_id_opt = Some(mpp_id);
+                               write_tlv_fields!(writer, {
+                                       (0, session_priv, required),
+                                       (1, mpp_id_opt, option),
+                                       (2, first_hop_htlc_msat, required),
+                                       (4, path, vec_type),
+                                });
+                       }
+                       HTLCSource::PreviousHopData(ref field) => {
+                               1u8.write(writer)?;
+                               field.write(writer)?;
+                       }
+               }
+               Ok(())
+       }
+}
 
 impl_writeable_tlv_based_enum!(HTLCFailReason,
        (0, LightningError) => {
@@ -5051,12 +5139,21 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
                }
 
                let pending_outbound_payments = self.pending_outbound_payments.lock().unwrap();
-               (pending_outbound_payments.len() as u64).write(writer)?;
-               for session_priv in pending_outbound_payments.iter() {
-                       session_priv.write(writer)?;
+               // For backwards compat, write the session privs and their total length.
+               let mut num_pending_outbounds_compat: u64 = 0;
+               for (_, outbounds) in pending_outbound_payments.iter() {
+                       num_pending_outbounds_compat += outbounds.len() as u64;
+               }
+               num_pending_outbounds_compat.write(writer)?;
+               for (_, outbounds) in pending_outbound_payments.iter() {
+                       for outbound in outbounds.iter() {
+                               outbound.write(writer)?;
+                       }
                }
 
-               write_tlv_fields!(writer, {});
+               write_tlv_fields!(writer, {
+                       (1, pending_outbound_payments, required),
+               });
 
                Ok(())
        }
@@ -5309,15 +5406,23 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                        }
                }
 
-               let pending_outbound_payments_count: u64 = Readable::read(reader)?;
-               let mut pending_outbound_payments: HashSet<[u8; 32]> = HashSet::with_capacity(cmp::min(pending_outbound_payments_count as usize, MAX_ALLOC_SIZE/32));
-               for _ in 0..pending_outbound_payments_count {
-                       if !pending_outbound_payments.insert(Readable::read(reader)?) {
-                               return Err(DecodeError::InvalidValue);
-                       }
+               let pending_outbound_payments_count_compat: u64 = Readable::read(reader)?;
+               let mut pending_outbound_payments_compat: HashMap<MppId, HashSet<[u8; 32]>> =
+                       HashMap::with_capacity(cmp::min(pending_outbound_payments_count_compat as usize, MAX_ALLOC_SIZE/32));
+               for _ in 0..pending_outbound_payments_count_compat {
+                       let session_priv = Readable::read(reader)?;
+                       if pending_outbound_payments_compat.insert(MppId(session_priv), [session_priv].iter().cloned().collect()).is_some() {
+                               return Err(DecodeError::InvalidValue)
+                       };
                }
 
-               read_tlv_fields!(reader, {});
+               let mut pending_outbound_payments = None;
+               read_tlv_fields!(reader, {
+                       (1, pending_outbound_payments, option),
+               });
+               if pending_outbound_payments.is_none() {
+                       pending_outbound_payments = Some(pending_outbound_payments_compat);
+               }
 
                let mut secp_ctx = Secp256k1::new();
                secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes());
@@ -5338,7 +5443,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                                pending_msg_events: Vec::new(),
                        }),
                        pending_inbound_payments: Mutex::new(pending_inbound_payments),
-                       pending_outbound_payments: Mutex::new(pending_outbound_payments),
+                       pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()),
 
                        our_network_key: args.keys_manager.get_node_secret(),
                        our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &args.keys_manager.get_node_secret()),
@@ -5376,7 +5481,7 @@ mod tests {
        use bitcoin::hashes::sha256::Hash as Sha256;
        use core::time::Duration;
        use ln::{PaymentPreimage, PaymentHash, PaymentSecret};
-       use ln::channelmanager::PaymentSendFailure;
+       use ln::channelmanager::{MppId, PaymentSendFailure};
        use ln::features::{InitFeatures, InvoiceFeatures};
        use ln::functional_test_utils::*;
        use ln::msgs;
@@ -5527,10 +5632,11 @@ mod tests {
                let net_graph_msg_handler = &nodes[0].net_graph_msg_handler;
                let route = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler.network_graph, &nodes[1].node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), 100_000, TEST_FINAL_CLTV, &logger).unwrap();
                let (payment_preimage, our_payment_hash, payment_secret) = get_payment_preimage_hash!(&nodes[1]);
+               let mpp_id = MppId([42; 32]);
                // Use the utility function send_payment_along_path to send the payment with MPP data which
                // indicates there are more HTLCs coming.
                let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match.
-               nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, &None).unwrap();
+               nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, mpp_id, &None).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
@@ -5560,7 +5666,7 @@ mod tests {
                expect_payment_failed!(nodes[0], our_payment_hash, true);
 
                // Send the second half of the original MPP payment.
-               nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, &None).unwrap();
+               nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, mpp_id, &None).unwrap();
                check_added_monitors!(nodes[0], 1);
                let mut events = nodes[0].node.get_and_clear_pending_msg_events();
                assert_eq!(events.len(), 1);
@@ -5598,7 +5704,8 @@ mod tests {
                nodes[0].node.handle_revoke_and_ack(&nodes[1].node.get_our_node_id(), &bs_third_raa);
                check_added_monitors!(nodes[0], 1);
 
-               // There's an existing bug that generates a PaymentSent event for each MPP path, so handle that here.
+               // Note that successful MPP payments will generate 1 event upon the first path's success. No
+               // further events will be generated for subsequence path successes.
                let events = nodes[0].node.get_and_clear_pending_events();
                match events[0] {
                        Event::PaymentSent { payment_preimage: ref preimage } => {
@@ -5606,12 +5713,6 @@ mod tests {
                        },
                        _ => panic!("Unexpected event"),
                }
-               match events[1] {
-                       Event::PaymentSent { payment_preimage: ref preimage } => {
-                               assert_eq!(payment_preimage, *preimage);
-                       },
-                       _ => panic!("Unexpected event"),
-               }
        }
 
        #[test]