Merge pull request #3115 from alecchendev/2024-06-specific-async-sign
[rust-lightning] / lightning / src / util / test_channel_signer.rs
index 2009e4d0b08a9a22fd68941b8290bfabffe75203..884acc2c2eac1835ca07f6949afd172effe2e62e 100644 (file)
@@ -18,7 +18,7 @@ use crate::sign::ecdsa::EcdsaChannelSigner;
 #[allow(unused_imports)]
 use crate::prelude::*;
 
-use core::cmp;
+use core::{cmp, fmt};
 use crate::sync::{Mutex, Arc};
 #[cfg(test)] use crate::sync::MutexGuard;
 
@@ -71,9 +71,46 @@ pub struct TestChannelSigner {
        /// Channel state used for policy enforcement
        pub state: Arc<Mutex<EnforcementState>>,
        pub disable_revocation_policy_check: bool,
-       /// When `true` (the default), the signer will respond immediately with signatures. When `false`,
-       /// the signer will return an error indicating that it is unavailable.
-       pub available: Arc<Mutex<bool>>,
+       /// Set of signer operations that are disabled. If an operation is disabled,
+       /// the signer will return `Err` when the corresponding method is called.
+       pub disabled_signer_ops: Arc<Mutex<HashSet<SignerOp>>>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum SignerOp {
+       GetPerCommitmentPoint,
+       ReleaseCommitmentSecret,
+       ValidateHolderCommitment,
+       SignCounterpartyCommitment,
+       ValidateCounterpartyRevocation,
+       SignHolderCommitment,
+       SignJusticeRevokedOutput,
+       SignJusticeRevokedHtlc,
+       SignHolderHtlcTransaction,
+       SignCounterpartyHtlcTransaction,
+       SignClosingTransaction,
+       SignHolderAnchorInput,
+       SignChannelAnnouncementWithFundingKey,
+}
+
+impl fmt::Display for SignerOp {
+       fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+               match self {
+                       SignerOp::GetPerCommitmentPoint => write!(f, "get_per_commitment_point"),
+                       SignerOp::ReleaseCommitmentSecret => write!(f, "release_commitment_secret"),
+                       SignerOp::ValidateHolderCommitment => write!(f, "validate_holder_commitment"),
+                       SignerOp::SignCounterpartyCommitment => write!(f, "sign_counterparty_commitment"),
+                       SignerOp::ValidateCounterpartyRevocation => write!(f, "validate_counterparty_revocation"),
+                       SignerOp::SignHolderCommitment => write!(f, "sign_holder_commitment"),
+                       SignerOp::SignJusticeRevokedOutput => write!(f, "sign_justice_revoked_output"),
+                       SignerOp::SignJusticeRevokedHtlc => write!(f, "sign_justice_revoked_htlc"),
+                       SignerOp::SignHolderHtlcTransaction => write!(f, "sign_holder_htlc_transaction"),
+                       SignerOp::SignCounterpartyHtlcTransaction => write!(f, "sign_counterparty_htlc_transaction"),
+                       SignerOp::SignClosingTransaction => write!(f, "sign_closing_transaction"),
+                       SignerOp::SignHolderAnchorInput => write!(f, "sign_holder_anchor_input"),
+                       SignerOp::SignChannelAnnouncementWithFundingKey => write!(f, "sign_channel_announcement_with_funding_key"),
+               }
+       }
 }
 
 impl PartialEq for TestChannelSigner {
@@ -90,7 +127,7 @@ impl TestChannelSigner {
                        inner,
                        state,
                        disable_revocation_policy_check: false,
-                       available: Arc::new(Mutex::new(true)),
+                       disabled_signer_ops: Arc::new(Mutex::new(new_hash_set())),
                }
        }
 
@@ -104,7 +141,7 @@ impl TestChannelSigner {
                        inner,
                        state,
                        disable_revocation_policy_check,
-                       available: Arc::new(Mutex::new(true)),
+                       disabled_signer_ops: Arc::new(Mutex::new(new_hash_set())),
                }
        }
 
@@ -115,13 +152,16 @@ impl TestChannelSigner {
                self.state.lock().unwrap()
        }
 
-       /// Marks the signer's availability.
-       ///
-       /// When `true`, methods are forwarded to the underlying signer as normal. When `false`, some
-       /// methods will return `Err` indicating that the signer is unavailable. Intended to be used for
-       /// testing asynchronous signing.
-       pub fn set_available(&self, available: bool) {
-               *self.available.lock().unwrap() = available;
+       pub fn enable_op(&mut self, signer_op: SignerOp) {
+               self.disabled_signer_ops.lock().unwrap().remove(&signer_op);
+       }
+
+       pub fn disable_op(&mut self, signer_op: SignerOp) {
+               self.disabled_signer_ops.lock().unwrap().insert(signer_op);
+       }
+
+       fn is_signer_available(&self, signer_op: SignerOp) -> bool {
+               !self.disabled_signer_ops.lock().unwrap().contains(&signer_op)
        }
 }
 
@@ -149,7 +189,7 @@ impl ChannelSigner for TestChannelSigner {
        }
 
        fn validate_counterparty_revocation(&self, idx: u64, _secret: &SecretKey) -> Result<(), ()> {
-               if !*self.available.lock().unwrap() {
+               if !self.is_signer_available(SignerOp::ValidateCounterpartyRevocation) {
                        return Err(());
                }
                let mut state = self.state.lock().unwrap();
@@ -172,7 +212,7 @@ impl EcdsaChannelSigner for TestChannelSigner {
                self.verify_counterparty_commitment_tx(commitment_tx, secp_ctx);
 
                {
-                       if !*self.available.lock().unwrap() {
+                       if !self.is_signer_available(SignerOp::SignCounterpartyCommitment) {
                                return Err(());
                        }
                        let mut state = self.state.lock().unwrap();
@@ -191,7 +231,7 @@ impl EcdsaChannelSigner for TestChannelSigner {
        }
 
        fn sign_holder_commitment(&self, commitment_tx: &HolderCommitmentTransaction, secp_ctx: &Secp256k1<secp256k1::All>) -> Result<Signature, ()> {
-               if !*self.available.lock().unwrap() {
+               if !self.is_signer_available(SignerOp::SignHolderCommitment) {
                        return Err(());
                }
                let trusted_tx = self.verify_holder_commitment_tx(commitment_tx, secp_ctx);
@@ -212,14 +252,14 @@ impl EcdsaChannelSigner for TestChannelSigner {
        }
 
        fn sign_justice_revoked_output(&self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, secp_ctx: &Secp256k1<secp256k1::All>) -> Result<Signature, ()> {
-               if !*self.available.lock().unwrap() {
+               if !self.is_signer_available(SignerOp::SignJusticeRevokedOutput) {
                        return Err(());
                }
                Ok(EcdsaChannelSigner::sign_justice_revoked_output(&self.inner, justice_tx, input, amount, per_commitment_key, secp_ctx).unwrap())
        }
 
        fn sign_justice_revoked_htlc(&self, justice_tx: &Transaction, input: usize, amount: u64, per_commitment_key: &SecretKey, htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1<secp256k1::All>) -> Result<Signature, ()> {
-               if !*self.available.lock().unwrap() {
+               if !self.is_signer_available(SignerOp::SignJusticeRevokedHtlc) {
                        return Err(());
                }
                Ok(EcdsaChannelSigner::sign_justice_revoked_htlc(&self.inner, justice_tx, input, amount, per_commitment_key, htlc, secp_ctx).unwrap())
@@ -229,7 +269,7 @@ impl EcdsaChannelSigner for TestChannelSigner {
                &self, htlc_tx: &Transaction, input: usize, htlc_descriptor: &HTLCDescriptor,
                secp_ctx: &Secp256k1<secp256k1::All>
        ) -> Result<Signature, ()> {
-               if !*self.available.lock().unwrap() {
+               if !self.is_signer_available(SignerOp::SignHolderHtlcTransaction) {
                        return Err(());
                }
                let state = self.state.lock().unwrap();
@@ -265,7 +305,7 @@ impl EcdsaChannelSigner for TestChannelSigner {
        }
 
        fn sign_counterparty_htlc_transaction(&self, htlc_tx: &Transaction, input: usize, amount: u64, per_commitment_point: &PublicKey, htlc: &HTLCOutputInCommitment, secp_ctx: &Secp256k1<secp256k1::All>) -> Result<Signature, ()> {
-               if !*self.available.lock().unwrap() {
+               if !self.is_signer_available(SignerOp::SignCounterpartyHtlcTransaction) {
                        return Err(());
                }
                Ok(EcdsaChannelSigner::sign_counterparty_htlc_transaction(&self.inner, htlc_tx, input, amount, per_commitment_point, htlc, secp_ctx).unwrap())
@@ -284,7 +324,7 @@ impl EcdsaChannelSigner for TestChannelSigner {
                // As long as our minimum dust limit is enforced and is greater than our anchor output
                // value, an anchor output can only have an index within [0, 1].
                assert!(anchor_tx.input[input].previous_output.vout == 0 || anchor_tx.input[input].previous_output.vout == 1);
-               if !*self.available.lock().unwrap() {
+               if !self.is_signer_available(SignerOp::SignHolderAnchorInput) {
                        return Err(());
                }
                EcdsaChannelSigner::sign_holder_anchor_input(&self.inner, anchor_tx, input, secp_ctx)