Add htlc_maximum_msat field
[rust-lightning] / lightning / src / ln / msgs.rs
index bd5d23504b19961b63224c301853ce26971fd68e..554cf7290721951fd3f8eaf00d820e723c48da81 100644 (file)
@@ -427,9 +427,10 @@ pub(crate) struct UnsignedChannelUpdate {
        pub(crate) chain_hash: BlockHash,
        pub(crate) short_channel_id: u64,
        pub(crate) timestamp: u32,
-       pub(crate) flags: u16,
+       pub(crate) flags: u8,
        pub(crate) cltv_expiry_delta: u16,
        pub(crate) htlc_minimum_msat: u64,
+       pub(crate) htlc_maximum_msat: OptionalField<u64>,
        pub(crate) fee_base_msat: u32,
        pub(crate) fee_proportional_millionths: u32,
        pub(crate) excess_data: Vec<u8>,
@@ -517,7 +518,7 @@ pub enum HTLCFailChannelUpdate {
 /// As we wish to serialize these differently from Option<T>s (Options get a tag byte, but
 /// OptionalFeild simply gets Present if there are enough bytes to read into it), we have a
 /// separate enum type for them.
-#[derive(Clone, PartialEq)]
+#[derive(Clone, PartialEq, Debug)]
 pub enum OptionalField<T> {
        /// Optional field is included in message
        Present(T),
@@ -742,6 +743,26 @@ impl Readable for OptionalField<Script> {
        }
 }
 
+impl Writeable for OptionalField<u64> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               match *self {
+                       OptionalField::Present(ref value) => {
+                               value.write(w)?;
+                       },
+                       OptionalField::Absent => {}
+               }
+               Ok(())
+       }
+}
+
+impl Readable for OptionalField<u64> {
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let value: u64 = Readable::read(r)?;
+               Ok(OptionalField::Present(value))
+       }
+}
+
+
 impl_writeable_len_match!(AcceptChannel, {
                {AcceptChannel{ shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 270 + 2 + script.len()},
                {_, 270}
@@ -1180,15 +1201,23 @@ impl_writeable_len_match!(ChannelAnnouncement, {
 
 impl Writeable for UnsignedChannelUpdate {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               w.size_hint(64 + self.excess_data.len());
+               let mut size = 64 + self.excess_data.len();
+               let mut message_flags: u8 = 0;
+               if let OptionalField::Present(_) = self.htlc_maximum_msat {
+                       size += 8;
+                       message_flags = 1;
+               }
+               w.size_hint(size);
                self.chain_hash.write(w)?;
                self.short_channel_id.write(w)?;
                self.timestamp.write(w)?;
-               self.flags.write(w)?;
+               let all_flags = self.flags as u16 | ((message_flags as u16) << 8);
+               all_flags.write(w)?;
                self.cltv_expiry_delta.write(w)?;
                self.htlc_minimum_msat.write(w)?;
                self.fee_base_msat.write(w)?;
                self.fee_proportional_millionths.write(w)?;
+               self.htlc_maximum_msat.write(w)?;
                w.write_all(&self.excess_data[..])?;
                Ok(())
        }
@@ -1196,15 +1225,22 @@ impl Writeable for UnsignedChannelUpdate {
 
 impl Readable for UnsignedChannelUpdate {
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let has_htlc_maximum_msat;
                Ok(Self {
                        chain_hash: Readable::read(r)?,
                        short_channel_id: Readable::read(r)?,
                        timestamp: Readable::read(r)?,
-                       flags: Readable::read(r)?,
+                       flags: {
+                               let flags: u16 = Readable::read(r)?;
+                               let message_flags = flags >> 8;
+                               has_htlc_maximum_msat = (message_flags as i32 & 1) == 1;
+                               flags as u8
+                       },
                        cltv_expiry_delta: Readable::read(r)?,
                        htlc_minimum_msat: Readable::read(r)?,
                        fee_base_msat: Readable::read(r)?,
                        fee_proportional_millionths: Readable::read(r)?,
+                       htlc_maximum_msat: if has_htlc_maximum_msat { Readable::read(r)? } else { OptionalField::Absent },
                        excess_data: {
                                let mut excess_data = vec![];
                                r.read_to_end(&mut excess_data)?;
@@ -1597,7 +1633,7 @@ mod tests {
                do_encoding_node_announcement(false, false, true, false, true, false, false);
        }
 
-       fn do_encoding_channel_update(direction: bool, disable: bool, htlc_maximum_msat: bool) {
+       fn do_encoding_channel_update(direction: bool, disable: bool, htlc_maximum_msat: bool, excess_data: bool) {
                let secp_ctx = Secp256k1::new();
                let (privkey_1, _) = get_keys_from!("0101010101010101010101010101010101010101010101010101010101010101", secp_ctx);
                let sig_1 = get_sig_on!(privkey_1, secp_ctx, String::from("01010101010101010101010101010101"));
@@ -1605,12 +1641,13 @@ mod tests {
                        chain_hash: BlockHash::from_hex("6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000").unwrap(),
                        short_channel_id: 2316138423780173,
                        timestamp: 20190119,
-                       flags: if direction { 1 } else { 0 } | if disable { 1 << 1 } else { 0 } | if htlc_maximum_msat { 1 << 8 } else { 0 },
+                       flags: if direction { 1 } else { 0 } | if disable { 1 << 1 } else { 0 },
                        cltv_expiry_delta: 144,
                        htlc_minimum_msat: 1000000,
+                       htlc_maximum_msat: if htlc_maximum_msat { OptionalField::Present(131355275467161) } else { OptionalField::Absent },
                        fee_base_msat: 10000,
                        fee_proportional_millionths: 20,
-                       excess_data: if htlc_maximum_msat { vec![0, 0, 0, 0, 59, 154, 202, 0] } else { Vec::new() }
+                       excess_data: if excess_data { vec![0, 0, 0, 0, 59, 154, 202, 0] } else { Vec::new() }
                };
                let channel_update = msgs::ChannelUpdate {
                        signature: sig_1,
@@ -1636,6 +1673,9 @@ mod tests {
                }
                target_value.append(&mut hex::decode("009000000000000f42400000271000000014").unwrap());
                if htlc_maximum_msat {
+                       target_value.append(&mut hex::decode("0000777788889999").unwrap());
+               }
+               if excess_data {
                        target_value.append(&mut hex::decode("000000003b9aca00").unwrap());
                }
                assert_eq!(encoded_value, target_value);
@@ -1643,11 +1683,16 @@ mod tests {
 
        #[test]
        fn encoding_channel_update() {
-               do_encoding_channel_update(false, false, false);
-               do_encoding_channel_update(true, false, false);
-               do_encoding_channel_update(false, true, false);
-               do_encoding_channel_update(false, false, true);
-               do_encoding_channel_update(true, true, true);
+               do_encoding_channel_update(false, false, false, false);
+               do_encoding_channel_update(false, false, false, true);
+               do_encoding_channel_update(true, false, false, false);
+               do_encoding_channel_update(true, false, false, true);
+               do_encoding_channel_update(false, true, false, false);
+               do_encoding_channel_update(false, true, false, true);
+               do_encoding_channel_update(false, false, true, false);
+               do_encoding_channel_update(false, false, true, true);
+               do_encoding_channel_update(true, true, true, false);
+               do_encoding_channel_update(true, true, true, true);
        }
 
        fn do_encoding_open_channel(random_bit: bool, shutdown: bool) {