Merge pull request #97 from TheBlueMatt/2018-07-no-useless-preimages
[rust-lightning] / src / ln / channelmanager.rs
index af2e5e600993df37a2beac3925aa55dac2302ed5..6c349fd0ebe2c031c0cbddfb8b3c41354b56778d 100644 (file)
@@ -137,13 +137,15 @@ impl ChannelHolder {
                        by_id: &mut self.by_id,
                        short_to_id: &mut self.short_to_id,
                        next_forward: &mut self.next_forward,
-                       /// short channel id -> forward infos. Key of 0 means payments received
                        forward_htlcs: &mut self.forward_htlcs,
                        claimable_htlcs: &mut self.claimable_htlcs,
                }
        }
 }
 
+#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
+const ERR: () = "You need at least 32 bit pointers (well, usize, but we'll assume they're the same) for ChannelManager::latest_block_height";
+
 /// Manager which keeps track of a number of channels and sends messages to the appropriate
 /// channel, also tracking HTLC preimages and forwarding onion packets appropriately.
 /// Implements ChannelMessageHandler, handling the multi-channel parts and passing things through
@@ -157,7 +159,7 @@ pub struct ChannelManager {
 
        announce_channels_publicly: bool,
        fee_proportional_millionths: u32,
-       latest_block_height: AtomicUsize, //TODO: Compile-time assert this is at least 32-bits long
+       latest_block_height: AtomicUsize,
        secp_ctx: Secp256k1,
 
        channel_state: Mutex<ChannelHolder>,
@@ -388,7 +390,7 @@ impl ChannelManager {
                let mut chan = {
                        let mut channel_state_lock = self.channel_state.lock().unwrap();
                        let channel_state = channel_state_lock.borrow_parts();
-                       if let Some(mut chan) = channel_state.by_id.remove(channel_id) {
+                       if let Some(chan) = channel_state.by_id.remove(channel_id) {
                                if let Some(short_id) = chan.get_short_channel_id() {
                                        channel_state.short_to_id.remove(&short_id);
                                }
@@ -1135,9 +1137,9 @@ impl ChainListener for ChannelManager {
                let mut new_events = Vec::new();
                let mut failed_channels = Vec::new();
                {
-                       let mut channel_state = self.channel_state.lock().unwrap();
-                       let mut short_to_ids_to_insert = Vec::new();
-                       let mut short_to_ids_to_remove = Vec::new();
+                       let mut channel_lock = self.channel_state.lock().unwrap();
+                       let channel_state = channel_lock.borrow_parts();
+                       let short_to_id = channel_state.short_to_id;
                        channel_state.by_id.retain(|_, channel| {
                                if let Some(funding_locked) = channel.block_connected(header, height, txn_matched, indexes_of_txn_matched) {
                                        let announcement_sigs = match self.get_announcement_sigs(channel) {
@@ -1152,14 +1154,14 @@ impl ChainListener for ChannelManager {
                                                msg: funding_locked,
                                                announcement_sigs: announcement_sigs
                                        });
-                                       short_to_ids_to_insert.push((channel.get_short_channel_id().unwrap(), channel.channel_id()));
+                                       short_to_id.insert(channel.get_short_channel_id().unwrap(), channel.channel_id());
                                }
                                if let Some(funding_txo) = channel.get_funding_txo() {
                                        for tx in txn_matched {
                                                for inp in tx.input.iter() {
                                                        if inp.prev_hash == funding_txo.txid && inp.prev_index == funding_txo.index as u32 {
                                                                if let Some(short_id) = channel.get_short_channel_id() {
-                                                                       short_to_ids_to_remove.push(short_id);
+                                                                       short_to_id.remove(&short_id);
                                                                }
                                                                // It looks like our counterparty went on-chain. We go ahead and
                                                                // broadcast our latest local state as well here, just in case its
@@ -1177,7 +1179,7 @@ impl ChainListener for ChannelManager {
                                }
                                if channel.channel_monitor().would_broadcast_at_height(height) {
                                        if let Some(short_id) = channel.get_short_channel_id() {
-                                               short_to_ids_to_remove.push(short_id);
+                                               short_to_id.remove(&short_id);
                                        }
                                        failed_channels.push(channel.force_shutdown());
                                        // If would_broadcast_at_height() is true, the channel_monitor will broadcast
@@ -1193,12 +1195,6 @@ impl ChainListener for ChannelManager {
                                }
                                true
                        });
-                       for to_remove in short_to_ids_to_remove {
-                               channel_state.short_to_id.remove(&to_remove);
-                       }
-                       for to_insert in short_to_ids_to_insert {
-                               channel_state.short_to_id.insert(to_insert.0, to_insert.1);
-                       }
                }
                for failure in failed_channels.drain(..) {
                        self.finish_force_close_channel(failure);
@@ -1628,20 +1624,14 @@ impl ChannelMessageHandler for ChannelManager {
                // destination. That's OK since those nodes are probably busted or trying to do network
                // mapping through repeated loops. In either case, we want them to stop talking to us, so
                // we send permanent_node_failure.
-               match &claimable_htlcs_entry {
-                       &hash_map::Entry::Occupied(ref e) => {
-                               let mut acceptable_cycle = false;
-                               match e.get() {
-                                       &PendingOutboundHTLC::OutboundRoute { .. } => {
-                                               acceptable_cycle = pending_forward_info.short_channel_id == 0;
-                                       },
-                                       _ => {},
-                               }
-                               if !acceptable_cycle {
-                                       return_err!("Payment looped through us twice", 0x4000 | 0x2000 | 2, &[0;0]);
-                               }
-                       },
-                       _ => {},
+               if let &hash_map::Entry::Occupied(ref e) = &claimable_htlcs_entry {
+                       let mut acceptable_cycle = false;
+                       if let &PendingOutboundHTLC::OutboundRoute { .. } = e.get() {
+                               acceptable_cycle = pending_forward_info.short_channel_id == 0;
+                       }
+                       if !acceptable_cycle {
+                               return_err!("Payment looped through us twice", 0x4000 | 0x2000 | 2, &[0;0]);
+                       }
                }
 
                let (source_short_channel_id, res) = match channel_state.by_id.get_mut(&msg.channel_id) {
@@ -1692,22 +1682,16 @@ impl ChannelMessageHandler for ChannelManager {
                // is broken, we may have enough info to get our own money!
                self.claim_funds_internal(msg.payment_preimage.clone(), false);
 
-               let monitor = {
-                       let mut channel_state = self.channel_state.lock().unwrap();
-                       match channel_state.by_id.get_mut(&msg.channel_id) {
-                               Some(chan) => {
-                                       if chan.get_their_node_id() != *their_node_id {
-                                               return Err(HandleError{err: "Got a message for a channel from the wrong node!", action: None})
-                                       }
-                                       chan.update_fulfill_htlc(&msg)?
-                               },
-                               None => return Err(HandleError{err: "Failed to find corresponding channel", action: None})
-                       }
-               };
-               if let Err(_e) = self.monitor.add_update_monitor(monitor.get_funding_txo().unwrap(), monitor) {
-                       unimplemented!();
+               let mut channel_state = self.channel_state.lock().unwrap();
+               match channel_state.by_id.get_mut(&msg.channel_id) {
+                       Some(chan) => {
+                               if chan.get_their_node_id() != *their_node_id {
+                                       return Err(HandleError{err: "Got a message for a channel from the wrong node!", action: None})
+                               }
+                               chan.update_fulfill_htlc(&msg)
+                       },
+                       None => return Err(HandleError{err: "Failed to find corresponding channel", action: None})
                }
-               Ok(())
        }
 
        fn handle_update_fail_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) -> Result<Option<msgs::HTLCFailChannelUpdate>, HandleError> {
@@ -2502,10 +2486,9 @@ mod tests {
                                        {
                                                let mut added_monitors = $node.chan_monitor.added_monitors.lock().unwrap();
                                                if $last_node {
-                                                       assert_eq!(added_monitors.len(), 1);
+                                                       assert_eq!(added_monitors.len(), 0);
                                                } else {
-                                                       assert_eq!(added_monitors.len(), 2);
-                                                       assert!(added_monitors[0].0 != added_monitors[1].0);
+                                                       assert_eq!(added_monitors.len(), 1);
                                                }
                                                added_monitors.clear();
                                        }