Pass `counterparty_node_id` to `force_close_channel`
[rust-lightning] / lightning / src / ln / channelmanager.rs
index a28d6347332694af4dab7fc0737dd13a2a38c834..f89aef6163c5d526f25fd52a446a87de34963a14 100644 (file)
@@ -28,7 +28,7 @@ use bitcoin::hashes::sha256::Hash as Sha256;
 use bitcoin::hashes::sha256d::Hash as Sha256dHash;
 use bitcoin::hash_types::{BlockHash, Txid};
 
-use bitcoin::secp256k1::key::{SecretKey,PublicKey};
+use bitcoin::secp256k1::{SecretKey,PublicKey};
 use bitcoin::secp256k1::Secp256k1;
 use bitcoin::secp256k1::ecdh::SharedSecret;
 use bitcoin::secp256k1;
@@ -48,6 +48,7 @@ use ln::msgs;
 use ln::msgs::NetAddress;
 use ln::onion_utils;
 use ln::msgs::{ChannelMessageHandler, DecodeError, LightningError, MAX_VALUE_MSAT, OptionalField};
+use ln::wire::Encode;
 use chain::keysinterface::{Sign, KeysInterface, KeysManager, InMemorySigner, Recipient};
 use util::config::UserConfig;
 use util::events::{EventHandler, EventsProvider, MessageSendEvent, MessageSendEventsProvider, ClosureReason};
@@ -1768,17 +1769,18 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                });
        }
 
-       fn close_channel_internal(&self, channel_id: &[u8; 32], target_feerate_sats_per_1000_weight: Option<u32>) -> Result<(), APIError> {
+       fn close_channel_internal(&self, channel_id: &[u8; 32], counterparty_node_id: &PublicKey, target_feerate_sats_per_1000_weight: Option<u32>) -> Result<(), APIError> {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
-               let counterparty_node_id;
                let mut failed_htlcs: Vec<(HTLCSource, PaymentHash)>;
                let result: Result<(), _> = loop {
                        let mut channel_state_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_state_lock;
                        match channel_state.by_id.entry(channel_id.clone()) {
                                hash_map::Entry::Occupied(mut chan_entry) => {
-                                       counterparty_node_id = chan_entry.get().get_counterparty_node_id();
+                                       if *counterparty_node_id != chan_entry.get().get_counterparty_node_id(){
+                                               return Err(APIError::APIMisuseError { err: "The passed counterparty_node_id doesn't match the channel's counterparty node_id".to_owned() });
+                                       }
                                        let per_peer_state = self.per_peer_state.read().unwrap();
                                        let (shutdown_msg, monitor_update, htlcs) = match per_peer_state.get(&counterparty_node_id) {
                                                Some(peer_state) => {
@@ -1803,7 +1805,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        }
 
                                        channel_state.pending_msg_events.push(events::MessageSendEvent::SendShutdown {
-                                               node_id: counterparty_node_id,
+                                               node_id: *counterparty_node_id,
                                                msg: shutdown_msg
                                        });
 
@@ -1826,7 +1828,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() });
                }
 
-               let _ = handle_error!(self, result, counterparty_node_id);
+               let _ = handle_error!(self, result, *counterparty_node_id);
                Ok(())
        }
 
@@ -1847,8 +1849,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        /// [`ChannelConfig::force_close_avoidance_max_fee_satoshis`]: crate::util::config::ChannelConfig::force_close_avoidance_max_fee_satoshis
        /// [`Background`]: crate::chain::chaininterface::ConfirmationTarget::Background
        /// [`Normal`]: crate::chain::chaininterface::ConfirmationTarget::Normal
-       pub fn close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> {
-               self.close_channel_internal(channel_id, None)
+       pub fn close_channel(&self, channel_id: &[u8; 32], counterparty_node_id: &PublicKey) -> Result<(), APIError> {
+               self.close_channel_internal(channel_id, counterparty_node_id, None)
        }
 
        /// Begins the process of closing a channel. After this call (plus some timeout), no new HTLCs
@@ -1870,8 +1872,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        /// [`ChannelConfig::force_close_avoidance_max_fee_satoshis`]: crate::util::config::ChannelConfig::force_close_avoidance_max_fee_satoshis
        /// [`Background`]: crate::chain::chaininterface::ConfirmationTarget::Background
        /// [`Normal`]: crate::chain::chaininterface::ConfirmationTarget::Normal
-       pub fn close_channel_with_target_feerate(&self, channel_id: &[u8; 32], target_feerate_sats_per_1000_weight: u32) -> Result<(), APIError> {
-               self.close_channel_internal(channel_id, Some(target_feerate_sats_per_1000_weight))
+       pub fn close_channel_with_target_feerate(&self, channel_id: &[u8; 32], counterparty_node_id: &PublicKey, target_feerate_sats_per_1000_weight: u32) -> Result<(), APIError> {
+               self.close_channel_internal(channel_id, counterparty_node_id, Some(target_feerate_sats_per_1000_weight))
        }
 
        #[inline]
@@ -1890,22 +1892,18 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                }
        }
 
-       /// `peer_node_id` should be set when we receive a message from a peer, but not set when the
+       /// `peer_msg` should be set when we receive a message from a peer, but not set when the
        /// user closes, which will be re-exposed as the `ChannelClosed` reason.
-       fn force_close_channel_with_peer(&self, channel_id: &[u8; 32], peer_node_id: Option<&PublicKey>, peer_msg: Option<&String>) -> Result<PublicKey, APIError> {
+       fn force_close_channel_with_peer(&self, channel_id: &[u8; 32], peer_node_id: &PublicKey, peer_msg: Option<&String>) -> Result<PublicKey, APIError> {
                let mut chan = {
                        let mut channel_state_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_state_lock;
                        if let hash_map::Entry::Occupied(chan) = channel_state.by_id.entry(channel_id.clone()) {
-                               if let Some(node_id) = peer_node_id {
-                                       if chan.get().get_counterparty_node_id() != *node_id {
-                                               return Err(APIError::ChannelUnavailable{err: "No such channel".to_owned()});
-                                       }
+                               if chan.get().get_counterparty_node_id() != *peer_node_id {
+                                       return Err(APIError::ChannelUnavailable{err: "No such channel".to_owned()});
                                }
-                               if peer_node_id.is_some() {
-                                       if let Some(peer_msg) = peer_msg {
-                                               self.issue_channel_close_events(chan.get(),ClosureReason::CounterpartyForceClosed { peer_msg: peer_msg.to_string() });
-                                       }
+                               if let Some(peer_msg) = peer_msg {
+                                       self.issue_channel_close_events(chan.get(),ClosureReason::CounterpartyForceClosed { peer_msg: peer_msg.to_string() });
                                } else {
                                        self.issue_channel_close_events(chan.get(),ClosureReason::HolderForceClosed);
                                }
@@ -1927,10 +1925,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        }
 
        /// Force closes a channel, immediately broadcasting the latest local commitment transaction to
-       /// the chain and rejecting new HTLCs on the given channel. Fails if channel_id is unknown to the manager.
-       pub fn force_close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> {
+       /// the chain and rejecting new HTLCs on the given channel. Fails if `channel_id` is unknown to
+       /// the manager, or if the `counterparty_node_id` isn't the counterparty of the corresponding
+       /// channel.
+       pub fn force_close_channel(&self, channel_id: &[u8; 32], counterparty_node_id: &PublicKey) -> Result<(), APIError> {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
-               match self.force_close_channel_with_peer(channel_id, None, None) {
+               match self.force_close_channel_with_peer(channel_id, counterparty_node_id, None) {
                        Ok(counterparty_node_id) => {
                                self.channel_state.lock().unwrap().pending_msg_events.push(
                                        events::MessageSendEvent::HandleError {
@@ -1950,7 +1950,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        /// for each to the chain and rejecting new HTLCs on each.
        pub fn force_close_all_channels(&self) {
                for chan in self.list_channels() {
-                       let _ = self.force_close_channel(&chan.channel_id);
+                       let _ = self.force_close_channel(&chan.channel_id, &chan.counterparty.node_id);
                }
        }
 
@@ -2070,11 +2070,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        return_malformed_err!("invalid ephemeral pubkey", 0x8000 | 0x4000 | 6);
                }
 
-               let shared_secret = {
-                       let mut arr = [0; 32];
-                       arr.copy_from_slice(&SharedSecret::new(&msg.onion_routing_packet.public_key.unwrap(), &self.our_network_key)[..]);
-                       arr
-               };
+               let shared_secret = SharedSecret::new(&msg.onion_routing_packet.public_key.unwrap(), &self.our_network_key).secret_bytes();
 
                if msg.onion_routing_packet.version != 0 {
                        //TODO: Spec doesn't indicate if we should only hash hop_data here (and in other
@@ -2253,7 +2249,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        break None;
                                }
                                {
-                                       let mut res = VecWriter(Vec::with_capacity(chan_update.serialized_length() + 8 + 2));
+                                       let mut res = VecWriter(Vec::with_capacity(chan_update.serialized_length() + 2 + 8 + 2));
                                        if let Some(chan_update) = chan_update {
                                                if code == 0x1000 | 11 || code == 0x1000 | 12 {
                                                        msg.amount_msat.write(&mut res).expect("Writes cannot fail");
@@ -2265,7 +2261,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                        // TODO: underspecified, follow https://github.com/lightningnetwork/lightning-rfc/issues/791
                                                        0u16.write(&mut res).expect("Writes cannot fail");
                                                }
-                                               (chan_update.serialized_length() as u16).write(&mut res).expect("Writes cannot fail");
+                                               (chan_update.serialized_length() as u16 + 2).write(&mut res).expect("Writes cannot fail");
+                                               msgs::ChannelUpdate::TYPE.write(&mut res).expect("Writes cannot fail");
                                                chan_update.write(&mut res).expect("Writes cannot fail");
                                        }
                                        return_err!(err, code, &res.0[..]);
@@ -2324,7 +2321,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                };
 
                let msg_hash = Sha256dHash::hash(&unsigned.encode()[..]);
-               let sig = self.secp_ctx.sign(&hash_to_message!(&msg_hash[..]), &self.our_network_key);
+               let sig = self.secp_ctx.sign_ecdsa(&hash_to_message!(&msg_hash[..]), &self.our_network_key);
 
                Ok(msgs::ChannelUpdate {
                        signature: sig,
@@ -2925,11 +2922,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                                        if let PendingHTLCRouting::Forward { onion_packet, .. } = routing {
                                                                                                let phantom_secret_res = self.keys_manager.get_node_secret(Recipient::PhantomNode);
                                                                                                if phantom_secret_res.is_ok() && fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, short_chan_id) {
-                                                                                                       let phantom_shared_secret = {
-                                                                                                               let mut arr = [0; 32];
-                                                                                                               arr.copy_from_slice(&SharedSecret::new(&onion_packet.public_key.unwrap(), &phantom_secret_res.unwrap())[..]);
-                                                                                                               arr
-                                                                                                       };
+                                                                                                       let phantom_shared_secret = SharedSecret::new(&onion_packet.public_key.unwrap(), &phantom_secret_res.unwrap()).secret_bytes();
                                                                                                        let next_hop = match onion_utils::decode_next_hop(phantom_shared_secret, &onion_packet.hop_data, onion_packet.hmac, payment_hash) {
                                                                                                                Ok(res) => res,
                                                                                                                Err(onion_utils::OnionDecodeErr::Malformed { err_msg, err_code }) => {
@@ -3551,12 +3544,13 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        fn get_htlc_temp_fail_err_and_data(&self, desired_err_code: u16, scid: u64, chan: &Channel<Signer>) -> (u16, Vec<u8>) {
                debug_assert_eq!(desired_err_code & 0x1000, 0x1000);
                if let Ok(upd) = self.get_channel_update_for_onion(scid, chan) {
-                       let mut enc = VecWriter(Vec::with_capacity(upd.serialized_length() + 4));
+                       let mut enc = VecWriter(Vec::with_capacity(upd.serialized_length() + 6));
                        if desired_err_code == 0x1000 | 20 {
                                // TODO: underspecified, follow https://github.com/lightning/bolts/issues/791
                                0u16.write(&mut enc).expect("Writes cannot fail");
                        }
-                       (upd.serialized_length() as u16).write(&mut enc).expect("Writes cannot fail");
+                       (upd.serialized_length() as u16 + 2).write(&mut enc).expect("Writes cannot fail");
+                       msgs::ChannelUpdate::TYPE.write(&mut enc).expect("Writes cannot fail");
                        upd.write(&mut enc).expect("Writes cannot fail");
                        (desired_err_code, enc.0)
                } else {
@@ -4217,6 +4211,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let mut pending_events = self.pending_events.lock().unwrap();
                pending_events.push(events::Event::FundingGenerationReady {
                        temporary_channel_id: msg.temporary_channel_id,
+                       counterparty_node_id: *counterparty_node_id,
                        channel_value_satoshis: value,
                        output_script,
                        user_channel_id: user_id,
@@ -5819,7 +5814,7 @@ impl<Signer: Sign, M: Deref , T: Deref , K: Deref , F: Deref , L: Deref >
                        for chan in self.list_channels() {
                                if chan.counterparty.node_id == *counterparty_node_id {
                                        // Untrusted messages from peer, we throw away the error if id points to a non-existent channel
-                                       let _ = self.force_close_channel_with_peer(&chan.channel_id, Some(counterparty_node_id), Some(&msg.data));
+                                       let _ = self.force_close_channel_with_peer(&chan.channel_id, counterparty_node_id, Some(&msg.data));
                                }
                        }
                } else {
@@ -5841,7 +5836,7 @@ impl<Signer: Sign, M: Deref , T: Deref , K: Deref , F: Deref , L: Deref >
                        }
 
                        // Untrusted messages from peer, we throw away the error if id points to a non-existent channel
-                       let _ = self.force_close_channel_with_peer(&msg.channel_id, Some(counterparty_node_id), Some(&msg.data));
+                       let _ = self.force_close_channel_with_peer(&msg.channel_id, counterparty_node_id, Some(&msg.data));
                }
        }
 }