Don't construct OnionMessage while holding peer lock
[rust-lightning] / lightning / src / onion_message / messenger.rs
index 2684ab8b3e8aaac5462d4a033db67e3b18e6636c..a3320f10d21091a0faa3131f3125384cc51b871f 100644 (file)
@@ -24,7 +24,6 @@ use super::utils;
 use util::events::OnionMessageProvider;
 use util::logger::Logger;
 
-use core::mem;
 use core::ops::Deref;
 use sync::{Arc, Mutex};
 use prelude::*;
@@ -35,9 +34,7 @@ use prelude::*;
 ///
 /// # Example
 ///
-//  Needs to be `ignore` until the `onion_message` module is made public, otherwise this is a test
-//  failure.
-/// ```ignore
+/// ```
 /// # extern crate bitcoin;
 /// # use bitcoin::hashes::_export::_core::time::Duration;
 /// # use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
@@ -66,7 +63,8 @@ use prelude::*;
 ///
 /// // Send an empty onion message to a node id.
 /// let intermediate_hops = [hop_node_id1, hop_node_id2];
-/// onion_messenger.send_onion_message(&intermediate_hops, Destination::Node(destination_node_id));
+/// let reply_path = None;
+/// onion_messenger.send_onion_message(&intermediate_hops, Destination::Node(destination_node_id), reply_path);
 ///
 /// // Create a blinded route to yourself, for someone to send an onion message to.
 /// # let your_node_id = hop_node_id1;
@@ -75,7 +73,8 @@ use prelude::*;
 ///
 /// // Send an empty onion message to a blinded route.
 /// # let intermediate_hops = [hop_node_id1, hop_node_id2];
-/// onion_messenger.send_onion_message(&intermediate_hops, Destination::BlindedRoute(blinded_route));
+/// let reply_path = None;
+/// onion_messenger.send_onion_message(&intermediate_hops, Destination::BlindedRoute(blinded_route), reply_path);
 /// ```
 ///
 /// [offers]: <https://github.com/lightning/bolts/pull/798>
@@ -86,7 +85,7 @@ pub struct OnionMessenger<Signer: Sign, K: Deref, L: Deref>
 {
        keys_manager: K,
        logger: L,
-       pending_messages: Mutex<HashMap<PublicKey, Vec<msgs::OnionMessage>>>,
+       pending_messages: Mutex<HashMap<PublicKey, VecDeque<msgs::OnionMessage>>>,
        secp_ctx: Secp256k1<secp256k1::All>,
        // Coming soon:
        // invoice_handler: InvoiceHandler,
@@ -123,6 +122,8 @@ pub enum SendError {
        /// The provided [`Destination`] was an invalid [`BlindedRoute`], due to having fewer than two
        /// blinded hops.
        TooFewBlindedHops,
+       /// Our next-hop peer was offline or does not support onion message forwarding.
+       InvalidFirstHop,
 }
 
 impl<Signer: Sign, K: Deref, L: Deref> OnionMessenger<Signer, K, L>
@@ -166,25 +167,28 @@ impl<Signer: Sign, K: Deref, L: Deref> OnionMessenger<Signer, K, L>
                        .map_err(|e| SendError::Secp256k1(e))?;
 
                let prng_seed = self.keys_manager.get_secure_random_bytes();
-               let onion_packet = construct_onion_message_packet(
+               let onion_routing_packet = construct_onion_message_packet(
                        packet_payloads, packet_keys, prng_seed).map_err(|()| SendError::TooBigPacket)?;
 
                let mut pending_per_peer_msgs = self.pending_messages.lock().unwrap();
-               let pending_msgs = pending_per_peer_msgs.entry(introduction_node_id).or_insert(Vec::new());
-               pending_msgs.push(
-                       msgs::OnionMessage {
-                               blinding_point,
-                               onion_routing_packet: onion_packet,
+               match pending_per_peer_msgs.entry(introduction_node_id) {
+                       hash_map::Entry::Vacant(_) => Err(SendError::InvalidFirstHop),
+                       hash_map::Entry::Occupied(mut e) => {
+                               e.get_mut().push_back(msgs::OnionMessage { blinding_point, onion_routing_packet });
+                               Ok(())
                        }
-               );
-               Ok(())
+               }
        }
 
        #[cfg(test)]
-       pub(super) fn release_pending_msgs(&self) -> HashMap<PublicKey, Vec<msgs::OnionMessage>> {
+       pub(super) fn release_pending_msgs(&self) -> HashMap<PublicKey, VecDeque<msgs::OnionMessage>> {
                let mut pending_msgs = self.pending_messages.lock().unwrap();
                let mut msgs = HashMap::new();
-               core::mem::swap(&mut *pending_msgs, &mut msgs);
+               // We don't want to disconnect the peers by removing them entirely from the original map, so we
+               // swap the pending message buffers individually.
+               for (peer_node_id, pending_messages) in &mut *pending_msgs {
+                       msgs.insert(*peer_node_id, core::mem::take(pending_messages));
+               }
                msgs
        }
 }
@@ -251,34 +255,44 @@ impl<Signer: Sign, K: Deref, L: Deref> OnionMessageHandler for OnionMessenger<Si
                                        hop_data: new_packet_bytes,
                                        hmac: next_hop_hmac,
                                };
-
-                               let mut pending_per_peer_msgs = self.pending_messages.lock().unwrap();
-                               let pending_msgs = pending_per_peer_msgs.entry(next_node_id).or_insert(Vec::new());
-                               pending_msgs.push(
-                                       msgs::OnionMessage {
-                                               blinding_point: match next_blinding_override {
-                                                       Some(blinding_point) => blinding_point,
-                                                       None => {
-                                                               let blinding_factor = {
-                                                                       let mut sha = Sha256::engine();
-                                                                       sha.input(&msg.blinding_point.serialize()[..]);
-                                                                       sha.input(control_tlvs_ss.as_ref());
-                                                                       Sha256::from_engine(sha).into_inner()
-                                                               };
-                                                               let next_blinding_point = msg.blinding_point;
-                                                               match next_blinding_point.mul_tweak(&self.secp_ctx, &Scalar::from_be_bytes(blinding_factor).unwrap()) {
-                                                                       Ok(bp) => bp,
-                                                                       Err(e) => {
-                                                                               log_trace!(self.logger, "Failed to compute next blinding point: {}", e);
-                                                                               return
-                                                                       }
+                               let onion_message = msgs::OnionMessage {
+                                       blinding_point: match next_blinding_override {
+                                               Some(blinding_point) => blinding_point,
+                                               None => {
+                                                       let blinding_factor = {
+                                                               let mut sha = Sha256::engine();
+                                                               sha.input(&msg.blinding_point.serialize()[..]);
+                                                               sha.input(control_tlvs_ss.as_ref());
+                                                               Sha256::from_engine(sha).into_inner()
+                                                       };
+                                                       let next_blinding_point = msg.blinding_point;
+                                                       match next_blinding_point.mul_tweak(&self.secp_ctx, &Scalar::from_be_bytes(blinding_factor).unwrap()) {
+                                                               Ok(bp) => bp,
+                                                               Err(e) => {
+                                                                       log_trace!(self.logger, "Failed to compute next blinding point: {}", e);
+                                                                       return
                                                                }
-                                                       },
+                                                       }
                                                },
-                                               onion_routing_packet: outgoing_packet,
                                        },
-                               );
-                               log_trace!(self.logger, "Forwarding an onion message to peer {}", next_node_id);
+                                       onion_routing_packet: outgoing_packet,
+                               };
+
+                               let mut pending_per_peer_msgs = self.pending_messages.lock().unwrap();
+
+                               #[cfg(fuzzing)]
+                               pending_per_peer_msgs.entry(next_node_id).or_insert_with(VecDeque::new);
+
+                               match pending_per_peer_msgs.entry(next_node_id) {
+                                       hash_map::Entry::Vacant(_) => {
+                                               log_trace!(self.logger, "Dropping forwarded onion message to disconnected peer {:?}", next_node_id);
+                                               return
+                                       },
+                                       hash_map::Entry::Occupied(mut e) => {
+                                               e.get_mut().push_back(onion_message);
+                                               log_trace!(self.logger, "Forwarding an onion message to peer {}", next_node_id);
+                                       }
+                               };
                        },
                        Err(e) => {
                                log_trace!(self.logger, "Errored decoding onion message packet: {:?}", e);
@@ -288,6 +302,18 @@ impl<Signer: Sign, K: Deref, L: Deref> OnionMessageHandler for OnionMessenger<Si
                        },
                };
        }
+
+       fn peer_connected(&self, their_node_id: &PublicKey, init: &msgs::Init) {
+               if init.features.supports_onion_messages() {
+                       let mut peers = self.pending_messages.lock().unwrap();
+                       peers.insert(their_node_id.clone(), VecDeque::new());
+               }
+       }
+
+       fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
+               let mut pending_msgs = self.pending_messages.lock().unwrap();
+               pending_msgs.remove(their_node_id);
+       }
 }
 
 impl<Signer: Sign, K: Deref, L: Deref> OnionMessageProvider for OnionMessenger<Signer, K, L>
@@ -295,6 +321,10 @@ impl<Signer: Sign, K: Deref, L: Deref> OnionMessageProvider for OnionMessenger<S
              L::Target: Logger,
 {
        fn next_onion_message_for_peer(&self, peer_node_id: PublicKey) -> Option<msgs::OnionMessage> {
+               let mut pending_msgs = self.pending_messages.lock().unwrap();
+               if let Some(msgs) = pending_msgs.get_mut(&peer_node_id) {
+                       return msgs.pop_front()
+               }
                None
        }
 }