Merge pull request #2400 from TheBlueMatt/2023-07-kill-vec_type
[rust-lightning] / lightning / src / onion_message / functional_tests.rs
index e6058c459eb972710a84605d30bbc69360fd976c..d1b01b71eef0e041ef37007712bfd0eecd014ee9 100644 (file)
@@ -20,10 +20,11 @@ use crate::util::test_utils;
 use bitcoin::network::constants::Network;
 use bitcoin::secp256k1::{PublicKey, Secp256k1};
 
-use core::sync::atomic::{AtomicU16, Ordering};
 use crate::io;
 use crate::io_extras::read_to_end;
-use crate::sync::Arc;
+use crate::sync::{Arc, Mutex};
+
+use crate::prelude::*;
 
 struct MessengerNode {
        keys_manager: Arc<test_utils::TestKeysInterface>,
@@ -36,7 +37,6 @@ struct MessengerNode {
                Arc<TestCustomMessageHandler>
        >,
        custom_message_handler: Arc<TestCustomMessageHandler>,
-       logger: Arc<test_utils::TestLogger>,
 }
 
 impl MessengerNode {
@@ -49,9 +49,12 @@ struct TestMessageRouter {}
 
 impl MessageRouter for TestMessageRouter {
        fn find_path(
-               &self, _sender: PublicKey, _peers: Vec<PublicKey>, _destination: Destination
+               &self, _sender: PublicKey, _peers: Vec<PublicKey>, destination: Destination
        ) -> Result<OnionMessagePath, ()> {
-               todo!()
+               Ok(OnionMessagePath {
+                       intermediate_nodes: vec![],
+                       destination,
+               })
        }
 }
 
@@ -63,31 +66,46 @@ impl OffersMessageHandler for TestOffersMessageHandler {
        }
 }
 
-#[derive(Clone)]
-struct TestCustomMessage {}
+#[derive(Clone, Debug, PartialEq)]
+enum TestCustomMessage {
+       Request,
+       Response,
+}
 
-const CUSTOM_MESSAGE_TYPE: u64 = 4242;
-const CUSTOM_MESSAGE_CONTENTS: [u8; 32] = [42; 32];
+const CUSTOM_REQUEST_MESSAGE_TYPE: u64 = 4242;
+const CUSTOM_RESPONSE_MESSAGE_TYPE: u64 = 4343;
+const CUSTOM_REQUEST_MESSAGE_CONTENTS: [u8; 32] = [42; 32];
+const CUSTOM_RESPONSE_MESSAGE_CONTENTS: [u8; 32] = [43; 32];
 
 impl CustomOnionMessageContents for TestCustomMessage {
        fn tlv_type(&self) -> u64 {
-               CUSTOM_MESSAGE_TYPE
+               match self {
+                       TestCustomMessage::Request => CUSTOM_REQUEST_MESSAGE_TYPE,
+                       TestCustomMessage::Response => CUSTOM_RESPONSE_MESSAGE_TYPE,
+               }
        }
 }
 
 impl Writeable for TestCustomMessage {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               Ok(CUSTOM_MESSAGE_CONTENTS.write(w)?)
+               match self {
+                       TestCustomMessage::Request => Ok(CUSTOM_REQUEST_MESSAGE_CONTENTS.write(w)?),
+                       TestCustomMessage::Response => Ok(CUSTOM_RESPONSE_MESSAGE_CONTENTS.write(w)?),
+               }
        }
 }
 
 struct TestCustomMessageHandler {
-       num_messages_expected: AtomicU16,
+       expected_messages: Mutex<VecDeque<TestCustomMessage>>,
 }
 
 impl TestCustomMessageHandler {
        fn new() -> Self {
-               Self { num_messages_expected: AtomicU16::new(0) }
+               Self { expected_messages: Mutex::new(VecDeque::new()) }
+       }
+
+       fn expect_message(&self, message: TestCustomMessage) {
+               self.expected_messages.lock().unwrap().push_back(message);
        }
 }
 
@@ -98,23 +116,37 @@ impl Drop for TestCustomMessageHandler {
                                return;
                        }
                }
-               assert_eq!(self.num_messages_expected.load(Ordering::SeqCst), 0);
+               assert!(self.expected_messages.lock().unwrap().is_empty());
        }
 }
 
 impl CustomOnionMessageHandler for TestCustomMessageHandler {
        type CustomMessage = TestCustomMessage;
-       fn handle_custom_message(&self, _msg: Self::CustomMessage) -> Option<Self::CustomMessage> {
-               self.num_messages_expected.fetch_sub(1, Ordering::SeqCst);
-               None
+       fn handle_custom_message(&self, msg: Self::CustomMessage) -> Option<Self::CustomMessage> {
+               match self.expected_messages.lock().unwrap().pop_front() {
+                       Some(expected_msg) => assert_eq!(expected_msg, msg),
+                       None => panic!("Unexpected message: {:?}", msg),
+               }
+
+               match msg {
+                       TestCustomMessage::Request => Some(TestCustomMessage::Response),
+                       TestCustomMessage::Response => None,
+               }
        }
        fn read_custom_message<R: io::Read>(&self, message_type: u64, buffer: &mut R) -> Result<Option<Self::CustomMessage>, DecodeError> where Self: Sized {
-               if message_type == CUSTOM_MESSAGE_TYPE {
-                       let buf = read_to_end(buffer)?;
-                       assert_eq!(buf, CUSTOM_MESSAGE_CONTENTS);
-                       return Ok(Some(TestCustomMessage {}))
+               match message_type {
+                       CUSTOM_REQUEST_MESSAGE_TYPE => {
+                               let buf = read_to_end(buffer)?;
+                               assert_eq!(buf, CUSTOM_REQUEST_MESSAGE_CONTENTS);
+                               Ok(Some(TestCustomMessage::Request))
+                       },
+                       CUSTOM_RESPONSE_MESSAGE_TYPE => {
+                               let buf = read_to_end(buffer)?;
+                               assert_eq!(buf, CUSTOM_RESPONSE_MESSAGE_CONTENTS);
+                               Ok(Some(TestCustomMessage::Response))
+                       },
+                       _ => Ok(None),
                }
-               Ok(None)
        }
 }
 
@@ -134,7 +166,6 @@ fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
                                offers_message_handler, custom_message_handler.clone()
                        ),
                        custom_message_handler,
-                       logger,
                });
        }
        for idx in 0..num_messengers - 1 {
@@ -149,7 +180,6 @@ fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
 }
 
 fn pass_along_path(path: &Vec<MessengerNode>) {
-       path[path.len() - 1].custom_message_handler.num_messages_expected.fetch_add(1, Ordering::SeqCst);
        let mut prev_node = &path[0];
        for node in path.into_iter().skip(1) {
                let events = prev_node.messenger.release_pending_msgs();
@@ -166,33 +196,35 @@ fn pass_along_path(path: &Vec<MessengerNode>) {
 #[test]
 fn one_hop() {
        let nodes = create_nodes(2);
-       let test_msg = OnionMessageContents::Custom(TestCustomMessage {});
+       let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
 
        let path = OnionMessagePath {
                intermediate_nodes: vec![],
                destination: Destination::Node(nodes[1].get_node_pk()),
        };
        nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap();
+       nodes[1].custom_message_handler.expect_message(TestCustomMessage::Response);
        pass_along_path(&nodes);
 }
 
 #[test]
 fn two_unblinded_hops() {
        let nodes = create_nodes(3);
-       let test_msg = OnionMessageContents::Custom(TestCustomMessage {});
+       let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
 
        let path = OnionMessagePath {
                intermediate_nodes: vec![nodes[1].get_node_pk()],
                destination: Destination::Node(nodes[2].get_node_pk()),
        };
        nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap();
+       nodes[2].custom_message_handler.expect_message(TestCustomMessage::Response);
        pass_along_path(&nodes);
 }
 
 #[test]
 fn two_unblinded_two_blinded() {
        let nodes = create_nodes(5);
-       let test_msg = OnionMessageContents::Custom(TestCustomMessage {});
+       let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
 
        let secp_ctx = Secp256k1::new();
        let blinded_path = BlindedPath::new_for_message(&[nodes[3].get_node_pk(), nodes[4].get_node_pk()], &*nodes[4].keys_manager, &secp_ctx).unwrap();
@@ -202,13 +234,14 @@ fn two_unblinded_two_blinded() {
        };
 
        nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap();
+       nodes[4].custom_message_handler.expect_message(TestCustomMessage::Response);
        pass_along_path(&nodes);
 }
 
 #[test]
 fn three_blinded_hops() {
        let nodes = create_nodes(4);
-       let test_msg = OnionMessageContents::Custom(TestCustomMessage {});
+       let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
 
        let secp_ctx = Secp256k1::new();
        let blinded_path = BlindedPath::new_for_message(&[nodes[1].get_node_pk(), nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx).unwrap();
@@ -218,6 +251,7 @@ fn three_blinded_hops() {
        };
 
        nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap();
+       nodes[3].custom_message_handler.expect_message(TestCustomMessage::Response);
        pass_along_path(&nodes);
 }
 
@@ -225,7 +259,7 @@ fn three_blinded_hops() {
 fn too_big_packet_error() {
        // Make sure we error as expected if a packet is too big to send.
        let nodes = create_nodes(2);
-       let test_msg = OnionMessageContents::Custom(TestCustomMessage {});
+       let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
 
        let hop_node_id = nodes[1].get_node_pk();
        let hops = vec![hop_node_id; 400];
@@ -242,7 +276,7 @@ fn we_are_intro_node() {
        // If we are sending straight to a blinded path and we are the introduction node, we need to
        // advance the blinded path by 1 hop so the second hop is the new introduction node.
        let mut nodes = create_nodes(3);
-       let test_msg = TestCustomMessage {};
+       let test_msg = TestCustomMessage::Response;
 
        let secp_ctx = Secp256k1::new();
        let blinded_path = BlindedPath::new_for_message(&[nodes[0].get_node_pk(), nodes[1].get_node_pk(), nodes[2].get_node_pk()], &*nodes[2].keys_manager, &secp_ctx).unwrap();
@@ -252,6 +286,7 @@ fn we_are_intro_node() {
        };
 
        nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg.clone()), None).unwrap();
+       nodes[2].custom_message_handler.expect_message(TestCustomMessage::Response);
        pass_along_path(&nodes);
 
        // Try with a two-hop blinded path where we are the introduction node.
@@ -261,6 +296,7 @@ fn we_are_intro_node() {
                destination: Destination::BlindedPath(blinded_path),
        };
        nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg), None).unwrap();
+       nodes[1].custom_message_handler.expect_message(TestCustomMessage::Response);
        nodes.remove(2);
        pass_along_path(&nodes);
 }
@@ -269,7 +305,7 @@ fn we_are_intro_node() {
 fn invalid_blinded_path_error() {
        // Make sure we error as expected if a provided blinded path has 0 or 1 hops.
        let nodes = create_nodes(3);
-       let test_msg = TestCustomMessage {};
+       let test_msg = TestCustomMessage::Response;
 
        // 0 hops
        let secp_ctx = Secp256k1::new();
@@ -296,8 +332,8 @@ fn invalid_blinded_path_error() {
 
 #[test]
 fn reply_path() {
-       let nodes = create_nodes(4);
-       let test_msg = TestCustomMessage {};
+       let mut nodes = create_nodes(4);
+       let test_msg = TestCustomMessage::Request;
        let secp_ctx = Secp256k1::new();
 
        // Destination::Node
@@ -307,11 +343,12 @@ fn reply_path() {
        };
        let reply_path = BlindedPath::new_for_message(&[nodes[2].get_node_pk(), nodes[1].get_node_pk(), nodes[0].get_node_pk()], &*nodes[0].keys_manager, &secp_ctx).unwrap();
        nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg.clone()), Some(reply_path)).unwrap();
+       nodes[3].custom_message_handler.expect_message(TestCustomMessage::Request);
        pass_along_path(&nodes);
        // Make sure the last node successfully decoded the reply path.
-       nodes[3].logger.assert_log_contains(
-               "lightning::onion_message::messenger",
-               &format!("Received an onion message with path_id None and a reply_path"), 1);
+       nodes[0].custom_message_handler.expect_message(TestCustomMessage::Response);
+       nodes.reverse();
+       pass_along_path(&nodes);
 
        // Destination::BlindedPath
        let blinded_path = BlindedPath::new_for_message(&[nodes[1].get_node_pk(), nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx).unwrap();
@@ -322,10 +359,13 @@ fn reply_path() {
        let reply_path = BlindedPath::new_for_message(&[nodes[2].get_node_pk(), nodes[1].get_node_pk(), nodes[0].get_node_pk()], &*nodes[0].keys_manager, &secp_ctx).unwrap();
 
        nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg), Some(reply_path)).unwrap();
+       nodes[3].custom_message_handler.expect_message(TestCustomMessage::Request);
+       pass_along_path(&nodes);
+
+       // Make sure the last node successfully decoded the reply path.
+       nodes[0].custom_message_handler.expect_message(TestCustomMessage::Response);
+       nodes.reverse();
        pass_along_path(&nodes);
-       nodes[3].logger.assert_log_contains(
-               "lightning::onion_message::messenger",
-               &format!("Received an onion message with path_id None and a reply_path"), 2);
 }
 
 #[test]
@@ -356,7 +396,7 @@ fn invalid_custom_message_type() {
 #[test]
 fn peer_buffer_full() {
        let nodes = create_nodes(2);
-       let test_msg = TestCustomMessage {};
+       let test_msg = TestCustomMessage::Request;
        let path = OnionMessagePath {
                intermediate_nodes: vec![],
                destination: Destination::Node(nodes[1].get_node_pk()),
@@ -374,7 +414,7 @@ fn many_hops() {
        // of size [`crate::onion_message::packet::BIG_PACKET_HOP_DATA_LEN`].
        let num_nodes: usize = 25;
        let nodes = create_nodes(num_nodes as u8);
-       let test_msg = OnionMessageContents::Custom(TestCustomMessage {});
+       let test_msg = TestCustomMessage::Response;
 
        let mut intermediate_nodes = vec![];
        for i in 1..(num_nodes-1) {
@@ -385,6 +425,7 @@ fn many_hops() {
                intermediate_nodes,
                destination: Destination::Node(nodes[num_nodes-1].get_node_pk()),
        };
-       nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap();
+       nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg), None).unwrap();
+       nodes[num_nodes-1].custom_message_handler.expect_message(TestCustomMessage::Response);
        pass_along_path(&nodes);
 }