Create and use methods for counting channels
authorDuncan Dean <git@dunxen.dev>
Wed, 7 Jun 2023 17:52:21 +0000 (19:52 +0200)
committerDuncan Dean <git@dunxen.dev>
Thu, 15 Jun 2023 10:55:40 +0000 (12:55 +0200)
This commit also adds two new maps to `PeerState` for keeping track
of `OutboundV1Channel`s and `InboundV1Channel`s so that further
commits are a bit easier to review.

lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs

index bebc1a8896ea4287b3a9de614d3ba03eddb06442..e07c26eaf0f137a48e25f956aed2c33188660a85 100644 (file)
@@ -613,6 +613,18 @@ pub(super) struct PeerState<Signer: ChannelSigner> {
        /// `channel_id`, the `temporary_channel_id` key in the map is updated and is replaced by the
        /// `channel_id`.
        pub(super) channel_by_id: HashMap<[u8; 32], Channel<Signer>>,
+       /// `temporary_channel_id` -> `OutboundV1Channel`.
+       ///
+       /// Holds all outbound V1 channels where the peer is the counterparty. Once an outbound channel has
+       /// been assigned a `channel_id`, the entry in this map is removed and one is created in
+       /// `channel_by_id`.
+       pub(super) outbound_v1_channel_by_id: HashMap<[u8; 32], OutboundV1Channel<Signer>>,
+       /// `temporary_channel_id` -> `InboundV1Channel`.
+       ///
+       /// Holds all inbound V1 channels where the peer is the counterparty. Once an inbound channel has
+       /// been assigned a `channel_id`, the entry in this map is removed and one is created in
+       /// `channel_by_id`.
+       pub(super) inbound_v1_channel_by_id: HashMap<[u8; 32], InboundV1Channel<Signer>>,
        /// The latest `InitFeatures` we heard from the peer.
        latest_features: InitFeatures,
        /// Messages to send to the peer - pushed to in the same lock that they are generated in (except
@@ -654,6 +666,20 @@ impl <Signer: ChannelSigner> PeerState<Signer> {
                }
                self.channel_by_id.is_empty() && self.monitor_update_blocked_actions.is_empty()
        }
+
+       // Returns a count of all channels we have with this peer, including pending channels.
+       fn total_channel_count(&self) -> usize {
+               self.channel_by_id.len() +
+                       self.outbound_v1_channel_by_id.len() +
+                       self.inbound_v1_channel_by_id.len()
+       }
+
+       // Returns a bool indicating if the given `channel_id` matches a channel we have with this peer.
+       fn has_channel(&self, channel_id: &[u8; 32]) -> bool {
+               self.channel_by_id.contains_key(channel_id) ||
+                       self.outbound_v1_channel_by_id.contains_key(channel_id) ||
+                       self.inbound_v1_channel_by_id.contains_key(channel_id)
+       }
 }
 
 /// Stores a PaymentSecret and any other data we may need to validate an inbound payment is
@@ -4765,13 +4791,14 @@ where
        fn do_accept_inbound_channel(&self, temporary_channel_id: &[u8; 32], counterparty_node_id: &PublicKey, accept_0conf: bool, user_channel_id: u128) -> Result<(), APIError> {
                let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(self);
 
-               let peers_without_funded_channels = self.peers_without_funded_channels(|peer| !peer.channel_by_id.is_empty());
+               let peers_without_funded_channels =
+                       self.peers_without_funded_channels(|peer| { peer.total_channel_count() > 0 });
                let per_peer_state = self.per_peer_state.read().unwrap();
                let peer_state_mutex = per_peer_state.get(counterparty_node_id)
                        .ok_or_else(|| APIError::ChannelUnavailable { err: format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id) })?;
                let mut peer_state_lock = peer_state_mutex.lock().unwrap();
                let peer_state = &mut *peer_state_lock;
-               let is_only_peer_channel = peer_state.channel_by_id.len() == 1;
+               let is_only_peer_channel = peer_state.total_channel_count() == 1;
                match peer_state.channel_by_id.entry(temporary_channel_id.clone()) {
                        hash_map::Entry::Occupied(mut channel) => {
                                if !channel.get().inbound_is_awaiting_accept() {
@@ -4833,7 +4860,7 @@ where
                                let peer = peer_mtx.lock().unwrap();
                                if !maybe_count_peer(&*peer) { continue; }
                                let num_unfunded_channels = Self::unfunded_channel_count(&peer, best_block_height);
-                               if num_unfunded_channels == peer.channel_by_id.len() {
+                               if num_unfunded_channels == peer.total_channel_count() {
                                        peers_without_funded_channels += 1;
                                }
                        }
@@ -4912,33 +4939,31 @@ where
                        },
                        Ok(res) => res
                };
-               match peer_state.channel_by_id.entry(channel.context.channel_id()) {
-                       hash_map::Entry::Occupied(_) => {
-                               self.outbound_scid_aliases.lock().unwrap().remove(&outbound_scid_alias);
-                               return Err(MsgHandleErrInternal::send_err_msg_no_close("temporary_channel_id collision for the same peer!".to_owned(), msg.temporary_channel_id.clone()))
-                       },
-                       hash_map::Entry::Vacant(entry) => {
-                               if !self.default_configuration.manually_accept_inbound_channels {
-                                       if channel.context.get_channel_type().requires_zero_conf() {
-                                               return Err(MsgHandleErrInternal::send_err_msg_no_close("No zero confirmation channels accepted".to_owned(), msg.temporary_channel_id.clone()));
-                                       }
-                                       peer_state.pending_msg_events.push(events::MessageSendEvent::SendAcceptChannel {
-                                               node_id: counterparty_node_id.clone(),
-                                               msg: channel.accept_inbound_channel(user_channel_id),
-                                       });
-                               } else {
-                                       let mut pending_events = self.pending_events.lock().unwrap();
-                                       pending_events.push_back((events::Event::OpenChannelRequest {
-                                               temporary_channel_id: msg.temporary_channel_id.clone(),
-                                               counterparty_node_id: counterparty_node_id.clone(),
-                                               funding_satoshis: msg.funding_satoshis,
-                                               push_msat: msg.push_msat,
-                                               channel_type: channel.context.get_channel_type().clone(),
-                                       }, None));
+               let channel_id = channel.context.channel_id();
+               let channel_exists = peer_state.has_channel(&channel_id);
+               if channel_exists {
+                       self.outbound_scid_aliases.lock().unwrap().remove(&outbound_scid_alias);
+                       return Err(MsgHandleErrInternal::send_err_msg_no_close("temporary_channel_id collision for the same peer!".to_owned(), msg.temporary_channel_id.clone()))
+               } else {
+                       if !self.default_configuration.manually_accept_inbound_channels {
+                               if channel.context.get_channel_type().requires_zero_conf() {
+                                       return Err(MsgHandleErrInternal::send_err_msg_no_close("No zero confirmation channels accepted".to_owned(), msg.temporary_channel_id.clone()));
                                }
-
-                               entry.insert(channel);
+                               peer_state.pending_msg_events.push(events::MessageSendEvent::SendAcceptChannel {
+                                       node_id: counterparty_node_id.clone(),
+                                       msg: channel.accept_inbound_channel(user_channel_id),
+                               });
+                       } else {
+                               let mut pending_events = self.pending_events.lock().unwrap();
+                               pending_events.push_back((events::Event::OpenChannelRequest {
+                                       temporary_channel_id: msg.temporary_channel_id.clone(),
+                                       counterparty_node_id: counterparty_node_id.clone(),
+                                       funding_satoshis: msg.funding_satoshis,
+                                       push_msat: msg.push_msat,
+                                       channel_type: channel.context.get_channel_type().clone(),
+                               }, None));
                        }
+                       peer_state.channel_by_id.insert(channel_id, channel);
                }
                Ok(())
        }
@@ -6878,6 +6903,8 @@ where
                                        }
                                        e.insert(Mutex::new(PeerState {
                                                channel_by_id: HashMap::new(),
+                                               outbound_v1_channel_by_id: HashMap::new(),
+                                               inbound_v1_channel_by_id: HashMap::new(),
                                                latest_features: init_msg.features.clone(),
                                                pending_msg_events: Vec::new(),
                                                monitor_update_blocked_actions: BTreeMap::new(),
@@ -8081,6 +8108,8 @@ where
                        let peer_pubkey = Readable::read(reader)?;
                        let peer_state = PeerState {
                                channel_by_id: peer_channels.remove(&peer_pubkey).unwrap_or(HashMap::new()),
+                               outbound_v1_channel_by_id: HashMap::new(),
+                               inbound_v1_channel_by_id: HashMap::new(),
                                latest_features: Readable::read(reader)?,
                                pending_msg_events: Vec::new(),
                                monitor_update_blocked_actions: BTreeMap::new(),
index a582836a70e844c379a34ffbb6f8a94e7c266263..7209c4a0b83cb986c67e660a6863ad64adeaf22a 100644 (file)
@@ -783,6 +783,28 @@ macro_rules! get_channel_ref {
        }
 }
 
+#[cfg(test)]
+macro_rules! get_inbound_v1_channel_ref {
+       ($node: expr, $counterparty_node: expr, $per_peer_state_lock: ident, $peer_state_lock: ident, $channel_id: expr) => {
+               {
+                       $per_peer_state_lock = $node.node.per_peer_state.read().unwrap();
+                       $peer_state_lock = $per_peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap();
+                       $peer_state_lock.inbound_v1_channel_by_id.get_mut(&$channel_id).unwrap()
+               }
+       }
+}
+
+#[cfg(test)]
+macro_rules! get_outbound_v1_channel_ref {
+       ($node: expr, $counterparty_node: expr, $per_peer_state_lock: ident, $peer_state_lock: ident, $channel_id: expr) => {
+               {
+                       $per_peer_state_lock = $node.node.per_peer_state.read().unwrap();
+                       $peer_state_lock = $per_peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap();
+                       $peer_state_lock.outbound_v1_channel_by_id.get_mut(&$channel_id).unwrap()
+               }
+       }
+}
+
 #[cfg(test)]
 macro_rules! get_feerate {
        ($node: expr, $counterparty_node: expr, $channel_id: expr) => {