Merge pull request #1967 from arik-so/2023-01-rename-signer-traits
[rust-lightning] / lightning / src / ln / channelmanager.rs
index af83edc5634d37d76d18c95bf61ede56999cdf23..82640ad15f16c4a324a4880c04b33aa55525b926 100644 (file)
@@ -26,12 +26,10 @@ use bitcoin::network::constants::Network;
 
 use bitcoin::hashes::Hash;
 use bitcoin::hashes::sha256::Hash as Sha256;
-use bitcoin::hashes::sha256d::Hash as Sha256dHash;
 use bitcoin::hash_types::{BlockHash, Txid};
 
 use bitcoin::secp256k1::{SecretKey,PublicKey};
 use bitcoin::secp256k1::Secp256k1;
-use bitcoin::secp256k1::ecdh::SharedSecret;
 use bitcoin::{LockTime, secp256k1, Sequence};
 
 use crate::chain;
@@ -57,7 +55,7 @@ use crate::ln::msgs::{ChannelMessageHandler, DecodeError, LightningError, MAX_VA
 use crate::ln::outbound_payment;
 use crate::ln::outbound_payment::{OutboundPayments, PendingOutboundPayment};
 use crate::ln::wire::Encode;
-use crate::chain::keysinterface::{EntropySource, KeysManager, NodeSigner, Recipient, Sign, SignerProvider};
+use crate::chain::keysinterface::{EntropySource, KeysManager, NodeSigner, Recipient, SignerProvider, ChannelSigner};
 use crate::util::config::{UserConfig, ChannelConfig};
 use crate::util::events::{Event, EventHandler, EventsProvider, MessageSendEvent, MessageSendEventsProvider, ClosureReason, HTLCDestination};
 use crate::util::events;
@@ -454,7 +452,7 @@ pub(crate) enum MonitorUpdateCompletionAction {
 }
 
 /// State we hold per-peer.
-pub(super) struct PeerState<Signer: Sign> {
+pub(super) struct PeerState<Signer: ChannelSigner> {
        /// `temporary_channel_id` or `channel_id` -> `channel`.
        ///
        /// Holds all channels where the peer is the counterparty. Once a channel has been assigned a
@@ -721,7 +719,6 @@ where
        #[cfg(not(test))]
        short_to_chan_info: FairRwLock<HashMap<u64, (PublicKey, [u8; 32])>>,
 
-       our_network_key: SecretKey,
        our_network_pubkey: PublicKey,
 
        inbound_payment_key: inbound_payment::ExpandedKey,
@@ -1153,12 +1150,12 @@ macro_rules! handle_error {
                match $internal {
                        Ok(msg) => Ok(msg),
                        Err(MsgHandleErrInternal { err, chan_id, shutdown_finish }) => {
-                               #[cfg(debug_assertions)]
+                               #[cfg(any(feature = "_test_utils", test))]
                                {
                                        // In testing, ensure there are no deadlocks where the lock is already held upon
                                        // entering the macro.
-                                       assert!($self.pending_events.try_lock().is_ok());
-                                       assert!($self.per_peer_state.try_write().is_ok());
+                                       debug_assert!($self.pending_events.try_lock().is_ok());
+                                       debug_assert!($self.per_peer_state.try_write().is_ok());
                                }
 
                                let mut msg_events = Vec::with_capacity(2);
@@ -1193,7 +1190,7 @@ macro_rules! handle_error {
                                                let mut peer_state = peer_state_mutex.lock().unwrap();
                                                peer_state.pending_msg_events.append(&mut msg_events);
                                        }
-                                       #[cfg(debug_assertions)]
+                                       #[cfg(any(feature = "_test_utils", test))]
                                        {
                                                if let None = per_peer_state.get(&$counterparty_node_id) {
                                                        // This shouldn't occour in tests unless an unkown counterparty_node_id
@@ -1206,10 +1203,10 @@ macro_rules! handle_error {
                                                                => {
                                                                        assert_eq!(*data, expected_error_str);
                                                                        if let Some((err_channel_id, _user_channel_id)) = chan_id {
-                                                                               assert_eq!(*channel_id, err_channel_id);
+                                                                               debug_assert_eq!(*channel_id, err_channel_id);
                                                                        }
                                                                }
-                                                               _ => panic!("Unexpected event"),
+                                                               _ => debug_assert!(false, "Unexpected event"),
                                                        }
                                                }
                                        }
@@ -1457,8 +1454,7 @@ where
                        id_to_peer: Mutex::new(HashMap::new()),
                        short_to_chan_info: FairRwLock::new(HashMap::new()),
 
-                       our_network_key: node_signer.get_node_secret(Recipient::Node).unwrap(),
-                       our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &node_signer.get_node_secret(Recipient::Node).unwrap()),
+                       our_network_pubkey: node_signer.get_node_id(Recipient::Node).unwrap(),
                        secp_ctx,
 
                        inbound_payment_key: expanded_inbound_key,
@@ -2017,7 +2013,9 @@ where
                        return_malformed_err!("invalid ephemeral pubkey", 0x8000 | 0x4000 | 6);
                }
 
-               let shared_secret = SharedSecret::new(&msg.onion_routing_packet.public_key.unwrap(), &self.our_network_key).secret_bytes();
+               let shared_secret = self.node_signer.ecdh(
+                       Recipient::Node, &msg.onion_routing_packet.public_key.unwrap(), None
+               ).unwrap().secret_bytes();
 
                if msg.onion_routing_packet.version != 0 {
                        //TODO: Spec doesn't indicate if we should only hash hop_data here (and in other
@@ -2266,7 +2264,7 @@ where
        }
        fn get_channel_update_for_onion(&self, short_channel_id: u64, chan: &Channel<<SP::Target as SignerProvider>::Signer>) -> Result<msgs::ChannelUpdate, LightningError> {
                log_trace!(self.logger, "Generating channel update for channel {}", log_bytes!(chan.channel_id()));
-               let were_node_one = PublicKey::from_secret_key(&self.secp_ctx, &self.our_network_key).serialize()[..] < chan.get_counterparty_node_id().serialize()[..];
+               let were_node_one = self.our_network_pubkey.serialize()[..] < chan.get_counterparty_node_id().serialize()[..];
 
                let unsigned = msgs::UnsignedChannelUpdate {
                        chain_hash: self.genesis_hash,
@@ -2280,9 +2278,11 @@ where
                        fee_proportional_millionths: chan.get_fee_proportional_millionths(),
                        excess_data: Vec::new(),
                };
-
-               let msg_hash = Sha256dHash::hash(&unsigned.encode()[..]);
-               let sig = self.secp_ctx.sign_ecdsa(&hash_to_message!(&msg_hash[..]), &self.our_network_key);
+               // Panic on failure to signal LDK should be restarted to retry signing the `ChannelUpdate`.
+               // If we returned an error and the `node_signer` cannot provide a signature for whatever
+               // reason`, we wouldn't be able to receive inbound payments through the corresponding
+               // channel.
+               let sig = self.node_signer.sign_gossip_message(msgs::UnsignedGossipMessage::ChannelUpdate(&unsigned)).unwrap();
 
                Ok(msgs::ChannelUpdate {
                        signature: sig,
@@ -2923,9 +2923,9 @@ where
                                                                                        }
                                                                                }
                                                                                if let PendingHTLCRouting::Forward { onion_packet, .. } = routing {
-                                                                                       let phantom_secret_res = self.node_signer.get_node_secret(Recipient::PhantomNode);
-                                                                                       if phantom_secret_res.is_ok() && fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, short_chan_id, &self.genesis_hash) {
-                                                                                               let phantom_shared_secret = SharedSecret::new(&onion_packet.public_key.unwrap(), &phantom_secret_res.unwrap()).secret_bytes();
+                                                                                       let phantom_pubkey_res = self.node_signer.get_node_id(Recipient::PhantomNode);
+                                                                                       if phantom_pubkey_res.is_ok() && fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, short_chan_id, &self.genesis_hash) {
+                                                                                               let phantom_shared_secret = self.node_signer.ecdh(Recipient::PhantomNode, &onion_packet.public_key.unwrap(), None).unwrap().secret_bytes();
                                                                                                let next_hop = match onion_utils::decode_next_payment_hop(phantom_shared_secret, &onion_packet.hop_data, onion_packet.hmac, payment_hash) {
                                                                                                        Ok(res) => res,
                                                                                                        Err(onion_utils::OnionDecodeErr::Malformed { err_msg, err_code }) => {
@@ -3565,7 +3565,7 @@ where
        /// Fails an HTLC backwards to the sender of it to us.
        /// Note that we do not assume that channels corresponding to failed HTLCs are still available.
        fn fail_htlc_backwards_internal(&self, source: &HTLCSource, payment_hash: &PaymentHash, onion_error: &HTLCFailReason, destination: HTLCDestination) {
-               #[cfg(debug_assertions)]
+               #[cfg(any(feature = "_test_utils", test))]
                {
                        // Ensure that no peer state channel storage lock is not held when calling this
                        // function.
@@ -3574,7 +3574,7 @@ where
                        // this function with any `per_peer_state` peer lock aquired would.
                        let per_peer_state = self.per_peer_state.read().unwrap();
                        for (_, peer) in per_peer_state.iter() {
-                               assert!(peer.try_lock().is_ok());
+                               debug_assert!(peer.try_lock().is_ok());
                        }
                }
 
@@ -4060,7 +4060,7 @@ where
                                return;
                        }
 
-                       let updates = channel.get_mut().monitor_updating_restored(&self.logger, self.get_our_node_id(), self.genesis_hash, &self.default_configuration, self.best_block.read().unwrap().height());
+                       let updates = channel.get_mut().monitor_updating_restored(&self.logger, &self.node_signer, self.genesis_hash, &self.default_configuration, self.best_block.read().unwrap().height());
                        let channel_update = if updates.channel_ready.is_some() && channel.get().is_usable() {
                                // We only send a channel_update in the case where we are just now sending a
                                // channel_ready and the channel is in a usable state. We may re-send a
@@ -4397,7 +4397,7 @@ where
                let peer_state = &mut *peer_state_lock;
                match peer_state.channel_by_id.entry(msg.channel_id) {
                        hash_map::Entry::Occupied(mut chan) => {
-                               let announcement_sigs_opt = try_chan_entry!(self, chan.get_mut().channel_ready(&msg, self.get_our_node_id(),
+                               let announcement_sigs_opt = try_chan_entry!(self, chan.get_mut().channel_ready(&msg, &self.node_signer,
                                        self.genesis_hash.clone(), &self.default_configuration, &self.best_block.read().unwrap(), &self.logger), chan);
                                if let Some(announcement_sigs) = announcement_sigs_opt {
                                        log_trace!(self.logger, "Sending announcement_signatures for channel {}", log_bytes!(chan.get().channel_id()));
@@ -4879,8 +4879,8 @@ where
 
                                peer_state.pending_msg_events.push(events::MessageSendEvent::BroadcastChannelAnnouncement {
                                        msg: try_chan_entry!(self, chan.get_mut().announcement_signatures(
-                                               self.get_our_node_id(), self.genesis_hash.clone(),
-                                               self.best_block.read().unwrap().height(), msg, &self.default_configuration
+                                               &self.node_signer, self.genesis_hash.clone(), self.best_block.read().unwrap().height(),
+                                               msg, &self.default_configuration
                                        ), chan),
                                        // Note that announcement_signatures fails if the channel cannot be announced,
                                        // so get_channel_update_for_broadcast will never fail by the time we get here.
@@ -4951,7 +4951,7 @@ where
                                        // freed HTLCs to fail backwards. If in the future we no longer drop pending
                                        // add-HTLCs on disconnect, we may be handed HTLCs to fail backwards here.
                                        let responses = try_chan_entry!(self, chan.get_mut().channel_reestablish(
-                                               msg, &self.logger, self.our_network_pubkey.clone(), self.genesis_hash,
+                                               msg, &self.logger, &self.node_signer, self.genesis_hash,
                                                &self.default_configuration, &*self.best_block.read().unwrap()), chan);
                                        let mut channel_update = None;
                                        if let Some(msg) = responses.shutdown_msg {
@@ -5630,7 +5630,7 @@ where
                        *best_block = BestBlock::new(header.prev_blockhash, new_height)
                }
 
-               self.do_chain_event(Some(new_height), |channel| channel.best_block_updated(new_height, header.time, self.genesis_hash.clone(), self.get_our_node_id(), self.default_configuration.clone(), &self.logger));
+               self.do_chain_event(Some(new_height), |channel| channel.best_block_updated(new_height, header.time, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger));
        }
 }
 
@@ -5654,13 +5654,13 @@ where
                log_trace!(self.logger, "{} transactions included in block {} at height {} provided", txdata.len(), block_hash, height);
 
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
-               self.do_chain_event(Some(height), |channel| channel.transactions_confirmed(&block_hash, height, txdata, self.genesis_hash.clone(), self.get_our_node_id(), &self.default_configuration, &self.logger)
+               self.do_chain_event(Some(height), |channel| channel.transactions_confirmed(&block_hash, height, txdata, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger)
                        .map(|(a, b)| (a, Vec::new(), b)));
 
                let last_best_block_height = self.best_block.read().unwrap().height();
                if height < last_best_block_height {
                        let timestamp = self.highest_seen_timestamp.load(Ordering::Acquire);
-                       self.do_chain_event(Some(last_best_block_height), |channel| channel.best_block_updated(last_best_block_height, timestamp as u32, self.genesis_hash.clone(), self.get_our_node_id(), self.default_configuration.clone(), &self.logger));
+                       self.do_chain_event(Some(last_best_block_height), |channel| channel.best_block_updated(last_best_block_height, timestamp as u32, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger));
                }
        }
 
@@ -5676,7 +5676,7 @@ where
 
                *self.best_block.write().unwrap() = BestBlock::new(block_hash, height);
 
-               self.do_chain_event(Some(height), |channel| channel.best_block_updated(height, header.time, self.genesis_hash.clone(), self.get_our_node_id(), self.default_configuration.clone(), &self.logger));
+               self.do_chain_event(Some(height), |channel| channel.best_block_updated(height, header.time, self.genesis_hash.clone(), &self.node_signer, &self.default_configuration, &self.logger));
 
                macro_rules! max_time {
                        ($timestamp: expr) => {
@@ -5787,7 +5787,7 @@ where
                                                                msg: announcement_sigs,
                                                        });
                                                        if let Some(height) = height_opt {
-                                                               if let Some(announcement) = channel.get_signed_channel_announcement(self.get_our_node_id(), self.genesis_hash, height, &self.default_configuration) {
+                                                               if let Some(announcement) = channel.get_signed_channel_announcement(&self.node_signer, self.genesis_hash, height, &self.default_configuration) {
                                                                        pending_msg_events.push(events::MessageSendEvent::BroadcastChannelAnnouncement {
                                                                                msg: announcement,
                                                                                // Note that announcement_signatures fails if the channel cannot be announced,
@@ -6177,7 +6177,7 @@ where
                                        }
                                } else { true };
                                if retain && chan.get_counterparty_node_id() != *counterparty_node_id {
-                                       if let Some(msg) = chan.get_signed_channel_announcement(self.get_our_node_id(), self.genesis_hash.clone(), self.best_block.read().unwrap().height(), &self.default_configuration) {
+                                       if let Some(msg) = chan.get_signed_channel_announcement(&self.node_signer, self.genesis_hash.clone(), self.best_block.read().unwrap().height(), &self.default_configuration) {
                                                if let Ok(update_msg) = self.get_channel_update_for_broadcast(chan) {
                                                        pending_msg_events.push(events::MessageSendEvent::SendChannelAnnouncement {
                                                                node_id: *counterparty_node_id,
@@ -7394,11 +7394,10 @@ where
                        pending_events_read.append(&mut channel_closures);
                }
 
-               let our_network_key = match args.node_signer.get_node_secret(Recipient::Node) {
+               let our_network_pubkey = match args.node_signer.get_node_id(Recipient::Node) {
                        Ok(key) => key,
                        Err(()) => return Err(DecodeError::InvalidValue)
                };
-               let our_network_pubkey = PublicKey::from_secret_key(&secp_ctx, &our_network_key);
                if let Some(network_pubkey) = received_network_pubkey {
                        if network_pubkey != our_network_pubkey {
                                log_error!(args.logger, "Key that was generated does not match the existing key.");
@@ -7514,7 +7513,6 @@ where
 
                        probing_cookie_secret: probing_cookie_secret.unwrap(),
 
-                       our_network_key,
                        our_network_pubkey,
                        secp_ctx,