Address custom HTLC TLV fixups
[rust-lightning] / lightning / src / ln / msgs.rs
index 0c23d4425f8663ebe6fce310692b2dc11bc89e08..515dcbba2d2e136846697dc990d582270e07e1d4 100644 (file)
@@ -43,7 +43,7 @@ use crate::io_extras::read_to_end;
 
 use crate::events::{MessageSendEventsProvider, OnionMessageProvider};
 use crate::util::logger;
-use crate::util::ser::{LengthReadable, Readable, ReadableArgs, Writeable, Writer, WithoutLength, FixedLengthReader, HighZeroBytesDroppedBigSize, Hostname, TransactionU16LenLimited};
+use crate::util::ser::{LengthReadable, Readable, ReadableArgs, Writeable, Writer, WithoutLength, FixedLengthReader, HighZeroBytesDroppedBigSize, Hostname, TransactionU16LenLimited, BigSize};
 
 use crate::ln::{PaymentPreimage, PaymentHash, PaymentSecret};
 
@@ -1423,30 +1423,45 @@ mod fuzzy_internal_msgs {
        // These types aren't intended to be pub, but are exposed for direct fuzzing (as we deserialize
        // them from untrusted input):
        #[derive(Clone)]
-       pub(crate) struct FinalOnionHopData {
-               pub(crate) payment_secret: PaymentSecret,
+       pub struct FinalOnionHopData {
+               pub payment_secret: PaymentSecret,
                /// The total value, in msat, of the payment as received by the ultimate recipient.
                /// Message serialization may panic if this value is more than 21 million Bitcoin.
-               pub(crate) total_msat: u64,
+               pub total_msat: u64,
        }
 
-       pub(crate) enum OnionHopDataFormat {
-               NonFinalNode {
+       pub enum InboundOnionPayload {
+               Forward {
                        short_channel_id: u64,
+                       /// The value, in msat, of the payment after this hop's fee is deducted.
+                       amt_to_forward: u64,
+                       outgoing_cltv_value: u32,
                },
-               FinalNode {
+               Receive {
                        payment_data: Option<FinalOnionHopData>,
                        payment_metadata: Option<Vec<u8>>,
                        keysend_preimage: Option<PaymentPreimage>,
+                       custom_tlvs: Vec<(u64, Vec<u8>)>,
+                       amt_msat: u64,
+                       outgoing_cltv_value: u32,
                },
        }
 
-       pub struct OnionHopData {
-               pub(crate) format: OnionHopDataFormat,
-               /// The value, in msat, of the payment after this hop's fee is deducted.
-               /// Message serialization may panic if this value is more than 21 million Bitcoin.
-               pub(crate) amt_to_forward: u64,
-               pub(crate) outgoing_cltv_value: u32,
+       pub(crate) enum OutboundOnionPayload {
+               Forward {
+                       short_channel_id: u64,
+                       /// The value, in msat, of the payment after this hop's fee is deducted.
+                       amt_to_forward: u64,
+                       outgoing_cltv_value: u32,
+               },
+               Receive {
+                       payment_data: Option<FinalOnionHopData>,
+                       payment_metadata: Option<Vec<u8>>,
+                       keysend_preimage: Option<PaymentPreimage>,
+                       custom_tlvs: Vec<(u64, Vec<u8>)>,
+                       amt_msat: u64,
+                       outgoing_cltv_value: u32,
+               },
        }
 
        pub struct DecodedOnionErrorPacket {
@@ -1955,31 +1970,39 @@ impl Readable for FinalOnionHopData {
        }
 }
 
-impl Writeable for OnionHopData {
+impl Writeable for OutboundOnionPayload {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), io::Error> {
-               match self.format {
-                       OnionHopDataFormat::NonFinalNode { short_channel_id } => {
+               match self {
+                       Self::Forward { short_channel_id, amt_to_forward, outgoing_cltv_value } => {
                                _encode_varint_length_prefixed_tlv!(w, {
-                                       (2, HighZeroBytesDroppedBigSize(self.amt_to_forward), required),
-                                       (4, HighZeroBytesDroppedBigSize(self.outgoing_cltv_value), required),
+                                       (2, HighZeroBytesDroppedBigSize(*amt_to_forward), required),
+                                       (4, HighZeroBytesDroppedBigSize(*outgoing_cltv_value), required),
                                        (6, short_channel_id, required)
                                });
                        },
-                       OnionHopDataFormat::FinalNode { ref payment_data, ref payment_metadata, ref keysend_preimage } => {
+                       Self::Receive {
+                               ref payment_data, ref payment_metadata, ref keysend_preimage, amt_msat,
+                               outgoing_cltv_value, ref custom_tlvs,
+                       } => {
+                               // We need to update [`ln::outbound_payment::RecipientOnionFields::with_custom_tlvs`]
+                               // to reject any reserved types in the experimental range if new ones are ever
+                               // standardized.
+                               let keysend_tlv = keysend_preimage.map(|preimage| (5482373484, preimage.encode()));
+                               let mut custom_tlvs: Vec<&(u64, Vec<u8>)> = custom_tlvs.iter().chain(keysend_tlv.iter()).collect();
+                               custom_tlvs.sort_unstable_by_key(|(typ, _)| *typ);
                                _encode_varint_length_prefixed_tlv!(w, {
-                                       (2, HighZeroBytesDroppedBigSize(self.amt_to_forward), required),
-                                       (4, HighZeroBytesDroppedBigSize(self.outgoing_cltv_value), required),
+                                       (2, HighZeroBytesDroppedBigSize(*amt_msat), required),
+                                       (4, HighZeroBytesDroppedBigSize(*outgoing_cltv_value), required),
                                        (8, payment_data, option),
-                                       (16, payment_metadata.as_ref().map(|m| WithoutLength(m)), option),
-                                       (5482373484, keysend_preimage, option)
-                               });
+                                       (16, payment_metadata.as_ref().map(|m| WithoutLength(m)), option)
+                               }, custom_tlvs.iter());
                        },
                }
                Ok(())
        }
 }
 
-impl Readable for OnionHopData {
+impl Readable for InboundOnionPayload {
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
                let mut amt = HighZeroBytesDroppedBigSize(0u64);
                let mut cltv_value = HighZeroBytesDroppedBigSize(0u32);
@@ -1987,7 +2010,11 @@ impl Readable for OnionHopData {
                let mut payment_data: Option<FinalOnionHopData> = None;
                let mut payment_metadata: Option<WithoutLength<Vec<u8>>> = None;
                let mut keysend_preimage: Option<PaymentPreimage> = None;
-               read_tlv_fields!(r, {
+               let mut custom_tlvs = Vec::new();
+
+               let tlv_len = BigSize::read(r)?;
+               let rd = FixedLengthReader::new(r, tlv_len.0);
+               decode_tlv_stream_with_custom_tlv_decode!(rd, {
                        (2, amt, required),
                        (4, cltv_value, required),
                        (6, short_id, option),
@@ -1995,41 +2022,44 @@ impl Readable for OnionHopData {
                        (16, payment_metadata, option),
                        // See https://github.com/lightning/blips/blob/master/blip-0003.md
                        (5482373484, keysend_preimage, option)
+               }, |msg_type: u64, msg_reader: &mut FixedLengthReader<_>| -> Result<bool, DecodeError> {
+                       if msg_type < 1 << 16 { return Ok(false) }
+                       let mut value = Vec::new();
+                       msg_reader.read_to_end(&mut value)?;
+                       custom_tlvs.push((msg_type, value));
+                       Ok(true)
                });
 
-               let format = if let Some(short_channel_id) = short_id {
-                       if payment_data.is_some() { return Err(DecodeError::InvalidValue); }
+               if amt.0 > MAX_VALUE_MSAT { return Err(DecodeError::InvalidValue) }
+               if let Some(short_channel_id) = short_id {
+                       if payment_data.is_some() { return Err(DecodeError::InvalidValue) }
                        if payment_metadata.is_some() { return Err(DecodeError::InvalidValue); }
-                       OnionHopDataFormat::NonFinalNode {
+                       Ok(Self::Forward {
                                short_channel_id,
-                       }
+                               amt_to_forward: amt.0,
+                               outgoing_cltv_value: cltv_value.0,
+                       })
                } else {
                        if let Some(data) = &payment_data {
                                if data.total_msat > MAX_VALUE_MSAT {
                                        return Err(DecodeError::InvalidValue);
                                }
                        }
-                       OnionHopDataFormat::FinalNode {
+                       Ok(Self::Receive {
                                payment_data,
                                payment_metadata: payment_metadata.map(|w| w.0),
                                keysend_preimage,
-                       }
-               };
-
-               if amt.0 > MAX_VALUE_MSAT {
-                       return Err(DecodeError::InvalidValue);
+                               amt_msat: amt.0,
+                               outgoing_cltv_value: cltv_value.0,
+                               custom_tlvs,
+                       })
                }
-               Ok(OnionHopData {
-                       format,
-                       amt_to_forward: amt.0,
-                       outgoing_cltv_value: cltv_value.0,
-               })
        }
 }
 
 // ReadableArgs because we need onion_utils::decode_next_hop to accommodate payment packets and
 // onion message packets.
-impl ReadableArgs<()> for OnionHopData {
+impl ReadableArgs<()> for InboundOnionPayload {
        fn read<R: Read>(r: &mut R, _arg: ()) -> Result<Self, DecodeError> {
                <Self as Readable>::read(r)
        }
@@ -2447,7 +2477,7 @@ mod tests {
        use hex;
        use crate::ln::{PaymentPreimage, PaymentHash, PaymentSecret};
        use crate::ln::features::{ChannelFeatures, ChannelTypeFeatures, InitFeatures, NodeFeatures};
-       use crate::ln::msgs::{self, FinalOnionHopData, OnionErrorPacket, OnionHopDataFormat};
+       use crate::ln::msgs::{self, FinalOnionHopData, OnionErrorPacket};
        use crate::routing::gossip::{NodeAlias, NodeId};
        use crate::util::ser::{Writeable, Readable, Hostname, TransactionU16LenLimited};
 
@@ -3530,75 +3560,144 @@ mod tests {
 
        #[test]
        fn encoding_nonfinal_onion_hop_data() {
-               let mut msg = msgs::OnionHopData {
-                       format: OnionHopDataFormat::NonFinalNode {
-                               short_channel_id: 0xdeadbeef1bad1dea,
-                       },
+               let outbound_msg = msgs::OutboundOnionPayload::Forward {
+                       short_channel_id: 0xdeadbeef1bad1dea,
                        amt_to_forward: 0x0badf00d01020304,
                        outgoing_cltv_value: 0xffffffff,
                };
-               let encoded_value = msg.encode();
+               let encoded_value = outbound_msg.encode();
                let target_value = hex::decode("1a02080badf00d010203040404ffffffff0608deadbeef1bad1dea").unwrap();
                assert_eq!(encoded_value, target_value);
-               msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
-               if let OnionHopDataFormat::NonFinalNode { short_channel_id } = msg.format {
+
+               let inbound_msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
+               if let msgs::InboundOnionPayload::Forward { short_channel_id, amt_to_forward, outgoing_cltv_value } = inbound_msg {
                        assert_eq!(short_channel_id, 0xdeadbeef1bad1dea);
+                       assert_eq!(amt_to_forward, 0x0badf00d01020304);
+                       assert_eq!(outgoing_cltv_value, 0xffffffff);
                } else { panic!(); }
-               assert_eq!(msg.amt_to_forward, 0x0badf00d01020304);
-               assert_eq!(msg.outgoing_cltv_value, 0xffffffff);
        }
 
        #[test]
        fn encoding_final_onion_hop_data() {
-               let mut msg = msgs::OnionHopData {
-                       format: OnionHopDataFormat::FinalNode {
-                               payment_data: None,
-                               payment_metadata: None,
-                               keysend_preimage: None,
-                       },
-                       amt_to_forward: 0x0badf00d01020304,
+               let outbound_msg = msgs::OutboundOnionPayload::Receive {
+                       payment_data: None,
+                       payment_metadata: None,
+                       keysend_preimage: None,
+                       amt_msat: 0x0badf00d01020304,
                        outgoing_cltv_value: 0xffffffff,
+                       custom_tlvs: vec![],
                };
-               let encoded_value = msg.encode();
+               let encoded_value = outbound_msg.encode();
                let target_value = hex::decode("1002080badf00d010203040404ffffffff").unwrap();
                assert_eq!(encoded_value, target_value);
-               msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
-               if let OnionHopDataFormat::FinalNode { payment_data: None, .. } = msg.format { } else { panic!(); }
-               assert_eq!(msg.amt_to_forward, 0x0badf00d01020304);
-               assert_eq!(msg.outgoing_cltv_value, 0xffffffff);
+
+               let inbound_msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
+               if let msgs::InboundOnionPayload::Receive { payment_data: None, amt_msat, outgoing_cltv_value, .. } = inbound_msg {
+                       assert_eq!(amt_msat, 0x0badf00d01020304);
+                       assert_eq!(outgoing_cltv_value, 0xffffffff);
+               } else { panic!(); }
        }
 
        #[test]
        fn encoding_final_onion_hop_data_with_secret() {
                let expected_payment_secret = PaymentSecret([0x42u8; 32]);
-               let mut msg = msgs::OnionHopData {
-                       format: OnionHopDataFormat::FinalNode {
-                               payment_data: Some(FinalOnionHopData {
-                                       payment_secret: expected_payment_secret,
-                                       total_msat: 0x1badca1f
-                               }),
-                               payment_metadata: None,
-                               keysend_preimage: None,
-                       },
-                       amt_to_forward: 0x0badf00d01020304,
+               let outbound_msg = msgs::OutboundOnionPayload::Receive {
+                       payment_data: Some(FinalOnionHopData {
+                               payment_secret: expected_payment_secret,
+                               total_msat: 0x1badca1f
+                       }),
+                       payment_metadata: None,
+                       keysend_preimage: None,
+                       amt_msat: 0x0badf00d01020304,
                        outgoing_cltv_value: 0xffffffff,
+                       custom_tlvs: vec![],
                };
-               let encoded_value = msg.encode();
+               let encoded_value = outbound_msg.encode();
                let target_value = hex::decode("3602080badf00d010203040404ffffffff082442424242424242424242424242424242424242424242424242424242424242421badca1f").unwrap();
                assert_eq!(encoded_value, target_value);
-               msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
-               if let OnionHopDataFormat::FinalNode {
+
+               let inbound_msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
+               if let msgs::InboundOnionPayload::Receive {
                        payment_data: Some(FinalOnionHopData {
                                payment_secret,
                                total_msat: 0x1badca1f
                        }),
+                       amt_msat, outgoing_cltv_value,
                        payment_metadata: None,
                        keysend_preimage: None,
-               } = msg.format {
+                       custom_tlvs,
+               } = inbound_msg  {
                        assert_eq!(payment_secret, expected_payment_secret);
+                       assert_eq!(amt_msat, 0x0badf00d01020304);
+                       assert_eq!(outgoing_cltv_value, 0xffffffff);
+                       assert_eq!(custom_tlvs, vec![]);
+               } else { panic!(); }
+       }
+
+       #[test]
+       fn encoding_final_onion_hop_data_with_bad_custom_tlvs() {
+               // If custom TLVs have type number within the range reserved for protocol, treat them as if
+               // they're unknown
+               let bad_type_range_tlvs = vec![
+                       ((1 << 16) - 4, vec![42]),
+                       ((1 << 16) - 2, vec![42; 32]),
+               ];
+               let mut msg = msgs::OutboundOnionPayload::Receive {
+                       payment_data: None,
+                       payment_metadata: None,
+                       keysend_preimage: None,
+                       custom_tlvs: bad_type_range_tlvs,
+                       amt_msat: 0x0badf00d01020304,
+                       outgoing_cltv_value: 0xffffffff,
+               };
+               let encoded_value = msg.encode();
+               assert!(msgs::InboundOnionPayload::read(&mut Cursor::new(&encoded_value[..])).is_err());
+               let good_type_range_tlvs = vec![
+                       ((1 << 16) - 3, vec![42]),
+                       ((1 << 16) - 1, vec![42; 32]),
+               ];
+               if let msgs::OutboundOnionPayload::Receive { ref mut custom_tlvs, .. } = msg {
+                       *custom_tlvs = good_type_range_tlvs.clone();
+               }
+               let encoded_value = msg.encode();
+               let inbound_msg = Readable::read(&mut Cursor::new(&encoded_value[..])).unwrap();
+               match inbound_msg {
+                       msgs::InboundOnionPayload::Receive { custom_tlvs, .. } => assert!(custom_tlvs.is_empty()),
+                       _ => panic!(),
+               }
+       }
+
+       #[test]
+       fn encoding_final_onion_hop_data_with_custom_tlvs() {
+               let expected_custom_tlvs = vec![
+                       (5482373483, vec![0x12, 0x34]),
+                       (5482373487, vec![0x42u8; 8]),
+               ];
+               let msg = msgs::OutboundOnionPayload::Receive {
+                       payment_data: None,
+                       payment_metadata: None,
+                       keysend_preimage: None,
+                       custom_tlvs: expected_custom_tlvs.clone(),
+                       amt_msat: 0x0badf00d01020304,
+                       outgoing_cltv_value: 0xffffffff,
+               };
+               let encoded_value = msg.encode();
+               let target_value = hex::decode("2e02080badf00d010203040404ffffffffff0000000146c6616b021234ff0000000146c6616f084242424242424242").unwrap();
+               assert_eq!(encoded_value, target_value);
+               let inbound_msg: msgs::InboundOnionPayload = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
+               if let msgs::InboundOnionPayload::Receive {
+                       payment_data: None,
+                       payment_metadata: None,
+                       keysend_preimage: None,
+                       custom_tlvs,
+                       amt_msat,
+                       outgoing_cltv_value,
+                       ..
+               } = inbound_msg {
+                       assert_eq!(custom_tlvs, expected_custom_tlvs);
+                       assert_eq!(amt_msat, 0x0badf00d01020304);
+                       assert_eq!(outgoing_cltv_value, 0xffffffff);
                } else { panic!(); }
-               assert_eq!(msg.amt_to_forward, 0x0badf00d01020304);
-               assert_eq!(msg.outgoing_cltv_value, 0xffffffff);
        }
 
        #[test]
@@ -3748,25 +3847,23 @@ mod tests {
                // payload length to be encoded over multiple bytes rather than a single u8.
                let big_payload = encode_big_payload().unwrap();
                let mut rd = Cursor::new(&big_payload[..]);
-               <msgs::OnionHopData as Readable>::read(&mut rd).unwrap();
+               <msgs::InboundOnionPayload as Readable>::read(&mut rd).unwrap();
        }
        // see above test, needs to be a separate method for use of the serialization macros.
        fn encode_big_payload() -> Result<Vec<u8>, io::Error> {
                use crate::util::ser::HighZeroBytesDroppedBigSize;
-               let payload = msgs::OnionHopData {
-                       format: OnionHopDataFormat::NonFinalNode {
-                               short_channel_id: 0xdeadbeef1bad1dea,
-                       },
+               let payload = msgs::OutboundOnionPayload::Forward {
+                       short_channel_id: 0xdeadbeef1bad1dea,
                        amt_to_forward: 1000,
                        outgoing_cltv_value: 0xffffffff,
                };
                let mut encoded_payload = Vec::new();
                let test_bytes = vec![42u8; 1000];
-               if let OnionHopDataFormat::NonFinalNode { short_channel_id } = payload.format {
+               if let msgs::OutboundOnionPayload::Forward { short_channel_id, amt_to_forward, outgoing_cltv_value } = payload {
                        _encode_varint_length_prefixed_tlv!(&mut encoded_payload, {
                                (1, test_bytes, required_vec),
-                               (2, HighZeroBytesDroppedBigSize(payload.amt_to_forward), required),
-                               (4, HighZeroBytesDroppedBigSize(payload.outgoing_cltv_value), required),
+                               (2, HighZeroBytesDroppedBigSize(amt_to_forward), required),
+                               (4, HighZeroBytesDroppedBigSize(outgoing_cltv_value), required),
                                (6, short_channel_id, required)
                        });
                }