Refuse to send and forward OMs to disconnected peers
[rust-lightning] / lightning / src / onion_message / messenger.rs
index 3a14c78d0fba724a6517776352d37fdefc6244bf..b9da9246cb103bcf44eb82050bd5870f45521ae5 100644 (file)
@@ -122,6 +122,8 @@ pub enum SendError {
        /// The provided [`Destination`] was an invalid [`BlindedRoute`], due to having fewer than two
        /// blinded hops.
        TooFewBlindedHops,
+       /// Our next-hop peer was offline or does not support onion message forwarding.
+       InvalidFirstHop,
 }
 
 impl<Signer: Sign, K: Deref, L: Deref> OnionMessenger<Signer, K, L>
@@ -165,25 +167,28 @@ impl<Signer: Sign, K: Deref, L: Deref> OnionMessenger<Signer, K, L>
                        .map_err(|e| SendError::Secp256k1(e))?;
 
                let prng_seed = self.keys_manager.get_secure_random_bytes();
-               let onion_packet = construct_onion_message_packet(
+               let onion_routing_packet = construct_onion_message_packet(
                        packet_payloads, packet_keys, prng_seed).map_err(|()| SendError::TooBigPacket)?;
 
                let mut pending_per_peer_msgs = self.pending_messages.lock().unwrap();
-               let pending_msgs = pending_per_peer_msgs.entry(introduction_node_id).or_insert_with(VecDeque::new);
-               pending_msgs.push_back(
-                       msgs::OnionMessage {
-                               blinding_point,
-                               onion_routing_packet: onion_packet,
+               match pending_per_peer_msgs.entry(introduction_node_id) {
+                       hash_map::Entry::Vacant(_) => Err(SendError::InvalidFirstHop),
+                       hash_map::Entry::Occupied(mut e) => {
+                               e.get_mut().push_back(msgs::OnionMessage { blinding_point, onion_routing_packet });
+                               Ok(())
                        }
-               );
-               Ok(())
+               }
        }
 
        #[cfg(test)]
        pub(super) fn release_pending_msgs(&self) -> HashMap<PublicKey, VecDeque<msgs::OnionMessage>> {
                let mut pending_msgs = self.pending_messages.lock().unwrap();
                let mut msgs = HashMap::new();
-               core::mem::swap(&mut *pending_msgs, &mut msgs);
+               // We don't want to disconnect the peers by removing them entirely from the original map, so we
+               // swap the pending message buffers individually.
+               for (peer_node_id, pending_messages) in &mut *pending_msgs {
+                       msgs.insert(*peer_node_id, core::mem::take(pending_messages));
+               }
                msgs
        }
 }
@@ -252,32 +257,43 @@ impl<Signer: Sign, K: Deref, L: Deref> OnionMessageHandler for OnionMessenger<Si
                                };
 
                                let mut pending_per_peer_msgs = self.pending_messages.lock().unwrap();
-                               let pending_msgs = pending_per_peer_msgs.entry(next_node_id).or_insert_with(VecDeque::new);
-                               pending_msgs.push_back(
-                                       msgs::OnionMessage {
-                                               blinding_point: match next_blinding_override {
-                                                       Some(blinding_point) => blinding_point,
-                                                       None => {
-                                                               let blinding_factor = {
-                                                                       let mut sha = Sha256::engine();
-                                                                       sha.input(&msg.blinding_point.serialize()[..]);
-                                                                       sha.input(control_tlvs_ss.as_ref());
-                                                                       Sha256::from_engine(sha).into_inner()
-                                                               };
-                                                               let next_blinding_point = msg.blinding_point;
-                                                               match next_blinding_point.mul_tweak(&self.secp_ctx, &Scalar::from_be_bytes(blinding_factor).unwrap()) {
-                                                                       Ok(bp) => bp,
-                                                                       Err(e) => {
-                                                                               log_trace!(self.logger, "Failed to compute next blinding point: {}", e);
-                                                                               return
-                                                                       }
-                                                               }
-                                                       },
-                                               },
-                                               onion_routing_packet: outgoing_packet,
+
+                               #[cfg(fuzzing)]
+                               pending_per_peer_msgs.entry(next_node_id).or_insert_with(VecDeque::new);
+
+                               match pending_per_peer_msgs.entry(next_node_id) {
+                                       hash_map::Entry::Vacant(_) => {
+                                               log_trace!(self.logger, "Dropping forwarded onion message to disconnected peer {:?}", next_node_id);
+                                               return
                                        },
-                               );
-                               log_trace!(self.logger, "Forwarding an onion message to peer {}", next_node_id);
+                                       hash_map::Entry::Occupied(mut e) => {
+                                               e.get_mut().push_back(
+                                                       msgs::OnionMessage {
+                                                               blinding_point: match next_blinding_override {
+                                                                       Some(blinding_point) => blinding_point,
+                                                                       None => {
+                                                                               let blinding_factor = {
+                                                                                       let mut sha = Sha256::engine();
+                                                                                       sha.input(&msg.blinding_point.serialize()[..]);
+                                                                                       sha.input(control_tlvs_ss.as_ref());
+                                                                                       Sha256::from_engine(sha).into_inner()
+                                                                               };
+                                                                               let next_blinding_point = msg.blinding_point;
+                                                                               match next_blinding_point.mul_tweak(&self.secp_ctx, &Scalar::from_be_bytes(blinding_factor).unwrap()) {
+                                                                                       Ok(bp) => bp,
+                                                                                       Err(e) => {
+                                                                                               log_trace!(self.logger, "Failed to compute next blinding point: {}", e);
+                                                                                               return
+                                                                                       }
+                                                                               }
+                                                                       },
+                                                               },
+                                                               onion_routing_packet: outgoing_packet,
+                                                       },
+                                               );
+                                               log_trace!(self.logger, "Forwarding an onion message to peer {}", next_node_id);
+                                       }
+                               };
                        },
                        Err(e) => {
                                log_trace!(self.logger, "Errored decoding onion message packet: {:?}", e);
@@ -287,6 +303,18 @@ impl<Signer: Sign, K: Deref, L: Deref> OnionMessageHandler for OnionMessenger<Si
                        },
                };
        }
+
+       fn peer_connected(&self, their_node_id: &PublicKey, init: &msgs::Init) {
+               if init.features.supports_onion_messages() {
+                       let mut peers = self.pending_messages.lock().unwrap();
+                       peers.insert(their_node_id.clone(), VecDeque::new());
+               }
+       }
+
+       fn peer_disconnected(&self, their_node_id: &PublicKey, _no_connection_possible: bool) {
+               let mut pending_msgs = self.pending_messages.lock().unwrap();
+               pending_msgs.remove(their_node_id);
+       }
 }
 
 impl<Signer: Sign, K: Deref, L: Deref> OnionMessageProvider for OnionMessenger<Signer, K, L>