Return correct SendSuccess in OnionMessenger
[rust-lightning] / lightning / src / onion_message / messenger.rs
index 52acc20d477614fc536ca9db770b18d96715ea76..21a1b302da09b76fa3f73cbc70953fa46919c3c6 100644 (file)
@@ -87,7 +87,7 @@ pub(super) const MAX_TIMER_TICKS: usize = 2;
 /// #         Ok(OnionMessagePath {
 /// #             intermediate_nodes: vec![hop_node_id1, hop_node_id2],
 /// #             destination,
-/// #             addresses: None,
+/// #             first_node_addresses: None,
 /// #         })
 /// #     }
 /// # }
@@ -223,6 +223,13 @@ impl OnionMessageRecipient {
                        *self = OnionMessageRecipient::ConnectedPeer(new_pending_messages);
                }
        }
+
+       fn is_connected(&self) -> bool {
+               match self {
+                       OnionMessageRecipient::ConnectedPeer(..) => true,
+                       OnionMessageRecipient::PendingConnection(..) => false,
+               }
+       }
 }
 
 /// An [`OnionMessage`] for [`OnionMessenger`] to send.
@@ -292,7 +299,9 @@ where
        ) -> Result<OnionMessagePath, ()> {
                let first_node = destination.first_node();
                if peers.contains(&first_node) {
-                       Ok(OnionMessagePath { intermediate_nodes: vec![], destination, addresses: None })
+                       Ok(OnionMessagePath {
+                               intermediate_nodes: vec![], destination, first_node_addresses: None
+                       })
                } else {
                        let network_graph = self.network_graph.deref().read_only();
                        let node_announcement = network_graph
@@ -303,8 +312,10 @@ where
 
                        match node_announcement {
                                Some(node_announcement) if node_announcement.features.supports_onion_messages() => {
-                                       let addresses = Some(node_announcement.addresses.clone());
-                                       Ok(OnionMessagePath { intermediate_nodes: vec![], destination, addresses })
+                                       let first_node_addresses = Some(node_announcement.addresses.clone());
+                                       Ok(OnionMessagePath {
+                                               intermediate_nodes: vec![], destination, first_node_addresses
+                                       })
                                },
                                _ => Err(()),
                        }
@@ -325,7 +336,7 @@ pub struct OnionMessagePath {
        ///
        /// Only needs to be set if a connection to the node is required. [`OnionMessenger`] may use
        /// this to initiate such a connection.
-       pub addresses: Option<Vec<SocketAddress>>,
+       pub first_node_addresses: Option<Vec<SocketAddress>>,
 }
 
 impl OnionMessagePath {
@@ -469,7 +480,7 @@ where
        ES::Target: EntropySource,
        NS::Target: NodeSigner,
 {
-       let OnionMessagePath { intermediate_nodes, mut destination, addresses } = path;
+       let OnionMessagePath { intermediate_nodes, mut destination, first_node_addresses } = path;
        if let Destination::BlindedPath(BlindedPath { ref blinded_hops, .. }) = destination {
                if blinded_hops.is_empty() {
                        return Err(SendError::TooFewBlindedHops);
@@ -511,7 +522,7 @@ where
                packet_payloads, packet_keys, prng_seed).map_err(|()| SendError::TooBigPacket)?;
 
        let message = OnionMessage { blinding_point, onion_routing_packet };
-       Ok((first_node_id, message, addresses))
+       Ok((first_node_id, message, first_node_addresses))
 }
 
 /// Decode one layer of an incoming [`OnionMessage`].
@@ -725,7 +736,11 @@ where
                        },
                        hash_map::Entry::Occupied(mut e) => {
                                e.get_mut().enqueue_message(onion_message);
-                               Ok(SendSuccess::Buffered)
+                               if e.get().is_connected() {
+                                       Ok(SendSuccess::Buffered)
+                               } else {
+                                       Ok(SendSuccess::BufferedAwaitingConnection(first_node_id))
+                               }
                        },
                }
        }