Use clear helper on flags copy to mask off bits
[rust-lightning] / lightning / src / ln / channel.rs
index ab3bdb98c0f5c21a1afefad0d7f2932aa1437831..fd1bf4f30a4da75094508d373a657052aec52b41 100644 (file)
@@ -299,6 +299,10 @@ macro_rules! define_state_flags {
 
                        #[allow(unused)]
                        fn is_set(&self, flag: Self) -> bool { *self & flag == flag }
+                       #[allow(unused)]
+                       fn set(&mut self, flag: Self) { *self |= flag }
+                       #[allow(unused)]
+                       fn clear(&mut self, flag: Self) -> Self { self.0 &= !flag.0; *self }
                }
 
                impl core::ops::Not for $flag_type {
@@ -323,6 +327,16 @@ macro_rules! define_state_flags {
        ($flag_type_doc: expr, $flag_type: ident, $flags: tt) => {
                define_state_flags!($flag_type_doc, $flag_type, $flags, 0);
        };
+       ($flag_type: ident, $flag: expr, $get: ident, $set: ident, $clear: ident) => {
+               impl $flag_type {
+                       #[allow(unused)]
+                       fn $get(&self) -> bool { self.is_set($flag_type::new() | $flag) }
+                       #[allow(unused)]
+                       fn $set(&mut self) { self.set($flag_type::new() | $flag) }
+                       #[allow(unused)]
+                       fn $clear(&mut self) -> Self { self.clear($flag_type::new() | $flag) }
+               }
+       };
        ($flag_type_doc: expr, FUNDED_STATE, $flag_type: ident, $flags: tt) => {
                define_state_flags!($flag_type_doc, $flag_type, $flags, FundedStateFlags::ALL.0);
                impl core::ops::BitOr<FundedStateFlags> for $flag_type {
@@ -525,14 +539,14 @@ impl ChannelState {
                }
        }
 
-       fn should_force_holding_cell(&self) -> bool {
+       fn can_generate_new_commitment(&self) -> bool {
                match self {
                        ChannelState::ChannelReady(flags) =>
-                               flags.is_set(ChannelReadyFlags::AWAITING_REMOTE_REVOKE) ||
-                                       flags.is_set(FundedStateFlags::MONITOR_UPDATE_IN_PROGRESS.into()) ||
-                                       flags.is_set(FundedStateFlags::PEER_DISCONNECTED.into()),
+                               !flags.is_set(ChannelReadyFlags::AWAITING_REMOTE_REVOKE) &&
+                                       !flags.is_set(FundedStateFlags::MONITOR_UPDATE_IN_PROGRESS.into()) &&
+                                       !flags.is_set(FundedStateFlags::PEER_DISCONNECTED.into()),
                        _ => {
-                               debug_assert!(false, "The holding cell is only valid within ChannelReady");
+                               debug_assert!(false, "Can only generate new commitment within ChannelReady");
                                false
                        },
                }
@@ -2389,11 +2403,7 @@ impl<SP: Deref> ChannelContext<SP> where SP::Target: SignerProvider  {
                        // funding transaction, don't return a funding txo (which prevents providing the
                        // monitor update to the user, even if we return one).
                        // See test_duplicate_chan_id and test_pre_lockin_no_chan_closed_update for more.
-                       let generate_monitor_update = match self.channel_state {
-                               ChannelState::AwaitingChannelReady(_)|ChannelState::ChannelReady(_)|ChannelState::ShutdownComplete => true,
-                               _ => false,
-                       };
-                       if generate_monitor_update {
+                       if !self.channel_state.is_pre_funded_state() {
                                self.latest_monitor_update_id = CLOSED_CHANNEL_UPDATE_ID;
                                Some((self.get_counterparty_node_id(), funding_txo, ChannelMonitorUpdate {
                                        update_id: self.latest_monitor_update_id,
@@ -2709,7 +2719,7 @@ impl<SP: Deref> Channel<SP> where
        where L::Target: Logger {
                // Assert that we'll add the HTLC claim to the holding cell in `get_update_fulfill_htlc`
                // (see equivalent if condition there).
-               assert!(self.context.channel_state.should_force_holding_cell());
+               assert!(!self.context.channel_state.can_generate_new_commitment());
                let mon_update_id = self.context.latest_monitor_update_id; // Forget the ChannelMonitor update
                let fulfill_resp = self.get_update_fulfill_htlc(htlc_id_arg, payment_preimage_arg, logger);
                self.context.latest_monitor_update_id = mon_update_id;
@@ -2779,7 +2789,7 @@ impl<SP: Deref> Channel<SP> where
                        }],
                };
 
-               if self.context.channel_state.should_force_holding_cell() {
+               if !self.context.channel_state.can_generate_new_commitment() {
                        // Note that this condition is the same as the assertion in
                        // `claim_htlc_while_disconnected_dropping_mon_update` and must match exactly -
                        // `claim_htlc_while_disconnected_dropping_mon_update` would not work correctly if we
@@ -2953,7 +2963,7 @@ impl<SP: Deref> Channel<SP> where
                        return Ok(None);
                }
 
-               if self.context.channel_state.should_force_holding_cell() {
+               if !self.context.channel_state.can_generate_new_commitment() {
                        debug_assert!(force_holding_cell, "!force_holding_cell is only called when emptying the holding cell, so we shouldn't end up back in it!");
                        force_holding_cell = true;
                }
@@ -3049,12 +3059,12 @@ impl<SP: Deref> Channel<SP> where
                let mut check_reconnection = false;
                match &self.context.channel_state {
                        ChannelState::AwaitingChannelReady(flags) => {
-                               let flags = *flags & !FundedStateFlags::ALL;
+                               let flags = flags.clone().clear(FundedStateFlags::ALL.into());
                                debug_assert!(!flags.is_set(AwaitingChannelReadyFlags::OUR_CHANNEL_READY) || !flags.is_set(AwaitingChannelReadyFlags::WAITING_FOR_BATCH));
-                               if flags & !AwaitingChannelReadyFlags::WAITING_FOR_BATCH == AwaitingChannelReadyFlags::THEIR_CHANNEL_READY {
+                               if flags.clone().clear(AwaitingChannelReadyFlags::WAITING_FOR_BATCH) == AwaitingChannelReadyFlags::THEIR_CHANNEL_READY {
                                        // If we reconnected before sending our `channel_ready` they may still resend theirs.
                                        check_reconnection = true;
-                               } else if (flags & !AwaitingChannelReadyFlags::WAITING_FOR_BATCH).is_empty() {
+                               } else if flags.clone().clear(AwaitingChannelReadyFlags::WAITING_FOR_BATCH).is_empty() {
                                        self.context.channel_state.set_their_channel_ready();
                                } else if flags == AwaitingChannelReadyFlags::OUR_CHANNEL_READY {
                                        self.context.channel_state = ChannelState::ChannelReady(self.context.channel_state.with_funded_state_flags_mask().into());
@@ -3570,7 +3580,7 @@ impl<SP: Deref> Channel<SP> where
        ) -> (Option<ChannelMonitorUpdate>, Vec<(HTLCSource, PaymentHash)>)
        where F::Target: FeeEstimator, L::Target: Logger
        {
-               if matches!(self.context.channel_state, ChannelState::ChannelReady(_)) && !self.context.channel_state.should_force_holding_cell() {
+               if matches!(self.context.channel_state, ChannelState::ChannelReady(_)) && self.context.channel_state.can_generate_new_commitment() {
                        self.free_holding_cell_htlcs(fee_estimator, logger)
                } else { (None, Vec::new()) }
        }
@@ -4180,8 +4190,8 @@ impl<SP: Deref> Channel<SP> where
                // first received the funding_signed.
                let mut funding_broadcastable =
                        if self.context.is_outbound() &&
-                               matches!(self.context.channel_state, ChannelState::AwaitingChannelReady(flags) if !flags.is_set(AwaitingChannelReadyFlags::WAITING_FOR_BATCH)) ||
-                               matches!(self.context.channel_state, ChannelState::ChannelReady(_))
+                               (matches!(self.context.channel_state, ChannelState::AwaitingChannelReady(flags) if !flags.is_set(AwaitingChannelReadyFlags::WAITING_FOR_BATCH)) ||
+                               matches!(self.context.channel_state, ChannelState::ChannelReady(_)))
                        {
                                self.context.funding_transaction.take()
                        } else { None };
@@ -5189,7 +5199,7 @@ impl<SP: Deref> Channel<SP> where
                if !self.is_awaiting_monitor_update() { return false; }
                if matches!(
                        self.context.channel_state, ChannelState::AwaitingChannelReady(flags)
-                       if (flags & !(AwaitingChannelReadyFlags::THEIR_CHANNEL_READY | FundedStateFlags::PEER_DISCONNECTED | FundedStateFlags::MONITOR_UPDATE_IN_PROGRESS | AwaitingChannelReadyFlags::WAITING_FOR_BATCH)).is_empty()
+                       if flags.clone().clear(AwaitingChannelReadyFlags::THEIR_CHANNEL_READY | FundedStateFlags::PEER_DISCONNECTED | FundedStateFlags::MONITOR_UPDATE_IN_PROGRESS | AwaitingChannelReadyFlags::WAITING_FOR_BATCH).is_empty()
                ) {
                        // If we're not a 0conf channel, we'll be waiting on a monitor update with only
                        // AwaitingChannelReady set, though our peer could have sent their channel_ready.
@@ -5275,14 +5285,14 @@ impl<SP: Deref> Channel<SP> where
 
                // Note that we don't include ChannelState::WaitingForBatch as we don't want to send
                // channel_ready until the entire batch is ready.
-               let need_commitment_update = if matches!(self.context.channel_state, ChannelState::AwaitingChannelReady(f) if (f & !FundedStateFlags::ALL).is_empty()) {
+               let need_commitment_update = if matches!(self.context.channel_state, ChannelState::AwaitingChannelReady(f) if f.clone().clear(FundedStateFlags::ALL.into()).is_empty()) {
                        self.context.channel_state.set_our_channel_ready();
                        true
-               } else if matches!(self.context.channel_state, ChannelState::AwaitingChannelReady(f) if f & !FundedStateFlags::ALL == AwaitingChannelReadyFlags::THEIR_CHANNEL_READY) {
+               } else if matches!(self.context.channel_state, ChannelState::AwaitingChannelReady(f) if f.clone().clear(FundedStateFlags::ALL.into()) == AwaitingChannelReadyFlags::THEIR_CHANNEL_READY) {
                        self.context.channel_state = ChannelState::ChannelReady(self.context.channel_state.with_funded_state_flags_mask().into());
                        self.context.update_time_counter += 1;
                        true
-               } else if matches!(self.context.channel_state, ChannelState::AwaitingChannelReady(f) if f & !FundedStateFlags::ALL == AwaitingChannelReadyFlags::OUR_CHANNEL_READY) {
+               } else if matches!(self.context.channel_state, ChannelState::AwaitingChannelReady(f) if f.clone().clear(FundedStateFlags::ALL.into()) == AwaitingChannelReadyFlags::OUR_CHANNEL_READY) {
                        // We got a reorg but not enough to trigger a force close, just ignore.
                        false
                } else {
@@ -5857,7 +5867,7 @@ impl<SP: Deref> Channel<SP> where
                        return Err(ChannelError::Ignore("Cannot send an HTLC while disconnected from channel counterparty".to_owned()));
                }
 
-               let need_holding_cell = self.context.channel_state.should_force_holding_cell();
+               let need_holding_cell = !self.context.channel_state.can_generate_new_commitment();
                log_debug!(logger, "Pushing new outbound HTLC with hash {} for {} msat {}",
                        payment_hash, amount_msat,
                        if force_holding_cell { "into holding cell" }
@@ -7475,6 +7485,8 @@ impl<SP: Deref> Writeable for Channel<SP> where SP::Target: SignerProvider {
                        let mut channel_state = self.context.channel_state;
                        if matches!(channel_state, ChannelState::AwaitingChannelReady(_)|ChannelState::ChannelReady(_)) {
                                channel_state.set_peer_disconnected();
+                       } else {
+                               debug_assert!(false, "Pre-funded/shutdown channels should not be written");
                        }
                        channel_state.to_u32().write(writer)?;
                }
@@ -8895,17 +8907,34 @@ mod tests {
        fn blinding_point_skimmed_fee_malformed_ser() {
                // Ensure that channel blinding points, skimmed fees, and malformed HTLCs are (de)serialized
                // properly.
+               let logger = test_utils::TestLogger::new();
                let feeest = LowerBoundedFeeEstimator::new(&TestFeeEstimator{fee_est: 15000});
                let secp_ctx = Secp256k1::new();
                let seed = [42; 32];
                let network = Network::Testnet;
+               let best_block = BestBlock::from_network(network);
                let keys_provider = test_utils::TestKeysInterface::new(&seed, network);
 
                let node_b_node_id = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[42; 32]).unwrap());
                let config = UserConfig::default();
                let features = channelmanager::provided_init_features(&config);
-               let outbound_chan = OutboundV1Channel::<&TestKeysInterface>::new(&feeest, &&keys_provider, &&keys_provider, node_b_node_id, &features, 10000000, 100000, 42, &config, 0, 42, None).unwrap();
-               let mut chan = Channel { context: outbound_chan.context };
+               let mut outbound_chan = OutboundV1Channel::<&TestKeysInterface>::new(
+                       &feeest, &&keys_provider, &&keys_provider, node_b_node_id, &features, 10000000, 100000, 42, &config, 0, 42, None
+               ).unwrap();
+               let inbound_chan = InboundV1Channel::<&TestKeysInterface>::new(
+                       &feeest, &&keys_provider, &&keys_provider, node_b_node_id, &channelmanager::provided_channel_type_features(&config),
+                       &features, &outbound_chan.get_open_channel(ChainHash::using_genesis_block(network)), 7, &config, 0, &&logger, false
+               ).unwrap();
+               outbound_chan.accept_channel(&inbound_chan.get_accept_channel_message(), &config.channel_handshake_limits, &features).unwrap();
+               let tx = Transaction { version: 1, lock_time: LockTime::ZERO, input: Vec::new(), output: vec![TxOut {
+                       value: 10000000, script_pubkey: outbound_chan.context.get_funding_redeemscript(),
+               }]};
+               let funding_outpoint = OutPoint{ txid: tx.txid(), index: 0 };
+               let funding_created = outbound_chan.get_funding_created(tx.clone(), funding_outpoint, false, &&logger).map_err(|_| ()).unwrap().unwrap();
+               let mut chan = match inbound_chan.funding_created(&funding_created, best_block, &&keys_provider, &&logger) {
+                       Ok((chan, _, _)) => chan,
+                       Err((_, e)) => panic!("{}", e),
+               };
 
                let dummy_htlc_source = HTLCSource::OutboundRoute {
                        path: Path {