Introduce traits to make test utils generic across the `CM` Holder
authorMatt Corallo <git@bluematt.me>
Fri, 17 Mar 2023 20:50:19 +0000 (20:50 +0000)
committerMatt Corallo <git@bluematt.me>
Thu, 13 Apr 2023 18:40:46 +0000 (18:40 +0000)
In our test utilities, we generally refer to a `Node` struct which
holds a `ChannelManager` and a number of other structs. However, we
use the same utilities in benchmarking, where we have a different
`Node`-like struct. This made moving from macros to functions
entirely impossible, as we end up needing multiple types in a given
context.

Thus, here, we take the pain and introduce some wrapper traits
which encapsulte what we need from `Node`, swapping some of our
macros to functions.

lightning/src/ln/channelmanager.rs
lightning/src/ln/functional_test_utils.rs

index 5f315b8fd4a0e2f61052db9a4168d331834d5cb5..df9cecba756971b03bcba8540250a0f686177d14 100644 (file)
@@ -618,6 +618,61 @@ pub type SimpleArcChannelManager<M, T, F, L> = ChannelManager<
 /// This is not exported to bindings users as Arcs don't make sense in bindings
 pub type SimpleRefChannelManager<'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, M, T, F, L> = ChannelManager<&'a M, &'b T, &'c KeysManager, &'c KeysManager, &'c KeysManager, &'d F, &'e DefaultRouter<&'f NetworkGraph<&'g L>, &'g L, &'h Mutex<ProbabilisticScorer<&'f NetworkGraph<&'g L>, &'g L>>>, &'g L>;
 
+/// A trivial trait which describes any [`ChannelManager`] used in testing.
+#[cfg(any(test, feature = "_test_utils"))]
+pub trait AChannelManager {
+       type Watch: chain::Watch<Self::Signer>;
+       type M: Deref<Target = Self::Watch>;
+       type Broadcaster: BroadcasterInterface;
+       type T: Deref<Target = Self::Broadcaster>;
+       type EntropySource: EntropySource;
+       type ES: Deref<Target = Self::EntropySource>;
+       type NodeSigner: NodeSigner;
+       type NS: Deref<Target = Self::NodeSigner>;
+       type Signer: WriteableEcdsaChannelSigner;
+       type SignerProvider: SignerProvider<Signer = Self::Signer>;
+       type SP: Deref<Target = Self::SignerProvider>;
+       type FeeEstimator: FeeEstimator;
+       type F: Deref<Target = Self::FeeEstimator>;
+       type Router: Router;
+       type R: Deref<Target = Self::Router>;
+       type Logger: Logger;
+       type L: Deref<Target = Self::Logger>;
+       fn get_cm(&self) -> &ChannelManager<Self::M, Self::T, Self::ES, Self::NS, Self::SP, Self::F, Self::R, Self::L>;
+}
+#[cfg(any(test, feature = "_test_utils"))]
+impl<M: Deref, T: Deref, ES: Deref, NS: Deref, SP: Deref, F: Deref, R: Deref, L: Deref> AChannelManager
+for ChannelManager<M, T, ES, NS, SP, F, R, L>
+where
+       M::Target: chain::Watch<<SP::Target as SignerProvider>::Signer> + Sized,
+       T::Target: BroadcasterInterface + Sized,
+       ES::Target: EntropySource + Sized,
+       NS::Target: NodeSigner + Sized,
+       SP::Target: SignerProvider + Sized,
+       F::Target: FeeEstimator + Sized,
+       R::Target: Router + Sized,
+       L::Target: Logger + Sized,
+{
+       type Watch = M::Target;
+       type M = M;
+       type Broadcaster = T::Target;
+       type T = T;
+       type EntropySource = ES::Target;
+       type ES = ES;
+       type NodeSigner = NS::Target;
+       type NS = NS;
+       type Signer = <SP::Target as SignerProvider>::Signer;
+       type SignerProvider = SP::Target;
+       type SP = SP;
+       type FeeEstimator = F::Target;
+       type F = F;
+       type Router = R::Target;
+       type R = R;
+       type Logger = L::Target;
+       type L = L;
+       fn get_cm(&self) -> &ChannelManager<M, T, ES, NS, SP, F, R, L> { self }
+}
+
 /// Manager which keeps track of a number of channels and sends messages to the appropriate
 /// channel, also tracking HTLC preimages and forwarding onion packets appropriately.
 ///
@@ -8839,14 +8894,23 @@ pub mod bench {
 
        use test::Bencher;
 
-       struct NodeHolder<'a, P: Persist<InMemorySigner>> {
-               node: &'a ChannelManager<
-                       &'a ChainMonitor<InMemorySigner, &'a test_utils::TestChainSource,
-                               &'a test_utils::TestBroadcaster, &'a test_utils::TestFeeEstimator,
-                               &'a test_utils::TestLogger, &'a P>,
-                       &'a test_utils::TestBroadcaster, &'a KeysManager, &'a KeysManager, &'a KeysManager,
-                       &'a test_utils::TestFeeEstimator, &'a test_utils::TestRouter<'a>,
-                       &'a test_utils::TestLogger>,
+       type Manager<'a, P> = ChannelManager<
+               &'a ChainMonitor<InMemorySigner, &'a test_utils::TestChainSource,
+                       &'a test_utils::TestBroadcaster, &'a test_utils::TestFeeEstimator,
+                       &'a test_utils::TestLogger, &'a P>,
+               &'a test_utils::TestBroadcaster, &'a KeysManager, &'a KeysManager, &'a KeysManager,
+               &'a test_utils::TestFeeEstimator, &'a test_utils::TestRouter<'a>,
+               &'a test_utils::TestLogger>;
+
+       struct ANodeHolder<'a, P: Persist<InMemorySigner>> {
+               node: &'a Manager<'a, P>,
+       }
+       impl<'a, P: Persist<InMemorySigner>> NodeHolder for ANodeHolder<'a, P> {
+               type CM = Manager<'a, P>;
+               #[inline]
+               fn node(&self) -> &Manager<'a, P> { self.node }
+               #[inline]
+               fn chain_monitor(&self) -> Option<&test_utils::TestChainMonitor> { None }
        }
 
        #[cfg(test)]
@@ -8877,7 +8941,7 @@ pub mod bench {
                        network,
                        best_block: BestBlock::from_network(network),
                });
-               let node_a_holder = NodeHolder { node: &node_a };
+               let node_a_holder = ANodeHolder { node: &node_a };
 
                let logger_b = test_utils::TestLogger::with_id("node a".to_owned());
                let chain_monitor_b = ChainMonitor::new(None, &tx_broadcaster, &logger_a, &fee_estimator, &persister_b);
@@ -8887,7 +8951,7 @@ pub mod bench {
                        network,
                        best_block: BestBlock::from_network(network),
                });
-               let node_b_holder = NodeHolder { node: &node_b };
+               let node_b_holder = ANodeHolder { node: &node_b };
 
                node_a.peer_connected(&node_b.get_our_node_id(), &Init { features: node_b.init_features(), remote_network_address: None }, true).unwrap();
                node_b.peer_connected(&node_a.get_our_node_id(), &Init { features: node_a.init_features(), remote_network_address: None }, false).unwrap();
@@ -8983,15 +9047,15 @@ pub mod bench {
                                let payment_event = SendEvent::from_event($node_a.get_and_clear_pending_msg_events().pop().unwrap());
                                $node_b.handle_update_add_htlc(&$node_a.get_our_node_id(), &payment_event.msgs[0]);
                                $node_b.handle_commitment_signed(&$node_a.get_our_node_id(), &payment_event.commitment_msg);
-                               let (raa, cs) = do_get_revoke_commit_msgs!(NodeHolder { node: &$node_b }, &$node_a.get_our_node_id());
+                               let (raa, cs) = get_revoke_commit_msgs(&ANodeHolder { node: &$node_b }, &$node_a.get_our_node_id());
                                $node_a.handle_revoke_and_ack(&$node_b.get_our_node_id(), &raa);
                                $node_a.handle_commitment_signed(&$node_b.get_our_node_id(), &cs);
-                               $node_b.handle_revoke_and_ack(&$node_a.get_our_node_id(), &get_event_msg!(NodeHolder { node: &$node_a }, MessageSendEvent::SendRevokeAndACK, $node_b.get_our_node_id()));
+                               $node_b.handle_revoke_and_ack(&$node_a.get_our_node_id(), &get_event_msg!(ANodeHolder { node: &$node_a }, MessageSendEvent::SendRevokeAndACK, $node_b.get_our_node_id()));
 
-                               expect_pending_htlcs_forwardable!(NodeHolder { node: &$node_b });
-                               expect_payment_claimable!(NodeHolder { node: &$node_b }, payment_hash, payment_secret, 10_000);
+                               expect_pending_htlcs_forwardable!(ANodeHolder { node: &$node_b });
+                               expect_payment_claimable!(ANodeHolder { node: &$node_b }, payment_hash, payment_secret, 10_000);
                                $node_b.claim_funds(payment_preimage);
-                               expect_payment_claimed!(NodeHolder { node: &$node_b }, payment_hash, 10_000);
+                               expect_payment_claimed!(ANodeHolder { node: &$node_b }, payment_hash, 10_000);
 
                                match $node_b.get_and_clear_pending_msg_events().pop().unwrap() {
                                        MessageSendEvent::UpdateHTLCs { node_id, updates } => {
@@ -9002,12 +9066,12 @@ pub mod bench {
                                        _ => panic!("Failed to generate claim event"),
                                }
 
-                               let (raa, cs) = do_get_revoke_commit_msgs!(NodeHolder { node: &$node_a }, &$node_b.get_our_node_id());
+                               let (raa, cs) = get_revoke_commit_msgs(&ANodeHolder { node: &$node_a }, &$node_b.get_our_node_id());
                                $node_b.handle_revoke_and_ack(&$node_a.get_our_node_id(), &raa);
                                $node_b.handle_commitment_signed(&$node_a.get_our_node_id(), &cs);
-                               $node_a.handle_revoke_and_ack(&$node_b.get_our_node_id(), &get_event_msg!(NodeHolder { node: &$node_b }, MessageSendEvent::SendRevokeAndACK, $node_a.get_our_node_id()));
+                               $node_a.handle_revoke_and_ack(&$node_b.get_our_node_id(), &get_event_msg!(ANodeHolder { node: &$node_b }, MessageSendEvent::SendRevokeAndACK, $node_a.get_our_node_id()));
 
-                               expect_payment_sent!(NodeHolder { node: &$node_a }, payment_preimage);
+                               expect_payment_sent!(ANodeHolder { node: &$node_a }, payment_preimage);
                        }
                }
 
index b98db4bb957bb5b4ba1ee705ec35b23938c640cd..39b88bd3b2002d8cbc0a32af39f25a1dd10f1465 100644 (file)
@@ -15,7 +15,7 @@ use crate::chain::channelmonitor::ChannelMonitor;
 use crate::chain::transaction::OutPoint;
 use crate::events::{ClosureReason, Event, HTLCDestination, MessageSendEvent, MessageSendEventsProvider, PathFailure, PaymentPurpose, PaymentFailureReason};
 use crate::ln::{PaymentPreimage, PaymentHash, PaymentSecret};
-use crate::ln::channelmanager::{ChainParameters, ChannelManager, ChannelManagerReadArgs, RAACommitmentOrder, PaymentSendFailure, RecipientOnionFields, PaymentId, MIN_CLTV_EXPIRY_DELTA};
+use crate::ln::channelmanager::{AChannelManager, ChainParameters, ChannelManager, ChannelManagerReadArgs, RAACommitmentOrder, PaymentSendFailure, RecipientOnionFields, PaymentId, MIN_CLTV_EXPIRY_DELTA};
 use crate::routing::gossip::{P2PGossipSync, NetworkGraph, NetworkUpdate};
 use crate::routing::router::{self, PaymentParameters, Route};
 use crate::ln::features::InitFeatures;
@@ -324,6 +324,8 @@ pub struct NodeCfg<'a> {
        pub override_init_features: Rc<RefCell<Option<InitFeatures>>>,
 }
 
+type TestChannelManager<'a, 'b, 'c> = ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'b test_utils::TestRouter<'c>, &'c test_utils::TestLogger>;
+
 pub struct Node<'a, 'b: 'a, 'c: 'b> {
        pub chain_source: &'c test_utils::TestChainSource,
        pub tx_broadcaster: &'c test_utils::TestBroadcaster,
@@ -331,7 +333,7 @@ pub struct Node<'a, 'b: 'a, 'c: 'b> {
        pub router: &'b test_utils::TestRouter<'c>,
        pub chain_monitor: &'b test_utils::TestChainMonitor<'c>,
        pub keys_manager: &'b test_utils::TestKeysInterface,
-       pub node: &'a ChannelManager<&'b TestChainMonitor<'c>, &'c test_utils::TestBroadcaster, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'b test_utils::TestKeysInterface, &'c test_utils::TestFeeEstimator, &'b test_utils::TestRouter<'c>, &'c test_utils::TestLogger>,
+       pub node: &'a TestChannelManager<'a, 'b, 'c>,
        pub network_graph: &'a NetworkGraph<&'c test_utils::TestLogger>,
        pub gossip_sync: P2PGossipSync<&'b NetworkGraph<&'c test_utils::TestLogger>, &'c test_utils::TestChainSource, &'c test_utils::TestLogger>,
        pub node_seed: [u8; 32],
@@ -367,6 +369,39 @@ impl NodePtr {
 unsafe impl Send for NodePtr {}
 unsafe impl Sync for NodePtr {}
 
+
+pub trait NodeHolder {
+       type CM: AChannelManager;
+       fn node(&self) -> &ChannelManager<
+               <Self::CM as AChannelManager>::M,
+               <Self::CM as AChannelManager>::T,
+               <Self::CM as AChannelManager>::ES,
+               <Self::CM as AChannelManager>::NS,
+               <Self::CM as AChannelManager>::SP,
+               <Self::CM as AChannelManager>::F,
+               <Self::CM as AChannelManager>::R,
+               <Self::CM as AChannelManager>::L>;
+       fn chain_monitor(&self) -> Option<&test_utils::TestChainMonitor>;
+}
+impl<H: NodeHolder> NodeHolder for &H {
+       type CM = H::CM;
+       fn node(&self) -> &ChannelManager<
+               <Self::CM as AChannelManager>::M,
+               <Self::CM as AChannelManager>::T,
+               <Self::CM as AChannelManager>::ES,
+               <Self::CM as AChannelManager>::NS,
+               <Self::CM as AChannelManager>::SP,
+               <Self::CM as AChannelManager>::F,
+               <Self::CM as AChannelManager>::R,
+               <Self::CM as AChannelManager>::L> { (*self).node() }
+       fn chain_monitor(&self) -> Option<&test_utils::TestChainMonitor> { (*self).chain_monitor() }
+}
+impl<'a, 'b: 'a, 'c: 'b> NodeHolder for Node<'a, 'b, 'c> {
+       type CM = TestChannelManager<'a, 'b, 'c>;
+       fn node(&self) -> &TestChannelManager<'a, 'b, 'c> { &self.node }
+       fn chain_monitor(&self) -> Option<&test_utils::TestChainMonitor> { Some(self.chain_monitor) }
+}
+
 impl<'a, 'b, 'c> Drop for Node<'a, 'b, 'c> {
        fn drop(&mut self) {
                if !panicking() {
@@ -486,36 +521,27 @@ pub fn create_chan_between_nodes_with_value<'a, 'b, 'c, 'd>(node_a: &'a Node<'b,
 }
 
 /// Gets an RAA and CS which were sent in response to a commitment update
-///
-/// Should only be used directly when the `$node` is not actually a [`Node`].
-macro_rules! do_get_revoke_commit_msgs {
-       ($node: expr, $recipient: expr) => { {
-               let events = $node.node.get_and_clear_pending_msg_events();
-               assert_eq!(events.len(), 2);
-               (match events[0] {
-                       MessageSendEvent::SendRevokeAndACK { ref node_id, ref msg } => {
-                               assert_eq!(node_id, $recipient);
-                               (*msg).clone()
-                       },
-                       _ => panic!("Unexpected event"),
-               }, match events[1] {
-                       MessageSendEvent::UpdateHTLCs { ref node_id, ref updates } => {
-                               assert_eq!(node_id, $recipient);
-                               assert!(updates.update_add_htlcs.is_empty());
-                               assert!(updates.update_fulfill_htlcs.is_empty());
-                               assert!(updates.update_fail_htlcs.is_empty());
-                               assert!(updates.update_fail_malformed_htlcs.is_empty());
-                               assert!(updates.update_fee.is_none());
-                               updates.commitment_signed.clone()
-                       },
-                       _ => panic!("Unexpected event"),
-               })
-       } }
-}
-
-/// Gets an RAA and CS which were sent in response to a commitment update
-pub fn get_revoke_commit_msgs(node: &Node, recipient: &PublicKey) -> (msgs::RevokeAndACK, msgs::CommitmentSigned) {
-       do_get_revoke_commit_msgs!(node, recipient)
+pub fn get_revoke_commit_msgs<CM: AChannelManager, H: NodeHolder<CM=CM>>(node: &H, recipient: &PublicKey) -> (msgs::RevokeAndACK, msgs::CommitmentSigned) {
+       let events = node.node().get_and_clear_pending_msg_events();
+       assert_eq!(events.len(), 2);
+       (match events[0] {
+               MessageSendEvent::SendRevokeAndACK { ref node_id, ref msg } => {
+                       assert_eq!(node_id, recipient);
+                       (*msg).clone()
+               },
+               _ => panic!("Unexpected event"),
+       }, match events[1] {
+               MessageSendEvent::UpdateHTLCs { ref node_id, ref updates } => {
+                       assert_eq!(node_id, recipient);
+                       assert!(updates.update_add_htlcs.is_empty());
+                       assert!(updates.update_fulfill_htlcs.is_empty());
+                       assert!(updates.update_fail_htlcs.is_empty());
+                       assert!(updates.update_fail_malformed_htlcs.is_empty());
+                       assert!(updates.update_fee.is_none());
+                       updates.commitment_signed.clone()
+               },
+               _ => panic!("Unexpected event"),
+       })
 }
 
 #[macro_export]
@@ -774,10 +800,12 @@ macro_rules! unwrap_send_err {
 }
 
 /// Check whether N channel monitor(s) have been added.
-pub fn check_added_monitors(node: &Node, count: usize) {
-       let mut added_monitors = node.chain_monitor.added_monitors.lock().unwrap();
-       assert_eq!(added_monitors.len(), count);
-       added_monitors.clear();
+pub fn check_added_monitors<CM: AChannelManager, H: NodeHolder<CM=CM>>(node: &H, count: usize) {
+       if let Some(chain_monitor) = node.chain_monitor() {
+               let mut added_monitors = chain_monitor.added_monitors.lock().unwrap();
+               assert_eq!(added_monitors.len(), count);
+               added_monitors.clear();
+       }
 }
 
 /// Check whether N channel monitor(s) have been added.