Move PersistenceNotifier to a new util module
[rust-lightning] / lightning / src / ln / channelmanager.rs
index d011d6b42c8621314dfd44fe49709bac702b6620..577c3d82cfe2ef82167ed97367e0fa2ffdab1cc3 100644 (file)
@@ -54,6 +54,8 @@ use chain::keysinterface::{Sign, KeysInterface, KeysManager, InMemorySigner, Rec
 use util::config::{UserConfig, ChannelConfig};
 use util::events::{EventHandler, EventsProvider, MessageSendEvent, MessageSendEventsProvider, ClosureReason, HTLCDestination};
 use util::{byte_utils, events};
+use util::crypto::sign;
+use util::wakers::PersistenceNotifier;
 use util::scid_utils::fake_scid;
 use util::ser::{BigSize, FixedLengthReader, Readable, ReadableArgs, MaybeReadable, Writeable, Writer, VecWriter};
 use util::logger::{Level, Logger};
@@ -64,15 +66,11 @@ use prelude::*;
 use core::{cmp, mem};
 use core::cell::RefCell;
 use io::Read;
-use sync::{Arc, Condvar, Mutex, MutexGuard, RwLock, RwLockReadGuard};
+use sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
 use core::sync::atomic::{AtomicUsize, Ordering};
 use core::time::Duration;
 use core::ops::Deref;
 
-#[cfg(any(test, feature = "std"))]
-use std::time::Instant;
-use util::crypto::sign;
-
 // We hold various information about HTLC relay in the HTLC objects in Channel itself:
 //
 // Upon receipt of an HTLC from a peer, we'll give it a PendingHTLCStatus indicating if it should
@@ -1635,7 +1633,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                }
        }
 
-       /// Gets the current configuration applied to all new channels,  as
+       /// Gets the current configuration applied to all new channels.
        pub fn get_current_default_configuration(&self) -> &UserConfig {
                &self.default_configuration
        }
@@ -2144,17 +2142,17 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                })
        }
 
-       fn decode_update_add_htlc_onion(&self, msg: &msgs::UpdateAddHTLC) -> (PendingHTLCStatus, MutexGuard<ChannelHolder<Signer>>) {
+       fn decode_update_add_htlc_onion(&self, msg: &msgs::UpdateAddHTLC) -> PendingHTLCStatus {
                macro_rules! return_malformed_err {
                        ($msg: expr, $err_code: expr) => {
                                {
                                        log_info!(self.logger, "Failed to accept/forward incoming HTLC: {}", $msg);
-                                       return (PendingHTLCStatus::Fail(HTLCFailureMsg::Malformed(msgs::UpdateFailMalformedHTLC {
+                                       return PendingHTLCStatus::Fail(HTLCFailureMsg::Malformed(msgs::UpdateFailMalformedHTLC {
                                                channel_id: msg.channel_id,
                                                htlc_id: msg.htlc_id,
                                                sha256_of_onion: Sha256::hash(&msg.onion_routing_packet.hop_data).into_inner(),
                                                failure_code: $err_code,
-                                       })), self.channel_state.lock().unwrap());
+                                       }));
                                }
                        }
                }
@@ -2174,25 +2172,20 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        //node knows the HMAC matched, so they already know what is there...
                        return_malformed_err!("Unknown onion packet version", 0x8000 | 0x4000 | 4);
                }
-
-               let mut channel_state = None;
                macro_rules! return_err {
                        ($msg: expr, $err_code: expr, $data: expr) => {
                                {
                                        log_info!(self.logger, "Failed to accept/forward incoming HTLC: {}", $msg);
-                                       if channel_state.is_none() {
-                                               channel_state = Some(self.channel_state.lock().unwrap());
-                                       }
-                                       return (PendingHTLCStatus::Fail(HTLCFailureMsg::Relay(msgs::UpdateFailHTLC {
+                                       return PendingHTLCStatus::Fail(HTLCFailureMsg::Relay(msgs::UpdateFailHTLC {
                                                channel_id: msg.channel_id,
                                                htlc_id: msg.htlc_id,
                                                reason: onion_utils::build_first_hop_failure_packet(&shared_secret, $err_code, $data),
-                                       })), channel_state.unwrap());
+                                       }));
                                }
                        }
                }
 
-               let next_hop = match onion_utils::decode_next_hop(shared_secret, &msg.onion_routing_packet.hop_data[..], msg.onion_routing_packet.hmac, msg.payment_hash) {
+               let next_hop = match onion_utils::decode_next_payment_hop(shared_secret, &msg.onion_routing_packet.hop_data[..], msg.onion_routing_packet.hmac, msg.payment_hash) {
                        Ok(res) => res,
                        Err(onion_utils::OnionDecodeErr::Malformed { err_msg, err_code }) => {
                                return_malformed_err!(err_msg, err_code);
@@ -2246,14 +2239,14 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        }
                };
 
-               channel_state = Some(self.channel_state.lock().unwrap());
                if let &PendingHTLCStatus::Forward(PendingHTLCInfo { ref routing, ref amt_to_forward, ref outgoing_cltv_value, .. }) = &pending_forward_info {
                        // If short_channel_id is 0 here, we'll reject the HTLC as there cannot be a channel
                        // with a short_channel_id of 0. This is important as various things later assume
                        // short_channel_id is non-0 in any ::Forward.
                        if let &PendingHTLCRouting::Forward { ref short_channel_id, .. } = routing {
-                               let id_option = channel_state.as_ref().unwrap().short_to_chan_info.get(&short_channel_id).cloned();
                                if let Some((err, code, chan_update)) = loop {
+                                       let mut channel_state = self.channel_state.lock().unwrap();
+                                       let id_option = channel_state.short_to_chan_info.get(&short_channel_id).cloned();
                                        let forwarding_id_opt = match id_option {
                                                None => { // unknown_next_peer
                                                        // Note that this is likely a timing oracle for detecting whether an scid is a
@@ -2267,7 +2260,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                Some((_cp_id, chan_id)) => Some(chan_id.clone()),
                                        };
                                        let chan_update_opt = if let Some(forwarding_id) = forwarding_id_opt {
-                                               let chan = channel_state.as_mut().unwrap().by_id.get_mut(&forwarding_id).unwrap();
+                                               let chan = channel_state.by_id.get_mut(&forwarding_id).unwrap();
                                                if !chan.should_announce() && !self.default_configuration.accept_forwards_to_priv_channels {
                                                        // Note that the behavior here should be identical to the above block - we
                                                        // should NOT reveal the existence or non-existence of a private channel if
@@ -2353,7 +2346,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        }
                }
 
-               (pending_forward_info, channel_state.unwrap())
+               pending_forward_info
        }
 
        /// Gets the current channel_update for the given channel. This first checks if the channel is
@@ -3153,7 +3146,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                                                let phantom_secret_res = self.keys_manager.get_node_secret(Recipient::PhantomNode);
                                                                                                if phantom_secret_res.is_ok() && fake_scid::is_valid_phantom(&self.fake_scid_rand_bytes, short_chan_id) {
                                                                                                        let phantom_shared_secret = SharedSecret::new(&onion_packet.public_key.unwrap(), &phantom_secret_res.unwrap()).secret_bytes();
-                                                                                                       let next_hop = match onion_utils::decode_next_hop(phantom_shared_secret, &onion_packet.hop_data, onion_packet.hmac, payment_hash) {
+                                                                                                       let next_hop = match onion_utils::decode_next_payment_hop(phantom_shared_secret, &onion_packet.hop_data, onion_packet.hmac, payment_hash) {
                                                                                                                Ok(res) => res,
                                                                                                                Err(onion_utils::OnionDecodeErr::Malformed { err_msg, err_code }) => {
                                                                                                                        let sha256_of_onion = Sha256::hash(&onion_packet.hop_data).into_inner();
@@ -4850,7 +4843,8 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                //encrypted with the same key. It's not immediately obvious how to usefully exploit that,
                //but we should prevent it anyway.
 
-               let (pending_forward_info, mut channel_state_lock) = self.decode_update_add_htlc_onion(msg);
+               let pending_forward_info = self.decode_update_add_htlc_onion(msg);
+               let mut channel_state_lock = self.channel_state.lock().unwrap();
                let channel_state = &mut *channel_state_lock;
 
                match channel_state.by_id.entry(msg.channel_id) {
@@ -5996,10 +5990,7 @@ where
 
        #[cfg(any(test, feature = "_test_utils"))]
        pub fn get_persistence_condvar_value(&self) -> bool {
-               let mutcond = &self.persistence_notifier.persistence_lock;
-               let &(ref mtx, _) = mutcond;
-               let guard = mtx.lock().unwrap();
-               *guard
+               self.persistence_notifier.needs_persist()
        }
 
        /// Gets the latest best block which was connected either via the [`chain::Listen`] or
@@ -6241,77 +6232,6 @@ impl<Signer: Sign, M: Deref , T: Deref , K: Deref , F: Deref , L: Deref >
        }
 }
 
-/// Used to signal to the ChannelManager persister that the manager needs to be re-persisted to
-/// disk/backups, through `await_persistable_update_timeout` and `await_persistable_update`.
-struct PersistenceNotifier {
-       /// Users won't access the persistence_lock directly, but rather wait on its bool using
-       /// `wait_timeout` and `wait`.
-       persistence_lock: (Mutex<bool>, Condvar),
-}
-
-impl PersistenceNotifier {
-       fn new() -> Self {
-               Self {
-                       persistence_lock: (Mutex::new(false), Condvar::new()),
-               }
-       }
-
-       fn wait(&self) {
-               loop {
-                       let &(ref mtx, ref cvar) = &self.persistence_lock;
-                       let mut guard = mtx.lock().unwrap();
-                       if *guard {
-                               *guard = false;
-                               return;
-                       }
-                       guard = cvar.wait(guard).unwrap();
-                       let result = *guard;
-                       if result {
-                               *guard = false;
-                               return
-                       }
-               }
-       }
-
-       #[cfg(any(test, feature = "std"))]
-       fn wait_timeout(&self, max_wait: Duration) -> bool {
-               let current_time = Instant::now();
-               loop {
-                       let &(ref mtx, ref cvar) = &self.persistence_lock;
-                       let mut guard = mtx.lock().unwrap();
-                       if *guard {
-                               *guard = false;
-                               return true;
-                       }
-                       guard = cvar.wait_timeout(guard, max_wait).unwrap().0;
-                       // Due to spurious wakeups that can happen on `wait_timeout`, here we need to check if the
-                       // desired wait time has actually passed, and if not then restart the loop with a reduced wait
-                       // time. Note that this logic can be highly simplified through the use of
-                       // `Condvar::wait_while` and `Condvar::wait_timeout_while`, if and when our MSRV is raised to
-                       // 1.42.0.
-                       let elapsed = current_time.elapsed();
-                       let result = *guard;
-                       if result || elapsed >= max_wait {
-                               *guard = false;
-                               return result;
-                       }
-                       match max_wait.checked_sub(elapsed) {
-                               None => return result,
-                               Some(_) => continue
-                       }
-               }
-       }
-
-       // Signal to the ChannelManager persister that there are updates necessitating persisting to disk.
-       fn notify(&self) {
-               let &(ref persist_mtx, ref cnd) = &self.persistence_lock;
-               let mut persistence_lock = persist_mtx.lock().unwrap();
-               *persistence_lock = true;
-               mem::drop(persistence_lock);
-               cnd.notify_all();
-       }
-}
-
 const SERIALIZATION_VERSION: u8 = 1;
 const MIN_SERIALIZATION_VERSION: u8 = 1;
 
@@ -7359,54 +7279,6 @@ mod tests {
        use util::test_utils;
        use chain::keysinterface::KeysInterface;
 
-       #[cfg(feature = "std")]
-       #[test]
-       fn test_wait_timeout() {
-               use ln::channelmanager::PersistenceNotifier;
-               use sync::Arc;
-               use core::sync::atomic::AtomicBool;
-               use std::thread;
-
-               let persistence_notifier = Arc::new(PersistenceNotifier::new());
-               let thread_notifier = Arc::clone(&persistence_notifier);
-
-               let exit_thread = Arc::new(AtomicBool::new(false));
-               let exit_thread_clone = exit_thread.clone();
-               thread::spawn(move || {
-                       loop {
-                               let &(ref persist_mtx, ref cnd) = &thread_notifier.persistence_lock;
-                               let mut persistence_lock = persist_mtx.lock().unwrap();
-                               *persistence_lock = true;
-                               cnd.notify_all();
-
-                               if exit_thread_clone.load(Ordering::SeqCst) {
-                                       break
-                               }
-                       }
-               });
-
-               // Check that we can block indefinitely until updates are available.
-               let _ = persistence_notifier.wait();
-
-               // Check that the PersistenceNotifier will return after the given duration if updates are
-               // available.
-               loop {
-                       if persistence_notifier.wait_timeout(Duration::from_millis(100)) {
-                               break
-                       }
-               }
-
-               exit_thread.store(true, Ordering::SeqCst);
-
-               // Check that the PersistenceNotifier will return after the given duration even if no updates
-               // are available.
-               loop {
-                       if !persistence_notifier.wait_timeout(Duration::from_millis(100)) {
-                               break
-                       }
-               }
-       }
-
        #[test]
        fn test_notify_limits() {
                // Check that a few cases which don't require the persistence of a new ChannelManager,