Merge pull request #1019 from jkczyz/2021-07-shutdown-pubkey
[rust-lightning] / lightning / src / ln / channelmanager.rs
index 888724c0f064a82f885203ef0f7a0d76b7f885e8..09383d40301b9f387cd00ec6060dffc5d6674374 100644 (file)
@@ -60,10 +60,11 @@ use util::chacha20::{ChaCha20, ChaChaReader};
 use util::logger::{Logger, Level};
 use util::errors::APIError;
 
+use io;
 use prelude::*;
 use core::{cmp, mem};
 use core::cell::RefCell;
-use std::io::{Cursor, Read};
+use io::{Cursor, Read};
 use sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard};
 use core::sync::atomic::{AtomicUsize, Ordering};
 use core::time::Duration;
@@ -490,6 +491,8 @@ pub struct ChannelManager<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref,
        /// Because adding or removing an entry is rare, we usually take an outer read lock and then
        /// operate on the inner value freely. Sadly, this prevents parallel operation when opening a
        /// new channel.
+       ///
+       /// If also holding `channel_state` lock, must lock `channel_state` prior to `per_peer_state`.
        per_peer_state: RwLock<HashMap<PublicKey, Mutex<PeerState>>>,
 
        pending_events: Mutex<Vec<events::Event>>,
@@ -870,6 +873,18 @@ macro_rules! try_chan_entry {
        }
 }
 
+macro_rules! remove_channel {
+       ($channel_state: expr, $entry: expr) => {
+               {
+                       let channel = $entry.remove_entry().1;
+                       if let Some(short_id) = channel.get_short_channel_id() {
+                               $channel_state.short_to_id.remove(&short_id);
+                       }
+                       channel
+               }
+       }
+}
+
 macro_rules! handle_monitor_err {
        ($self: ident, $err: expr, $channel_state: expr, $entry: expr, $action_type: path, $resend_raa: expr, $resend_commitment: expr) => {
                handle_monitor_err!($self, $err, $channel_state, $entry, $action_type, $resend_raa, $resend_commitment, Vec::new(), Vec::new())
@@ -1164,8 +1179,18 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        return Err(APIError::APIMisuseError { err: format!("Channel value must be at least 1000 satoshis. It was {}", channel_value_satoshis) });
                }
 
-               let config = if override_config.is_some() { override_config.as_ref().unwrap() } else { &self.default_configuration };
-               let channel = Channel::new_outbound(&self.fee_estimator, &self.keys_manager, their_network_key, channel_value_satoshis, push_msat, user_id, config)?;
+               let channel = {
+                       let per_peer_state = self.per_peer_state.read().unwrap();
+                       match per_peer_state.get(&their_network_key) {
+                               Some(peer_state) => {
+                                       let peer_state = peer_state.lock().unwrap();
+                                       let their_features = &peer_state.latest_features;
+                                       let config = if override_config.is_some() { override_config.as_ref().unwrap() } else { &self.default_configuration };
+                                       Channel::new_outbound(&self.fee_estimator, &self.keys_manager, their_network_key, their_features, channel_value_satoshis, push_msat, user_id, config)?
+                               },
+                               None => return Err(APIError::ChannelUnavailable { err: format!("Not connected to node: {}", their_network_key) }),
+                       }
+               };
                let res = channel.get_open_channel(self.genesis_hash.clone());
 
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
@@ -1259,40 +1284,61 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        pub fn close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
 
-               let (mut failed_htlcs, chan_option) = {
+               let counterparty_node_id;
+               let mut failed_htlcs: Vec<(HTLCSource, PaymentHash)>;
+               let result: Result<(), _> = loop {
                        let mut channel_state_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_state_lock;
                        match channel_state.by_id.entry(channel_id.clone()) {
                                hash_map::Entry::Occupied(mut chan_entry) => {
-                                       let (shutdown_msg, failed_htlcs) = chan_entry.get_mut().get_shutdown()?;
+                                       counterparty_node_id = chan_entry.get().get_counterparty_node_id();
+                                       let per_peer_state = self.per_peer_state.read().unwrap();
+                                       let (shutdown_msg, monitor_update, htlcs) = match per_peer_state.get(&counterparty_node_id) {
+                                               Some(peer_state) => {
+                                                       let peer_state = peer_state.lock().unwrap();
+                                                       let their_features = &peer_state.latest_features;
+                                                       chan_entry.get_mut().get_shutdown(&self.keys_manager, their_features)?
+                                               },
+                                               None => return Err(APIError::ChannelUnavailable { err: format!("Not connected to node: {}", counterparty_node_id) }),
+                                       };
+                                       failed_htlcs = htlcs;
+
+                                       // Update the monitor with the shutdown script if necessary.
+                                       if let Some(monitor_update) = monitor_update {
+                                               if let Err(e) = self.chain_monitor.update_channel(chan_entry.get().get_funding_txo().unwrap(), monitor_update) {
+                                                       let (result, is_permanent) =
+                                                               handle_monitor_err!(self, e, channel_state.short_to_id, chan_entry.get_mut(), RAACommitmentOrder::CommitmentFirst, false, false, Vec::new(), Vec::new(), chan_entry.key());
+                                                       if is_permanent {
+                                                               remove_channel!(channel_state, chan_entry);
+                                                               break result;
+                                                       }
+                                               }
+                                       }
+
                                        channel_state.pending_msg_events.push(events::MessageSendEvent::SendShutdown {
-                                               node_id: chan_entry.get().get_counterparty_node_id(),
+                                               node_id: counterparty_node_id,
                                                msg: shutdown_msg
                                        });
+
                                        if chan_entry.get().is_shutdown() {
-                                               if let Some(short_id) = chan_entry.get().get_short_channel_id() {
-                                                       channel_state.short_to_id.remove(&short_id);
+                                               let channel = remove_channel!(channel_state, chan_entry);
+                                               if let Ok(channel_update) = self.get_channel_update_for_broadcast(&channel) {
+                                                       channel_state.pending_msg_events.push(events::MessageSendEvent::BroadcastChannelUpdate {
+                                                               msg: channel_update
+                                                       });
                                                }
-                                               (failed_htlcs, Some(chan_entry.remove_entry().1))
-                                       } else { (failed_htlcs, None) }
+                                       }
+                                       break Ok(());
                                },
                                hash_map::Entry::Vacant(_) => return Err(APIError::ChannelUnavailable{err: "No such channel".to_owned()})
                        }
                };
+
                for htlc_source in failed_htlcs.drain(..) {
                        self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() });
                }
-               let chan_update = if let Some(chan) = chan_option {
-                       self.get_channel_update_for_broadcast(&chan).ok()
-               } else { None };
-
-               if let Some(update) = chan_update {
-                       let mut channel_state = self.channel_state.lock().unwrap();
-                       channel_state.pending_msg_events.push(events::MessageSendEvent::BroadcastChannelUpdate {
-                               msg: update
-                       });
-               }
 
+               let _ = handle_error!(self, result, counterparty_node_id);
                Ok(())
        }
 
@@ -2804,7 +2850,13 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                        } else { errs.push((pk, err)); }
                                                },
                                                ClaimFundsFromHop::PrevHopForceClosed => unreachable!("We already checked for channel existence, we can't fail here!"),
-                                               _ => claimed_any_htlcs = true,
+                                               ClaimFundsFromHop::DuplicateClaim => {
+                                                       // While we should never get here in most cases, if we do, it likely
+                                                       // indicates that the HTLC was timed out some time ago and is no longer
+                                                       // available to be claimed. Thus, it does not make sense to set
+                                                       // `claimed_any_htlcs`.
+                                               },
+                                               ClaimFundsFromHop::Success(_) => claimed_any_htlcs = true,
                                        }
                                }
                        }
@@ -3020,7 +3072,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        return Err(MsgHandleErrInternal::send_err_msg_no_close("Unknown genesis block hash".to_owned(), msg.temporary_channel_id.clone()));
                }
 
-               let channel = Channel::new_from_req(&self.fee_estimator, &self.keys_manager, counterparty_node_id.clone(), their_features, msg, 0, &self.default_configuration)
+               let channel = Channel::new_from_req(&self.fee_estimator, &self.keys_manager, counterparty_node_id.clone(), &their_features, msg, 0, &self.default_configuration)
                        .map_err(|e| MsgHandleErrInternal::from_chan_no_close(e, msg.temporary_channel_id))?;
                let mut channel_state_lock = self.channel_state.lock().unwrap();
                let channel_state = &mut *channel_state_lock;
@@ -3046,7 +3098,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        if chan.get().get_counterparty_node_id() != *counterparty_node_id {
                                                return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a message for a channel from the wrong node!".to_owned(), msg.temporary_channel_id));
                                        }
-                                       try_chan_entry!(self, chan.get_mut().accept_channel(&msg, &self.default_configuration, their_features), channel_state, chan);
+                                       try_chan_entry!(self, chan.get_mut().accept_channel(&msg, &self.default_configuration, &their_features), channel_state, chan);
                                        (chan.get().get_value_satoshis(), chan.get().get_funding_redeemscript().to_v0_p2wsh(), chan.get().get_user_id())
                                },
                                hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close("Failed to find corresponding channel".to_owned(), msg.temporary_channel_id))
@@ -3183,7 +3235,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
        }
 
        fn internal_shutdown(&self, counterparty_node_id: &PublicKey, their_features: &InitFeatures, msg: &msgs::Shutdown) -> Result<(), MsgHandleErrInternal> {
-               let (mut dropped_htlcs, chan_option) = {
+               let mut dropped_htlcs: Vec<(HTLCSource, PaymentHash)>;
+               let result: Result<(), _> = loop {
                        let mut channel_state_lock = self.channel_state.lock().unwrap();
                        let channel_state = &mut *channel_state_lock;
 
@@ -3192,25 +3245,37 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                        if chan_entry.get().get_counterparty_node_id() != *counterparty_node_id {
                                                return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a message for a channel from the wrong node!".to_owned(), msg.channel_id));
                                        }
-                                       let (shutdown, closing_signed, dropped_htlcs) = try_chan_entry!(self, chan_entry.get_mut().shutdown(&self.fee_estimator, &their_features, &msg), channel_state, chan_entry);
+
+                                       let (shutdown, closing_signed, monitor_update, htlcs) = try_chan_entry!(self, chan_entry.get_mut().shutdown(&self.fee_estimator, &self.keys_manager, &their_features, &msg), channel_state, chan_entry);
+                                       dropped_htlcs = htlcs;
+
+                                       // Update the monitor with the shutdown script if necessary.
+                                       if let Some(monitor_update) = monitor_update {
+                                               if let Err(e) = self.chain_monitor.update_channel(chan_entry.get().get_funding_txo().unwrap(), monitor_update) {
+                                                       let (result, is_permanent) =
+                                                               handle_monitor_err!(self, e, channel_state.short_to_id, chan_entry.get_mut(), RAACommitmentOrder::CommitmentFirst, false, false, Vec::new(), Vec::new(), chan_entry.key());
+                                                       if is_permanent {
+                                                               remove_channel!(channel_state, chan_entry);
+                                                               break result;
+                                                       }
+                                               }
+                                       }
+
                                        if let Some(msg) = shutdown {
                                                channel_state.pending_msg_events.push(events::MessageSendEvent::SendShutdown {
-                                                       node_id: counterparty_node_id.clone(),
+                                                       node_id: *counterparty_node_id,
                                                        msg,
                                                });
                                        }
                                        if let Some(msg) = closing_signed {
+                                               // TODO: Do not send this if the monitor update failed.
                                                channel_state.pending_msg_events.push(events::MessageSendEvent::SendClosingSigned {
-                                                       node_id: counterparty_node_id.clone(),
+                                                       node_id: *counterparty_node_id,
                                                        msg,
                                                });
                                        }
-                                       if chan_entry.get().is_shutdown() {
-                                               if let Some(short_id) = chan_entry.get().get_short_channel_id() {
-                                                       channel_state.short_to_id.remove(&short_id);
-                                               }
-                                               (dropped_htlcs, Some(chan_entry.remove_entry().1))
-                                       } else { (dropped_htlcs, None) }
+
+                                       break Ok(());
                                },
                                hash_map::Entry::Vacant(_) => return Err(MsgHandleErrInternal::send_err_msg_no_close("Failed to find corresponding channel".to_owned(), msg.channel_id))
                        }
@@ -3218,14 +3283,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                for htlc_source in dropped_htlcs.drain(..) {
                        self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() });
                }
-               if let Some(chan) = chan_option {
-                       if let Ok(update) = self.get_channel_update_for_broadcast(&chan) {
-                               let mut channel_state = self.channel_state.lock().unwrap();
-                               channel_state.pending_msg_events.push(events::MessageSendEvent::BroadcastChannelUpdate {
-                                       msg: update
-                               });
-                       }
-               }
+
+               let _ = handle_error!(self, result, *counterparty_node_id);
                Ok(())
        }
 
@@ -4023,7 +4082,7 @@ where
                                result = NotifyOption::DoPersist;
                        }
 
-                       let mut pending_events = std::mem::replace(&mut *self.pending_events.lock().unwrap(), vec![]);
+                       let mut pending_events = mem::replace(&mut *self.pending_events.lock().unwrap(), vec![]);
                        if !pending_events.is_empty() {
                                result = NotifyOption::DoPersist;
                        }
@@ -4643,7 +4702,7 @@ impl_writeable_tlv_based!(HTLCPreviousHopData, {
 });
 
 impl Writeable for ClaimableHTLC {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
                let payment_data = match &self.onion_payload {
                        OnionPayload::Invoice(data) => Some(data.clone()),
                        _ => None,
@@ -4747,7 +4806,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> Writeable f
         F::Target: FeeEstimator,
         L::Target: Logger,
 {
-       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), io::Error> {
                let _consistency_lock = self.total_consistency_lock.write().unwrap();
 
                write_ver_prefix!(writer, SERIALIZATION_VERSION, MIN_SERIALIZATION_VERSION);
@@ -4943,7 +5002,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
         F::Target: FeeEstimator,
         L::Target: Logger,
 {
-       fn read<R: ::std::io::Read>(reader: &mut R, args: ChannelManagerReadArgs<'a, Signer, M, T, K, F, L>) -> Result<Self, DecodeError> {
+       fn read<R: io::Read>(reader: &mut R, args: ChannelManagerReadArgs<'a, Signer, M, T, K, F, L>) -> Result<Self, DecodeError> {
                let (blockhash, chan_manager) = <(BlockHash, ChannelManager<Signer, M, T, K, F, L>)>::read(reader, args)?;
                Ok((blockhash, Arc::new(chan_manager)))
        }
@@ -4957,7 +5016,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
         F::Target: FeeEstimator,
         L::Target: Logger,
 {
-       fn read<R: ::std::io::Read>(reader: &mut R, mut args: ChannelManagerReadArgs<'a, Signer, M, T, K, F, L>) -> Result<Self, DecodeError> {
+       fn read<R: io::Read>(reader: &mut R, mut args: ChannelManagerReadArgs<'a, Signer, M, T, K, F, L>) -> Result<Self, DecodeError> {
                let _ver = read_ver_prefix!(reader, SERIALIZATION_VERSION);
 
                let genesis_hash: BlockHash = Readable::read(reader)?;
@@ -4993,6 +5052,10 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                                                channel.get_cur_counterparty_commitment_transaction_number() > monitor.get_cur_counterparty_commitment_number() ||
                                                channel.get_latest_monitor_update_id() < monitor.get_latest_update_id() {
                                        // But if the channel is behind of the monitor, close the channel:
+                                       log_error!(args.logger, "A ChannelManager is stale compared to the current ChannelMonitor!");
+                                       log_error!(args.logger, " The channel will be force-closed and the latest commitment transaction from the ChannelMonitor broadcast.");
+                                       log_error!(args.logger, " The ChannelMonitor for channel {} is at update_id {} but the ChannelManager is at update_id {}.",
+                                               log_bytes!(channel.channel_id()), monitor.get_latest_update_id(), channel.get_latest_monitor_update_id());
                                        let (_, mut new_failed_htlcs) = channel.force_shutdown(true);
                                        failed_htlcs.append(&mut new_failed_htlcs);
                                        monitor.broadcast_latest_holder_commitment_txn(&args.tx_broadcaster, &args.logger);
@@ -5147,10 +5210,8 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
 mod tests {
        use bitcoin::hashes::Hash;
        use bitcoin::hashes::sha256::Hash as Sha256;
-       use core::sync::atomic::{AtomicBool, Ordering};
        use core::time::Duration;
        use ln::{PaymentPreimage, PaymentHash, PaymentSecret};
-       use ln::channelmanager::PersistenceNotifier;
        use ln::features::{InitFeatures, InvoiceFeatures};
        use ln::functional_test_utils::*;
        use ln::msgs;
@@ -5158,12 +5219,15 @@ mod tests {
        use routing::router::{get_keysend_route, get_route};
        use util::events::{Event, MessageSendEvent, MessageSendEventsProvider};
        use util::test_utils;
-       use std::sync::Arc;
-       use std::thread;
 
        #[cfg(feature = "std")]
        #[test]
        fn test_wait_timeout() {
+               use ln::channelmanager::PersistenceNotifier;
+               use sync::Arc;
+               use core::sync::atomic::{AtomicBool, Ordering};
+               use std::thread;
+
                let persistence_notifier = Arc::new(PersistenceNotifier::new());
                let thread_notifier = Arc::clone(&persistence_notifier);
 
@@ -5213,6 +5277,12 @@ mod tests {
                let node_chanmgrs = create_node_chanmgrs(3, &node_cfgs, &[None, None, None]);
                let nodes = create_network(3, &node_cfgs, &node_chanmgrs);
 
+               // All nodes start with a persistable update pending as `create_network` connects each node
+               // with all other nodes to make most tests simpler.
+               assert!(nodes[0].node.await_persistable_update_timeout(Duration::from_millis(1)));
+               assert!(nodes[1].node.await_persistable_update_timeout(Duration::from_millis(1)));
+               assert!(nodes[2].node.await_persistable_update_timeout(Duration::from_millis(1)));
+
                let mut chan = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known());
 
                // We check that the channel info nodes have doesn't change too early, even though we try
@@ -5545,7 +5615,7 @@ pub mod bench {
        use ln::channelmanager::{BestBlock, ChainParameters, ChannelManager, PaymentHash, PaymentPreimage};
        use ln::features::{InitFeatures, InvoiceFeatures};
        use ln::functional_test_utils::*;
-       use ln::msgs::ChannelMessageHandler;
+       use ln::msgs::{ChannelMessageHandler, Init};
        use routing::network_graph::NetworkGraph;
        use routing::router::get_route;
        use util::test_utils;
@@ -5608,6 +5678,8 @@ pub mod bench {
                });
                let node_b_holder = NodeHolder { node: &node_b };
 
+               node_a.peer_connected(&node_b.get_our_node_id(), &Init { features: InitFeatures::known() });
+               node_b.peer_connected(&node_a.get_our_node_id(), &Init { features: InitFeatures::known() });
                node_a.create_channel(node_b.get_our_node_id(), 8_000_000, 100_000_000, 42, None).unwrap();
                node_b.handle_open_channel(&node_a.get_our_node_id(), InitFeatures::known(), &get_event_msg!(node_a_holder, MessageSendEvent::SendOpenChannel, node_b.get_our_node_id()));
                node_a.handle_accept_channel(&node_b.get_our_node_id(), InitFeatures::known(), &get_event_msg!(node_b_holder, MessageSendEvent::SendAcceptChannel, node_a.get_our_node_id()));