Implement Writeable/Readable for Option<T>
authorAntoine Riard <ariard@student.42.fr>
Wed, 23 Jan 2019 16:26:32 +0000 (11:26 -0500)
committerMatt Corallo <git@bluematt.me>
Wed, 23 Jan 2019 16:31:26 +0000 (11:31 -0500)
Add OptionalField in OpenChannel, AcceptChannel
ChannelReestablish to avoid serialization implementation
conflicts

src/ln/channel.rs
src/ln/channelmanager.rs
src/ln/channelmonitor.rs
src/ln/msgs.rs
src/util/ser.rs

index 237ccd840f561a86d89d915bf2c4c635157f37b7..28b28f539ddd0f663e3ecfa7f6142e5b56b32026 100644 (file)
@@ -15,7 +15,7 @@ use secp256k1::{Secp256k1,Message,Signature};
 use secp256k1;
 
 use ln::msgs;
-use ln::msgs::DecodeError;
+use ln::msgs::{DecodeError, OptionalField};
 use ln::channelmonitor::ChannelMonitor;
 use ln::channelmanager::{PendingHTLCStatus, HTLCSource, HTLCFailReason, HTLCFailureMsg, PendingForwardHTLCInfo, RAACommitmentOrder, PaymentPreimage, PaymentHash};
 use ln::chan_utils::{TxCreationKeys,HTLCOutputInCommitment,HTLC_SUCCESS_TX_WEIGHT,HTLC_TIMEOUT_TX_WEIGHT};
@@ -2946,7 +2946,7 @@ impl Channel {
                        htlc_basepoint: PublicKey::from_secret_key(&self.secp_ctx, &self.local_keys.htlc_base_key),
                        first_per_commitment_point: PublicKey::from_secret_key(&self.secp_ctx, &local_commitment_secret),
                        channel_flags: if self.config.announced_channel {1} else {0},
-                       shutdown_scriptpubkey: None,
+                       shutdown_scriptpubkey: OptionalField::Absent
                }
        }
 
@@ -2978,7 +2978,7 @@ impl Channel {
                        delayed_payment_basepoint: PublicKey::from_secret_key(&self.secp_ctx, &self.local_keys.delayed_payment_base_key),
                        htlc_basepoint: PublicKey::from_secret_key(&self.secp_ctx, &self.local_keys.htlc_base_key),
                        first_per_commitment_point: PublicKey::from_secret_key(&self.secp_ctx, &local_commitment_secret),
-                       shutdown_scriptpubkey: None,
+                       shutdown_scriptpubkey: OptionalField::Absent
                }
        }
 
@@ -3106,7 +3106,7 @@ impl Channel {
                        // dropped this channel on disconnect as it hasn't yet reached FundingSent so we can't
                        // overflow here.
                        next_remote_commitment_number: INITIAL_COMMITMENT_NUMBER - self.cur_remote_commitment_transaction_number - 1,
-                       data_loss_protect: None,
+                       data_loss_protect: OptionalField::Absent,
                }
        }
 
@@ -3691,14 +3691,6 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                        });
                }
 
-               macro_rules! read_option { () => {
-                       match <u8 as Readable<R>>::read(reader)? {
-                               0 => None,
-                               1 => Some(Readable::read(reader)?),
-                               _ => return Err(DecodeError::InvalidValue),
-                       }
-               } }
-
                let pending_outbound_htlc_count: u64 = Readable::read(reader)?;
                let mut pending_outbound_htlcs = Vec::with_capacity(cmp::min(pending_outbound_htlc_count as usize, OUR_MAX_HTLCS as usize));
                for _ in 0..pending_outbound_htlc_count {
@@ -3708,7 +3700,7 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                                cltv_expiry: Readable::read(reader)?,
                                payment_hash: Readable::read(reader)?,
                                source: Readable::read(reader)?,
-                               fail_reason: read_option!(),
+                               fail_reason: Readable::read(reader)?,
                                state: match <u8 as Readable<R>>::read(reader)? {
                                        0 => OutboundHTLCState::LocalAnnounced(Box::new(Readable::read(reader)?)),
                                        1 => OutboundHTLCState::Committed,
@@ -3766,8 +3758,8 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                        monitor_pending_failures.push((Readable::read(reader)?, Readable::read(reader)?, Readable::read(reader)?));
                }
 
-               let pending_update_fee = read_option!();
-               let holding_cell_update_fee = read_option!();
+               let pending_update_fee = Readable::read(reader)?;
+               let holding_cell_update_fee = Readable::read(reader)?;
 
                let next_local_htlc_id = Readable::read(reader)?;
                let next_remote_htlc_id = Readable::read(reader)?;
@@ -3789,8 +3781,8 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                        _ => return Err(DecodeError::InvalidValue),
                };
 
-               let funding_tx_confirmed_in = read_option!();
-               let short_channel_id = read_option!();
+               let funding_tx_confirmed_in = Readable::read(reader)?;
+               let short_channel_id = Readable::read(reader)?;
 
                let last_block_connected = Readable::read(reader)?;
                let funding_tx_confirmations = Readable::read(reader)?;
@@ -3805,17 +3797,17 @@ impl<R : ::std::io::Read> ReadableArgs<R, Arc<Logger>> for Channel {
                let their_max_accepted_htlcs = Readable::read(reader)?;
                let minimum_depth = Readable::read(reader)?;
 
-               let their_funding_pubkey = read_option!();
-               let their_revocation_basepoint = read_option!();
-               let their_payment_basepoint = read_option!();
-               let their_delayed_payment_basepoint = read_option!();
-               let their_htlc_basepoint = read_option!();
-               let their_cur_commitment_point = read_option!();
+               let their_funding_pubkey = Readable::read(reader)?;
+               let their_revocation_basepoint = Readable::read(reader)?;
+               let their_payment_basepoint = Readable::read(reader)?;
+               let their_delayed_payment_basepoint = Readable::read(reader)?;
+               let their_htlc_basepoint = Readable::read(reader)?;
+               let their_cur_commitment_point = Readable::read(reader)?;
 
-               let their_prev_commitment_point = read_option!();
+               let their_prev_commitment_point = Readable::read(reader)?;
                let their_node_id = Readable::read(reader)?;
 
-               let their_shutdown_scriptpubkey = read_option!();
+               let their_shutdown_scriptpubkey = Readable::read(reader)?;
                let (monitor_last_block, channel_monitor) = ReadableArgs::read(reader, logger.clone())?;
                // We drop the ChannelMonitor's last block connected hash cause we don't actually bother
                // doing full block connection operations on the internal CHannelMonitor copies
index 0b7def9ad87bf8b02a9ae22565d9854de2617144..766571cd8a3c74c2dcdae8a9681feae688583325 100644 (file)
@@ -2612,12 +2612,7 @@ const MIN_SERIALIZATION_VERSION: u8 = 1;
 
 impl Writeable for PendingForwardHTLCInfo {
        fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
-               if let &Some(ref onion) = &self.onion_packet {
-                       1u8.write(writer)?;
-                       onion.write(writer)?;
-               } else {
-                       0u8.write(writer)?;
-               }
+               self.onion_packet.write(writer)?;
                self.incoming_shared_secret.write(writer)?;
                self.payment_hash.write(writer)?;
                self.short_channel_id.write(writer)?;
@@ -2629,13 +2624,8 @@ impl Writeable for PendingForwardHTLCInfo {
 
 impl<R: ::std::io::Read> Readable<R> for PendingForwardHTLCInfo {
        fn read(reader: &mut R) -> Result<PendingForwardHTLCInfo, DecodeError> {
-               let onion_packet = match <u8 as Readable<R>>::read(reader)? {
-                       0 => None,
-                       1 => Some(msgs::OnionPacket::read(reader)?),
-                       _ => return Err(DecodeError::InvalidValue),
-               };
                Ok(PendingForwardHTLCInfo {
-                       onion_packet,
+                       onion_packet: Readable::read(reader)?,
                        incoming_shared_secret: Readable::read(reader)?,
                        payment_hash: Readable::read(reader)?,
                        short_channel_id: Readable::read(reader)?,
index 3cd76203439b73417b79b222bbb916265e4f219a..573ec7f72f547d2258a31506d6bc004e078881c8 100644 (file)
@@ -1960,13 +1960,6 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                }
                        }
                }
-               macro_rules! read_option { () => {
-                       match <u8 as Readable<R>>::read(reader)? {
-                               0 => None,
-                               1 => Some(Readable::read(reader)?),
-                               _ => return Err(DecodeError::InvalidValue),
-                       }
-               } }
 
                let _ver: u8 = Readable::read(reader)?;
                let min_ver: u8 = Readable::read(reader)?;
@@ -1983,16 +1976,8 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                let delayed_payment_base_key = Readable::read(reader)?;
                                let payment_base_key = Readable::read(reader)?;
                                let shutdown_pubkey = Readable::read(reader)?;
-                               let prev_latest_per_commitment_point = match <u8 as Readable<R>>::read(reader)? {
-                                       0 => None,
-                                       1 => Some(Readable::read(reader)?),
-                                       _ => return Err(DecodeError::InvalidValue),
-                               };
-                               let latest_per_commitment_point = match <u8 as Readable<R>>::read(reader)? {
-                                       0 => None,
-                                       1 => Some(Readable::read(reader)?),
-                                       _ => return Err(DecodeError::InvalidValue),
-                               };
+                               let prev_latest_per_commitment_point = Readable::read(reader)?;
+                               let latest_per_commitment_point = Readable::read(reader)?;
                                // Technically this can fail and serialize fail a round-trip, but only for serialization of
                                // barely-init'd ChannelMonitors that we can't do anything with.
                                let outpoint = OutPoint {
@@ -2000,8 +1985,8 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                        index: Readable::read(reader)?,
                                };
                                let funding_info = Some((outpoint, Readable::read(reader)?));
-                               let current_remote_commitment_txid = read_option!();
-                               let prev_remote_commitment_txid = read_option!();
+                               let current_remote_commitment_txid = Readable::read(reader)?;
+                               let prev_remote_commitment_txid = Readable::read(reader)?;
                                Storage::Local {
                                        revocation_base_key,
                                        htlc_base_key,
@@ -2052,7 +2037,7 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                        let amount_msat: u64 = Readable::read(reader)?;
                                        let cltv_expiry: u32 = Readable::read(reader)?;
                                        let payment_hash: PaymentHash = Readable::read(reader)?;
-                                       let transaction_output_index: Option<u32> = read_option!();
+                                       let transaction_output_index: Option<u32> = Readable::read(reader)?;
 
                                        HTLCOutputInCommitment {
                                                offered, amount_msat, cltv_expiry, payment_hash, transaction_output_index
@@ -2068,7 +2053,7 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                        let htlcs_count: u64 = Readable::read(reader)?;
                        let mut htlcs = Vec::with_capacity(cmp::min(htlcs_count as usize, MAX_ALLOC_SIZE / 32));
                        for _ in 0..htlcs_count {
-                               htlcs.push((read_htlc_in_commitment!(), read_option!().map(|o: HTLCSource| Box::new(o))));
+                               htlcs.push((read_htlc_in_commitment!(), <Option<HTLCSource> as Readable<R>>::read(reader)?.map(|o: HTLCSource| Box::new(o))));
                        }
                        if let Some(_) = remote_claimable_outpoints.insert(txid, htlcs) {
                                return Err(DecodeError::InvalidValue);
@@ -2131,7 +2116,7 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for (Sha256dHash, ChannelM
                                                        1 => Some((Readable::read(reader)?, Readable::read(reader)?)),
                                                        _ => return Err(DecodeError::InvalidValue),
                                                };
-                                               htlcs.push((htlc, sigs, read_option!()));
+                                               htlcs.push((htlc, sigs, Readable::read(reader)?));
                                        }
 
                                        LocalSignedTx {
index 4cff6256f83f9176460d80f23044d93036114b7a..494abf1953616f8dd46f720d40dc735f4d31dfe2 100644 (file)
@@ -191,7 +191,7 @@ pub struct OpenChannel {
        pub(crate) htlc_basepoint: PublicKey,
        pub(crate) first_per_commitment_point: PublicKey,
        pub(crate) channel_flags: u8,
-       pub(crate) shutdown_scriptpubkey: Option<Script>,
+       pub(crate) shutdown_scriptpubkey: OptionalField<Script>,
 }
 
 /// An accept_channel message to be sent or received from a peer
@@ -211,7 +211,7 @@ pub struct AcceptChannel {
        pub(crate) delayed_payment_basepoint: PublicKey,
        pub(crate) htlc_basepoint: PublicKey,
        pub(crate) first_per_commitment_point: PublicKey,
-       pub(crate) shutdown_scriptpubkey: Option<Script>,
+       pub(crate) shutdown_scriptpubkey: OptionalField<Script>
 }
 
 /// A funding_created message to be sent or received from a peer
@@ -323,7 +323,7 @@ pub struct ChannelReestablish {
        pub(crate) channel_id: [u8; 32],
        pub(crate) next_local_commitment_number: u64,
        pub(crate) next_remote_commitment_number: u64,
-       pub(crate) data_loss_protect: Option<DataLossProtect>,
+       pub(crate) data_loss_protect: OptionalField<DataLossProtect>,
 }
 
 /// An announcement_signatures message to be sent or received from a peer
@@ -518,6 +518,18 @@ pub enum HTLCFailChannelUpdate {
        }
 }
 
+/// Messages could have optional fields to use with extended features
+/// As we wish to serialize these differently from Option<T>s (Options get a tag byte, but
+/// OptionalFeild simply gets Present if there are enough bytes to read into it), we have a
+/// separate enum type for them.
+#[derive(Clone, PartialEq)]
+pub enum OptionalField<T> {
+       /// Optional field is included in message
+       Present(T),
+       /// Optional field is absent in message
+       Absent
+}
+
 /// A trait to describe an object which can receive channel messages.
 ///
 /// Messages MAY be called in parallel when they originate from different their_node_ids, however
@@ -696,8 +708,35 @@ impl From<::std::io::Error> for DecodeError {
        }
 }
 
+impl Writeable for OptionalField<Script> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               match *self {
+                       OptionalField::Present(ref script) => {
+                               // Note that Writeable for script includes the 16-bit length tag for us
+                               script.write(w)?;
+                       },
+                       OptionalField::Absent => {}
+               }
+               Ok(())
+       }
+}
+
+impl<R: Read> Readable<R> for OptionalField<Script> {
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
+               match <u16 as Readable<R>>::read(r) {
+                       Ok(len) => {
+                               let mut buf = vec![0; len as usize];
+                               r.read_exact(&mut buf)?;
+                               Ok(OptionalField::Present(Script::from(buf)))
+                       },
+                       Err(DecodeError::ShortRead) => Ok(OptionalField::Absent),
+                       Err(e) => Err(e)
+               }
+       }
+}
+
 impl_writeable_len_match!(AcceptChannel, {
-               {AcceptChannel{ shutdown_scriptpubkey: Some(ref script), ..}, 270 + 2 + script.len()},
+               {AcceptChannel{ shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 270 + 2 + script.len()},
                {_, 270}
        }, {
        temporary_channel_id,
@@ -726,13 +765,16 @@ impl_writeable!(AnnouncementSignatures, 32+8+64*2, {
 
 impl Writeable for ChannelReestablish {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               w.size_hint(if self.data_loss_protect.is_some() { 32+2*8+33+32 } else { 32+2*8 });
+               w.size_hint(if let OptionalField::Present(..) = self.data_loss_protect { 32+2*8+33+32 } else { 32+2*8 });
                self.channel_id.write(w)?;
                self.next_local_commitment_number.write(w)?;
                self.next_remote_commitment_number.write(w)?;
-               if let Some(ref data_loss_protect) = self.data_loss_protect {
-                       data_loss_protect.your_last_per_commitment_secret.write(w)?;
-                       data_loss_protect.my_current_per_commitment_point.write(w)?;
+               match self.data_loss_protect {
+                       OptionalField::Present(ref data_loss_protect) => {
+                               (*data_loss_protect).your_last_per_commitment_secret.write(w)?;
+                               (*data_loss_protect).my_current_per_commitment_point.write(w)?;
+                       },
+                       OptionalField::Absent => {}
                }
                Ok(())
        }
@@ -747,11 +789,11 @@ impl<R: Read> Readable<R> for ChannelReestablish{
                        data_loss_protect: {
                                match <[u8; 32] as Readable<R>>::read(r) {
                                        Ok(your_last_per_commitment_secret) =>
-                                               Some(DataLossProtect {
+                                               OptionalField::Present(DataLossProtect {
                                                        your_last_per_commitment_secret,
                                                        my_current_per_commitment_point: Readable::read(r)?,
                                                }),
-                                       Err(DecodeError::ShortRead) => None,
+                                       Err(DecodeError::ShortRead) => OptionalField::Absent,
                                        Err(e) => return Err(e)
                                }
                        }
@@ -818,8 +860,8 @@ impl_writeable_len_match!(Init, {
 });
 
 impl_writeable_len_match!(OpenChannel, {
-               { OpenChannel { shutdown_scriptpubkey: Some(ref script), .. }, 319 + 2 + script.len() },
-               { OpenChannel { shutdown_scriptpubkey: None, .. }, 319 }
+               { OpenChannel { shutdown_scriptpubkey: OptionalField::Present(ref script), .. }, 319 + 2 + script.len() },
+               { _, 319 }
        }, {
        chain_hash,
        temporary_channel_id,
@@ -1322,6 +1364,7 @@ impl_writeable_len_match!(NodeAnnouncement, {
 mod tests {
        use hex;
        use ln::msgs;
+       use ln::msgs::OptionalField;
        use util::ser::Writeable;
        use secp256k1::key::{PublicKey,SecretKey};
        use secp256k1::Secp256k1;
@@ -1332,7 +1375,7 @@ mod tests {
                        channel_id: [4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0],
                        next_local_commitment_number: 3,
                        next_remote_commitment_number: 4,
-                       data_loss_protect: None,
+                       data_loss_protect: OptionalField::Absent,
                };
 
                let encoded_value = cr.encode();
@@ -1353,7 +1396,7 @@ mod tests {
                        channel_id: [4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7, 0, 0, 0, 0, 0, 0, 0],
                        next_local_commitment_number: 3,
                        next_remote_commitment_number: 4,
-                       data_loss_protect: Some(msgs::DataLossProtect { your_last_per_commitment_secret: [9;32], my_current_per_commitment_point: public_key}),
+                       data_loss_protect: OptionalField::Present(msgs::DataLossProtect { your_last_per_commitment_secret: [9;32], my_current_per_commitment_point: public_key}),
                };
 
                let encoded_value = cr.encode();
index c1d2802a76da834d03555cf73b0e2f88139d52a5..78fd4c1ce8ab8a77e5e747c390d12520205a68e2 100644 (file)
@@ -302,29 +302,6 @@ impl<R: Read> Readable<R> for Script {
        }
 }
 
-impl Writeable for Option<Script> {
-       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
-               if let &Some(ref script) = self {
-                       script.write(w)?;
-               }
-               Ok(())
-       }
-}
-
-impl<R: Read> Readable<R> for Option<Script> {
-       fn read(r: &mut R) -> Result<Self, DecodeError> {
-               match <u16 as Readable<R>>::read(r) {
-                       Ok(len) => {
-                               let mut buf = vec![0; len as usize];
-                               r.read_exact(&mut buf)?;
-                               Ok(Some(Script::from(buf)))
-                       },
-                       Err(DecodeError::ShortRead) => Ok(None),
-                       Err(e) => Err(e)
-               }
-       }
-}
-
 impl Writeable for PublicKey {
        fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
                self.serialize().write(w)
@@ -413,3 +390,29 @@ impl<R: Read> Readable<R> for PaymentHash {
                Ok(PaymentHash(buf))
        }
 }
+
+impl<T: Writeable> Writeable for Option<T> {
+       fn write<W: Writer>(&self, w: &mut W) -> Result<(), ::std::io::Error> {
+               match *self {
+                       None => 0u8.write(w)?,
+                       Some(ref data) => {
+                               1u8.write(w)?;
+                               data.write(w)?;
+                       }
+               }
+               Ok(())
+       }
+}
+
+impl<R, T> Readable<R> for Option<T>
+       where R: Read,
+             T: Readable<R>
+{
+       fn read(r: &mut R) -> Result<Self, DecodeError> {
+               match <u8 as Readable<R>>::read(r)? {
+                       0 => Ok(None),
+                       1 => Ok(Some(Readable::read(r)?)),
+                       _ => return Err(DecodeError::InvalidValue),
+               }
+       }
+}