Add `counterparty_node_id` to `short_to_id` map
authorViktor Tigerström <11711198+ViktorTigerstrom@users.noreply.github.com>
Tue, 17 May 2022 16:07:41 +0000 (18:07 +0200)
committerViktor Tigerström <11711198+ViktorTigerstrom@users.noreply.github.com>
Thu, 7 Jul 2022 11:34:18 +0000 (13:34 +0200)
lightning/src/ln/channelmanager.rs

index 36c14242d9edffab80f07c864b0a35f5b8693fa5..c87316acc7ccddcafd37219aadeedcba2e3aa216 100644 (file)
@@ -397,10 +397,12 @@ pub(super) enum RAACommitmentOrder {
 // Note this is only exposed in cfg(test):
 pub(super) struct ChannelHolder<Signer: Sign> {
        pub(super) by_id: HashMap<[u8; 32], Channel<Signer>>,
-       /// SCIDs (and outbound SCID aliases) to the real channel id. Outbound SCID aliases are added
-       /// here once the channel is available for normal use, with SCIDs being added once the funding
-       /// transaction is confirmed at the channel's required confirmation depth.
-       pub(super) short_to_id: HashMap<u64, [u8; 32]>,
+       /// SCIDs (and outbound SCID aliases) -> `counterparty_node_id`s and `channel_id`s.
+       ///
+       /// Outbound SCID aliases are added here once the channel is available for normal use, with
+       /// SCIDs being added once the funding transaction is confirmed at the channel's required
+       /// confirmation depth.
+       pub(super) short_to_id: HashMap<u64, (PublicKey, [u8; 32])>,
        /// SCID/SCID Alias -> forward infos. Key of 0 means payments received.
        ///
        /// Note that because we may have an SCID Alias as the key we can have two entries per channel,
@@ -1406,12 +1408,12 @@ macro_rules! send_channel_ready {
                });
                // Note that we may send a `channel_ready` multiple times for a channel if we reconnect, so
                // we allow collisions, but we shouldn't ever be updating the channel ID pointed to.
-               let outbound_alias_insert = $short_to_id.insert($channel.outbound_scid_alias(), $channel.channel_id());
-               assert!(outbound_alias_insert.is_none() || outbound_alias_insert.unwrap() == $channel.channel_id(),
+               let outbound_alias_insert = $short_to_id.insert($channel.outbound_scid_alias(), ($channel.get_counterparty_node_id(), $channel.channel_id()));
+               assert!(outbound_alias_insert.is_none() || outbound_alias_insert.unwrap() == ($channel.get_counterparty_node_id(), $channel.channel_id()),
                        "SCIDs should never collide - ensure you weren't behind the chain tip by a full month when creating channels");
                if let Some(real_scid) = $channel.get_short_channel_id() {
-                       let scid_insert = $short_to_id.insert(real_scid, $channel.channel_id());
-                       assert!(scid_insert.is_none() || scid_insert.unwrap() == $channel.channel_id(),
+                       let scid_insert = $short_to_id.insert(real_scid, ($channel.get_counterparty_node_id(), $channel.channel_id()));
+                       assert!(scid_insert.is_none() || scid_insert.unwrap() == ($channel.get_counterparty_node_id(), $channel.channel_id()),
                                "SCIDs should never collide - ensure you weren't behind the chain tip by a full month when creating channels");
                }
        }
@@ -2231,7 +2233,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                                                                break Some(("Don't have available channel for forwarding as requested.", 0x4000 | 10, None));
                                                        }
                                                },
-                                               Some(id) => Some(id.clone()),
+                                               Some((_cp_id, chan_id)) => Some(chan_id.clone()),
                                        };
                                        let chan_update_opt = if let Some(forwarding_id) = forwarding_id_opt {
                                                let chan = channel_state.as_mut().unwrap().by_id.get_mut(&forwarding_id).unwrap();
@@ -2414,7 +2416,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
 
                        let id = match channel_lock.short_to_id.get(&path.first().unwrap().short_channel_id) {
                                None => return Err(APIError::ChannelUnavailable{err: "No channel available with first hop!".to_owned()}),
-                               Some(id) => id.clone(),
+                               Some((_cp_id, chan_id)) => chan_id.clone(),
                        };
 
                        macro_rules! insert_outbound_payment {
@@ -3029,7 +3031,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                        for (short_chan_id, mut pending_forwards) in channel_state.forward_htlcs.drain() {
                                if short_chan_id != 0 {
                                        let forward_chan_id = match channel_state.short_to_id.get(&short_chan_id) {
-                                               Some(chan_id) => chan_id.clone(),
+                                               Some((_cp_id, chan_id)) => chan_id.clone(),
                                                None => {
                                                        for forward_info in pending_forwards.drain(..) {
                                                                match forward_info {
@@ -3445,7 +3447,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                self.process_background_events();
        }
 
-       fn update_channel_fee(&self, short_to_id: &mut HashMap<u64, [u8; 32]>, pending_msg_events: &mut Vec<events::MessageSendEvent>, chan_id: &[u8; 32], chan: &mut Channel<Signer>, new_feerate: u32) -> (bool, NotifyOption, Result<(), MsgHandleErrInternal>) {
+       fn update_channel_fee(&self, short_to_id: &mut HashMap<u64, (PublicKey, [u8; 32])>, pending_msg_events: &mut Vec<events::MessageSendEvent>, chan_id: &[u8; 32], chan: &mut Channel<Signer>, new_feerate: u32) -> (bool, NotifyOption, Result<(), MsgHandleErrInternal>) {
                if !chan.is_outbound() { return (true, NotifyOption::SkipPersist, Ok(())); }
                // If the feerate has decreased by less than half, don't bother
                if new_feerate <= chan.get_feerate() && new_feerate * 2 > chan.get_feerate() {
@@ -4065,7 +4067,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                //TODO: Delay the claimed_funds relaying just like we do outbound relay!
                let channel_state = &mut **channel_state_lock;
                let chan_id = match channel_state.short_to_id.get(&prev_hop.short_channel_id) {
-                       Some(chan_id) => chan_id.clone(),
+                       Some((_cp_id, chan_id)) => chan_id.clone(),
                        None => {
                                return ClaimFundsFromHop::PrevHopForceClosed
                        }
@@ -4987,7 +4989,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
                let mut channel_state_lock = self.channel_state.lock().unwrap();
                let channel_state = &mut *channel_state_lock;
                let chan_id = match channel_state.short_to_id.get(&msg.contents.short_channel_id) {
-                       Some(chan_id) => chan_id.clone(),
+                       Some((_cp_id, chan_id)) => chan_id.clone(),
                        None => {
                                // It's not a local channel
                                return Ok(NotifyOption::SkipPersist)
@@ -5769,8 +5771,8 @@ where
                                                        // using the real SCID at relay-time (i.e. enforce option_scid_alias
                                                        // then), and if the funding tx is ever un-confirmed we force-close the
                                                        // channel, ensuring short_to_id is always consistent.
-                                                       let scid_insert = short_to_id.insert(real_scid, channel.channel_id());
-                                                       assert!(scid_insert.is_none() || scid_insert.unwrap() == channel.channel_id(),
+                                                       let scid_insert = short_to_id.insert(real_scid, (channel.get_counterparty_node_id(), channel.channel_id()));
+                                                       assert!(scid_insert.is_none() || scid_insert.unwrap() == (channel.get_counterparty_node_id(), channel.channel_id()),
                                                                "SCIDs should never collide - ensure you weren't behind by a full {} blocks when creating channels",
                                                                fake_scid::MAX_SCID_BLOCKS_FROM_NOW);
                                                }
@@ -6814,7 +6816,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                                } else {
                                        log_info!(args.logger, "Successfully loaded channel {}", log_bytes!(channel.channel_id()));
                                        if let Some(short_channel_id) = channel.get_short_channel_id() {
-                                               short_to_id.insert(short_channel_id, channel.channel_id());
+                                               short_to_id.insert(short_channel_id, (channel.get_counterparty_node_id(), channel.channel_id()));
                                        }
                                        by_id.insert(channel.channel_id(), channel);
                                }
@@ -7073,7 +7075,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
                                return Err(DecodeError::InvalidValue);
                        }
                        if chan.is_usable() {
-                               if short_to_id.insert(chan.outbound_scid_alias(), *chan_id).is_some() {
+                               if short_to_id.insert(chan.outbound_scid_alias(), (chan.get_counterparty_node_id(), *chan_id)).is_some() {
                                        // Note that in rare cases its possible to hit this while reading an older
                                        // channel if we just happened to pick a colliding outbound alias above.
                                        log_error!(args.logger, "Got duplicate outbound SCID alias; {}", chan.outbound_scid_alias());