Impl Base AMP in the receive pipeline and expose payment_secret
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 1a207bf9104f1b4dcc45b9329c0cedf43af3986f..efcf2159a068fa21b6f18e9674c98960a0f2f501 100644 (file)
@@ -171,6 +171,9 @@ pub struct PaymentHash(pub [u8;32]);
 /// payment_preimage type, use to route payment between hop
 #[derive(Hash, Copy, Clone, PartialEq, Eq, Debug)]
 pub struct PaymentPreimage(pub [u8;32]);
+/// payment_secret type, use to authenticate sender to the receiver and tie MPP HTLCs together
+#[derive(Hash, Copy, Clone, PartialEq, Eq, Debug)]
+pub struct PaymentSecret(pub [u8;32]);
 
 type ShutdownResult = (Option<OutPoint>, ChannelMonitorUpdate, Vec<(HTLCSource, PaymentHash)>);
 
@@ -288,11 +291,14 @@ pub(super) struct ChannelHolder<ChanSigner: ChannelKeys> {
        /// guarantees are made about the existence of a channel with the short id here, nor the short
        /// ids in the PendingHTLCInfo!
        pub(super) forward_htlcs: HashMap<u64, Vec<HTLCForwardInfo>>,
-       /// Tracks HTLCs that were to us and can be failed/claimed by the user
+       /// (payment_hash, payment_secret) -> Vec<HTLCs> for tracking HTLCs that
+       /// were to us and can be failed/claimed by the user
        /// Note that while this is held in the same mutex as the channels themselves, no consistency
        /// guarantees are made about the channels given here actually existing anymore by the time you
        /// go to read them!
-       claimable_htlcs: HashMap<PaymentHash, Vec<ClaimableHTLC>>,
+       /// TODO: We need to time out HTLCs sitting here which are waiting on other AMP HTLCs to
+       /// arrive.
+       claimable_htlcs: HashMap<(PaymentHash, Option<PaymentSecret>), Vec<ClaimableHTLC>>,
        /// Messages to send to peers - pushed to in the same lock that they are generated in (except
        /// for broadcast messages, where ordering isn't as strict).
        pub(super) pending_msg_events: Vec<events::MessageSendEvent>,
@@ -1215,7 +1221,16 @@ impl<ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref> ChannelMan
        /// In case of APIError::MonitorUpdateFailed, the commitment update has been irrevocably
        /// committed on our end and we're just waiting for a monitor update to send it. Do NOT retry
        /// the payment via a different route unless you intend to pay twice!
-       pub fn send_payment(&self, route: Route, payment_hash: PaymentHash) -> Result<(), APIError> {
+       ///
+       /// payment_secret is unrelated to payment_hash (or PaymentPreimage) and exists to authenticate
+       /// the sender to the recipient and prevent payment-probing (deanonymization) attacks. For
+       /// newer nodes, it will be provided to you in the invoice. If you do not have one, the Route
+       /// must not contain multiple paths as multi-path payments require a recipient-provided
+       /// payment_secret.
+       /// If a payment_secret *is* provided, we assume that the invoice had the payment_secret feature
+       /// bit set (either as required or as available). If multiple paths are present in the Route,
+       /// we assume the invoice had the basic_mpp feature set.
+       pub fn send_payment(&self, route: Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>) -> Result<(), APIError> {
                if route.hops.len() < 1 || route.hops.len() > 20 {
                        return Err(APIError::RouteError{err: "Route didn't go anywhere/had bogus size"});
                }
@@ -1232,7 +1247,7 @@ impl<ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref> ChannelMan
 
                let onion_keys = secp_call!(onion_utils::construct_onion_keys(&self.secp_ctx, &route, &session_priv),
                                APIError::RouteError{err: "Pubkey along hop was maliciously selected"});
-               let (onion_payloads, htlc_msat, htlc_cltv) = onion_utils::build_onion_payloads(&route, cur_height)?;
+               let (onion_payloads, htlc_msat, htlc_cltv) = onion_utils::build_onion_payloads(&route, payment_secret, cur_height)?;
                if onion_utils::route_size_insane(&onion_payloads) {
                        return Err(APIError::RouteError{err: "Route size too large considering onion data"});
                }
@@ -1461,7 +1476,9 @@ impl<ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref> ChannelMan
                                                                                        htlc_id: prev_htlc_id,
                                                                                        incoming_packet_shared_secret: forward_info.incoming_shared_secret,
                                                                                });
-                                                                               failed_forwards.push((htlc_source, forward_info.payment_hash, 0x4000 | 10, None));
+                                                                               failed_forwards.push((htlc_source, forward_info.payment_hash,
+                                                                                       HTLCFailReason::Reason { failure_code: 0x4000 | 10, data: Vec::new() }
+                                                                               ));
                                                                        },
                                                                        HTLCForwardInfo::FailHTLC { .. } => {
                                                                                // Channel went away before we could fail it. This implies
@@ -1497,7 +1514,9 @@ impl<ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref> ChannelMan
                                                                                                panic!("Stated return value requirements in send_htlc() were not met");
                                                                                        }
                                                                                        let chan_update = self.get_channel_update(chan.get()).unwrap();
-                                                                                       failed_forwards.push((htlc_source, payment_hash, 0x1000 | 7, Some(chan_update)));
+                                                                                       failed_forwards.push((htlc_source, payment_hash,
+                                                                                               HTLCFailReason::Reason { failure_code: 0x1000 | 7, data: chan_update.encode_with_len() }
+                                                                                       ));
                                                                                        continue;
                                                                                },
                                                                                Ok(update_add) => {
@@ -1604,15 +1623,49 @@ impl<ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref> ChannelMan
                                                                        htlc_id: prev_htlc_id,
                                                                        incoming_packet_shared_secret: incoming_shared_secret,
                                                                };
-                                                               channel_state.claimable_htlcs.entry(payment_hash).or_insert(Vec::new()).push(ClaimableHTLC {
+
+                                                               let mut total_value = 0;
+                                                               let payment_secret_opt =
+                                                                       if let &Some(ref data) = &payment_data { Some(data.payment_secret.clone()) } else { None };
+                                                               let htlcs = channel_state.claimable_htlcs.entry((payment_hash, payment_secret_opt))
+                                                                       .or_insert(Vec::new());
+                                                               htlcs.push(ClaimableHTLC {
                                                                        prev_hop,
                                                                        value: amt_to_forward,
-                                                                       payment_data,
-                                                               });
-                                                               new_events.push(events::Event::PaymentReceived {
-                                                                       payment_hash: payment_hash,
-                                                                       amt: amt_to_forward,
+                                                                       payment_data: payment_data.clone(),
                                                                });
+                                                               if let &Some(ref data) = &payment_data {
+                                                                       for htlc in htlcs.iter() {
+                                                                               total_value += htlc.value;
+                                                                               if htlc.payment_data.as_ref().unwrap().total_msat != data.total_msat {
+                                                                                       total_value = msgs::MAX_VALUE_MSAT;
+                                                                               }
+                                                                               if total_value >= msgs::MAX_VALUE_MSAT { break; }
+                                                                       }
+                                                                       if total_value >= msgs::MAX_VALUE_MSAT || total_value > data.total_msat  {
+                                                                               for htlc in htlcs.iter() {
+                                                                                       failed_forwards.push((HTLCSource::PreviousHopData(HTLCPreviousHopData {
+                                                                                                       short_channel_id: htlc.prev_hop.short_channel_id,
+                                                                                                       htlc_id: htlc.prev_hop.htlc_id,
+                                                                                                       incoming_packet_shared_secret: htlc.prev_hop.incoming_packet_shared_secret,
+                                                                                               }), payment_hash,
+                                                                                               HTLCFailReason::Reason { failure_code: 0x4000 | 15, data: byte_utils::be64_to_array(htlc.value).to_vec() }
+                                                                                       ));
+                                                                               }
+                                                                       } else if total_value == data.total_msat {
+                                                                               new_events.push(events::Event::PaymentReceived {
+                                                                                       payment_hash: payment_hash,
+                                                                                       payment_secret: Some(data.payment_secret),
+                                                                                       amt: total_value,
+                                                                               });
+                                                                       }
+                                                               } else {
+                                                                       new_events.push(events::Event::PaymentReceived {
+                                                                               payment_hash: payment_hash,
+                                                                               payment_secret: None,
+                                                                               amt: amt_to_forward,
+                                                                       });
+                                                               }
                                                        },
                                                        HTLCForwardInfo::AddHTLC { .. } => {
                                                                panic!("short_channel_id == 0 should imply any pending_forward entries are of type Receive");
@@ -1626,11 +1679,8 @@ impl<ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref> ChannelMan
                        }
                }
 
-               for (htlc_source, payment_hash, failure_code, update) in failed_forwards.drain(..) {
-                       match update {
-                               None => self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source, &payment_hash, HTLCFailReason::Reason { failure_code, data: Vec::new() }),
-                               Some(chan_update) => self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source, &payment_hash, HTLCFailReason::Reason { failure_code, data: chan_update.encode_with_len() }),
-                       };
+               for (htlc_source, payment_hash, failure_reason) in failed_forwards.drain(..) {
+                       self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source, &payment_hash, failure_reason);
                }
 
                for (their_node_id, err) in handle_errors.drain(..) {
@@ -1672,11 +1722,11 @@ impl<ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref> ChannelMan
        /// along the path (including in our own channel on which we received it).
        /// Returns false if no payment was found to fail backwards, true if the process of failing the
        /// HTLC backwards has been started.
-       pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash) -> bool {
+       pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash, payment_secret: &Option<PaymentSecret>) -> bool {
                let _ = self.total_consistency_lock.read().unwrap();
 
                let mut channel_state = Some(self.channel_state.lock().unwrap());
-               let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(payment_hash);
+               let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&(*payment_hash, *payment_secret));
                if let Some(mut sources) = removed_source {
                        for htlc in sources.drain(..) {
                                if channel_state.is_none() { channel_state = Some(self.channel_state.lock().unwrap()); }
@@ -1798,13 +1848,13 @@ impl<ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref> ChannelMan
        /// motivated attackers.
        ///
        /// May panic if called except in response to a PaymentReceived event.
-       pub fn claim_funds(&self, payment_preimage: PaymentPreimage, expected_amount: u64) -> bool {
+       pub fn claim_funds(&self, payment_preimage: PaymentPreimage, payment_secret: &Option<PaymentSecret>, expected_amount: u64) -> bool {
                let payment_hash = PaymentHash(Sha256::hash(&payment_preimage.0).into_inner());
 
                let _ = self.total_consistency_lock.read().unwrap();
 
                let mut channel_state = Some(self.channel_state.lock().unwrap());
-               let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&payment_hash);
+               let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&(payment_hash, *payment_secret));
                if let Some(mut sources) = removed_source {
                        for htlc in sources.drain(..) {
                                if channel_state.is_none() { channel_state = Some(self.channel_state.lock().unwrap()); }