X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=lightning%2Fsrc%2Fln%2Fchannelmanager.rs;h=21ee32b0f0be2dfec5535fa0f8e01f6f708d19e1;hb=ff00f6f8861419b73269e6c51d75ac9de75f1d1f;hp=621013724bc4195cd58b1dac9fb8c6aee915ef0b;hpb=c35002fa9c16042badfa5e7bf819df5f1d2ae60a;p=rust-lightning diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 62101372..21ee32b0 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -46,7 +46,7 @@ use ln::msgs; use ln::msgs::NetAddress; use ln::onion_utils; use ln::msgs::{ChannelMessageHandler, DecodeError, LightningError, OptionalField}; -use chain::keysinterface::{ChannelKeys, KeysInterface, KeysManager, InMemoryChannelKeys}; +use chain::keysinterface::{Sign, KeysInterface, KeysManager, InMemorySigner}; use util::config::UserConfig; use util::events::{Event, EventsProvider, MessageSendEvent, MessageSendEventsProvider}; use util::{byte_utils, events}; @@ -58,9 +58,11 @@ use util::errors::APIError; use std::{cmp, mem}; use std::collections::{HashMap, hash_map, HashSet}; use std::io::{Cursor, Read}; -use std::sync::{Arc, Mutex, MutexGuard, RwLock}; +use std::sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; +#[cfg(any(test, feature = "allow_wallclock_use"))] +use std::time::Instant; use std::marker::{Sync, Send}; use std::ops::Deref; use bitcoin::hashes::hex::ToHex; @@ -312,8 +314,8 @@ pub(super) enum RAACommitmentOrder { } // Note this is only exposed in cfg(test): -pub(super) struct ChannelHolder { - pub(super) by_id: HashMap<[u8; 32], Channel>, +pub(super) struct ChannelHolder { + pub(super) by_id: HashMap<[u8; 32], Channel>, pub(super) short_to_id: HashMap, /// short channel id -> forward infos. Key of 0 means payments received /// Note that while this is held in the same mutex as the channels themselves, no consistency @@ -347,7 +349,7 @@ const ERR: () = "You need at least 32 bit pointers (well, usize, but we'll assum /// issues such as overly long function definitions. Note that the ChannelManager can take any /// type that implements KeysInterface for its keys manager, but this type alias chooses the /// concrete type of the KeysManager. -pub type SimpleArcChannelManager = Arc, Arc, Arc, Arc, Arc>>; +pub type SimpleArcChannelManager = Arc, Arc, Arc, Arc, Arc>>; /// SimpleRefChannelManager is a type alias for a ChannelManager reference, and is the reference /// counterpart to the SimpleArcChannelManager type alias. Use this type by default when you don't @@ -357,7 +359,7 @@ pub type SimpleArcChannelManager = Arc = ChannelManager; +pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L> = ChannelManager; /// Manager which keeps track of a number of channels and sends messages to the appropriate /// channel, also tracking HTLC preimages and forwarding onion packets appropriately. @@ -395,10 +397,10 @@ pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, M, T, F, L> = ChannelManage /// essentially you should default to using a SimpleRefChannelManager, and use a /// SimpleArcChannelManager when you require a ChannelManager with a static lifetime, such as when /// you're using lightning-net-tokio. -pub struct ChannelManager - where M::Target: chain::Watch, +pub struct ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -416,9 +418,9 @@ pub struct ChannelManager, #[cfg(any(test, feature = "_test_utils"))] - pub(super) channel_state: Mutex>, + pub(super) channel_state: Mutex>, #[cfg(not(any(test, feature = "_test_utils")))] - channel_state: Mutex>, + channel_state: Mutex>, our_network_key: SecretKey, /// Used to track the last value sent in a node_announcement "timestamp" field. We ensure this @@ -437,13 +439,46 @@ pub struct ChannelManager, + persistence_notifier: PersistenceNotifier, + keys_manager: K, logger: L, } +/// Whenever we release the `ChannelManager`'s `total_consistency_lock`, from read mode, it is +/// desirable to notify any listeners on `wait_timeout`/`wait` that new updates are available for +/// persistence. Therefore, this struct is responsible for locking the total consistency lock and, +/// upon going out of scope, sending the aforementioned notification (since the lock being released +/// indicates that the updates are ready for persistence). +struct PersistenceNotifierGuard<'a> { + persistence_notifier: &'a PersistenceNotifier, + // We hold onto this result so the lock doesn't get released immediately. + _read_guard: RwLockReadGuard<'a, ()>, +} + +impl<'a> PersistenceNotifierGuard<'a> { + fn new(lock: &'a RwLock<()>, notifier: &'a PersistenceNotifier) -> Self { + let read_guard = lock.read().unwrap(); + + Self { + persistence_notifier: notifier, + _read_guard: read_guard, + } + } +} + +impl<'a> Drop for PersistenceNotifierGuard<'a> { + fn drop(&mut self) { + self.persistence_notifier.notify(); + } +} + /// The amount of time we require our counterparty wait to claim their money (ie time between when /// we, or our watchtower, must check for them having broadcast a theft transaction). pub(crate) const BREAKDOWN_TIMEOUT: u16 = 6 * 24; @@ -514,7 +549,7 @@ pub struct ChannelDetails { /// If a payment fails to send, it can be in one of several states. This enum is returned as the /// Err() type describing which state the payment is in, see the description of individual enum /// states for more. -#[derive(Debug)] +#[derive(Clone, Debug)] pub enum PaymentSendFailure { /// A parameter which was passed to send_payment was invalid, preventing us from attempting to /// send the payment at all. No channel state has been changed or messages sent to peers, and @@ -709,10 +744,10 @@ macro_rules! maybe_break_monitor_err { } } -impl ChannelManager - where M::Target: chain::Watch, +impl ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -759,6 +794,7 @@ impl pending_events: Mutex::new(Vec::new()), total_consistency_lock: RwLock::new(()), + persistence_notifier: PersistenceNotifier::new(), keys_manager, @@ -787,7 +823,10 @@ impl let channel = Channel::new_outbound(&self.fee_estimator, &self.keys_manager, their_network_key, channel_value_satoshis, push_msat, user_id, config)?; let res = channel.get_open_channel(self.genesis_hash.clone()); - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); + // We want to make sure the lock is actually acquired by PersistenceNotifierGuard. + debug_assert!(&self.total_consistency_lock.try_write().is_err()); + let mut channel_state = self.channel_state.lock().unwrap(); match channel_state.by_id.entry(channel.channel_id()) { hash_map::Entry::Occupied(_) => { @@ -806,7 +845,7 @@ impl Ok(()) } - fn list_channels_with_filter)) -> bool>(&self, f: Fn) -> Vec { + fn list_channels_with_filter)) -> bool>(&self, f: Fn) -> Vec { let mut res = Vec::new(); { let channel_state = self.channel_state.lock().unwrap(); @@ -859,7 +898,7 @@ impl /// /// May generate a SendShutdown message event on success, which should be relayed. pub fn close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let (mut failed_htlcs, chan_option) = { let mut channel_state_lock = self.channel_state.lock().unwrap(); @@ -951,7 +990,7 @@ impl /// Force closes a channel, immediately broadcasting the latest local commitment transaction to /// the chain and rejecting new HTLCs on the given channel. Fails if channel_id is unknown to the manager. pub fn force_close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); self.force_close_channel_with_peer(channel_id, None) } @@ -963,7 +1002,7 @@ impl } } - fn decode_update_add_htlc_onion(&self, msg: &msgs::UpdateAddHTLC) -> (PendingHTLCStatus, MutexGuard>) { + fn decode_update_add_htlc_onion(&self, msg: &msgs::UpdateAddHTLC) -> (PendingHTLCStatus, MutexGuard>) { macro_rules! return_malformed_err { ($msg: expr, $err_code: expr) => { { @@ -1235,7 +1274,7 @@ impl /// only fails if the channel does not yet have an assigned short_id /// May be called with channel_state already locked! - fn get_channel_update(&self, chan: &Channel) -> Result { + fn get_channel_update(&self, chan: &Channel) -> Result { let short_channel_id = match chan.get_short_channel_id() { None => return Err(LightningError{err: "Channel not yet established".to_owned(), action: msgs::ErrorAction::IgnoreError}), Some(id) => id, @@ -1279,7 +1318,7 @@ impl } let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash); - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let err: Result<(), _> = loop { let mut channel_lock = self.channel_state.lock().unwrap(); @@ -1447,7 +1486,7 @@ impl /// May panic if the funding_txo is duplicative with some other channel (note that this should /// be trivially prevented by using unique funding transaction keys per-channel). pub fn funding_transaction_generated(&self, temporary_channel_id: &[u8; 32], funding_txo: OutPoint) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let (chan, msg) = { let (res, chan) = match self.channel_state.lock().unwrap().by_id.remove(temporary_channel_id) { @@ -1483,7 +1522,7 @@ impl } } - fn get_announcement_sigs(&self, chan: &Channel) -> Option { + fn get_announcement_sigs(&self, chan: &Channel) -> Option { if !chan.should_announce() { log_trace!(self.logger, "Can't send announcement_signatures for private channel {}", log_bytes!(chan.channel_id())); return None @@ -1530,7 +1569,7 @@ impl /// /// Panics if addresses is absurdly large (more than 500). pub fn broadcast_node_announcement(&self, rgb: [u8; 3], alias: [u8; 32], addresses: Vec) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); if addresses.len() > 500 { panic!("More than half the message size was taken up by public addresses!"); @@ -1560,7 +1599,7 @@ impl /// Should only really ever be called in response to a PendingHTLCsForwardable event. /// Will likely generate further events. pub fn process_pending_htlc_forwards(&self) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut new_events = Vec::new(); let mut failed_forwards = Vec::new(); @@ -1820,7 +1859,7 @@ impl /// /// This method handles all the details, and must be called roughly once per minute. pub fn timer_chan_freshness_every_min(&self) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut channel_state_lock = self.channel_state.lock().unwrap(); let channel_state = &mut *channel_state_lock; for (_, chan) in channel_state.by_id.iter_mut() { @@ -1845,7 +1884,7 @@ impl /// Returns false if no payment was found to fail backwards, true if the process of failing the /// HTLC backwards has been started. pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash, payment_secret: &Option) -> bool { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut channel_state = Some(self.channel_state.lock().unwrap()); let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&(*payment_hash, *payment_secret)); @@ -1908,7 +1947,7 @@ impl /// to fail and take the channel_state lock for each iteration (as we take ownership and may /// drop it). In other words, no assumptions are made that entries in claimable_htlcs point to /// still-available channels. - fn fail_htlc_backwards_internal(&self, mut channel_state_lock: MutexGuard>, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason) { + fn fail_htlc_backwards_internal(&self, mut channel_state_lock: MutexGuard>, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason) { //TODO: There is a timing attack here where if a node fails an HTLC back to us they can //identify whether we sent it or not based on the (I presume) very different runtime //between the branches here. We should make this async and move it into the forward HTLCs @@ -2024,7 +2063,7 @@ impl pub fn claim_funds(&self, payment_preimage: PaymentPreimage, payment_secret: &Option, expected_amount: u64) -> bool { let payment_hash = PaymentHash(Sha256::hash(&payment_preimage.0).into_inner()); - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut channel_state = Some(self.channel_state.lock().unwrap()); let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&(payment_hash, *payment_secret)); @@ -2102,7 +2141,7 @@ impl } else { false } } - fn claim_funds_from_hop(&self, channel_state_lock: &mut MutexGuard>, prev_hop: HTLCPreviousHopData, payment_preimage: PaymentPreimage) -> Result<(), Option<(PublicKey, MsgHandleErrInternal)>> { + fn claim_funds_from_hop(&self, channel_state_lock: &mut MutexGuard>, prev_hop: HTLCPreviousHopData, payment_preimage: PaymentPreimage) -> Result<(), Option<(PublicKey, MsgHandleErrInternal)>> { //TODO: Delay the claimed_funds relaying just like we do outbound relay! let channel_state = &mut **channel_state_lock; let chan_id = match channel_state.short_to_id.get(&prev_hop.short_channel_id) { @@ -2155,7 +2194,7 @@ impl } else { unreachable!(); } } - fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard>, source: HTLCSource, payment_preimage: PaymentPreimage) { + fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard>, source: HTLCSource, payment_preimage: PaymentPreimage) { match source { HTLCSource::OutboundRoute { .. } => { mem::drop(channel_state_lock); @@ -2220,7 +2259,7 @@ impl /// 4) once all remote copies are updated, you call this function with the update_id that /// completed, and once it is the latest the Channel will be re-enabled. pub fn channel_monitor_updated(&self, funding_txo: &OutPoint, highest_applied_update_id: u64) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut close_results = Vec::new(); let mut htlc_forwards = Vec::new(); @@ -2580,7 +2619,7 @@ impl return Err(MsgHandleErrInternal::send_err_msg_no_close("Got a message for a channel from the wrong node!".to_owned(), msg.channel_id)); } - let create_pending_htlc_status = |chan: &Channel, pending_forward_info: PendingHTLCStatus, error_code: u16| { + let create_pending_htlc_status = |chan: &Channel, pending_forward_info: PendingHTLCStatus, error_code: u16| { // Ensure error_code has the UPDATE flag set, since by default we send a // channel update along as part of failing the HTLC. assert!((error_code & 0x1000) != 0); @@ -2971,7 +3010,7 @@ impl /// (C-not exported) Cause its doc(hidden) anyway #[doc(hidden)] pub fn update_fee(&self, channel_id: [u8;32], feerate_per_kw: u32) -> Result<(), APIError> { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let counterparty_node_id; let err: Result<(), _> = loop { let mut channel_state_lock = self.channel_state.lock().unwrap(); @@ -3062,10 +3101,10 @@ impl } } -impl MessageSendEventsProvider for ChannelManager - where M::Target: chain::Watch, +impl MessageSendEventsProvider for ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3081,10 +3120,10 @@ impl } } -impl EventsProvider for ChannelManager - where M::Target: chain::Watch, +impl EventsProvider for ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3100,10 +3139,10 @@ impl } } -impl ChannelManager - where M::Target: chain::Watch, +impl ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3111,7 +3150,7 @@ impl pub fn block_connected(&self, header: &BlockHeader, txdata: &TransactionData, height: u32) { let header_hash = header.block_hash(); log_trace!(self.logger, "Block {} at height {} connected", header_hash, height); - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut failed_channels = Vec::new(); let mut timed_out_htlcs = Vec::new(); { @@ -3224,7 +3263,7 @@ impl /// If necessary, the channel may be force-closed without letting the counterparty participate /// in the shutdown. pub fn block_disconnected(&self, header: &BlockHeader) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut failed_channels = Vec::new(); { let mut channel_lock = self.channel_state.lock().unwrap(); @@ -3254,98 +3293,121 @@ impl self.latest_block_height.fetch_sub(1, Ordering::AcqRel); *self.last_block_hash.try_lock().expect("block_(dis)connected must not be called in parallel") = header.block_hash(); } + + /// Blocks until ChannelManager needs to be persisted or a timeout is reached. It returns a bool + /// indicating whether persistence is necessary. Only one listener on `wait_timeout` is + /// guaranteed to be woken up. + /// Note that the feature `allow_wallclock_use` must be enabled to use this function. + #[cfg(any(test, feature = "allow_wallclock_use"))] + pub fn wait_timeout(&self, max_wait: Duration) -> bool { + self.persistence_notifier.wait_timeout(max_wait) + } + + /// Blocks until ChannelManager needs to be persisted. Only one listener on `wait` is + /// guaranteed to be woken up. + pub fn wait(&self) { + self.persistence_notifier.wait() + } + + #[cfg(any(test, feature = "_test_utils"))] + pub fn get_persistence_condvar_value(&self) -> bool { + let mutcond = &self.persistence_notifier.persistence_lock; + let &(ref mtx, _) = mutcond; + let guard = mtx.lock().unwrap(); + *guard + } } -impl - ChannelMessageHandler for ChannelManager - where M::Target: chain::Watch, +impl + ChannelMessageHandler for ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { fn handle_open_channel(&self, counterparty_node_id: &PublicKey, their_features: InitFeatures, msg: &msgs::OpenChannel) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_open_channel(counterparty_node_id, their_features, msg), *counterparty_node_id); } fn handle_accept_channel(&self, counterparty_node_id: &PublicKey, their_features: InitFeatures, msg: &msgs::AcceptChannel) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_accept_channel(counterparty_node_id, their_features, msg), *counterparty_node_id); } fn handle_funding_created(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingCreated) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_funding_created(counterparty_node_id, msg), *counterparty_node_id); } fn handle_funding_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingSigned) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_funding_signed(counterparty_node_id, msg), *counterparty_node_id); } fn handle_funding_locked(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingLocked) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_funding_locked(counterparty_node_id, msg), *counterparty_node_id); } fn handle_shutdown(&self, counterparty_node_id: &PublicKey, msg: &msgs::Shutdown) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_shutdown(counterparty_node_id, msg), *counterparty_node_id); } fn handle_closing_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::ClosingSigned) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_closing_signed(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_add_htlc(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateAddHTLC) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_add_htlc(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_fulfill_htlc(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateFulfillHTLC) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_fulfill_htlc(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_fail_htlc(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_fail_htlc(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_fail_malformed_htlc(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_fail_malformed_htlc(counterparty_node_id, msg), *counterparty_node_id); } fn handle_commitment_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::CommitmentSigned) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_commitment_signed(counterparty_node_id, msg), *counterparty_node_id); } fn handle_revoke_and_ack(&self, counterparty_node_id: &PublicKey, msg: &msgs::RevokeAndACK) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_revoke_and_ack(counterparty_node_id, msg), *counterparty_node_id); } fn handle_update_fee(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateFee) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_update_fee(counterparty_node_id, msg), *counterparty_node_id); } fn handle_announcement_signatures(&self, counterparty_node_id: &PublicKey, msg: &msgs::AnnouncementSignatures) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_announcement_signatures(counterparty_node_id, msg), *counterparty_node_id); } fn handle_channel_reestablish(&self, counterparty_node_id: &PublicKey, msg: &msgs::ChannelReestablish) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let _ = handle_error!(self, self.internal_channel_reestablish(counterparty_node_id, msg), *counterparty_node_id); } fn peer_disconnected(&self, counterparty_node_id: &PublicKey, no_connection_possible: bool) { - let _consistency_lock = self.total_consistency_lock.read().unwrap(); + let _persistence_guard = PersistenceNotifierGuard::new(&self.total_consistency_lock, &self.persistence_notifier); let mut failed_channels = Vec::new(); let mut failed_payments = Vec::new(); let mut no_channels_remain = true; @@ -3438,7 +3500,7 @@ impl, Condvar), +} + +impl PersistenceNotifier { + fn new() -> Self { + Self { + persistence_lock: (Mutex::new(false), Condvar::new()), + } + } + + fn wait(&self) { + loop { + let &(ref mtx, ref cvar) = &self.persistence_lock; + let mut guard = mtx.lock().unwrap(); + guard = cvar.wait(guard).unwrap(); + let result = *guard; + if result { + *guard = false; + return + } + } + } + + #[cfg(any(test, feature = "allow_wallclock_use"))] + fn wait_timeout(&self, max_wait: Duration) -> bool { + let current_time = Instant::now(); + loop { + let &(ref mtx, ref cvar) = &self.persistence_lock; + let mut guard = mtx.lock().unwrap(); + guard = cvar.wait_timeout(guard, max_wait).unwrap().0; + // Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the + // desired wait time has actually passed, and if not then restart the loop with a reduced wait + // time. Note that this logic can be highly simplified through the use of + // `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to + // 1.42.0. + let elapsed = current_time.elapsed(); + let result = *guard; + if result || elapsed >= max_wait { + *guard = false; + return result; + } + match max_wait.checked_sub(elapsed) { + None => return result, + Some(_) => continue + } + } + } + + // Signal to the ChannelManager persister that there are updates necessitating persisting to disk. + fn notify(&self) { + let &(ref persist_mtx, ref cnd) = &self.persistence_lock; + let mut persistence_lock = persist_mtx.lock().unwrap(); + *persistence_lock = true; + mem::drop(persistence_lock); + cnd.notify_all(); + } +} + const SERIALIZATION_VERSION: u8 = 1; const MIN_SERIALIZATION_VERSION: u8 = 1; @@ -3707,10 +3832,10 @@ impl Readable for HTLCForwardInfo { } } -impl Writeable for ChannelManager - where M::Target: chain::Watch, +impl Writeable for ChannelManager + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3790,10 +3915,10 @@ impl /// 4) Reconnect blocks on your ChannelMonitors. /// 5) Move the ChannelMonitors into your local chain::Watch. /// 6) Disconnect/connect blocks on the ChannelManager. -pub struct ChannelManagerReadArgs<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> - where M::Target: chain::Watch, +pub struct ChannelManagerReadArgs<'a, Signer: 'a + Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3836,14 +3961,14 @@ pub struct ChannelManagerReadArgs<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: /// this struct. /// /// (C-not exported) because we have no HashMap bindings - pub channel_monitors: HashMap>, + pub channel_monitors: HashMap>, } -impl<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> - ChannelManagerReadArgs<'a, ChanSigner, M, T, K, F, L> - where M::Target: chain::Watch, +impl<'a, Signer: 'a + Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> + ChannelManagerReadArgs<'a, Signer, M, T, K, F, L> + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { @@ -3851,7 +3976,7 @@ impl<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L /// HashMap for you. This is primarily useful for C bindings where it is not practical to /// populate a HashMap directly from C. pub fn new(keys_manager: K, fee_estimator: F, chain_monitor: M, tx_broadcaster: T, logger: L, default_config: UserConfig, - mut channel_monitors: Vec<&'a mut ChannelMonitor>) -> Self { + mut channel_monitors: Vec<&'a mut ChannelMonitor>) -> Self { Self { keys_manager, fee_estimator, chain_monitor, tx_broadcaster, logger, default_config, channel_monitors: channel_monitors.drain(..).map(|monitor| { (monitor.get_funding_txo().0, monitor) }).collect() @@ -3861,29 +3986,29 @@ impl<'a, ChanSigner: 'a + ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L // Implement ReadableArgs for an Arc'd ChannelManager to make it a bit easier to work with the // SipmleArcChannelManager type: -impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> - ReadableArgs> for (BlockHash, Arc>) - where M::Target: chain::Watch, +impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> + ReadableArgs> for (BlockHash, Arc>) + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { - fn read(reader: &mut R, args: ChannelManagerReadArgs<'a, ChanSigner, M, T, K, F, L>) -> Result { - let (blockhash, chan_manager) = <(BlockHash, ChannelManager)>::read(reader, args)?; + fn read(reader: &mut R, args: ChannelManagerReadArgs<'a, Signer, M, T, K, F, L>) -> Result { + let (blockhash, chan_manager) = <(BlockHash, ChannelManager)>::read(reader, args)?; Ok((blockhash, Arc::new(chan_manager))) } } -impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> - ReadableArgs> for (BlockHash, ChannelManager) - where M::Target: chain::Watch, +impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> + ReadableArgs> for (BlockHash, ChannelManager) + where M::Target: chain::Watch, T::Target: BroadcasterInterface, - K::Target: KeysInterface, + K::Target: KeysInterface, F::Target: FeeEstimator, L::Target: Logger, { - fn read(reader: &mut R, mut args: ChannelManagerReadArgs<'a, ChanSigner, M, T, K, F, L>) -> Result { + fn read(reader: &mut R, mut args: ChannelManagerReadArgs<'a, Signer, M, T, K, F, L>) -> Result { let _ver: u8 = Readable::read(reader)?; let min_ver: u8 = Readable::read(reader)?; if min_ver > SERIALIZATION_VERSION { @@ -3901,7 +4026,7 @@ impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Der let mut by_id = HashMap::with_capacity(cmp::min(channel_count as usize, 128)); let mut short_to_id = HashMap::with_capacity(cmp::min(channel_count as usize, 128)); for _ in 0..channel_count { - let mut channel: Channel = Channel::read(reader, &args.keys_manager)?; + let mut channel: Channel = Channel::read(reader, &args.keys_manager)?; if channel.last_block_connected != Default::default() && channel.last_block_connected != last_block_hash { return Err(DecodeError::InvalidValue); } @@ -4011,6 +4136,8 @@ impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Der pending_events: Mutex::new(pending_events_read), total_consistency_lock: RwLock::new(()), + persistence_notifier: PersistenceNotifier::new(), + keys_manager: args.keys_manager, logger: args.logger, default_configuration: args.default_config, @@ -4026,3 +4153,54 @@ impl<'a, ChanSigner: ChannelKeys, M: Deref, T: Deref, K: Deref, F: Deref, L: Der Ok((last_block_hash.clone(), channel_manager)) } } + +#[cfg(test)] +mod tests { + use ln::channelmanager::PersistenceNotifier; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::thread; + use std::time::Duration; + + #[test] + fn test_wait_timeout() { + let persistence_notifier = Arc::new(PersistenceNotifier::new()); + let thread_notifier = Arc::clone(&persistence_notifier); + + let exit_thread = Arc::new(AtomicBool::new(false)); + let exit_thread_clone = exit_thread.clone(); + thread::spawn(move || { + loop { + let &(ref persist_mtx, ref cnd) = &thread_notifier.persistence_lock; + let mut persistence_lock = persist_mtx.lock().unwrap(); + *persistence_lock = true; + cnd.notify_all(); + + if exit_thread_clone.load(Ordering::SeqCst) { + break + } + } + }); + + // Check that we can block indefinitely until updates are available. + let _ = persistence_notifier.wait(); + + // Check that the PersistenceNotifier will return after the given duration if updates are + // available. + loop { + if persistence_notifier.wait_timeout(Duration::from_millis(100)) { + break + } + } + + exit_thread.store(true, Ordering::SeqCst); + + // Check that the PersistenceNotifier will return after the given duration even if no updates + // are available. + loop { + if !persistence_notifier.wait_timeout(Duration::from_millis(100)) { + break + } + } + } +}