Add FailureCode enum and ChannelManager::fail_htlc_backwards_with_reason
[rust-lightning] / lightning / src / ln / channelmanager.rs
index f502a3336cadd935a474b4abe89ef189f8432b0f..72bf349e9b4f8fd698479888865cc241fff30f40 100644 (file)
@@ -26,12 +26,10 @@ use bitcoin::network::constants::Network;
 
 use bitcoin::hashes::Hash;
 use bitcoin::hashes::sha256::Hash as Sha256;
-use bitcoin::hashes::sha256d::Hash as Sha256dHash;
 use bitcoin::hash_types::{BlockHash, Txid};
 
 use bitcoin::secp256k1::{SecretKey,PublicKey};
 use bitcoin::secp256k1::Secp256k1;
-use bitcoin::secp256k1::ecdh::SharedSecret;
 use bitcoin::{LockTime, secp256k1, Sequence};
 
 use crate::chain;
@@ -291,6 +289,25 @@ struct ReceiveError {
        msg: &'static str,
 }
 
+/// This enum is used to specify which error data to send to peers when failing back an HTLC
+/// using [`ChannelManager::fail_htlc_backwards_with_reason`].
+///
+/// For more info on failure codes, see <https://github.com/lightning/bolts/blob/master/04-onion-routing.md#failure-messages>.
+#[derive(Clone, Copy)]
+pub enum FailureCode {
+       /// We had a temporary error processing the payment. Useful if no other error codes fit
+       /// and you want to indicate that the payer may want to retry.
+       TemporaryNodeFailure             = 0x2000 | 2,
+       /// We have a required feature which was not in this onion. For example, you may require
+       /// some additional metadata that was not provided with this payment.
+       RequiredNodeFeatureMissing       = 0x4000 | 0x2000 | 3,
+       /// You may wish to use this when a `payment_preimage` is unknown, or the CLTV expiry of
+       /// the HTLC is too close to the current block height for safe handling.
+       /// Using this failure code in [`ChannelManager::fail_htlc_backwards_with_reason`] is
+       /// equivalent to calling [`ChannelManager::fail_htlc_backwards`].
+       IncorrectOrUnknownPaymentDetails = 0x4000 | 15,
+}
+
 type ShutdownResult = (Option<(OutPoint, ChannelMonitorUpdate)>, Vec<(HTLCSource, PaymentHash, PublicKey, [u8; 32])>);
 
 /// Error type returned across the peer_state mutex boundary. When an Err is generated for a
@@ -721,7 +738,6 @@ where
        #[cfg(not(test))]
        short_to_chan_info: FairRwLock<HashMap<u64, (PublicKey, [u8; 32])>>,
 
-       our_network_key: SecretKey,
        our_network_pubkey: PublicKey,
 
        inbound_payment_key: inbound_payment::ExpandedKey,
@@ -1457,8 +1473,7 @@ where
                        id_to_peer: Mutex::new(HashMap::new()),
                        short_to_chan_info: FairRwLock::new(HashMap::new()),
 
-                       our_network_key: node_signer.get_node_secret(Recipient::Node).unwrap(),
-                       our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &node_signer.get_node_secret(Recipient::Node).unwrap()),
+                       our_network_pubkey: node_signer.get_node_id(Recipient::Node).unwrap(),
                        secp_ctx,
 
                        inbound_payment_key: expanded_inbound_key,
@@ -2017,7 +2032,9 @@ where
                        return_malformed_err!("invalid ephemeral pubkey", 0x8000 | 0x4000 | 6);
                }
 
-               let shared_secret = SharedSecret::new(&msg.onion_routing_packet.public_key.unwrap(), &self.our_network_key).secret_bytes();
+               let shared_secret = self.node_signer.ecdh(
+                       Recipient::Node, &msg.onion_routing_packet.public_key.unwrap(), None
+               ).unwrap().secret_bytes();
 
                if msg.onion_routing_packet.version != 0 {
                        //TODO: Spec doesn't indicate if we should only hash hop_data here (and in other
@@ -2266,7 +2283,7 @@ where
        }
        fn get_channel_update_for_onion(&self, short_channel_id: u64, chan: &Channel<<SP::Target as SignerProvider>::Signer>) -> Result<msgs::ChannelUpdate, LightningError> {
                log_trace!(self.logger, "Generating channel update for channel {}", log_bytes!(chan.channel_id()));
-               let were_node_one = PublicKey::from_secret_key(&self.secp_ctx, &self.our_network_key).serialize()[..] < chan.get_counterparty_node_id().serialize()[..];
+               let were_node_one = self.our_network_pubkey.serialize()[..] < chan.get_counterparty_node_id().serialize()[..];
 
                let unsigned = msgs::UnsignedChannelUpdate {
                        chain_hash: self.genesis_hash,
@@ -2280,9 +2297,11 @@ where
                        fee_proportional_millionths: chan.get_fee_proportional_millionths(),
                        excess_data: Vec::new(),
                };
-
-               let msg_hash = Sha256dHash::hash(&unsigned.encode()[..]);
-               let sig = self.secp_ctx.sign_ecdsa(&hash_to_message!(&msg_hash[..]), &self.our_network_key);
+               // Panic on failure to signal LDK should be restarted to retry signing the `ChannelUpdate`.
+               // If we returned an error and the `node_signer` cannot provide a signature for whatever
+               // reason`, we wouldn't be able to receive inbound payments through the corresponding
+               // channel.
+               let sig = self.node_signer.sign_gossip_message(msgs::UnsignedGossipMessage::ChannelUpdate(&unsigned)).unwrap();
 
                Ok(msgs::ChannelUpdate {
                        signature: sig,
@@ -2923,9 +2942,9 @@ where
                                                                                        }
                                                                                }
                                                                                if let PendingHTLCRouting::Forward { onion_packet, .. } = routing {
-                                                                                       let phantom_secret_res = self.node_signer.get_node_secret(Recipient::PhantomNode);
-                                                                                       if phantom_secret_res.is_ok() && fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, short_chan_id, &self.genesis_hash) {
-                                                                                               let phantom_shared_secret = SharedSecret::new(&onion_packet.public_key.unwrap(), &phantom_secret_res.unwrap()).secret_bytes();
+                                                                                       let phantom_pubkey_res = self.node_signer.get_node_id(Recipient::PhantomNode);
+                                                                                       if phantom_pubkey_res.is_ok() && fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, short_chan_id, &self.genesis_hash) {
+                                                                                               let phantom_shared_secret = self.node_signer.ecdh(Recipient::PhantomNode, &onion_packet.public_key.unwrap(), None).unwrap().secret_bytes();
                                                                                                let next_hop = match onion_utils::decode_next_payment_hop(phantom_shared_secret, &onion_packet.hop_data, onion_packet.hmac, payment_hash) {
                                                                                                        Ok(res) => res,
                                                                                                        Err(onion_utils::OnionDecodeErr::Malformed { err_msg, err_code }) => {
@@ -3472,21 +3491,40 @@ where
        /// [`events::Event::PaymentClaimed`] events even for payments you intend to fail, especially on
        /// startup during which time claims that were in-progress at shutdown may be replayed.
        pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash) {
+               self.fail_htlc_backwards_with_reason(payment_hash, &FailureCode::IncorrectOrUnknownPaymentDetails);
+       }
+
+       /// This is a variant of [`ChannelManager::fail_htlc_backwards`] that allows you to specify the
+       /// reason for the failure.
+       ///
+       /// See [`FailureCode`] for valid failure codes.
+       pub fn fail_htlc_backwards_with_reason(&self, payment_hash: &PaymentHash, failure_code: &FailureCode) {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
                let removed_source = self.claimable_payments.lock().unwrap().claimable_htlcs.remove(payment_hash);
                if let Some((_, mut sources)) = removed_source {
                        for htlc in sources.drain(..) {
-                               let mut htlc_msat_height_data = htlc.value.to_be_bytes().to_vec();
-                               htlc_msat_height_data.extend_from_slice(&self.best_block.read().unwrap().height().to_be_bytes());
+                               let reason = self.get_htlc_fail_reason_from_failure_code(failure_code, &htlc);
                                let source = HTLCSource::PreviousHopData(htlc.prev_hop);
-                               let reason = HTLCFailReason::reason(0x4000 | 15, htlc_msat_height_data);
                                let receiver = HTLCDestination::FailedPayment { payment_hash: *payment_hash };
                                self.fail_htlc_backwards_internal(&source, &payment_hash, &reason, receiver);
                        }
                }
        }
 
+       /// Gets error data to form an [`HTLCFailReason`] given a [`FailureCode`] and [`ClaimableHTLC`].
+       fn get_htlc_fail_reason_from_failure_code(&self, failure_code: &FailureCode, htlc: &ClaimableHTLC) -> HTLCFailReason {
+               match failure_code {
+                       FailureCode::TemporaryNodeFailure => HTLCFailReason::from_failure_code(*failure_code as u16),
+                       FailureCode::RequiredNodeFeatureMissing => HTLCFailReason::from_failure_code(*failure_code as u16),
+                       FailureCode::IncorrectOrUnknownPaymentDetails => {
+                               let mut htlc_msat_height_data = htlc.value.to_be_bytes().to_vec();
+                               htlc_msat_height_data.extend_from_slice(&self.best_block.read().unwrap().height().to_be_bytes());
+                               HTLCFailReason::reason(*failure_code as u16, htlc_msat_height_data)
+                       }
+               }
+       }
+
        /// Gets an HTLC onion failure code and error data for an `UPDATE` error, given the error code
        /// that we want to return and a channel.
        ///
@@ -4060,7 +4098,7 @@ where
                                return;
                        }
 
-                       let updates = channel.get_mut().monitor_updating_restored(&self.logger, self.get_our_node_id(), self.genesis_hash, &self.default_configuration, self.best_block.read().unwrap().height());
+                       let updates = channel.get_mut().monitor_updating_restored(&self.logger, &self.node_signer, self.genesis_hash, &self.default_configuration, self.best_block.read().unwrap().height());
                        let channel_update = if updates.channel_ready.is_some() && channel.get().is_usable() {
                                // We only send a channel_update in the case where we are just now sending a
                                // channel_ready and the channel is in a usable state. We may re-send a
@@ -4397,7 +4435,7 @@ where
                let peer_state = &mut *peer_state_lock;
                match peer_state.channel_by_id.entry(msg.channel_id) {
                        hash_map::Entry::Occupied(mut chan) => {
-                               let announcement_sigs_opt = try_chan_entry!(self, chan.get_mut().channel_ready(&msg, self.get_our_node_id(),
+                               let announcement_sigs_opt = try_chan_entry!(self, chan.get_mut().channel_ready(&msg, &self.node_signer,
                                        self.genesis_hash.clone(), &self.default_configuration, &self.best_block.read().unwrap(), &self.logger), chan);
                                if let Some(announcement_sigs) = announcement_sigs_opt {
                                        log_trace!(self.logger, "Sending announcement_signatures for channel {}", log_bytes!(chan.get().channel_id()));
@@ -4879,8 +4917,8 @@ where
 
                                peer_state.pending_msg_events.push(events::MessageSendEvent::BroadcastChannelAnnouncement {
                                        msg: try_chan_entry!(self, chan.get_mut().announcement_signatures(
-                                               self.get_our_node_id(), self.genesis_hash.clone(),
-                                               self.best_block.read().unwrap().height(), msg, &self.default_configuration
+                                               &self.node_signer, self.genesis_hash.clone(), self.best_block.read().unwrap().height(),
+                                               msg, &self.default_configuration
                                        ), chan),
                                        // Note that announcement_signatures fails if the channel cannot be announced,
                                        // so get_channel_update_for_broadcast will never fail by the time we get here.
@@ -4951,7 +4989,7 @@ where
                                        // freed HTLCs to fail backwards. If in the future we no longer drop pending
                                        // add-HTLCs on disconnect, we may be handed HTLCs to fail backwards here.
                                        let responses = try_chan_entry!(self, chan.get_mut().channel_reestablish(
-                                               msg, &self.logger, self.our_network_pubkey.clone(), self.genesis_hash,
+                                               msg, &self.logger, &self.node_signer, self.genesis_hash,
                                                &self.default_configuration, &*self.best_block.read().unwrap()), chan);
                                        let mut channel_update = None;
                                        if let Some(msg) = responses.shutdown_msg {
@@ -5630,7 +5668,7 @@ where
                        *best_block = BestBlock::new(header.prev_blockhash, new_height)
                }
 
-               self.do_chain_event(Some(new_height), |channel| channel.best_block_updated(new_height, header.time, self.genesis_hash.clone(), self.get_our_node_id(), self.default_configuration.clone(), &self.logger));
+               self.do_chain_event(Some(new_height), |channel| channel.best_block_updated(new_height, header.time, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger));
        }
 }
 
@@ -5654,13 +5692,13 @@ where
                log_trace!(self.logger, "{} transactions included in block {} at height {} provided", txdata.len(), block_hash, height);
 
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
-               self.do_chain_event(Some(height), |channel| channel.transactions_confirmed(&block_hash, height, txdata, self.genesis_hash.clone(), self.get_our_node_id(), &self.default_configuration, &self.logger)
+               self.do_chain_event(Some(height), |channel| channel.transactions_confirmed(&block_hash, height, txdata, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger)
                        .map(|(a, b)| (a, Vec::new(), b)));
 
                let last_best_block_height = self.best_block.read().unwrap().height();
                if height < last_best_block_height {
                        let timestamp = self.highest_seen_timestamp.load(Ordering::Acquire);
-                       self.do_chain_event(Some(last_best_block_height), |channel| channel.best_block_updated(last_best_block_height, timestamp as u32, self.genesis_hash.clone(), self.get_our_node_id(), self.default_configuration.clone(), &self.logger));
+                       self.do_chain_event(Some(last_best_block_height), |channel| channel.best_block_updated(last_best_block_height, timestamp as u32, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger));
                }
        }
 
@@ -5676,7 +5714,7 @@ where
 
                *self.best_block.write().unwrap() = BestBlock::new(block_hash, height);
 
-               self.do_chain_event(Some(height), |channel| channel.best_block_updated(height, header.time, self.genesis_hash.clone(), self.get_our_node_id(), self.default_configuration.clone(), &self.logger));
+               self.do_chain_event(Some(height), |channel| channel.best_block_updated(height, header.time, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger));
 
                macro_rules! max_time {
                        ($timestamp: expr) => {
@@ -5787,7 +5825,7 @@ where
                                                                msg: announcement_sigs,
                                                        });
                                                        if let Some(height) = height_opt {
-                                                               if let Some(announcement) = channel.get_signed_channel_announcement(self.get_our_node_id(), self.genesis_hash, height, &self.default_configuration) {
+                                                               if let Some(announcement) = channel.get_signed_channel_announcement(&self.node_signer, self.genesis_hash, height, &self.default_configuration) {
                                                                        pending_msg_events.push(events::MessageSendEvent::BroadcastChannelAnnouncement {
                                                                                msg: announcement,
                                                                                // Note that announcement_signatures fails if the channel cannot be announced,
@@ -6177,7 +6215,7 @@ where
                                        }
                                } else { true };
                                if retain && chan.get_counterparty_node_id() != *counterparty_node_id {
-                                       if let Some(msg) = chan.get_signed_channel_announcement(self.get_our_node_id(), self.genesis_hash.clone(), self.best_block.read().unwrap().height(), &self.default_configuration) {
+                                       if let Some(msg) = chan.get_signed_channel_announcement(&self.node_signer, self.genesis_hash.clone(), self.best_block.read().unwrap().height(), &self.default_configuration) {
                                                if let Ok(update_msg) = self.get_channel_update_for_broadcast(chan) {
                                                        pending_msg_events.push(events::MessageSendEvent::SendChannelAnnouncement {
                                                                node_id: *counterparty_node_id,
@@ -7394,11 +7432,10 @@ where
                        pending_events_read.append(&mut channel_closures);
                }
 
-               let our_network_key = match args.node_signer.get_node_secret(Recipient::Node) {
+               let our_network_pubkey = match args.node_signer.get_node_id(Recipient::Node) {
                        Ok(key) => key,
                        Err(()) => return Err(DecodeError::InvalidValue)
                };
-               let our_network_pubkey = PublicKey::from_secret_key(&secp_ctx, &our_network_key);
                if let Some(network_pubkey) = received_network_pubkey {
                        if network_pubkey != our_network_pubkey {
                                log_error!(args.logger, "Key that was generated does not match the existing key.");
@@ -7514,7 +7551,6 @@ where
 
                        probing_cookie_secret: probing_cookie_secret.unwrap(),
 
-                       our_network_key,
                        our_network_pubkey,
                        secp_ctx,