Implement Writeable/Readable for Option<T>
[rust-lightning] / src / ln / msgs.rs
index 4cff6256f83f9176460d80f23044d93036114b7a..494abf1953616f8dd46f720d40dc735f4d31dfe2 100644 (file)
@@ -191,7 +191,7 @@ pub struct OpenChannel {
        pub(crate) htlc_basepoint: PublicKey,
        pub(crate) first_per_commitment_point: PublicKey,
        pub(crate) channel_flags: u8,
-       pub(crate) shutdown_scriptpubkey: Option<Script>,
+       pub(crate) shutdown_scriptpubkey: OptionalField<Script>,
 }
 
 /// An accept_channel message to be sent or received from a peer
@@ -211,7 +211,7 @@ pub struct AcceptChannel {
        pub(crate) delayed_payment_basepoint: PublicKey,
        pub(crate) htlc_basepoint: PublicKey,
        pub(crate) first_per_commitment_point: PublicKey,
-       pub(crate) shutdown_scriptpubkey: Option<Script>,
+       pub(crate) shutdown_scriptpubkey: OptionalField<Script>
 }
 
 /// A funding_created message to be sent or received from a peer
@@ -323,7 +323,7 @@ pub struct ChannelReestablish {
        pub(crate) channel_id: [u8; 32],
        pub(crate) next_local_commitment_number: u64,
        pub(crate) next_remote_commitment_number: u64,
-       pub(crate) data_loss_protect: Option<DataLossProtect>,
+       pub(crate) data_loss_protect: OptionalField<DataLossProtect>,
 }
 
 /// An announcement_signatures message to be sent or received from a peer
@@ -518,6 +518,18 @@ pub enum HTLCFailChannelUpdate {
        }
 }
 
+/// Messages could have optional fields to use with extended features
+/// As we wish to serialize these differently from Option<T>s (Options get a tag byte, but
+/// OptionalFeild simply gets Present if there are enough bytes to read into it), we have a
+/// separate enum type for them.
+#[derive(Clone, PartialEq)]
+pub enum OptionalField<T> {
+       /// Optional field is included in message
+       Present(T),
+       /// Optional field is absent in message
+       Absent
+}
+
 /// A trait to describe an object which can receive channel messages.
 ///
 /// Messages MAY be called in parallel when they originate from different their_node_ids, however
@@ -696,8 +708,35 @@ impl From<::std::io::Error> for DecodeError {
        }
 }
 
+impl Writeable for OptionalField<Script> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               match *self {
+                       OptionalField::Present(ref script) => {
+                               // Note that Writeable for script includes the 16-bit length tag for us
+                               script.write(w)?;
+                       },
+                       OptionalField::Absent => {}
+               }
+               Ok(())
+       }
+}
+
+impl<R: Read> Readable<R> for OptionalField<Script> {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
+               match <u16 as Readable<R>>::read(r) {
+                       Ok(len) => {
+                               let mut buf = vec![0; len as usize];
+                               r.read_exact(&mut buf)?;
+                               Ok(OptionalField::Present(Script::from(buf)))
+                       },
+                       Err(DecodeError::ShortRead) => Ok(OptionalField::Absent),
+                       Err(e) => Err(e)
+               }
+       }
+}
+
 impl_writeable_len_match!(AcceptChannel, {
-               {AcceptChannel{ shutdown_scriptpubkey: Some(ref script), ..}, 270 + 2 + script.len()},
+               {AcceptChannel{ shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 270 + 2 + script.len()},
                {_, 270}
        }, {
        temporary_channel_id,
@@ -726,13 +765,16 @@ impl_writeable!(AnnouncementSignatures, 32+8+64*2, {
 
 impl Writeable for ChannelReestablish {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               w.size_hint(if self.data_loss_protect.is_some() { 32+2*8+33+32 } else { 32+2*8 });
+               w.size_hint(if let OptionalField::Present(..) = self.data_loss_protect { 32+2*8+33+32 } else { 32+2*8 });
                self.channel_id.write(w)?;
                self.next_local_commitment_number.write(w)?;
                self.next_remote_commitment_number.write(w)?;
-               if let Some(ref data_loss_protect) = self.data_loss_protect {
-                       data_loss_protect.your_last_per_commitment_secret.write(w)?;
-                       data_loss_protect.my_current_per_commitment_point.write(w)?;
+               match self.data_loss_protect {
+                       OptionalField::Present(ref data_loss_protect) => {
+                               (*data_loss_protect).your_last_per_commitment_secret.write(w)?;
+                               (*data_loss_protect).my_current_per_commitment_point.write(w)?;
+                       },
+                       OptionalField::Absent => {}
                }
                Ok(())
        }
@@ -747,11 +789,11 @@ impl<R: Read> Readable<R> for ChannelReestablish{
                        data_loss_protect: {
                                match <[u8; 32] as Readable<R>>::read(r) {
                                        Ok(your_last_per_commitment_secret) =>
-                                               Some(DataLossProtect {
+                                               OptionalField::Present(DataLossProtect {
                                                        your_last_per_commitment_secret,
                                                        my_current_per_commitment_point: Readable::read(r)?,
                                                }),
-                                       Err(DecodeError::ShortRead) => None,
+                                       Err(DecodeError::ShortRead) => OptionalField::Absent,
                                        Err(e) => return Err(e)
                                }
                        }
@@ -818,8 +860,8 @@ impl_writeable_len_match!(Init, {
 });
 
 impl_writeable_len_match!(OpenChannel, {
-               { OpenChannel { shutdown_scriptpubkey: Some(ref script), .. }, 319 + 2 + script.len() },
-               { OpenChannel { shutdown_scriptpubkey: None, .. }, 319 }
+               { OpenChannel { shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 319 + 2 + script.len() },
+               { _, 319 }
        }, {
        chain_hash,
        temporary_channel_id,
@@ -1322,6 +1364,7 @@ impl_writeable_len_match!(NodeAnnouncement, {
 mod tests {
        use hex;
        use ln::msgs;
+       use ln::msgs::OptionalField;
        use util::ser::Writeable;
        use secp256k1::key::{PublicKey,SecretKey};
        use secp256k1::Secp256k1;
@@ -1332,7 +1375,7 @@ mod tests {
                        channel_id: [4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0],
                        next_local_commitment_number: 3,
                        next_remote_commitment_number: 4,
-                       data_loss_protect: None,
+                       data_loss_protect: OptionalField::Absent,
                };
 
                let encoded_value = cr.encode();
@@ -1353,7 +1396,7 @@ mod tests {
                        channel_id: [4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0],
                        next_local_commitment_number: 3,
                        next_remote_commitment_number: 4,
-                       data_loss_protect: Some(msgs::DataLossProtect { your_last_per_commitment_secret: [9;32], my_current_per_commitment_point: public_key}),
+                       data_loss_protect: OptionalField::Present(msgs::DataLossProtect { your_last_per_commitment_secret: [9;32], my_current_per_commitment_point: public_key}),
                };
 
                let encoded_value = cr.encode();