X-Git-Url: http://git.bitcoin.ninja/index.cgi?a=blobdiff_plain;f=src%2Fln%2Fchannelmanager.rs;fp=src%2Fln%2Fchannelmanager.rs;h=6c349fd0ebe2c031c0cbddfb8b3c41354b56778d;hb=cf6532e3e6a52154e4d09f64c17a46b34e441da0;hp=af2e5e600993df37a2beac3925aa55dac2302ed5;hpb=e571c7ac4011c458970e780fa1b7b92136f06f2d;p=rust-lightning diff --git a/src/ln/channelmanager.rs b/src/ln/channelmanager.rs index af2e5e60..6c349fd0 100644 --- a/src/ln/channelmanager.rs +++ b/src/ln/channelmanager.rs @@ -137,13 +137,15 @@ impl ChannelHolder { by_id: &mut self.by_id, short_to_id: &mut self.short_to_id, next_forward: &mut self.next_forward, - /// short channel id -> forward infos. Key of 0 means payments received forward_htlcs: &mut self.forward_htlcs, claimable_htlcs: &mut self.claimable_htlcs, } } } +#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))] +const ERR: () = "You need at least 32 bit pointers (well, usize, but we'll assume they're the same) for ChannelManager::latest_block_height"; + /// Manager which keeps track of a number of channels and sends messages to the appropriate /// channel, also tracking HTLC preimages and forwarding onion packets appropriately. /// Implements ChannelMessageHandler, handling the multi-channel parts and passing things through @@ -157,7 +159,7 @@ pub struct ChannelManager { announce_channels_publicly: bool, fee_proportional_millionths: u32, - latest_block_height: AtomicUsize, //TODO: Compile-time assert this is at least 32-bits long + latest_block_height: AtomicUsize, secp_ctx: Secp256k1, channel_state: Mutex, @@ -388,7 +390,7 @@ impl ChannelManager { let mut chan = { let mut channel_state_lock = self.channel_state.lock().unwrap(); let channel_state = channel_state_lock.borrow_parts(); - if let Some(mut chan) = channel_state.by_id.remove(channel_id) { + if let Some(chan) = channel_state.by_id.remove(channel_id) { if let Some(short_id) = chan.get_short_channel_id() { channel_state.short_to_id.remove(&short_id); } @@ -1135,9 +1137,9 @@ impl ChainListener for ChannelManager { let mut new_events = Vec::new(); let mut failed_channels = Vec::new(); { - let mut channel_state = self.channel_state.lock().unwrap(); - let mut short_to_ids_to_insert = Vec::new(); - let mut short_to_ids_to_remove = Vec::new(); + let mut channel_lock = self.channel_state.lock().unwrap(); + let channel_state = channel_lock.borrow_parts(); + let short_to_id = channel_state.short_to_id; channel_state.by_id.retain(|_, channel| { if let Some(funding_locked) = channel.block_connected(header, height, txn_matched, indexes_of_txn_matched) { let announcement_sigs = match self.get_announcement_sigs(channel) { @@ -1152,14 +1154,14 @@ impl ChainListener for ChannelManager { msg: funding_locked, announcement_sigs: announcement_sigs }); - short_to_ids_to_insert.push((channel.get_short_channel_id().unwrap(), channel.channel_id())); + short_to_id.insert(channel.get_short_channel_id().unwrap(), channel.channel_id()); } if let Some(funding_txo) = channel.get_funding_txo() { for tx in txn_matched { for inp in tx.input.iter() { if inp.prev_hash == funding_txo.txid && inp.prev_index == funding_txo.index as u32 { if let Some(short_id) = channel.get_short_channel_id() { - short_to_ids_to_remove.push(short_id); + short_to_id.remove(&short_id); } // It looks like our counterparty went on-chain. We go ahead and // broadcast our latest local state as well here, just in case its @@ -1177,7 +1179,7 @@ impl ChainListener for ChannelManager { } if channel.channel_monitor().would_broadcast_at_height(height) { if let Some(short_id) = channel.get_short_channel_id() { - short_to_ids_to_remove.push(short_id); + short_to_id.remove(&short_id); } failed_channels.push(channel.force_shutdown()); // If would_broadcast_at_height() is true, the channel_monitor will broadcast @@ -1193,12 +1195,6 @@ impl ChainListener for ChannelManager { } true }); - for to_remove in short_to_ids_to_remove { - channel_state.short_to_id.remove(&to_remove); - } - for to_insert in short_to_ids_to_insert { - channel_state.short_to_id.insert(to_insert.0, to_insert.1); - } } for failure in failed_channels.drain(..) { self.finish_force_close_channel(failure); @@ -1628,20 +1624,14 @@ impl ChannelMessageHandler for ChannelManager { // destination. That's OK since those nodes are probably busted or trying to do network // mapping through repeated loops. In either case, we want them to stop talking to us, so // we send permanent_node_failure. - match &claimable_htlcs_entry { - &hash_map::Entry::Occupied(ref e) => { - let mut acceptable_cycle = false; - match e.get() { - &PendingOutboundHTLC::OutboundRoute { .. } => { - acceptable_cycle = pending_forward_info.short_channel_id == 0; - }, - _ => {}, - } - if !acceptable_cycle { - return_err!("Payment looped through us twice", 0x4000 | 0x2000 | 2, &[0;0]); - } - }, - _ => {}, + if let &hash_map::Entry::Occupied(ref e) = &claimable_htlcs_entry { + let mut acceptable_cycle = false; + if let &PendingOutboundHTLC::OutboundRoute { .. } = e.get() { + acceptable_cycle = pending_forward_info.short_channel_id == 0; + } + if !acceptable_cycle { + return_err!("Payment looped through us twice", 0x4000 | 0x2000 | 2, &[0;0]); + } } let (source_short_channel_id, res) = match channel_state.by_id.get_mut(&msg.channel_id) { @@ -1692,22 +1682,16 @@ impl ChannelMessageHandler for ChannelManager { // is broken, we may have enough info to get our own money! self.claim_funds_internal(msg.payment_preimage.clone(), false); - let monitor = { - let mut channel_state = self.channel_state.lock().unwrap(); - match channel_state.by_id.get_mut(&msg.channel_id) { - Some(chan) => { - if chan.get_their_node_id() != *their_node_id { - return Err(HandleError{err: "Got a message for a channel from the wrong node!", action: None}) - } - chan.update_fulfill_htlc(&msg)? - }, - None => return Err(HandleError{err: "Failed to find corresponding channel", action: None}) - } - }; - if let Err(_e) = self.monitor.add_update_monitor(monitor.get_funding_txo().unwrap(), monitor) { - unimplemented!(); + let mut channel_state = self.channel_state.lock().unwrap(); + match channel_state.by_id.get_mut(&msg.channel_id) { + Some(chan) => { + if chan.get_their_node_id() != *their_node_id { + return Err(HandleError{err: "Got a message for a channel from the wrong node!", action: None}) + } + chan.update_fulfill_htlc(&msg) + }, + None => return Err(HandleError{err: "Failed to find corresponding channel", action: None}) } - Ok(()) } fn handle_update_fail_htlc(&self, their_node_id: &PublicKey, msg: &msgs::UpdateFailHTLC) -> Result, HandleError> { @@ -2502,10 +2486,9 @@ mod tests { { let mut added_monitors = $node.chan_monitor.added_monitors.lock().unwrap(); if $last_node { - assert_eq!(added_monitors.len(), 1); + assert_eq!(added_monitors.len(), 0); } else { - assert_eq!(added_monitors.len(), 2); - assert!(added_monitors[0].0 != added_monitors[1].0); + assert_eq!(added_monitors.len(), 1); } added_monitors.clear(); }