Receive payment onions as new InboundPayload instead of OnionHopData
[rust-lightning] / lightning / src / ln / msgs.rs
index cfee16033be236232392caef89d09b72967affdc..590b26632fbe9216b632d33844a1e5121b5cef64 100644 (file)
@@ -1419,11 +1419,27 @@ mod fuzzy_internal_msgs {
        // These types aren't intended to be pub, but are exposed for direct fuzzing (as we deserialize
        // them from untrusted input):
        #[derive(Clone)]
-       pub(crate) struct FinalOnionHopData {
-               pub(crate) payment_secret: PaymentSecret,
+       pub struct FinalOnionHopData {
+               pub payment_secret: PaymentSecret,
                /// The total value, in msat, of the payment as received by the ultimate recipient.
                /// Message serialization may panic if this value is more than 21 million Bitcoin.
-               pub(crate) total_msat: u64,
+               pub total_msat: u64,
+       }
+
+       pub enum InboundOnionPayload {
+               Forward {
+                       short_channel_id: u64,
+                       /// The value, in msat, of the payment after this hop's fee is deducted.
+                       amt_to_forward: u64,
+                       outgoing_cltv_value: u32,
+               },
+               Receive {
+                       payment_data: Option<FinalOnionHopData>,
+                       payment_metadata: Option<Vec<u8>>,
+                       keysend_preimage: Option<PaymentPreimage>,
+                       amt_msat: u64,
+                       outgoing_cltv_value: u32,
+               },
        }
 
        pub(crate) enum OnionHopDataFormat {
@@ -1974,7 +1990,7 @@ impl Writeable for OnionHopData {
        }
 }
 
-impl Readable for OnionHopData {
+impl Readable for InboundOnionPayload {
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
                let mut amt = HighZeroBytesDroppedBigSize(0u64);
                let mut cltv_value = HighZeroBytesDroppedBigSize(0u32);
@@ -1992,39 +2008,35 @@ impl Readable for OnionHopData {
                        (5482373484, keysend_preimage, option)
                });
 
-               let format = if let Some(short_channel_id) = short_id {
-                       if payment_data.is_some() { return Err(DecodeError::InvalidValue); }
+               if amt.0 > MAX_VALUE_MSAT { return Err(DecodeError::InvalidValue) }
+               if let Some(short_channel_id) = short_id {
+                       if payment_data.is_some() { return Err(DecodeError::InvalidValue) }
                        if payment_metadata.is_some() { return Err(DecodeError::InvalidValue); }
-                       OnionHopDataFormat::NonFinalNode {
+                       Ok(Self::Forward {
                                short_channel_id,
-                       }
+                               amt_to_forward: amt.0,
+                               outgoing_cltv_value: cltv_value.0,
+                       })
                } else {
                        if let Some(data) = &payment_data {
                                if data.total_msat > MAX_VALUE_MSAT {
                                        return Err(DecodeError::InvalidValue);
                                }
                        }
-                       OnionHopDataFormat::FinalNode {
+                       Ok(Self::Receive {
                                payment_data,
                                payment_metadata: payment_metadata.map(|w| w.0),
                                keysend_preimage,
-                       }
-               };
-
-               if amt.0 > MAX_VALUE_MSAT {
-                       return Err(DecodeError::InvalidValue);
+                               amt_msat: amt.0,
+                               outgoing_cltv_value: cltv_value.0,
+                       })
                }
-               Ok(OnionHopData {
-                       format,
-                       amt_to_forward: amt.0,
-                       outgoing_cltv_value: cltv_value.0,
-               })
        }
 }
 
 // ReadableArgs because we need onion_utils::decode_next_hop to accommodate payment packets and
 // onion message packets.
-impl ReadableArgs<()> for OnionHopData {
+impl ReadableArgs<()> for InboundOnionPayload {
        fn read<R: Read>(r: &mut R, _arg: ()) -> Result<Self, DecodeError> {
                <Self as Readable>::read(r)
        }
@@ -3525,7 +3537,7 @@ mod tests {
 
        #[test]
        fn encoding_nonfinal_onion_hop_data() {
-               let mut msg = msgs::OnionHopData {
+               let msg = msgs::OnionHopData {
                        format: OnionHopDataFormat::NonFinalNode {
                                short_channel_id: 0xdeadbeef1bad1dea,
                        },
@@ -3535,17 +3547,18 @@ mod tests {
                let encoded_value = msg.encode();
                let target_value = hex::decode("1a02080badf00d010203040404ffffffff0608deadbeef1bad1dea").unwrap();
                assert_eq!(encoded_value, target_value);
-               msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
-               if let OnionHopDataFormat::NonFinalNode { short_channel_id } = msg.format {
+
+               let inbound_msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
+               if let msgs::InboundOnionPayload::Forward { short_channel_id, amt_to_forward, outgoing_cltv_value } = inbound_msg {
                        assert_eq!(short_channel_id, 0xdeadbeef1bad1dea);
+                       assert_eq!(amt_to_forward, 0x0badf00d01020304);
+                       assert_eq!(outgoing_cltv_value, 0xffffffff);
                } else { panic!(); }
-               assert_eq!(msg.amt_to_forward, 0x0badf00d01020304);
-               assert_eq!(msg.outgoing_cltv_value, 0xffffffff);
        }
 
        #[test]
        fn encoding_final_onion_hop_data() {
-               let mut msg = msgs::OnionHopData {
+               let msg = msgs::OnionHopData {
                        format: OnionHopDataFormat::FinalNode {
                                payment_data: None,
                                payment_metadata: None,
@@ -3557,16 +3570,18 @@ mod tests {
                let encoded_value = msg.encode();
                let target_value = hex::decode("1002080badf00d010203040404ffffffff").unwrap();
                assert_eq!(encoded_value, target_value);
-               msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
-               if let OnionHopDataFormat::FinalNode { payment_data: None, .. } = msg.format { } else { panic!(); }
-               assert_eq!(msg.amt_to_forward, 0x0badf00d01020304);
-               assert_eq!(msg.outgoing_cltv_value, 0xffffffff);
+
+               let inbound_msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
+               if let msgs::InboundOnionPayload::Receive { payment_data: None, amt_msat, outgoing_cltv_value, .. } = inbound_msg {
+                       assert_eq!(amt_msat, 0x0badf00d01020304);
+                       assert_eq!(outgoing_cltv_value, 0xffffffff);
+               } else { panic!(); }
        }
 
        #[test]
        fn encoding_final_onion_hop_data_with_secret() {
                let expected_payment_secret = PaymentSecret([0x42u8; 32]);
-               let mut msg = msgs::OnionHopData {
+               let msg = msgs::OnionHopData {
                        format: OnionHopDataFormat::FinalNode {
                                payment_data: Some(FinalOnionHopData {
                                        payment_secret: expected_payment_secret,
@@ -3581,19 +3596,21 @@ mod tests {
                let encoded_value = msg.encode();
                let target_value = hex::decode("3602080badf00d010203040404ffffffff082442424242424242424242424242424242424242424242424242424242424242421badca1f").unwrap();
                assert_eq!(encoded_value, target_value);
-               msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
-               if let OnionHopDataFormat::FinalNode {
+
+               let inbound_msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
+               if let msgs::InboundOnionPayload::Receive {
                        payment_data: Some(FinalOnionHopData {
                                payment_secret,
                                total_msat: 0x1badca1f
                        }),
+                       amt_msat, outgoing_cltv_value,
                        payment_metadata: None,
                        keysend_preimage: None,
-               } = msg.format {
+               } = inbound_msg  {
                        assert_eq!(payment_secret, expected_payment_secret);
+                       assert_eq!(amt_msat, 0x0badf00d01020304);
+                       assert_eq!(outgoing_cltv_value, 0xffffffff);
                } else { panic!(); }
-               assert_eq!(msg.amt_to_forward, 0x0badf00d01020304);
-               assert_eq!(msg.outgoing_cltv_value, 0xffffffff);
        }
 
        #[test]
@@ -3743,7 +3760,7 @@ mod tests {
                // payload length to be encoded over multiple bytes rather than a single u8.
                let big_payload = encode_big_payload().unwrap();
                let mut rd = Cursor::new(&big_payload[..]);
-               <msgs::OnionHopData as Readable>::read(&mut rd).unwrap();
+               <msgs::InboundOnionPayload as Readable>::read(&mut rd).unwrap();
        }
        // see above test, needs to be a separate method for use of the serialization macros.
        fn encode_big_payload() -> Result<Vec<u8>, io::Error> {