Split TestCustomMessage into Request and Response
[rust-lightning] / lightning / src / onion_message / functional_tests.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 //! Onion message testing and test utilities live here.
11
12 use crate::blinded_path::BlindedPath;
13 use crate::sign::{NodeSigner, Recipient};
14 use crate::ln::features::InitFeatures;
15 use crate::ln::msgs::{self, DecodeError, OnionMessageHandler};
16 use super::{CustomOnionMessageContents, CustomOnionMessageHandler, Destination, MessageRouter, OffersMessage, OffersMessageHandler, OnionMessageContents, OnionMessagePath, OnionMessenger, SendError};
17 use crate::util::ser::{Writeable, Writer};
18 use crate::util::test_utils;
19
20 use bitcoin::network::constants::Network;
21 use bitcoin::secp256k1::{PublicKey, Secp256k1};
22
23 use core::sync::atomic::{AtomicU16, Ordering};
24 use crate::io;
25 use crate::io_extras::read_to_end;
26 use crate::sync::Arc;
27
28 struct MessengerNode {
29         keys_manager: Arc<test_utils::TestKeysInterface>,
30         messenger: OnionMessenger<
31                 Arc<test_utils::TestKeysInterface>,
32                 Arc<test_utils::TestKeysInterface>,
33                 Arc<test_utils::TestLogger>,
34                 Arc<TestMessageRouter>,
35                 Arc<TestOffersMessageHandler>,
36                 Arc<TestCustomMessageHandler>
37         >,
38         custom_message_handler: Arc<TestCustomMessageHandler>,
39         logger: Arc<test_utils::TestLogger>,
40 }
41
42 impl MessengerNode {
43         fn get_node_pk(&self) -> PublicKey {
44                 self.keys_manager.get_node_id(Recipient::Node).unwrap()
45         }
46 }
47
48 struct TestMessageRouter {}
49
50 impl MessageRouter for TestMessageRouter {
51         fn find_path(
52                 &self, _sender: PublicKey, _peers: Vec<PublicKey>, _destination: Destination
53         ) -> Result<OnionMessagePath, ()> {
54                 todo!()
55         }
56 }
57
58 struct TestOffersMessageHandler {}
59
60 impl OffersMessageHandler for TestOffersMessageHandler {
61         fn handle_message(&self, _message: OffersMessage) -> Option<OffersMessage> {
62                 None
63         }
64 }
65
66 #[derive(Clone)]
67 enum TestCustomMessage {
68         Request,
69         Response,
70 }
71
72 const CUSTOM_REQUEST_MESSAGE_TYPE: u64 = 4242;
73 const CUSTOM_RESPONSE_MESSAGE_TYPE: u64 = 4343;
74 const CUSTOM_REQUEST_MESSAGE_CONTENTS: [u8; 32] = [42; 32];
75 const CUSTOM_RESPONSE_MESSAGE_CONTENTS: [u8; 32] = [43; 32];
76
77 impl CustomOnionMessageContents for TestCustomMessage {
78         fn tlv_type(&self) -> u64 {
79                 match self {
80                         TestCustomMessage::Request => CUSTOM_REQUEST_MESSAGE_TYPE,
81                         TestCustomMessage::Response => CUSTOM_RESPONSE_MESSAGE_TYPE,
82                 }
83         }
84 }
85
86 impl Writeable for TestCustomMessage {
87         fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
88                 match self {
89                         TestCustomMessage::Request => Ok(CUSTOM_REQUEST_MESSAGE_CONTENTS.write(w)?),
90                         TestCustomMessage::Response => Ok(CUSTOM_RESPONSE_MESSAGE_CONTENTS.write(w)?),
91                 }
92         }
93 }
94
95 struct TestCustomMessageHandler {
96         num_messages_expected: AtomicU16,
97 }
98
99 impl TestCustomMessageHandler {
100         fn new() -> Self {
101                 Self { num_messages_expected: AtomicU16::new(0) }
102         }
103 }
104
105 impl Drop for TestCustomMessageHandler {
106         fn drop(&mut self) {
107                 #[cfg(feature = "std")] {
108                         if std::thread::panicking() {
109                                 return;
110                         }
111                 }
112                 assert_eq!(self.num_messages_expected.load(Ordering::SeqCst), 0);
113         }
114 }
115
116 impl CustomOnionMessageHandler for TestCustomMessageHandler {
117         type CustomMessage = TestCustomMessage;
118         fn handle_custom_message(&self, msg: Self::CustomMessage) -> Option<Self::CustomMessage> {
119                 self.num_messages_expected.fetch_sub(1, Ordering::SeqCst);
120                 match msg {
121                         TestCustomMessage::Request => Some(TestCustomMessage::Response),
122                         TestCustomMessage::Response => None,
123                 }
124         }
125         fn read_custom_message<R: io::Read>(&self, message_type: u64, buffer: &mut R) -> Result<Option<Self::CustomMessage>, DecodeError> where Self: Sized {
126                 match message_type {
127                         CUSTOM_REQUEST_MESSAGE_TYPE => {
128                                 let buf = read_to_end(buffer)?;
129                                 assert_eq!(buf, CUSTOM_REQUEST_MESSAGE_CONTENTS);
130                                 Ok(Some(TestCustomMessage::Request))
131                         },
132                         CUSTOM_RESPONSE_MESSAGE_TYPE => {
133                                 let buf = read_to_end(buffer)?;
134                                 assert_eq!(buf, CUSTOM_RESPONSE_MESSAGE_CONTENTS);
135                                 Ok(Some(TestCustomMessage::Response))
136                         },
137                         _ => Ok(None),
138                 }
139         }
140 }
141
142 fn create_nodes(num_messengers: u8) -> Vec<MessengerNode> {
143         let mut nodes = Vec::new();
144         for i in 0..num_messengers {
145                 let logger = Arc::new(test_utils::TestLogger::with_id(format!("node {}", i)));
146                 let seed = [i as u8; 32];
147                 let keys_manager = Arc::new(test_utils::TestKeysInterface::new(&seed, Network::Testnet));
148                 let message_router = Arc::new(TestMessageRouter {});
149                 let offers_message_handler = Arc::new(TestOffersMessageHandler {});
150                 let custom_message_handler = Arc::new(TestCustomMessageHandler::new());
151                 nodes.push(MessengerNode {
152                         keys_manager: keys_manager.clone(),
153                         messenger: OnionMessenger::new(
154                                 keys_manager.clone(), keys_manager, logger.clone(), message_router,
155                                 offers_message_handler, custom_message_handler.clone()
156                         ),
157                         custom_message_handler,
158                         logger,
159                 });
160         }
161         for idx in 0..num_messengers - 1 {
162                 let i = idx as usize;
163                 let mut features = InitFeatures::empty();
164                 features.set_onion_messages_optional();
165                 let init_msg = msgs::Init { features, networks: None, remote_network_address: None };
166                 nodes[i].messenger.peer_connected(&nodes[i + 1].get_node_pk(), &init_msg.clone(), true).unwrap();
167                 nodes[i + 1].messenger.peer_connected(&nodes[i].get_node_pk(), &init_msg.clone(), false).unwrap();
168         }
169         nodes
170 }
171
172 fn pass_along_path(path: &Vec<MessengerNode>) {
173         path[path.len() - 1].custom_message_handler.num_messages_expected.fetch_add(1, Ordering::SeqCst);
174         let mut prev_node = &path[0];
175         for node in path.into_iter().skip(1) {
176                 let events = prev_node.messenger.release_pending_msgs();
177                 let onion_msg =  {
178                         let msgs = events.get(&node.get_node_pk()).unwrap();
179                         assert_eq!(msgs.len(), 1);
180                         msgs[0].clone()
181                 };
182                 node.messenger.handle_onion_message(&prev_node.get_node_pk(), &onion_msg);
183                 prev_node = node;
184         }
185 }
186
187 #[test]
188 fn one_hop() {
189         let nodes = create_nodes(2);
190         let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
191
192         let path = OnionMessagePath {
193                 intermediate_nodes: vec![],
194                 destination: Destination::Node(nodes[1].get_node_pk()),
195         };
196         nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap();
197         pass_along_path(&nodes);
198 }
199
200 #[test]
201 fn two_unblinded_hops() {
202         let nodes = create_nodes(3);
203         let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
204
205         let path = OnionMessagePath {
206                 intermediate_nodes: vec![nodes[1].get_node_pk()],
207                 destination: Destination::Node(nodes[2].get_node_pk()),
208         };
209         nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap();
210         pass_along_path(&nodes);
211 }
212
213 #[test]
214 fn two_unblinded_two_blinded() {
215         let nodes = create_nodes(5);
216         let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
217
218         let secp_ctx = Secp256k1::new();
219         let blinded_path = BlindedPath::new_for_message(&[nodes[3].get_node_pk(), nodes[4].get_node_pk()], &*nodes[4].keys_manager, &secp_ctx).unwrap();
220         let path = OnionMessagePath {
221                 intermediate_nodes: vec![nodes[1].get_node_pk(), nodes[2].get_node_pk()],
222                 destination: Destination::BlindedPath(blinded_path),
223         };
224
225         nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap();
226         pass_along_path(&nodes);
227 }
228
229 #[test]
230 fn three_blinded_hops() {
231         let nodes = create_nodes(4);
232         let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
233
234         let secp_ctx = Secp256k1::new();
235         let blinded_path = BlindedPath::new_for_message(&[nodes[1].get_node_pk(), nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx).unwrap();
236         let path = OnionMessagePath {
237                 intermediate_nodes: vec![],
238                 destination: Destination::BlindedPath(blinded_path),
239         };
240
241         nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap();
242         pass_along_path(&nodes);
243 }
244
245 #[test]
246 fn too_big_packet_error() {
247         // Make sure we error as expected if a packet is too big to send.
248         let nodes = create_nodes(2);
249         let test_msg = OnionMessageContents::Custom(TestCustomMessage::Response);
250
251         let hop_node_id = nodes[1].get_node_pk();
252         let hops = vec![hop_node_id; 400];
253         let path = OnionMessagePath {
254                 intermediate_nodes: hops,
255                 destination: Destination::Node(hop_node_id),
256         };
257         let err = nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap_err();
258         assert_eq!(err, SendError::TooBigPacket);
259 }
260
261 #[test]
262 fn we_are_intro_node() {
263         // If we are sending straight to a blinded path and we are the introduction node, we need to
264         // advance the blinded path by 1 hop so the second hop is the new introduction node.
265         let mut nodes = create_nodes(3);
266         let test_msg = TestCustomMessage::Response;
267
268         let secp_ctx = Secp256k1::new();
269         let blinded_path = BlindedPath::new_for_message(&[nodes[0].get_node_pk(), nodes[1].get_node_pk(), nodes[2].get_node_pk()], &*nodes[2].keys_manager, &secp_ctx).unwrap();
270         let path = OnionMessagePath {
271                 intermediate_nodes: vec![],
272                 destination: Destination::BlindedPath(blinded_path),
273         };
274
275         nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg.clone()), None).unwrap();
276         pass_along_path(&nodes);
277
278         // Try with a two-hop blinded path where we are the introduction node.
279         let blinded_path = BlindedPath::new_for_message(&[nodes[0].get_node_pk(), nodes[1].get_node_pk()], &*nodes[1].keys_manager, &secp_ctx).unwrap();
280         let path = OnionMessagePath {
281                 intermediate_nodes: vec![],
282                 destination: Destination::BlindedPath(blinded_path),
283         };
284         nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg), None).unwrap();
285         nodes.remove(2);
286         pass_along_path(&nodes);
287 }
288
289 #[test]
290 fn invalid_blinded_path_error() {
291         // Make sure we error as expected if a provided blinded path has 0 or 1 hops.
292         let nodes = create_nodes(3);
293         let test_msg = TestCustomMessage::Response;
294
295         // 0 hops
296         let secp_ctx = Secp256k1::new();
297         let mut blinded_path = BlindedPath::new_for_message(&[nodes[1].get_node_pk(), nodes[2].get_node_pk()], &*nodes[2].keys_manager, &secp_ctx).unwrap();
298         blinded_path.blinded_hops.clear();
299         let path = OnionMessagePath {
300                 intermediate_nodes: vec![],
301                 destination: Destination::BlindedPath(blinded_path),
302         };
303         let err = nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg.clone()), None).unwrap_err();
304         assert_eq!(err, SendError::TooFewBlindedHops);
305
306         // 1 hop
307         let mut blinded_path = BlindedPath::new_for_message(&[nodes[1].get_node_pk(), nodes[2].get_node_pk()], &*nodes[2].keys_manager, &secp_ctx).unwrap();
308         blinded_path.blinded_hops.remove(0);
309         assert_eq!(blinded_path.blinded_hops.len(), 1);
310         let path = OnionMessagePath {
311                 intermediate_nodes: vec![],
312                 destination: Destination::BlindedPath(blinded_path),
313         };
314         let err = nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg), None).unwrap_err();
315         assert_eq!(err, SendError::TooFewBlindedHops);
316 }
317
318 #[test]
319 fn reply_path() {
320         let nodes = create_nodes(4);
321         let test_msg = TestCustomMessage::Response;
322         let secp_ctx = Secp256k1::new();
323
324         // Destination::Node
325         let path = OnionMessagePath {
326                 intermediate_nodes: vec![nodes[1].get_node_pk(), nodes[2].get_node_pk()],
327                 destination: Destination::Node(nodes[3].get_node_pk()),
328         };
329         let reply_path = BlindedPath::new_for_message(&[nodes[2].get_node_pk(), nodes[1].get_node_pk(), nodes[0].get_node_pk()], &*nodes[0].keys_manager, &secp_ctx).unwrap();
330         nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg.clone()), Some(reply_path)).unwrap();
331         pass_along_path(&nodes);
332         // Make sure the last node successfully decoded the reply path.
333         nodes[3].logger.assert_log_contains(
334                 "lightning::onion_message::messenger",
335                 &format!("Received an onion message with path_id None and a reply_path"), 1);
336
337         // Destination::BlindedPath
338         let blinded_path = BlindedPath::new_for_message(&[nodes[1].get_node_pk(), nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx).unwrap();
339         let path = OnionMessagePath {
340                 intermediate_nodes: vec![],
341                 destination: Destination::BlindedPath(blinded_path),
342         };
343         let reply_path = BlindedPath::new_for_message(&[nodes[2].get_node_pk(), nodes[1].get_node_pk(), nodes[0].get_node_pk()], &*nodes[0].keys_manager, &secp_ctx).unwrap();
344
345         nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg), Some(reply_path)).unwrap();
346         pass_along_path(&nodes);
347         nodes[3].logger.assert_log_contains(
348                 "lightning::onion_message::messenger",
349                 &format!("Received an onion message with path_id None and a reply_path"), 2);
350 }
351
352 #[test]
353 fn invalid_custom_message_type() {
354         let nodes = create_nodes(2);
355
356         struct InvalidCustomMessage{}
357         impl CustomOnionMessageContents for InvalidCustomMessage {
358                 fn tlv_type(&self) -> u64 {
359                         // Onion message contents must have a TLV >= 64.
360                         63
361                 }
362         }
363
364         impl Writeable for InvalidCustomMessage {
365                 fn write<W: Writer>(&self, _w: &mut W) -> Result<(), io::Error> { unreachable!() }
366         }
367
368         let test_msg = OnionMessageContents::Custom(InvalidCustomMessage {});
369         let path = OnionMessagePath {
370                 intermediate_nodes: vec![],
371                 destination: Destination::Node(nodes[1].get_node_pk()),
372         };
373         let err = nodes[0].messenger.send_onion_message(path, test_msg, None).unwrap_err();
374         assert_eq!(err, SendError::InvalidMessage);
375 }
376
377 #[test]
378 fn peer_buffer_full() {
379         let nodes = create_nodes(2);
380         let test_msg = TestCustomMessage::Response;
381         let path = OnionMessagePath {
382                 intermediate_nodes: vec![],
383                 destination: Destination::Node(nodes[1].get_node_pk()),
384         };
385         for _ in 0..188 { // Based on MAX_PER_PEER_BUFFER_SIZE in OnionMessenger
386                 nodes[0].messenger.send_onion_message(path.clone(), OnionMessageContents::Custom(test_msg.clone()), None).unwrap();
387         }
388         let err = nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg), None).unwrap_err();
389         assert_eq!(err, SendError::BufferFull);
390 }
391
392 #[test]
393 fn many_hops() {
394         // Check we can send over a route with many hops. This will exercise our logic for onion messages
395         // of size [`crate::onion_message::packet::BIG_PACKET_HOP_DATA_LEN`].
396         let num_nodes: usize = 25;
397         let nodes = create_nodes(num_nodes as u8);
398         let test_msg = TestCustomMessage::Response;
399
400         let mut intermediate_nodes = vec![];
401         for i in 1..(num_nodes-1) {
402                 intermediate_nodes.push(nodes[i].get_node_pk());
403         }
404
405         let path = OnionMessagePath {
406                 intermediate_nodes,
407                 destination: Destination::Node(nodes[num_nodes-1].get_node_pk()),
408         };
409         nodes[0].messenger.send_onion_message(path, OnionMessageContents::Custom(test_msg), None).unwrap();
410         pass_along_path(&nodes);
411 }