Expand the Route object to include multiple paths.
[rust-lightning] / lightning / src / ln / msgs.rs
index 03a2917802d370b014860b319f03a02cde855127..cd93e236a00a5d662c2ebf564a439eb60142fea4 100644 (file)
@@ -33,6 +33,9 @@ use util::ser::{Readable, Writeable, Writer, FixedLengthReader, HighZeroBytesDro
 
 use ln::channelmanager::{PaymentPreimage, PaymentHash};
 
+/// 21 million * 10^8 * 1000
+pub(crate) const MAX_VALUE_MSAT: u64 = 21_000_000_0000_0000_000;
+
 /// An error in decoding a message or struct.
 #[derive(Debug)]
 pub enum DecodeError {
@@ -661,6 +664,11 @@ pub trait RoutingMessageHandler : Send + Sync {
 mod fuzzy_internal_msgs {
        // These types aren't intended to be pub, but are exposed for direct fuzzing (as we deserialize
        // them from untrusted input):
+       #[derive(Clone)]
+       pub(crate) struct FinalOnionHopData {
+               pub(crate) payment_secret: [u8; 32],
+               pub(crate) total_msat: u64,
+       }
 
        pub(crate) enum OnionHopDataFormat {
                Legacy { // aka Realm-0
@@ -669,7 +677,9 @@ mod fuzzy_internal_msgs {
                NonFinalNode {
                        short_channel_id: u64,
                },
-               FinalNode,
+               FinalNode {
+                       payment_data: Option<FinalOnionHopData>,
+               },
        }
 
        pub struct OnionHopData {
@@ -1013,6 +1023,11 @@ impl_writeable!(UpdateAddHTLC, 32+8+8+32+4+1366, {
        onion_routing_packet
 });
 
+impl_writeable!(FinalOnionHopData, 32+8, {
+       payment_secret,
+       total_msat
+});
+
 impl Writeable for OnionHopData {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
                w.size_hint(33);
@@ -1030,11 +1045,19 @@ impl Writeable for OnionHopData {
                                        (6, short_channel_id)
                                });
                        },
-                       OnionHopDataFormat::FinalNode => {
-                               encode_varint_length_prefixed_tlv!(w, {
-                                       (2, HighZeroBytesDroppedVarInt(self.amt_to_forward)),
-                                       (4, HighZeroBytesDroppedVarInt(self.outgoing_cltv_value))
-                               });
+                       OnionHopDataFormat::FinalNode { ref payment_data } => {
+                               if let &Some(ref final_data) = payment_data {
+                                       encode_varint_length_prefixed_tlv!(w, {
+                                               (2, HighZeroBytesDroppedVarInt(self.amt_to_forward)),
+                                               (4, HighZeroBytesDroppedVarInt(self.outgoing_cltv_value)),
+                                               (8, final_data)
+                                       });
+                               } else {
+                                       encode_varint_length_prefixed_tlv!(w, {
+                                               (2, HighZeroBytesDroppedVarInt(self.amt_to_forward)),
+                                               (4, HighZeroBytesDroppedVarInt(self.outgoing_cltv_value))
+                                       });
+                               }
                        },
                }
                match self.format {
@@ -1060,19 +1083,29 @@ impl<R: Read> Readable<R> for OnionHopData {
                        let mut amt = HighZeroBytesDroppedVarInt(0u64);
                        let mut cltv_value = HighZeroBytesDroppedVarInt(0u32);
                        let mut short_id: Option<u64> = None;
+                       let mut payment_data: Option<FinalOnionHopData> = None;
                        decode_tlv!(&mut rd, {
                                (2, amt),
                                (4, cltv_value)
                        }, {
-                               (6, short_id)
+                               (6, short_id),
+                               (8, payment_data)
                        });
                        rd.eat_remaining().map_err(|_| DecodeError::ShortRead)?;
                        let format = if let Some(short_channel_id) = short_id {
+                               if payment_data.is_some() { return Err(DecodeError::InvalidValue); }
                                OnionHopDataFormat::NonFinalNode {
                                        short_channel_id,
                                }
                        } else {
-                               OnionHopDataFormat::FinalNode
+                               if let &Some(ref data) = &payment_data {
+                                       if data.total_msat > MAX_VALUE_MSAT {
+                                               return Err(DecodeError::InvalidValue);
+                                       }
+                               }
+                               OnionHopDataFormat::FinalNode {
+                                       payment_data
+                               }
                        };
                        (format, amt.0, cltv_value.0)
                } else {
@@ -1081,6 +1114,9 @@ impl<R: Read> Readable<R> for OnionHopData {
                        };
                        let amt: u64 = Readable::read(r)?;
                        let cltv_value: u32 = Readable::read(r)?;
+                       if amt > MAX_VALUE_MSAT {
+                               return Err(DecodeError::InvalidValue);
+                       }
                        r.read_exact(&mut [0; 12])?;
                        (format, amt, cltv_value)
                };
@@ -2067,7 +2103,9 @@ mod tests {
        #[test]
        fn encoding_final_onion_hop_data() {
                let mut msg = msgs::OnionHopData {
-                       format: OnionHopDataFormat::FinalNode,
+                       format: OnionHopDataFormat::FinalNode {
+                               payment_data: None,
+                       },
                        amt_to_forward: 0x0badf00d01020304,
                        outgoing_cltv_value: 0xffffffff,
                };
@@ -2075,7 +2113,7 @@ mod tests {
                let target_value = hex::decode("1002080badf00d010203040404ffffffff").unwrap();
                assert_eq!(encoded_value, target_value);
                msg = Readable::read(&mut Cursor::new(&target_value[..])).unwrap();
-               if let OnionHopDataFormat::FinalNode = msg.format { } else { panic!(); }
+               if let OnionHopDataFormat::FinalNode { payment_data: None } = msg.format { } else { panic!(); }
                assert_eq!(msg.amt_to_forward, 0x0badf00d01020304);
                assert_eq!(msg.outgoing_cltv_value, 0xffffffff);
        }