Generalize onion message ForwardTlvs::next_node_id
[rust-lightning] / lightning / src / onion_message / packet.rs
index d9349fdadbfaba6c1f0d08c265c48bf14da49bb7..510f0ea025a0d615b0f54292d865602ac7e103a6 100644 (file)
@@ -13,7 +13,7 @@ use bitcoin::secp256k1::PublicKey;
 use bitcoin::secp256k1::ecdh::SharedSecret;
 
 use crate::blinded_path::BlindedPath;
-use crate::blinded_path::message::{ForwardTlvs, ReceiveTlvs};
+use crate::blinded_path::message::{ForwardTlvs, NextHop, ReceiveTlvs};
 use crate::blinded_path::utils::Padding;
 use crate::ln::msgs::DecodeError;
 use crate::ln::onion_utils;
@@ -284,20 +284,26 @@ impl Readable for ControlTlvs {
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
                _init_and_read_tlv_stream!(r, {
                        (1, _padding, option),
-                       (2, _short_channel_id, option),
+                       (2, short_channel_id, option),
                        (4, next_node_id, option),
                        (6, path_id, option),
                        (8, next_blinding_override, option),
                });
                let _padding: Option<Padding> = _padding;
-               let _short_channel_id: Option<u64> = _short_channel_id;
 
-               let valid_fwd_fmt  = next_node_id.is_some() && path_id.is_none();
-               let valid_recv_fmt = next_node_id.is_none() && next_blinding_override.is_none();
+               let next_hop = match (short_channel_id, next_node_id) {
+                       (Some(_), Some(_)) => return Err(DecodeError::InvalidValue),
+                       (Some(scid), None) => Some(NextHop::ShortChannelId(scid)),
+                       (None, Some(pubkey)) => Some(NextHop::NodeId(pubkey)),
+                       (None, None) => None,
+               };
+
+               let valid_fwd_fmt = next_hop.is_some() && path_id.is_none();
+               let valid_recv_fmt = next_hop.is_none() && next_blinding_override.is_none();
 
                let payload_fmt = if valid_fwd_fmt {
                        ControlTlvs::Forward(ForwardTlvs {
-                               next_node_id: next_node_id.unwrap(),
+                               next_hop: next_hop.unwrap(),
                                next_blinding_override,
                        })
                } else if valid_recv_fmt {