Refactor: move channel checks for HTLC adds into Channel
[rust-lightning] / lightning / src / ln / channel.rs
index 0c30a7e96cb53bfec845174a2417e002aedce211..f792a0a5e3963f255947eb8c894ac2eef3b10e28 100644 (file)
@@ -1095,10 +1095,10 @@ impl<ChanSigner: ChannelKeys> Channel<ChanSigner> {
        fn build_local_transaction_keys(&self, commitment_number: u64) -> Result<TxCreationKeys, ChannelError> {
                let per_commitment_point = PublicKey::from_secret_key(&self.secp_ctx, &self.build_local_commitment_secret(commitment_number));
                let delayed_payment_base = &self.local_keys.pubkeys().delayed_payment_basepoint;
-               let htlc_basepoint = PublicKey::from_secret_key(&self.secp_ctx, self.local_keys.htlc_base_key());
+               let htlc_basepoint = &self.local_keys.pubkeys().htlc_basepoint;
                let their_pubkeys = self.their_pubkeys.as_ref().unwrap();
 
-               Ok(secp_check!(TxCreationKeys::new(&self.secp_ctx, &per_commitment_point, delayed_payment_base, &htlc_basepoint, &their_pubkeys.revocation_basepoint, &their_pubkeys.htlc_basepoint), "Local tx keys generation got bogus keys"))
+               Ok(secp_check!(TxCreationKeys::new(&self.secp_ctx, &per_commitment_point, delayed_payment_base, htlc_basepoint, &their_pubkeys.revocation_basepoint, &their_pubkeys.htlc_basepoint), "Local tx keys generation got bogus keys"))
        }
 
        #[inline]
@@ -1109,10 +1109,10 @@ impl<ChanSigner: ChannelKeys> Channel<ChanSigner> {
                //TODO: Ensure that the payment_key derived here ends up in the library users' wallet as we
                //may see payments to it!
                let revocation_basepoint = &self.local_keys.pubkeys().revocation_basepoint;
-               let htlc_basepoint = PublicKey::from_secret_key(&self.secp_ctx, self.local_keys.htlc_base_key());
+               let htlc_basepoint = &self.local_keys.pubkeys().htlc_basepoint;
                let their_pubkeys = self.their_pubkeys.as_ref().unwrap();
 
-               Ok(secp_check!(TxCreationKeys::new(&self.secp_ctx, &self.their_cur_commitment_point.unwrap(), &their_pubkeys.delayed_payment_basepoint, &their_pubkeys.htlc_basepoint, revocation_basepoint, &htlc_basepoint), "Remote tx keys generation got bogus keys"))
+               Ok(secp_check!(TxCreationKeys::new(&self.secp_ctx, &self.their_cur_commitment_point.unwrap(), &their_pubkeys.delayed_payment_basepoint, &their_pubkeys.htlc_basepoint, revocation_basepoint, htlc_basepoint), "Remote tx keys generation got bogus keys"))
        }
 
        /// Gets the redeemscript for the funding transaction output (ie the funding transaction output
@@ -1663,8 +1663,16 @@ impl<ChanSigner: ChannelKeys> Channel<ChanSigner> {
                cmp::min(self.value_to_self_msat as i64 - self.get_outbound_pending_htlc_stats().1 as i64, 0) as u64)
        }
 
-       pub fn update_add_htlc(&mut self, msg: &msgs::UpdateAddHTLC, pending_forward_state: PendingHTLCStatus) -> Result<(), ChannelError> {
-               if (self.channel_state & (ChannelState::ChannelFunded as u32 | ChannelState::RemoteShutdownSent as u32)) != (ChannelState::ChannelFunded as u32) {
+       pub fn update_add_htlc<F>(&mut self, msg: &msgs::UpdateAddHTLC, mut pending_forward_status: PendingHTLCStatus, create_pending_htlc_status: F) -> Result<(), ChannelError>
+       where F: for<'a> Fn(&'a Self, PendingHTLCStatus, u16) -> PendingHTLCStatus {
+               // We can't accept HTLCs sent after we've sent a shutdown.
+               let local_sent_shutdown = (self.channel_state & (ChannelState::ChannelFunded as u32 | ChannelState::LocalShutdownSent as u32)) != (ChannelState::ChannelFunded as u32);
+               if local_sent_shutdown {
+                       pending_forward_status = create_pending_htlc_status(self, pending_forward_status, 0x1000|20);
+               }
+               // If the remote has sent a shutdown prior to adding this HTLC, then they are in violation of the spec.
+               let remote_sent_shutdown = (self.channel_state & (ChannelState::ChannelFunded as u32 | ChannelState::RemoteShutdownSent as u32)) != (ChannelState::ChannelFunded as u32);
+               if remote_sent_shutdown {
                        return Err(ChannelError::Close("Got add HTLC message when channel was not in an operational state"));
                }
                if self.channel_state & (ChannelState::PeerDisconnected as u32) == ChannelState::PeerDisconnected as u32 {
@@ -1719,7 +1727,7 @@ impl<ChanSigner: ChannelKeys> Channel<ChanSigner> {
                }
 
                if self.channel_state & ChannelState::LocalShutdownSent as u32 != 0 {
-                       if let PendingHTLCStatus::Forward(_) = pending_forward_state {
+                       if let PendingHTLCStatus::Forward(_) = pending_forward_status {
                                panic!("ChannelManager shouldn't be trying to add a forwardable HTLC after we've started closing");
                        }
                }
@@ -1731,7 +1739,7 @@ impl<ChanSigner: ChannelKeys> Channel<ChanSigner> {
                        amount_msat: msg.amount_msat,
                        payment_hash: msg.payment_hash,
                        cltv_expiry: msg.cltv_expiry,
-                       state: InboundHTLCState::RemoteAnnounced(pending_forward_state),
+                       state: InboundHTLCState::RemoteAnnounced(pending_forward_status),
                });
                Ok(())
        }
@@ -3318,7 +3326,7 @@ impl<ChanSigner: ChannelKeys> Channel<ChanSigner> {
                        revocation_basepoint: local_keys.revocation_basepoint,
                        payment_point: local_keys.payment_point,
                        delayed_payment_basepoint: local_keys.delayed_payment_basepoint,
-                       htlc_basepoint: PublicKey::from_secret_key(&self.secp_ctx, self.local_keys.htlc_base_key()),
+                       htlc_basepoint: local_keys.htlc_basepoint,
                        first_per_commitment_point: PublicKey::from_secret_key(&self.secp_ctx, &local_commitment_secret),
                        channel_flags: if self.config.announced_channel {1} else {0},
                        shutdown_scriptpubkey: OptionalField::Present(if self.config.commit_upfront_shutdown_pubkey { self.get_closing_scriptpubkey() } else { Builder::new().into_script() })
@@ -3352,7 +3360,7 @@ impl<ChanSigner: ChannelKeys> Channel<ChanSigner> {
                        revocation_basepoint: local_keys.revocation_basepoint,
                        payment_point: local_keys.payment_point,
                        delayed_payment_basepoint: local_keys.delayed_payment_basepoint,
-                       htlc_basepoint: PublicKey::from_secret_key(&self.secp_ctx, self.local_keys.htlc_base_key()),
+                       htlc_basepoint: local_keys.htlc_basepoint,
                        first_per_commitment_point: PublicKey::from_secret_key(&self.secp_ctx, &local_commitment_secret),
                        shutdown_scriptpubkey: OptionalField::Present(if self.config.commit_upfront_shutdown_pubkey { self.get_closing_scriptpubkey() } else { Builder::new().into_script() })
                }
@@ -4479,8 +4487,8 @@ mod tests {
                let delayed_payment_base = &chan.local_keys.pubkeys().delayed_payment_basepoint;
                let per_commitment_secret = SecretKey::from_slice(&hex::decode("1f1e1d1c1b1a191817161514131211100f0e0d0c0b0a09080706050403020100").unwrap()[..]).unwrap();
                let per_commitment_point = PublicKey::from_secret_key(&secp_ctx, &per_commitment_secret);
-               let htlc_basepoint = PublicKey::from_secret_key(&secp_ctx, chan.local_keys.htlc_base_key());
-               let keys = TxCreationKeys::new(&secp_ctx, &per_commitment_point, delayed_payment_base, &htlc_basepoint, &their_pubkeys.revocation_basepoint, &their_pubkeys.htlc_basepoint).unwrap();
+               let htlc_basepoint = &chan.local_keys.pubkeys().htlc_basepoint;
+               let keys = TxCreationKeys::new(&secp_ctx, &per_commitment_point, delayed_payment_base, htlc_basepoint, &their_pubkeys.revocation_basepoint, &their_pubkeys.htlc_basepoint).unwrap();
 
                chan.their_pubkeys = Some(their_pubkeys);