Relicense as dual Apache-2.0 + MIT
[rust-lightning] / lightning / src / ln / msgs.rs
index 6936a7e93c45b96018d6542bdfe99a527f04e882..4fccba37212ce8584af1ce83e8bddc697a68f5ae 100644 (file)
@@ -1,3 +1,12 @@
+// This file is Copyright its original authors, visible in version control
+// history.
+//
+// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
+// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
+// You may not use this file except in accordance with one or both of these
+// licenses.
+
 //! Wire messages, traits representing wire message handlers, and a few error types live here.
 //!
 //! For a normal node you probably don't need to use anything here, however, if you wish to split a
@@ -25,7 +34,6 @@ use ln::features::{ChannelFeatures, InitFeatures, NodeFeatures};
 
 use std::{cmp, fmt};
 use std::io::Read;
-use std::result::Result;
 
 use util::events;
 use util::ser::{Readable, Writeable, Writer, FixedLengthReader, HighZeroBytesDroppedVarInt};
@@ -428,9 +436,10 @@ pub(crate) struct UnsignedChannelUpdate {
        pub(crate) chain_hash: BlockHash,
        pub(crate) short_channel_id: u64,
        pub(crate) timestamp: u32,
-       pub(crate) flags: u16,
+       pub(crate) flags: u8,
        pub(crate) cltv_expiry_delta: u16,
        pub(crate) htlc_minimum_msat: u64,
+       pub(crate) htlc_maximum_msat: OptionalField<u64>,
        pub(crate) fee_base_msat: u32,
        pub(crate) fee_proportional_millionths: u32,
        pub(crate) excess_data: Vec<u8>,
@@ -462,7 +471,7 @@ pub enum ErrorAction {
 /// An Err type for failure to process messages.
 pub struct LightningError {
        /// A human-readable message describing the error
-       pub err: &'static str,
+       pub err: String,
        /// The action which should be taken against the offending peer.
        pub action: ErrorAction,
 }
@@ -518,7 +527,7 @@ pub enum HTLCFailChannelUpdate {
 /// As we wish to serialize these differently from Option<T>s (Options get a tag byte, but
 /// OptionalFeild simply gets Present if there are enough bytes to read into it), we have a
 /// separate enum type for them.
-#[derive(Clone, PartialEq)]
+#[derive(Clone, PartialEq, Debug)]
 pub enum OptionalField<T> {
        /// Optional field is included in message
        Present(T),
@@ -702,7 +711,7 @@ impl fmt::Display for DecodeError {
 
 impl fmt::Debug for LightningError {
        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-               f.write_str(self.err)
+               f.write_str(self.err.as_str())
        }
 }
 
@@ -743,6 +752,26 @@ impl Readable for OptionalField<Script> {
        }
 }
 
+impl Writeable for OptionalField<u64> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               match *self {
+                       OptionalField::Present(ref value) => {
+                               value.write(w)?;
+                       },
+                       OptionalField::Absent => {}
+               }
+               Ok(())
+       }
+}
+
+impl Readable for OptionalField<u64> {
+       fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let value: u64 = Readable::read(r)?;
+               Ok(OptionalField::Present(value))
+       }
+}
+
+
 impl_writeable_len_match!(AcceptChannel, {
                {AcceptChannel{ shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 270 + 2 + script.len()},
                {_, 270}
@@ -1181,15 +1210,23 @@ impl_writeable_len_match!(ChannelAnnouncement, {
 
 impl Writeable for UnsignedChannelUpdate {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               w.size_hint(64 + self.excess_data.len());
+               let mut size = 64 + self.excess_data.len();
+               let mut message_flags: u8 = 0;
+               if let OptionalField::Present(_) = self.htlc_maximum_msat {
+                       size += 8;
+                       message_flags = 1;
+               }
+               w.size_hint(size);
                self.chain_hash.write(w)?;
                self.short_channel_id.write(w)?;
                self.timestamp.write(w)?;
-               self.flags.write(w)?;
+               let all_flags = self.flags as u16 | ((message_flags as u16) << 8);
+               all_flags.write(w)?;
                self.cltv_expiry_delta.write(w)?;
                self.htlc_minimum_msat.write(w)?;
                self.fee_base_msat.write(w)?;
                self.fee_proportional_millionths.write(w)?;
+               self.htlc_maximum_msat.write(w)?;
                w.write_all(&self.excess_data[..])?;
                Ok(())
        }
@@ -1197,15 +1234,22 @@ impl Writeable for UnsignedChannelUpdate {
 
 impl Readable for UnsignedChannelUpdate {
        fn read<R: Read>(r: &mut R) -> Result<Self, DecodeError> {
+               let has_htlc_maximum_msat;
                Ok(Self {
                        chain_hash: Readable::read(r)?,
                        short_channel_id: Readable::read(r)?,
                        timestamp: Readable::read(r)?,
-                       flags: Readable::read(r)?,
+                       flags: {
+                               let flags: u16 = Readable::read(r)?;
+                               let message_flags = flags >> 8;
+                               has_htlc_maximum_msat = (message_flags as i32 & 1) == 1;
+                               flags as u8
+                       },
                        cltv_expiry_delta: Readable::read(r)?,
                        htlc_minimum_msat: Readable::read(r)?,
                        fee_base_msat: Readable::read(r)?,
                        fee_proportional_millionths: Readable::read(r)?,
+                       htlc_maximum_msat: if has_htlc_maximum_msat { Readable::read(r)? } else { OptionalField::Absent },
                        excess_data: {
                                let mut excess_data = vec![];
                                r.read_to_end(&mut excess_data)?;
@@ -1598,7 +1642,7 @@ mod tests {
                do_encoding_node_announcement(false, false, true, false, true, false, false);
        }
 
-       fn do_encoding_channel_update(direction: bool, disable: bool, htlc_maximum_msat: bool) {
+       fn do_encoding_channel_update(direction: bool, disable: bool, htlc_maximum_msat: bool, excess_data: bool) {
                let secp_ctx = Secp256k1::new();
                let (privkey_1, _) = get_keys_from!("0101010101010101010101010101010101010101010101010101010101010101", secp_ctx);
                let sig_1 = get_sig_on!(privkey_1, secp_ctx, String::from("01010101010101010101010101010101"));
@@ -1606,12 +1650,13 @@ mod tests {
                        chain_hash: BlockHash::from_hex("6fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d6190000000000").unwrap(),
                        short_channel_id: 2316138423780173,
                        timestamp: 20190119,
-                       flags: if direction { 1 } else { 0 } | if disable { 1 << 1 } else { 0 } | if htlc_maximum_msat { 1 << 8 } else { 0 },
+                       flags: if direction { 1 } else { 0 } | if disable { 1 << 1 } else { 0 },
                        cltv_expiry_delta: 144,
                        htlc_minimum_msat: 1000000,
+                       htlc_maximum_msat: if htlc_maximum_msat { OptionalField::Present(131355275467161) } else { OptionalField::Absent },
                        fee_base_msat: 10000,
                        fee_proportional_millionths: 20,
-                       excess_data: if htlc_maximum_msat { vec![0, 0, 0, 0, 59, 154, 202, 0] } else { Vec::new() }
+                       excess_data: if excess_data { vec![0, 0, 0, 0, 59, 154, 202, 0] } else { Vec::new() }
                };
                let channel_update = msgs::ChannelUpdate {
                        signature: sig_1,
@@ -1637,6 +1682,9 @@ mod tests {
                }
                target_value.append(&mut hex::decode("009000000000000f42400000271000000014").unwrap());
                if htlc_maximum_msat {
+                       target_value.append(&mut hex::decode("0000777788889999").unwrap());
+               }
+               if excess_data {
                        target_value.append(&mut hex::decode("000000003b9aca00").unwrap());
                }
                assert_eq!(encoded_value, target_value);
@@ -1644,11 +1692,16 @@ mod tests {
 
        #[test]
        fn encoding_channel_update() {
-               do_encoding_channel_update(false, false, false);
-               do_encoding_channel_update(true, false, false);
-               do_encoding_channel_update(false, true, false);
-               do_encoding_channel_update(false, false, true);
-               do_encoding_channel_update(true, true, true);
+               do_encoding_channel_update(false, false, false, false);
+               do_encoding_channel_update(false, false, false, true);
+               do_encoding_channel_update(true, false, false, false);
+               do_encoding_channel_update(true, false, false, true);
+               do_encoding_channel_update(false, true, false, false);
+               do_encoding_channel_update(false, true, false, true);
+               do_encoding_channel_update(false, false, true, false);
+               do_encoding_channel_update(false, false, true, true);
+               do_encoding_channel_update(true, true, true, false);
+               do_encoding_channel_update(true, true, true, true);
        }
 
        fn do_encoding_open_channel(random_bit: bool, shutdown: bool) {