Hold channel_state lock into fail_htlc_backwards_internal
authorMatt Corallo <git@bluematt.me>
Fri, 23 Mar 2018 20:57:22 +0000 (16:57 -0400)
committerMatt Corallo <git@bluematt.me>
Fri, 23 Mar 2018 20:57:22 +0000 (16:57 -0400)
src/ln/channelmanager.rs

index 519a6c20ddfb5399cce5741b89fcb3c564d637af..4e243935f06edbe9e8dc49e3d576863da0a18b35 100644 (file)
@@ -26,11 +26,10 @@ use crypto::digest::Digest;
 use crypto::symmetriccipher::SynchronousStreamCipher;
 use crypto::chacha20::ChaCha20;
 
-use std::sync::{Mutex,Arc};
+use std::sync::{Mutex,MutexGuard,Arc};
 use std::collections::HashMap;
 use std::collections::hash_map;
-use std::ptr;
-use std::mem;
+use std::{ptr, mem};
 use std::time::{Instant,Duration};
 
 /// Stores the info we will need to send when we want to forward an HTLC onwards
@@ -651,11 +650,10 @@ impl ChannelManager {
 
        /// Indicates that the preimage for payment_hash is unknown after a PaymentReceived event.
        pub fn fail_htlc_backwards(&self, payment_hash: &[u8; 32]) -> bool {
-               self.fail_htlc_backwards_internal(payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 15 })
+               self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 15 })
        }
 
-       fn fail_htlc_backwards_internal(&self, payment_hash: &[u8; 32], onion_error: HTLCFailReason) -> bool {
-               let mut channel_state = self.channel_state.lock().unwrap();
+       fn fail_htlc_backwards_internal(&self, mut channel_state: MutexGuard<ChannelHolder>, payment_hash: &[u8; 32], onion_error: HTLCFailReason) -> bool {
                let mut pending_htlc = {
                        match channel_state.claimable_htlcs.remove(payment_hash) {
                                Some(pending_htlc) => pending_htlc,
@@ -674,6 +672,7 @@ impl ChannelManager {
                        PendingOutboundHTLC::CycledRoute { .. } => { panic!("WAT"); },
                        PendingOutboundHTLC::OutboundRoute { .. } => {
                                //TODO: DECRYPT route from OutboundRoute
+                               mem::drop(channel_state);
                                let mut pending_events = self.pending_events.lock().unwrap();
                                pending_events.push(events::Event::PaymentFailed {
                                        payment_hash: payment_hash.clone()
@@ -707,6 +706,7 @@ impl ChannelManager {
                                        }
                                };
 
+                               mem::drop(channel_state);
                                let mut pending_events = self.pending_events.lock().unwrap();
                                pending_events.push(events::Event::SendFailHTLC {
                                        node_id,
@@ -1217,36 +1217,34 @@ impl ChannelMessageHandler for ChannelManager {
        }
 
        fn handle_update_fail_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) -> Result<Option<(Vec<msgs::UpdateAddHTLC>, msgs::CommitmentSigned)>, HandleError> {
-               let res = {
-                       let mut channel_state = self.channel_state.lock().unwrap();
-                       match channel_state.by_id.get_mut(&msg.channel_id) {
-                               Some(chan) => {
-                                       if chan.get_their_node_id() != *their_node_id {
-                                               return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None})
-                                       }
-                                       chan.update_fail_htlc(&msg)?
-                               },
-                               None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None})
-                       }
-               };
-               self.fail_htlc_backwards_internal(&res.0, HTLCFailReason::ErrorPacket { err: &msg.reason });
+               let mut channel_state = self.channel_state.lock().unwrap();
+               let res;
+               match channel_state.by_id.get_mut(&msg.channel_id) {
+                       Some(chan) => {
+                               if chan.get_their_node_id() != *their_node_id {
+                                       return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None})
+                               }
+                               res = chan.update_fail_htlc(&msg)?;
+                       },
+                       None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None})
+               }
+               self.fail_htlc_backwards_internal(channel_state, &res.0, HTLCFailReason::ErrorPacket { err: &msg.reason });
                Ok(res.1)
        }
 
        fn handle_update_fail_malformed_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailMalformedHTLC) -> Result<Option<(Vec<msgs::UpdateAddHTLC>, msgs::CommitmentSigned)>, HandleError> {
-               let res = {
-                       let mut channel_state = self.channel_state.lock().unwrap();
-                       match channel_state.by_id.get_mut(&msg.channel_id) {
-                               Some(chan) => {
-                                       if chan.get_their_node_id() != *their_node_id {
-                                               return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None})
-                                       }
-                                       chan.update_fail_malformed_htlc(&msg)?
-                               },
-                               None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None})
-                       }
-               };
-               self.fail_htlc_backwards_internal(&res.0, HTLCFailReason::Reason { failure_code: msg.failure_code });
+               let mut channel_state = self.channel_state.lock().unwrap();
+               let res;
+               match channel_state.by_id.get_mut(&msg.channel_id) {
+                       Some(chan) => {
+                               if chan.get_their_node_id() != *their_node_id {
+                                       return Err(HandleError{err: "Got a message for a channel from the wrong node!", msg: None})
+                               }
+                               res = chan.update_fail_malformed_htlc(&msg)?;
+                       },
+                       None => return Err(HandleError{err: "Failed to find corresponding channel", msg: None})
+               }
+               self.fail_htlc_backwards_internal(channel_state, &res.0, HTLCFailReason::Reason { failure_code: msg.failure_code });
                Ok(res.1)
        }