Refactor send_payment internals for retries
authorValentine Wallace <vwallace@protonmail.com>
Thu, 23 Sep 2021 20:30:15 +0000 (16:30 -0400)
committerValentine Wallace <vwallace@protonmail.com>
Tue, 28 Sep 2021 23:39:37 +0000 (19:39 -0400)
We want to reuse send_payment internal functions for retries,
so some need to now be parameterized by PaymentId to avoid
generating a new PaymentId on retry

lightning/src/ln/channelmanager.rs

index c231258eaa02809bdc99a12d6946eaf7200936bc..32bb0a8e85514abcc9f59ed404f6cd519899f49e 100644 (file)
@@ -1998,10 +1998,10 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        /// bit set (either as required or as available). If multiple paths are present in the Route,
        /// we assume the invoice had the basic_mpp feature set.
        pub fn send_payment(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>) -> Result<PaymentId, PaymentSendFailure> {
-               self.send_payment_internal(route, payment_hash, payment_secret, None)
+               self.send_payment_internal(route, payment_hash, payment_secret, None, None)
        }
 
-       fn send_payment_internal(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>, keysend_preimage: Option<PaymentPreimage>) -> Result<PaymentId, PaymentSendFailure> {
+       fn send_payment_internal(&self, route: &Route, payment_hash: PaymentHash, payment_secret: &Option<PaymentSecret>, keysend_preimage: Option<PaymentPreimage>, payment_id: Option<PaymentId>) -> Result<PaymentId, PaymentSendFailure> {
                if route.paths.len() < 1 {
                        return Err(PaymentSendFailure::ParameterError(APIError::RouteError{err: "There must be at least one path to send over"}));
                }
@@ -2017,7 +2017,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let mut total_value = 0;
                let our_node_id = self.get_our_node_id();
                let mut path_errs = Vec::with_capacity(route.paths.len());
-               let payment_id = PaymentId(self.keys_manager.get_secure_random_bytes());
+               let payment_id = if let Some(id) = payment_id { id } else { PaymentId(self.keys_manager.get_secure_random_bytes()) };
                'path_check: for path in route.paths.iter() {
                        if path.len() < 1 || path.len() > 20 {
                                path_errs.push(Err(APIError::RouteError{err: "Path didn't go anywhere/had bogus size"}));
@@ -2083,7 +2083,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        None => PaymentPreimage(self.keys_manager.get_secure_random_bytes()),
                };
                let payment_hash = PaymentHash(Sha256::hash(&preimage.0).into_inner());
-               match self.send_payment_internal(route, payment_hash, &None, Some(preimage)) {
+               match self.send_payment_internal(route, payment_hash, &None, Some(preimage), None) {
                        Ok(payment_id) => Ok((payment_hash, payment_id)),
                        Err(e) => Err(e)
                }
@@ -5936,7 +5936,7 @@ mod tests {
 
                let test_preimage = PaymentPreimage([42; 32]);
                let mismatch_payment_hash = PaymentHash([43; 32]);
-               let _ = nodes[0].node.send_payment_internal(&route, mismatch_payment_hash, &None, Some(test_preimage)).unwrap();
+               let _ = nodes[0].node.send_payment_internal(&route, mismatch_payment_hash, &None, Some(test_preimage), None).unwrap();
                check_added_monitors!(nodes[0], 1);
 
                let updates = get_htlc_update_msgs!(nodes[0], nodes[1].node.get_our_node_id());
@@ -5973,7 +5973,7 @@ mod tests {
                let test_preimage = PaymentPreimage([42; 32]);
                let test_secret = PaymentSecret([43; 32]);
                let payment_hash = PaymentHash(Sha256::hash(&test_preimage.0).into_inner());
-               let _ = nodes[0].node.send_payment_internal(&route, payment_hash, &Some(test_secret), Some(test_preimage)).unwrap();
+               let _ = nodes[0].node.send_payment_internal(&route, payment_hash, &Some(test_secret), Some(test_preimage), None).unwrap();
                check_added_monitors!(nodes[0], 1);
 
                let updates = get_htlc_update_msgs!(nodes[0], nodes[1].node.get_our_node_id());