X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Fln%2Fchannelmanager.rs;h=3109b8529057661cdafa06383d967ead3ac923b4;hb=feca062072de1df83fe0b61262173d61f611df19;hp=e353700be16f5bff37627e53b06cd45a8a585114;hpb=f151c029756622d06303252f918fce1b26809d23;p=rust-lightning diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index e353700b..3109b852 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -18,7 +18,7 @@ //! imply it needs to fail HTLCs/payments/channels it manages). //! -use bitcoin::blockdata::block::BlockHeader; +use bitcoin::blockdata::block::{Block, BlockHeader}; use bitcoin::blockdata::constants::genesis_block; use bitcoin::network::constants::Network; @@ -46,7 +46,7 @@ use ln::msgs; use ln::msgs::NetAddress; use ln::onion_utils; use ln::msgs::{ChannelMessageHandler, DecodeError, LightningError, OptionalField}; -use chain::keysinterface::{ChannelKeys, KeysInterface, KeysManager, InMemoryChannelKeys}; +use chain::keysinterface::{Sign, KeysInterface, KeysManager, InMemorySigner}; use util::config::UserConfig; use util::events::{Event, EventsProvider, MessageSendEvent, MessageSendEventsProvider}; use util::{byte_utils, events}; @@ -58,9 +58,11 @@ use util::errors::APIError; use std::{cmp, mem}; use std::collections::{HashMap, hash_map, HashSet}; use std::io::{Cursor, Read}; -use std::sync::{Arc, Mutex, MutexGuard, RwLock}; +use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; +#[cfg(any(test, feature = "allow_wallclock_use"))] +use std::time::Instant; use std::marker::{Sync, Send}; use std::ops::Deref; use bitcoin::hashes::hex::ToHex; @@ -312,8 +314,8 @@ pub(super) enum RAACommitmentOrder { } // Note this is only exposed in cfg(test): -pub(super) struct ChannelHolder { - pub(super) by_id: HashMap<[u8; 32], Channel>, +pub(super) struct ChannelHolder { + pub(super) by_id: HashMap<[u8; 32], Channel>, pub(super) short_to_id: HashMap, /// short channel id -> forward infos. Key of 0 means payments received /// Note that while this is held in the same mutex as the channels themselves, no consistency @@ -347,7 +349,7 @@ const ERR: () = "You need at least 32 bit pointers (well, usize, but we'll assum /// issues such as overly long function definitions. Note that the ChannelManager can take any /// type that implements KeysInterface for its keys manager, but this type alias chooses the /// concrete type of the KeysManager. -pub type SimpleArcChannelManager = Arc, Arc, Arc, Arc, Arc>>; +pub type SimpleArcChannelManager = ChannelManager, Arc, Arc, Arc, Arc>; /// SimpleRefChannelManager is a type alias for a ChannelManager reference, and is the reference /// counterpart to the SimpleArcChannelManager type alias. Use this type by default when you don't @@ -357,7 +359,7 @@ pub type SimpleArcChannelManager = Arc = ChannelManager; +pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L> = ChannelManager; /// Manager which keeps track of a number of channels and sends messages to the appropriate /// channel, also tracking HTLC preimages and forwarding onion packets appropriately. @@ -378,7 +380,7 @@ pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L> = ChannelManage /// ChannelMonitors passed by reference to read(), those channels will be force-closed based on the /// ChannelMonitor state and no funds will be lost (mod on-chain transaction fees). /// -/// Note that the deserializer is only implemented for (Sha256dHash, ChannelManager), which +/// Note that the deserializer is only implemented for (Option, ChannelManager), which /// tells you the last block hash which was block_connect()ed. You MUST rescan any blocks along /// the "reorg path" (ie call block_disconnected() until you get to a common block and then call /// block_connected() to step towards your best block) upon deserialization before using the @@ -395,10 +397,10 @@ pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L> = ChannelManage /// essentially you should default to using a SimpleRefChannelManager, and use a /// SimpleArcChannelManager when you require a ChannelManager with a static lifetime, such as when /// you're using lightning-net-tokio. -pub struct ChannelManager - where M::Target: chain::Watch, +pub struct ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -416,9 +418,9 @@ pub struct ChannelManager, #[cfg(any(test, feature = "_test_utils"))] - pub(super) channel_state: Mutex>, + pub(super) channel_state: Mutex>, #[cfg(not(any(test, feature = "_test_utils")))] - channel_state: Mutex>, + channel_state: Mutex>, our_network_key: SecretKey, /// Used to track the last value sent in a node_announcement "timestamp" field. We ensure this @@ -437,13 +439,46 @@ pub struct ChannelManager, + persistence_notifier: PersistenceNotifier, + keys_manager: K, logger: L, } +/// Whenever we release the `ChannelManager`'s `total_consistency_lock`, from read mode, it is +/// desirable to notify any listeners on `wait_timeout`/`wait` that new updates are available for +/// persistence. Therefore, this struct is responsible for locking the total consistency lock and, +/// upon going out of scope, sending the aforementioned notification (since the lock being released +/// indicates that the updates are ready for persistence). +struct PersistenceNotifierGuard<'a> { + persistence_notifier: &'a PersistenceNotifier, + // We hold onto this result so the lock doesn't get released immediately. + _read_guard: RwLockReadGuard<'a, ()>, +} + +impl<'a> PersistenceNotifierGuard<'a> { + fn new(lock: &'a RwLock<()>, notifier: &'a PersistenceNotifier) -> Self { + let read_guard = lock.read().unwrap(); + + Self { + persistence_notifier: notifier, + _read_guard: read_guard, + } + } +} + +impl<'a> Drop for PersistenceNotifierGuard<'a> { + fn drop(&mut self) { + self.persistence_notifier.notify(); + } +} + /// The amount of time we require our counterparty wait to claim their money (ie time between when /// we, or our watchtower, must check for them having broadcast a theft transaction). pub(crate) const BREAKDOWN_TIMEOUT: u16 = 6 * 24; @@ -514,7 +549,7 @@ pub struct ChannelDetails { /// If a payment fails to send, it can be in one of several states. This enum is returned as the /// Err() type describing which state the payment is in, see the description of individual enum /// states for more. -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum PaymentSendFailure { /// A parameter which was passed to send_payment was invalid, preventing us from attempting to /// send the payment at all. No channel state has been changed or messages sent to peers, and @@ -709,10 +744,10 @@ macro_rules! maybe_break_monitor_err { } } -impl ChannelManager - where M::Target: chain::Watch, +impl ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -731,7 +766,8 @@ impl /// Users need to notify the new ChannelManager when a new block is connected or /// disconnected using its `block_connected` and `block_disconnected` methods. pub fn new(network: Network, fee_est: F, chain_monitor: M, tx_broadcaster: T, logger: L, keys_manager: K, config: UserConfig, current_blockchain_height: usize) -> Self { - let secp_ctx = Secp256k1::new(); + let mut secp_ctx = Secp256k1::new(); + secp_ctx.seeded_randomize(&keys_manager.get_secure_random_bytes()); ChannelManager { default_configuration: config.clone(), @@ -759,6 +795,7 @@ impl pending_events: Mutex::new(Vec::new()), total_consistency_lock: RwLock::new(()), + persistence_notifier: PersistenceNotifier::new(), keys_manager, @@ -787,7 +824,10 @@ impl let channel = Channel::new_outbound(&self.fee_estimator, &self.keys_manager, their_network_key, channel_value_satoshis, push_msat, user_id, config)?; let res = channel.get_open_channel(self.genesis_hash.clone()); - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); + // We want to make sure the lock is actually acquired by PersistenceNotifierGuard. + debug_assert!(&self.total_consistency_lock.try_write().is_err()); + let mut channel_state = self.channel_state.lock().unwrap(); match channel_state.by_id.entry(channel.channel_id()) { hash_map::Entry::Occupied(_) => { @@ -806,7 +846,7 @@ impl Ok(()) } - fn list_channels_with_filter)) -> bool>(&self, f: Fn) -> Vec { + fn list_channels_with_filter)) -> bool>(&self, f: Fn) -> Vec { let mut res = Vec::new(); { let channel_state = self.channel_state.lock().unwrap(); @@ -859,7 +899,7 @@ impl /// /// May generate a SendShutdown message event on success, which should be relayed. pub fn close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let (mut failed_htlcs, chan_option) = { let mut channel_state_lock = self.channel_state.lock().unwrap(); @@ -916,19 +956,22 @@ impl } } - /// Force closes a channel, immediately broadcasting the latest local commitment transaction to - /// the chain and rejecting new HTLCs on the given channel. Fails if channel_id is unknown to the manager. - pub fn force_close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError>{ - let _consistency_lock = self.total_consistency_lock.read().unwrap(); - + fn force_close_channel_with_peer(&self, channel_id: &[u8; 32], peer_node_id: Option<&PublicKey>) -> Result<(), APIError> { let mut chan = { let mut channel_state_lock = self.channel_state.lock().unwrap(); let channel_state = &mut *channel_state_lock; - if let Some(chan) = channel_state.by_id.remove(channel_id) { - if let Some(short_id) = chan.get_short_channel_id() { + if let hash_map::Entry::Occupied(chan) = channel_state.by_id.entry(channel_id.clone()) { + if let Some(node_id) = peer_node_id { + if chan.get().get_counterparty_node_id() != *node_id { + // Error or Ok here doesn't matter - the result is only exposed publicly + // when peer_node_id is None anyway. + return Ok(()); + } + } + if let Some(short_id) = chan.get().get_short_channel_id() { channel_state.short_to_id.remove(&short_id); } - chan + chan.remove_entry().1 } else { return Err(APIError::ChannelUnavailable{err: "No such channel".to_owned()}); } @@ -945,6 +988,13 @@ impl Ok(()) } + /// Force closes a channel, immediately broadcasting the latest local commitment transaction to + /// the chain and rejecting new HTLCs on the given channel. Fails if channel_id is unknown to the manager. + pub fn force_close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> { + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); + self.force_close_channel_with_peer(channel_id, None) + } + /// Force close all channels, immediately broadcasting the latest local commitment transaction /// for each to the chain and rejecting new HTLCs on each. pub fn force_close_all_channels(&self) { @@ -953,7 +1003,7 @@ impl } } - fn decode_update_add_htlc_onion(&self, msg: &msgs::UpdateAddHTLC) -> (PendingHTLCStatus, MutexGuard>) { + fn decode_update_add_htlc_onion(&self, msg: &msgs::UpdateAddHTLC) -> (PendingHTLCStatus, MutexGuard>) { macro_rules! return_malformed_err { ($msg: expr, $err_code: expr) => { { @@ -1225,7 +1275,7 @@ impl /// only fails if the channel does not yet have an assigned short_id /// May be called with channel_state already locked! - fn get_channel_update(&self, chan: &Channel) -> Result { + fn get_channel_update(&self, chan: &Channel) -> Result { let short_channel_id = match chan.get_short_channel_id() { None => return Err(LightningError{err: "Channel not yet established".to_owned(), action: msgs::ErrorAction::IgnoreError}), Some(id) => id, @@ -1269,7 +1319,7 @@ impl } let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash); - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let err: Result<(), _> = loop { let mut channel_lock = self.channel_state.lock().unwrap(); @@ -1437,7 +1487,7 @@ impl /// May panic if the funding_txo is duplicative with some other channel (note that this should /// be trivially prevented by using unique funding transaction keys per-channel). pub fn funding_transaction_generated(&self, temporary_channel_id: &[u8; 32], funding_txo: OutPoint) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let (chan, msg) = { let (res, chan) = match self.channel_state.lock().unwrap().by_id.remove(temporary_channel_id) { @@ -1473,7 +1523,7 @@ impl } } - fn get_announcement_sigs(&self, chan: &Channel) -> Option { + fn get_announcement_sigs(&self, chan: &Channel) -> Option { if !chan.should_announce() { log_trace!(self.logger, "Can't send announcement_signatures for private channel {}", log_bytes!(chan.channel_id())); return None @@ -1520,7 +1570,7 @@ impl /// /// Panics if addresses is absurdly large (more than 500). pub fn broadcast_node_announcement(&self, rgb: [u8; 3], alias: [u8; 32], addresses: Vec) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); if addresses.len() > 500 { panic!("More than half the message size was taken up by public addresses!"); @@ -1550,7 +1600,7 @@ impl /// Should only really ever be called in response to a PendingHTLCsForwardable event. /// Will likely generate further events. pub fn process_pending_htlc_forwards(&self) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut new_events = Vec::new(); let mut failed_forwards = Vec::new(); @@ -1810,7 +1860,7 @@ impl /// /// This method handles all the details, and must be called roughly once per minute. pub fn timer_chan_freshness_every_min(&self) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut channel_state_lock = self.channel_state.lock().unwrap(); let channel_state = &mut *channel_state_lock; for (_, chan) in channel_state.by_id.iter_mut() { @@ -1835,7 +1885,7 @@ impl /// Returns false if no payment was found to fail backwards, true if the process of failing the /// HTLC backwards has been started. pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash, payment_secret: &Option) -> bool { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut channel_state = Some(self.channel_state.lock().unwrap()); let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&(*payment_hash, *payment_secret)); @@ -1898,7 +1948,7 @@ impl /// to fail and take the channel_state lock for each iteration (as we take ownership and may /// drop it). In other words, no assumptions are made that entries in claimable_htlcs point to /// still-available channels. - fn fail_htlc_backwards_internal(&self, mut channel_state_lock: MutexGuard>, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason) { + fn fail_htlc_backwards_internal(&self, mut channel_state_lock: MutexGuard>, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason) { //TODO: There is a timing attack here where if a node fails an HTLC back to us they can //identify whether we sent it or not based on the (I presume) very different runtime //between the branches here. We should make this async and move it into the forward HTLCs @@ -2014,7 +2064,7 @@ impl pub fn claim_funds(&self, payment_preimage: PaymentPreimage, payment_secret: &Option, expected_amount: u64) -> bool { let payment_hash = PaymentHash(Sha256::hash(&payment_preimage.0).into_inner()); - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut channel_state = Some(self.channel_state.lock().unwrap()); let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&(payment_hash, *payment_secret)); @@ -2092,7 +2142,7 @@ impl } else { false } } - fn claim_funds_from_hop(&self, channel_state_lock: &mut MutexGuard>, prev_hop: HTLCPreviousHopData, payment_preimage: PaymentPreimage) -> Result<(), Option<(PublicKey, MsgHandleErrInternal)>> { + fn claim_funds_from_hop(&self, channel_state_lock: &mut MutexGuard>, prev_hop: HTLCPreviousHopData, payment_preimage: PaymentPreimage) -> Result<(), Option<(PublicKey, MsgHandleErrInternal)>> { //TODO: Delay the claimed_funds relaying just like we do outbound relay! let channel_state = &mut **channel_state_lock; let chan_id = match channel_state.short_to_id.get(&prev_hop.short_channel_id) { @@ -2145,7 +2195,7 @@ impl } else { unreachable!(); } } - fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard>, source: HTLCSource, payment_preimage: PaymentPreimage) { + fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard>, source: HTLCSource, payment_preimage: PaymentPreimage) { match source { HTLCSource::OutboundRoute { .. } => { mem::drop(channel_state_lock); @@ -2210,7 +2260,7 @@ impl /// 4) once all remote copies are updated, you call this function with the update_id that /// completed, and once it is the latest the Channel will be re-enabled. pub fn channel_monitor_updated(&self, funding_txo: &OutPoint, highest_applied_update_id: u64) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut close_results = Vec::new(); let mut htlc_forwards = Vec::new(); @@ -2458,7 +2508,7 @@ impl } } - fn internal_shutdown(&self, counterparty_node_id: &PublicKey, msg: &msgs::Shutdown) -> Result<(), MsgHandleErrInternal> { + fn internal_shutdown(&self, counterparty_node_id: &PublicKey, their_features: &InitFeatures, msg: &msgs::Shutdown) -> Result<(), MsgHandleErrInternal> { let (mut dropped_htlcs, chan_option) = { let mut channel_state_lock = self.channel_state.lock().unwrap(); let channel_state = &mut *channel_state_lock; @@ -2468,7 +2518,7 @@ impl if chan_entry.get().get_counterparty_node_id() != *counterparty_node_id { return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a message for a channel from the wrong node!".to_owned(), msg.channel_id)); } - let (shutdown, closing_signed, dropped_htlcs) = try_chan_entry!(self, chan_entry.get_mut().shutdown(&self.fee_estimator, &msg), channel_state, chan_entry); + let (shutdown, closing_signed, dropped_htlcs) = try_chan_entry!(self, chan_entry.get_mut().shutdown(&self.fee_estimator, &their_features, &msg), channel_state, chan_entry); if let Some(msg) = shutdown { channel_state.pending_msg_events.push(events::MessageSendEvent::SendShutdown { node_id: counterparty_node_id.clone(), @@ -2570,7 +2620,7 @@ impl return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a message for a channel from the wrong node!".to_owned(), msg.channel_id)); } - let create_pending_htlc_status = |chan: &Channel, pending_forward_info: PendingHTLCStatus, error_code: u16| { + let create_pending_htlc_status = |chan: &Channel, pending_forward_info: PendingHTLCStatus, error_code: u16| { // Ensure error_code has the UPDATE flag set, since by default we send a // channel update along as part of failing the HTLC. assert!((error_code & 0x1000) != 0); @@ -2961,7 +3011,7 @@ impl /// (C-not exported) Cause its doc(hidden) anyway #[doc(hidden)] pub fn update_fee(&self, channel_id: [u8;32], feerate_per_kw: u32) -> Result<(), APIError> { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let counterparty_node_id; let err: Result<(), _> = loop { let mut channel_state_lock = self.channel_state.lock().unwrap(); @@ -3052,10 +3102,10 @@ impl } } -impl MessageSendEventsProvider for ChannelManager - where M::Target: chain::Watch, +impl MessageSendEventsProvider for ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3071,10 +3121,10 @@ impl } } -impl EventsProvider for ChannelManager - where M::Target: chain::Watch, +impl EventsProvider for ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3090,10 +3140,28 @@ impl } } -impl ChannelManager - where M::Target: chain::Watch, +impl chain::Listen for ChannelManager +where + M::Target: chain::Watch, + T::Target: BroadcasterInterface, + K::Target: KeysInterface, + F::Target: FeeEstimator, + L::Target: Logger, +{ + fn block_connected(&self, block: &Block, height: u32) { + let txdata: Vec<_> = block.txdata.iter().enumerate().collect(); + ChannelManager::block_connected(self, &block.header, &txdata, height); + } + + fn block_disconnected(&self, header: &BlockHeader, _height: u32) { + ChannelManager::block_disconnected(self, header); + } +} + +impl ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3101,7 +3169,7 @@ impl pub fn block_connected(&self, header: &BlockHeader, txdata: &TransactionData, height: u32) { let header_hash = header.block_hash(); log_trace!(self.logger, "Block {} at height {} connected", header_hash, height); - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut failed_channels = Vec::new(); let mut timed_out_htlcs = Vec::new(); { @@ -3214,7 +3282,7 @@ impl /// If necessary, the channel may be force-closed without letting the counterparty participate /// in the shutdown. pub fn block_disconnected(&self, header: &BlockHeader) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut failed_channels = Vec::new(); { let mut channel_lock = self.channel_state.lock().unwrap(); @@ -3244,98 +3312,121 @@ impl self.latest_block_height.fetch_sub(1, Ordering::AcqRel); *self.last_block_hash.try_lock().expect("block_(dis)connected must not be called in parallel") = header.block_hash(); } + + /// Blocks until ChannelManager needs to be persisted or a timeout is reached. It returns a bool + /// indicating whether persistence is necessary. Only one listener on `wait_timeout` is + /// guaranteed to be woken up. + /// Note that the feature `allow_wallclock_use` must be enabled to use this function. + #[cfg(any(test, feature = "allow_wallclock_use"))] + pub fn wait_timeout(&self, max_wait: Duration) -> bool { + self.persistence_notifier.wait_timeout(max_wait) + } + + /// Blocks until ChannelManager needs to be persisted. Only one listener on `wait` is + /// guaranteed to be woken up. + pub fn wait(&self) { + self.persistence_notifier.wait() + } + + #[cfg(any(test, feature = "_test_utils"))] + pub fn get_persistence_condvar_value(&self) -> bool { + let mutcond = &self.persistence_notifier.persistence_lock; + let &(ref mtx, _) = mutcond; + let guard = mtx.lock().unwrap(); + *guard + } } -impl - ChannelMessageHandler for ChannelManager - where M::Target: chain::Watch, +impl + ChannelMessageHandler for ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { fn handle_open_channel(&self, counterparty_node_id: &PublicKey, their_features: InitFeatures, msg: &msgs::OpenChannel) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_open_channel(counterparty_node_id, their_features, msg), *counterparty_node_id); } fn handle_accept_channel(&self, counterparty_node_id: &PublicKey, their_features: InitFeatures, msg: &msgs::AcceptChannel) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_accept_channel(counterparty_node_id, their_features, msg), *counterparty_node_id); } fn handle_funding_created(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingCreated) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_funding_created(counterparty_node_id, msg), *counterparty_node_id); } fn handle_funding_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingSigned) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_funding_signed(counterparty_node_id, msg), *counterparty_node_id); } fn handle_funding_locked(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingLocked) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_funding_locked(counterparty_node_id, msg), *counterparty_node_id); } - fn handle_shutdown(&self, counterparty_node_id: &PublicKey, msg: &msgs::Shutdown) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); - let _ = handle_error!(self, self.internal_shutdown(counterparty_node_id, msg), *counterparty_node_id); + fn handle_shutdown(&self, counterparty_node_id: &PublicKey, their_features: &InitFeatures, msg: &msgs::Shutdown) { + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); + let _ = handle_error!(self, self.internal_shutdown(counterparty_node_id, their_features, msg), *counterparty_node_id); } fn handle_closing_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::ClosingSigned) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_closing_signed(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_add_htlc(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateAddHTLC) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_add_htlc(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_fulfill_htlc(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateFulfillHTLC) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_fulfill_htlc(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_fail_htlc(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_fail_htlc(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_fail_malformed_htlc(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_fail_malformed_htlc(counterparty_node_id, msg), *counterparty_node_id); } fn handle_commitment_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::CommitmentSigned) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_commitment_signed(counterparty_node_id, msg), *counterparty_node_id); } fn handle_revoke_and_ack(&self, counterparty_node_id: &PublicKey, msg: &msgs::RevokeAndACK) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_revoke_and_ack(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_fee(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateFee) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_fee(counterparty_node_id, msg), *counterparty_node_id); } fn handle_announcement_signatures(&self, counterparty_node_id: &PublicKey, msg: &msgs::AnnouncementSignatures) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_announcement_signatures(counterparty_node_id, msg), *counterparty_node_id); } fn handle_channel_reestablish(&self, counterparty_node_id: &PublicKey, msg: &msgs::ChannelReestablish) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_channel_reestablish(counterparty_node_id, msg), *counterparty_node_id); } fn peer_disconnected(&self, counterparty_node_id: &PublicKey, no_connection_possible: bool) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut failed_channels = Vec::new(); let mut failed_payments = Vec::new(); let mut no_channels_remain = true; @@ -3408,6 +3499,7 @@ impl true, &events::MessageSendEvent::SendChannelRangeQuery { .. } => false, &events::MessageSendEvent::SendShortIdsQuery { .. } => false, + &events::MessageSendEvent::SendReplyChannelRange { .. } => false, } }); } @@ -3428,7 +3520,7 @@ impl, Condvar), +} + +impl PersistenceNotifier { + fn new() -> Self { + Self { + persistence_lock: (Mutex::new(false), Condvar::new()), + } + } + + fn wait(&self) { + loop { + let &(ref mtx, ref cvar) = &self.persistence_lock; + let mut guard = mtx.lock().unwrap(); + guard = cvar.wait(guard).unwrap(); + let result = *guard; + if result { + *guard = false; + return + } + } + } + + #[cfg(any(test, feature = "allow_wallclock_use"))] + fn wait_timeout(&self, max_wait: Duration) -> bool { + let current_time = Instant::now(); + loop { + let &(ref mtx, ref cvar) = &self.persistence_lock; + let mut guard = mtx.lock().unwrap(); + guard = cvar.wait_timeout(guard, max_wait).unwrap().0; + // Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the + // desired wait time has actually passed, and if not then restart the loop with a reduced wait + // time. Note that this logic can be highly simplified through the use of + // `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to + // 1.42.0. + let elapsed = current_time.elapsed(); + let result = *guard; + if result || elapsed >= max_wait { + *guard = false; + return result; + } + match max_wait.checked_sub(elapsed) { + None => return result, + Some(_) => continue + } } } + + // Signal to the ChannelManager persister that there are updates necessitating persisting to disk. + fn notify(&self) { + let &(ref persist_mtx, ref cnd) = &self.persistence_lock; + let mut persistence_lock = persist_mtx.lock().unwrap(); + *persistence_lock = true; + mem::drop(persistence_lock); + cnd.notify_all(); + } } const SERIALIZATION_VERSION: u8 = 1; @@ -3697,10 +3852,10 @@ impl Readable for HTLCForwardInfo { } } -impl Writeable for ChannelManager - where M::Target: chain::Watch, +impl Writeable for ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3771,7 +3926,7 @@ impl /// At a high-level, the process for deserializing a ChannelManager and resuming normal operation /// is: /// 1) Deserialize all stored ChannelMonitors. -/// 2) Deserialize the ChannelManager by filling in this struct and calling <(Sha256dHash, +/// 2) Deserialize the ChannelManager by filling in this struct and calling <(Option, /// ChannelManager)>::read(reader, args). /// This may result in closing some Channels if the ChannelMonitor is newer than the stored /// ChannelManager state to ensure no loss of funds. Thus, transactions may be broadcasted. @@ -3780,10 +3935,10 @@ impl /// 4) Reconnect blocks on your ChannelMonitors. /// 5) Move the ChannelMonitors into your local chain::Watch. /// 6) Disconnect/connect blocks on the ChannelManager. -pub struct ChannelManagerReadArgs<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> - where M::Target: chain::Watch, +pub struct ChannelManagerReadArgs<'a, Signer: 'a + Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3826,14 +3981,14 @@ pub struct ChannelManagerReadArgs<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: /// this struct. /// /// (C-not exported) because we have no HashMap bindings - pub channel_monitors: HashMap>, + pub channel_monitors: HashMap>, } -impl<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> - ChannelManagerReadArgs<'a, ChanSigner, M, T, K, F, L> - where M::Target: chain::Watch, +impl<'a, Signer: 'a + Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> + ChannelManagerReadArgs<'a, Signer, M, T, K, F, L> + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3841,7 +3996,7 @@ impl<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L /// HashMap for you. This is primarily useful for C bindings where it is not practical to /// populate a HashMap directly from C. pub fn new(keys_manager: K, fee_estimator: F, chain_monitor: M, tx_broadcaster: T, logger: L, default_config: UserConfig, - mut channel_monitors: Vec<&'a mut ChannelMonitor>) -> Self { + mut channel_monitors: Vec<&'a mut ChannelMonitor>) -> Self { Self { keys_manager, fee_estimator, chain_monitor, tx_broadcaster, logger, default_config, channel_monitors: channel_monitors.drain(..).map(|monitor| { (monitor.get_funding_txo().0, monitor) }).collect() @@ -3851,29 +4006,29 @@ impl<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L // Implement ReadableArgs for an Arc'd ChannelManager to make it a bit easier to work with the // SipmleArcChannelManager type: -impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> - ReadableArgs> for (BlockHash, Arc>) - where M::Target: chain::Watch, +impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> + ReadableArgs> for (Option, Arc>) + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { - fn read(reader: &mut R, args: ChannelManagerReadArgs<'a, ChanSigner, M, T, K, F, L>) -> Result { - let (blockhash, chan_manager) = <(BlockHash, ChannelManager)>::read(reader, args)?; + fn read(reader: &mut R, args: ChannelManagerReadArgs<'a, Signer, M, T, K, F, L>) -> Result { + let (blockhash, chan_manager) = <(Option, ChannelManager)>::read(reader, args)?; Ok((blockhash, Arc::new(chan_manager))) } } -impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> - ReadableArgs> for (BlockHash, ChannelManager) - where M::Target: chain::Watch, +impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> + ReadableArgs> for (Option, ChannelManager) + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { - fn read(reader: &mut R, mut args: ChannelManagerReadArgs<'a, ChanSigner, M, T, K, F, L>) -> Result { + fn read(reader: &mut R, mut args: ChannelManagerReadArgs<'a, Signer, M, T, K, F, L>) -> Result { let _ver: u8 = Readable::read(reader)?; let min_ver: u8 = Readable::read(reader)?; if min_ver > SERIALIZATION_VERSION { @@ -3891,7 +4046,7 @@ impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Der let mut by_id = HashMap::with_capacity(cmp::min(channel_count as usize, 128)); let mut short_to_id = HashMap::with_capacity(cmp::min(channel_count as usize, 128)); for _ in 0..channel_count { - let mut channel: Channel = Channel::read(reader, &args.keys_manager)?; + let mut channel: Channel = Channel::read(reader, &args.keys_manager)?; if channel.last_block_connected != Default::default() && channel.last_block_connected != last_block_hash { return Err(DecodeError::InvalidValue); } @@ -3976,6 +4131,9 @@ impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Der let last_node_announcement_serial: u32 = Readable::read(reader)?; + let mut secp_ctx = Secp256k1::new(); + secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes()); + let channel_manager = ChannelManager { genesis_hash, fee_estimator: args.fee_estimator, @@ -3984,7 +4142,7 @@ impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Der latest_block_height: AtomicUsize::new(latest_block_height as usize), last_block_hash: Mutex::new(last_block_hash), - secp_ctx: Secp256k1::new(), + secp_ctx, channel_state: Mutex::new(ChannelHolder { by_id, @@ -4001,6 +4159,8 @@ impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Der pending_events: Mutex::new(pending_events_read), total_consistency_lock: RwLock::new(()), + persistence_notifier: PersistenceNotifier::new(), + keys_manager: args.keys_manager, logger: args.logger, default_configuration: args.default_config, @@ -4013,6 +4173,62 @@ impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Der //TODO: Broadcast channel update for closed channels, but only after we've made a //connection or two. - Ok((last_block_hash.clone(), channel_manager)) + let last_seen_block_hash = if last_block_hash == Default::default() { + None + } else { + Some(last_block_hash) + }; + Ok((last_seen_block_hash, channel_manager)) + } +} + +#[cfg(test)] +mod tests { + use ln::channelmanager::PersistenceNotifier; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::thread; + use std::time::Duration; + + #[test] + fn test_wait_timeout() { + let persistence_notifier = Arc::new(PersistenceNotifier::new()); + let thread_notifier = Arc::clone(&persistence_notifier); + + let exit_thread = Arc::new(AtomicBool::new(false)); + let exit_thread_clone = exit_thread.clone(); + thread::spawn(move || { + loop { + let &(ref persist_mtx, ref cnd) = &thread_notifier.persistence_lock; + let mut persistence_lock = persist_mtx.lock().unwrap(); + *persistence_lock = true; + cnd.notify_all(); + + if exit_thread_clone.load(Ordering::SeqCst) { + break + } + } + }); + + // Check that we can block indefinitely until updates are available. + let _ = persistence_notifier.wait(); + + // Check that the PersistenceNotifier will return after the given duration if updates are + // available. + loop { + if persistence_notifier.wait_timeout(Duration::from_millis(100)) { + break + } + } + + exit_thread.store(true, Ordering::SeqCst); + + // Check that the PersistenceNotifier will return after the given duration even if no updates + // are available. + loop { + if !persistence_notifier.wait_timeout(Duration::from_millis(100)) { + break + } + } } }