Fix bug in onion payment payload decode
[rust-lightning] / lightning / src / ln / msgs.rs
index 9886bfeea5e511fd180d3decd9f3158cc64f916f..4db1525718998bab0c29556f382810c4418b1f0b 100644 (file)
@@ -42,7 +42,7 @@ use io_extras::read_to_end;
 
 use util::events::MessageSendEventsProvider;
 use util::logger;
-use util::ser::{LengthReadable, Readable, ReadableArgs, Writeable, Writer, FixedLengthReader, HighZeroBytesDroppedVarInt, Hostname};
+use util::ser::{BigSize, LengthReadable, Readable, ReadableArgs, Writeable, Writer, FixedLengthReader, HighZeroBytesDroppedVarInt, Hostname};
 
 use ln::{PaymentPreimage, PaymentHash, PaymentSecret};
 
@@ -1418,16 +1418,11 @@ impl Writeable for OnionHopData {
 }
 
 impl Readable for OnionHopData {
-       fn read<R: Read>(mut r: &mut R) -> Result<Self, DecodeError> {
-               use bitcoin::consensus::encode::{Decodable, Error, VarInt};
-               let v: VarInt = Decodable::consensus_decode(&mut r)
-                       .map_err(|e| match e {
-                               Error::Io(ioe) => DecodeError::from(ioe),
-                               _ => DecodeError::InvalidValue
-                       })?;
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let b: BigSize = Readable::read(r)?;
                const LEGACY_ONION_HOP_FLAG: u64 = 0;
-               let (format, amt, cltv_value) = if v.0 != LEGACY_ONION_HOP_FLAG {
-                       let mut rd = FixedLengthReader::new(r, v.0);
+               let (format, amt, cltv_value) = if b.0 != LEGACY_ONION_HOP_FLAG {
+                       let mut rd = FixedLengthReader::new(r, b.0);
                        let mut amt = HighZeroBytesDroppedVarInt(0u64);
                        let mut cltv_value = HighZeroBytesDroppedVarInt(0u32);
                        let mut short_id: Option<u64> = None;
@@ -1913,7 +1908,7 @@ mod tests {
        use bitcoin::secp256k1::{PublicKey,SecretKey};
        use bitcoin::secp256k1::{Secp256k1, Message};
 
-       use io::Cursor;
+       use io::{self, Cursor};
        use prelude::*;
        use core::convert::TryFrom;
 
@@ -2824,4 +2819,40 @@ mod tests {
                assert_eq!(gossip_timestamp_filter.first_timestamp, 1590000000);
                assert_eq!(gossip_timestamp_filter.timestamp_range, 0xffff_ffff);
        }
+
+       #[test]
+       fn decode_onion_hop_data_len_as_bigsize() {
+               // Tests that we can decode an onion payload that is >253 bytes.
+               // Previously, receiving a payload of this size could've caused us to fail to decode a valid
+               // payload, because we were decoding the length (a BigSize, big-endian) as a VarInt
+               // (little-endian).
+
+               // Encode a test onion payload with a big custom TLV such that it's >253 bytes, forcing the
+               // payload length to be encoded over multiple bytes rather than a single u8.
+               let big_payload = encode_big_payload().unwrap();
+               let mut rd = Cursor::new(&big_payload[..]);
+               <msgs::OnionHopData as Readable>::read(&mut rd).unwrap();
+       }
+       // see above test, needs to be a separate method for use of the serialization macros.
+       fn encode_big_payload() -> Result<Vec<u8>, io::Error> {
+               use util::ser::HighZeroBytesDroppedVarInt;
+               let payload = msgs::OnionHopData {
+                       format: OnionHopDataFormat::NonFinalNode {
+                               short_channel_id: 0xdeadbeef1bad1dea,
+                       },
+                       amt_to_forward: 1000,
+                       outgoing_cltv_value: 0xffffffff,
+               };
+               let mut encoded_payload = Vec::new();
+               let test_bytes = vec![42u8; 1000];
+               if let OnionHopDataFormat::NonFinalNode { short_channel_id } = payload.format {
+                       encode_varint_length_prefixed_tlv!(&mut encoded_payload, {
+                               (1, test_bytes, vec_type),
+                               (2, HighZeroBytesDroppedVarInt(payload.amt_to_forward), required),
+                               (4, HighZeroBytesDroppedVarInt(payload.outgoing_cltv_value), required),
+                               (6, short_channel_id, required)
+                       });
+               }
+               Ok(encoded_payload)
+       }
 }