]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Add a BIG lock to ChannelManager
authorMatt Corallo <git@bluematt.me>
Sat, 20 Oct 2018 22:46:03 +0000 (18:46 -0400)
committerMatt Corallo <git@bluematt.me>
Sat, 27 Oct 2018 13:42:04 +0000 (09:42 -0400)
During normal operation we should never need to take this, so we
use a RwLock that allows normal parallelism until we want to
serialize out our ChannelManager, at which point we can take the
write-mode lock.

src/ln/channelmanager.rs

index 08ed2515f0d2aee35774ac7474d142be6adec237..f0fd7ad2d30f5d0673fb34fd9d6a82ddd54aa7ee 100644 (file)
@@ -45,7 +45,7 @@ use std::{ptr, mem};
 use std::collections::HashMap;
 use std::collections::hash_map;
 use std::io::Cursor;
-use std::sync::{Mutex,MutexGuard,Arc};
+use std::sync::{Arc, Mutex, MutexGuard, RwLock};
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::time::{Instant,Duration};
 
@@ -317,6 +317,10 @@ pub struct ChannelManager {
        our_network_key: SecretKey,
 
        pending_events: Mutex<Vec<events::Event>>,
+       /// Used when we have to take a BIG lock to make sure everything is self-consistent.
+       /// Essentially just when we're serializing ourselves out.
+       /// Taken first everywhere where we are making changes before any other locks.
+       total_consistency_lock: RwLock<()>,
 
        keys_manager: Arc<KeysInterface>,
 
@@ -418,6 +422,7 @@ impl ChannelManager {
                        our_network_key: keys_manager.get_node_secret(),
 
                        pending_events: Mutex::new(Vec::new()),
+                       total_consistency_lock: RwLock::new(()),
 
                        keys_manager,
 
@@ -442,6 +447,8 @@ impl ChannelManager {
        pub fn create_channel(&self, their_network_key: PublicKey, channel_value_satoshis: u64, push_msat: u64, user_id: u64) -> Result<(), APIError> {
                let channel = Channel::new_outbound(&*self.fee_estimator, &self.keys_manager, their_network_key, channel_value_satoshis, push_msat, self.announce_channels_publicly, user_id, Arc::clone(&self.logger))?;
                let res = channel.get_open_channel(self.genesis_hash.clone(), &*self.fee_estimator);
+
+               let _ = self.total_consistency_lock.read().unwrap();
                let mut channel_state = self.channel_state.lock().unwrap();
                match channel_state.by_id.entry(channel.channel_id()) {
                        hash_map::Entry::Occupied(_) => {
@@ -505,6 +512,8 @@ impl ChannelManager {
        ///
        /// May generate a SendShutdown message event on success, which should be relayed.
        pub fn close_channel(&self, channel_id: &[u8; 32]) -> Result<(), APIError> {
+               let _ = self.total_consistency_lock.read().unwrap();
+
                let (mut failed_htlcs, chan_option) = {
                        let mut channel_state_lock = self.channel_state.lock().unwrap();
                        let channel_state = channel_state_lock.borrow_parts();
@@ -567,6 +576,8 @@ impl ChannelManager {
        /// Force closes a channel, immediately broadcasting the latest local commitment transaction to
        /// the chain and rejecting new HTLCs on the given channel.
        pub fn force_close_channel(&self, channel_id: &[u8; 32]) {
+               let _ = self.total_consistency_lock.read().unwrap();
+
                let mut chan = {
                        let mut channel_state_lock = self.channel_state.lock().unwrap();
                        let channel_state = channel_state_lock.borrow_parts();
@@ -1147,6 +1158,7 @@ impl ChannelManager {
                let (onion_payloads, htlc_msat, htlc_cltv) = ChannelManager::build_onion_payloads(&route, cur_height)?;
                let onion_packet = ChannelManager::construct_onion_packet(onion_payloads, onion_keys, &payment_hash);
 
+               let _ = self.total_consistency_lock.read().unwrap();
                let mut channel_state = self.channel_state.lock().unwrap();
 
                let id = match channel_state.short_to_id.get(&route.hops.first().unwrap().short_channel_id) {
@@ -1203,6 +1215,8 @@ impl ChannelManager {
        /// May panic if the funding_txo is duplicative with some other channel (note that this should
        /// be trivially prevented by using unique funding transaction keys per-channel).
        pub fn funding_transaction_generated(&self, temporary_channel_id: &[u8; 32], funding_txo: OutPoint) {
+               let _ = self.total_consistency_lock.read().unwrap();
+
                let (chan, msg, chan_monitor) = {
                        let mut channel_state = self.channel_state.lock().unwrap();
                        match channel_state.by_id.remove(temporary_channel_id) {
@@ -1268,6 +1282,8 @@ impl ChannelManager {
        /// Should only really ever be called in response to an PendingHTLCsForwardable event.
        /// Will likely generate further events.
        pub fn process_pending_htlc_forwards(&self) {
+               let _ = self.total_consistency_lock.read().unwrap();
+
                let mut new_events = Vec::new();
                let mut failed_forwards = Vec::new();
                {
@@ -1389,6 +1405,8 @@ impl ChannelManager {
 
        /// Indicates that the preimage for payment_hash is unknown or the received amount is incorrect after a PaymentReceived event.
        pub fn fail_htlc_backwards(&self, payment_hash: &[u8; 32], reason: PaymentFailReason) -> bool {
+               let _ = self.total_consistency_lock.read().unwrap();
+
                let mut channel_state = Some(self.channel_state.lock().unwrap());
                let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(payment_hash);
                if let Some(mut sources) = removed_source {
@@ -1484,6 +1502,8 @@ impl ChannelManager {
                let mut payment_hash = [0; 32];
                sha.result(&mut payment_hash);
 
+               let _ = self.total_consistency_lock.read().unwrap();
+
                let mut channel_state = Some(self.channel_state.lock().unwrap());
                let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&payment_hash);
                if let Some(mut sources) = removed_source {
@@ -1562,6 +1582,7 @@ impl ChannelManager {
                let mut close_results = Vec::new();
                let mut htlc_forwards = Vec::new();
                let mut htlc_failures = Vec::new();
+               let _ = self.total_consistency_lock.read().unwrap();
 
                {
                        let mut channel_lock = self.channel_state.lock().unwrap();
@@ -2366,6 +2387,7 @@ impl ChannelManager {
        /// Note: This API is likely to change!
        #[doc(hidden)]
        pub fn update_fee(&self, channel_id: [u8;32], feerate_per_kw: u64) -> Result<(), APIError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                let mut channel_state_lock = self.channel_state.lock().unwrap();
                let channel_state = channel_state_lock.borrow_parts();
 
@@ -2423,6 +2445,7 @@ impl events::EventsProvider for ChannelManager {
 
 impl ChainListener for ChannelManager {
        fn block_connected(&self, header: &BlockHeader, height: u32, txn_matched: &[&Transaction], indexes_of_txn_matched: &[u32]) {
+               let _ = self.total_consistency_lock.read().unwrap();
                let mut failed_channels = Vec::new();
                {
                        let mut channel_lock = self.channel_state.lock().unwrap();
@@ -2500,6 +2523,7 @@ impl ChainListener for ChannelManager {
 
        /// We force-close the channel without letting our counterparty participate in the shutdown
        fn block_disconnected(&self, header: &BlockHeader) {
+               let _ = self.total_consistency_lock.read().unwrap();
                let mut failed_channels = Vec::new();
                {
                        let mut channel_lock = self.channel_state.lock().unwrap();
@@ -2565,70 +2589,87 @@ macro_rules! handle_error {
 impl ChannelMessageHandler for ChannelManager {
        //TODO: Handle errors and close channel (or so)
        fn handle_open_channel(&self, their_node_id: &PublicKey, msg: &msgs::OpenChannel) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_open_channel(their_node_id, msg), their_node_id)
        }
 
        fn handle_accept_channel(&self, their_node_id: &PublicKey, msg: &msgs::AcceptChannel) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_accept_channel(their_node_id, msg), their_node_id)
        }
 
        fn handle_funding_created(&self, their_node_id: &PublicKey, msg: &msgs::FundingCreated) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_funding_created(their_node_id, msg), their_node_id)
        }
 
        fn handle_funding_signed(&self, their_node_id: &PublicKey, msg: &msgs::FundingSigned) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_funding_signed(their_node_id, msg), their_node_id)
        }
 
        fn handle_funding_locked(&self, their_node_id: &PublicKey, msg: &msgs::FundingLocked) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_funding_locked(their_node_id, msg), their_node_id)
        }
 
        fn handle_shutdown(&self, their_node_id: &PublicKey, msg: &msgs::Shutdown) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_shutdown(their_node_id, msg), their_node_id)
        }
 
        fn handle_closing_signed(&self, their_node_id: &PublicKey, msg: &msgs::ClosingSigned) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_closing_signed(their_node_id, msg), their_node_id)
        }
 
        fn handle_update_add_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateAddHTLC) -> Result<(), msgs::HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_update_add_htlc(their_node_id, msg), their_node_id)
        }
 
        fn handle_update_fulfill_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFulfillHTLC) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_update_fulfill_htlc(their_node_id, msg), their_node_id)
        }
 
        fn handle_update_fail_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_update_fail_htlc(their_node_id, msg), their_node_id)
        }
 
        fn handle_update_fail_malformed_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_update_fail_malformed_htlc(their_node_id, msg), their_node_id)
        }
 
        fn handle_commitment_signed(&self, their_node_id: &PublicKey, msg: &msgs::CommitmentSigned) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_commitment_signed(their_node_id, msg), their_node_id)
        }
 
        fn handle_revoke_and_ack(&self, their_node_id: &PublicKey, msg: &msgs::RevokeAndACK) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_revoke_and_ack(their_node_id, msg), their_node_id)
        }
 
        fn handle_update_fee(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFee) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_update_fee(their_node_id, msg), their_node_id)
        }
 
        fn handle_announcement_signatures(&self, their_node_id: &PublicKey, msg: &msgs::AnnouncementSignatures) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_announcement_signatures(their_node_id, msg), their_node_id)
        }
 
        fn handle_channel_reestablish(&self, their_node_id: &PublicKey, msg: &msgs::ChannelReestablish) -> Result<(), HandleError> {
+               let _ = self.total_consistency_lock.read().unwrap();
                handle_error!(self, self.internal_channel_reestablish(their_node_id, msg), their_node_id)
        }
 
        fn peer_disconnected(&self, their_node_id: &PublicKey, no_connection_possible: bool) {
+               let _ = self.total_consistency_lock.read().unwrap();
                let mut failed_channels = Vec::new();
                let mut failed_payments = Vec::new();
                {
@@ -2684,6 +2725,7 @@ impl ChannelMessageHandler for ChannelManager {
        }
 
        fn peer_connected(&self, their_node_id: &PublicKey) {
+               let _ = self.total_consistency_lock.read().unwrap();
                let mut channel_state_lock = self.channel_state.lock().unwrap();
                let channel_state = channel_state_lock.borrow_parts();
                let pending_msg_events = channel_state.pending_msg_events;
@@ -2708,6 +2750,8 @@ impl ChannelMessageHandler for ChannelManager {
        }
 
        fn handle_error(&self, their_node_id: &PublicKey, msg: &msgs::ErrorMessage) {
+               let _ = self.total_consistency_lock.read().unwrap();
+
                if msg.channel_id == [0; 32] {
                        for chan in self.list_channels() {
                                if chan.remote_network_id == *their_node_id {