From 40959b74b78dd2cf99547c10d2888ce014930b3d Mon Sep 17 00:00:00 2001 From: Valentine Wallace Date: Tue, 22 Jun 2021 16:50:18 -0400 Subject: [PATCH] Fix TLV serialization to work with large types. Previous to this PR, TLV serialization involved iterating from 0 to the highest given TLV type. This worked until we decided to implement keysend, which has a TLV type of ~5.48 billion. So instead, we now specify the type of whatever is being (de)serialized (which can be an Option, a Vec type, or a non-Option (specified in the serialization macros as "required"). --- lightning/src/chain/channelmonitor.rs | 96 +++---- lightning/src/chain/keysinterface.rs | 36 +-- lightning/src/chain/onchaintx.rs | 20 +- lightning/src/chain/package.rs | 85 +++--- lightning/src/ln/chan_utils.rs | 90 +++--- lightning/src/ln/channel.rs | 4 +- lightning/src/ln/channelmanager.rs | 94 +++---- lightning/src/ln/msgs.rs | 26 +- lightning/src/routing/network_graph.rs | 64 ++--- lightning/src/routing/router.rs | 18 +- lightning/src/util/events.rs | 54 ++-- lightning/src/util/ser_macros.rs | 372 ++++++++++++------------- 12 files changed, 462 insertions(+), 497 deletions(-) diff --git a/lightning/src/chain/channelmonitor.rs b/lightning/src/chain/channelmonitor.rs index 0bd21270..3463d1c8 100644 --- a/lightning/src/chain/channelmonitor.rs +++ b/lightning/src/chain/channelmonitor.rs @@ -95,7 +95,7 @@ impl Writeable for ChannelMonitorUpdate { for update_step in self.updates.iter() { update_step.write(w)?; } - write_tlv_fields!(w, {}, {}); + write_tlv_fields!(w, {}); Ok(()) } } @@ -108,7 +108,7 @@ impl Readable for ChannelMonitorUpdate { for _ in 0..len { updates.push(Readable::read(r)?); } - read_tlv_fields!(r, {}, {}); + read_tlv_fields!(r, {}); Ok(Self { update_id, updates }) } } @@ -202,11 +202,10 @@ pub struct HTLCUpdate { pub(crate) source: HTLCSource } impl_writeable_tlv_based!(HTLCUpdate, { - (0, payment_hash), - (2, source), -}, { - (4, payment_preimage) -}, {}); + (0, payment_hash, required), + (2, source, required), + (4, payment_preimage, option), +}); /// If an HTLC expires within this many blocks, don't try to claim it in a shared transaction, /// instead claiming it in its own individual transaction. @@ -273,15 +272,14 @@ struct HolderSignedTx { htlc_outputs: Vec<(HTLCOutputInCommitment, Option, Option)>, } impl_writeable_tlv_based!(HolderSignedTx, { - (0, txid), - (2, revocation_key), - (4, a_htlc_key), - (6, b_htlc_key), - (8, delayed_payment_key), - (10, per_commitment_point), - (12, feerate_per_kw), -}, {}, { - (14, htlc_outputs) + (0, txid, required), + (2, revocation_key, required), + (4, a_htlc_key, required), + (6, b_htlc_key, required), + (8, delayed_payment_key, required), + (10, per_commitment_point, required), + (12, feerate_per_kw, required), + (14, htlc_outputs, vec_type) }); /// We use this to track counterparty commitment transactions and htlcs outputs and @@ -305,10 +303,10 @@ impl Writeable for CounterpartyCommitmentTransaction { } } write_tlv_fields!(w, { - (0, self.counterparty_delayed_payment_base_key), - (2, self.counterparty_htlc_base_key), - (4, self.on_counterparty_tx_csv), - }, {}); + (0, self.counterparty_delayed_payment_base_key, required), + (2, self.counterparty_htlc_base_key, required), + (4, self.on_counterparty_tx_csv, required), + }); Ok(()) } } @@ -333,10 +331,10 @@ impl Readable for CounterpartyCommitmentTransaction { let mut counterparty_htlc_base_key = OptionDeserWrapper(None); let mut on_counterparty_tx_csv: u16 = 0; read_tlv_fields!(r, { - (0, counterparty_delayed_payment_base_key), - (2, counterparty_htlc_base_key), - (4, on_counterparty_tx_csv), - }, {}); + (0, counterparty_delayed_payment_base_key, required), + (2, counterparty_htlc_base_key, required), + (4, on_counterparty_tx_csv, required), + }); CounterpartyCommitmentTransaction { counterparty_delayed_payment_base_key: counterparty_delayed_payment_base_key.0.unwrap(), counterparty_htlc_base_key: counterparty_htlc_base_key.0.unwrap(), @@ -394,19 +392,19 @@ enum OnchainEvent { } impl_writeable_tlv_based!(OnchainEventEntry, { - (0, txid), - (2, height), - (4, event), -}, {}, {}); + (0, txid, required), + (2, height, required), + (4, event, required), +}); impl_writeable_tlv_based_enum!(OnchainEvent, (0, HTLCUpdate) => { - (0, source), - (2, payment_hash), - }, {}, {}, + (0, source, required), + (2, payment_hash, required), + }, (1, MaturingOutput) => { - (0, descriptor), - }, {}, {}, + (0, descriptor, required), + }, ;); #[cfg_attr(any(test, feature = "fuzztarget", feature = "_test_utils"), derive(PartialEq))] @@ -440,27 +438,25 @@ pub(crate) enum ChannelMonitorUpdateStep { impl_writeable_tlv_based_enum!(ChannelMonitorUpdateStep, (0, LatestHolderCommitmentTXInfo) => { - (0, commitment_tx), - }, {}, { - (2, htlc_outputs), + (0, commitment_tx, required), + (2, htlc_outputs, vec_type), }, (1, LatestCounterpartyCommitmentTXInfo) => { - (0, commitment_txid), - (2, commitment_number), - (4, their_revocation_point), - }, {}, { - (6, htlc_outputs), + (0, commitment_txid, required), + (2, commitment_number, required), + (4, their_revocation_point, required), + (6, htlc_outputs, vec_type), }, (2, PaymentPreimage) => { - (0, payment_preimage), - }, {}, {}, + (0, payment_preimage, required), + }, (3, CommitmentSecret) => { - (0, idx), - (2, secret), - }, {}, {}, + (0, idx, required), + (2, secret, required), + }, (4, ChannelForceClosed) => { - (0, should_broadcast), - }, {}, {}, + (0, should_broadcast, required), + }, ;); /// A ChannelMonitor handles chain events (blocks connected and disconnected) and generates @@ -792,7 +788,7 @@ impl Writeable for ChannelMonitorImpl { self.lockdown_from_offchain.write(writer)?; self.holder_tx_signed.write(writer)?; - write_tlv_fields!(writer, {}, {}); + write_tlv_fields!(writer, {}); Ok(()) } @@ -2740,7 +2736,7 @@ impl<'a, Signer: Sign, K: KeysInterface> ReadableArgs<&'a K> let lockdown_from_offchain = Readable::read(reader)?; let holder_tx_signed = Readable::read(reader)?; - read_tlv_fields!(reader, {}, {}); + read_tlv_fields!(reader, {}); let mut secp_ctx = Secp256k1::new(); secp_ctx.seeded_randomize(&keys_manager.get_secure_random_bytes()); diff --git a/lightning/src/chain/keysinterface.rs b/lightning/src/chain/keysinterface.rs index c5ad6f28..d7ff2a63 100644 --- a/lightning/src/chain/keysinterface.rs +++ b/lightning/src/chain/keysinterface.rs @@ -73,14 +73,14 @@ impl DelayedPaymentOutputDescriptor { } impl_writeable_tlv_based!(DelayedPaymentOutputDescriptor, { - (0, outpoint), - (2, per_commitment_point), - (4, to_self_delay), - (6, output), - (8, revocation_pubkey), - (10, channel_keys_id), - (12, channel_value_satoshis), -}, {}, {}); + (0, outpoint, required), + (2, per_commitment_point, required), + (4, to_self_delay, required), + (6, output, required), + (8, revocation_pubkey, required), + (10, channel_keys_id, required), + (12, channel_value_satoshis, required), +}); /// Information about a spendable output to our "payment key". See /// SpendableOutputDescriptor::StaticPaymentOutput for more details on how to spend this. @@ -104,11 +104,11 @@ impl StaticPaymentOutputDescriptor { pub const MAX_WITNESS_LENGTH: usize = 1 + 73 + 34; } impl_writeable_tlv_based!(StaticPaymentOutputDescriptor, { - (0, outpoint), - (2, output), - (4, channel_keys_id), - (6, channel_value_satoshis), -}, {}, {}); + (0, outpoint, required), + (2, output, required), + (4, channel_keys_id, required), + (6, channel_value_satoshis, required), +}); /// When on-chain outputs are created by rust-lightning (which our counterparty is not able to /// claim at any point in the future) an event is generated which you must track and be able to @@ -169,9 +169,9 @@ pub enum SpendableOutputDescriptor { impl_writeable_tlv_based_enum!(SpendableOutputDescriptor, (0, StaticOutput) => { - (0, outpoint), - (2, output), - }, {}, {}, + (0, outpoint, required), + (2, output, required), + }, ; (1, DelayedPaymentOutput), (2, StaticPaymentOutput), @@ -692,7 +692,7 @@ impl Writeable for InMemorySigner { self.channel_value_satoshis.write(writer)?; self.channel_keys_id.write(writer)?; - write_tlv_fields!(writer, {}, {}); + write_tlv_fields!(writer, {}); Ok(()) } @@ -717,7 +717,7 @@ impl Readable for InMemorySigner { &htlc_base_key); let keys_id = Readable::read(reader)?; - read_tlv_fields!(reader, {}, {}); + read_tlv_fields!(reader, {}); Ok(InMemorySigner { funding_key, diff --git a/lightning/src/chain/onchaintx.rs b/lightning/src/chain/onchaintx.rs index 30493eb5..7e616ef2 100644 --- a/lightning/src/chain/onchaintx.rs +++ b/lightning/src/chain/onchaintx.rs @@ -79,18 +79,18 @@ enum OnchainEvent { } impl_writeable_tlv_based!(OnchainEventEntry, { - (0, txid), - (2, height), - (4, event), -}, {}, {}); + (0, txid, required), + (2, height, required), + (4, event, required), +}); impl_writeable_tlv_based_enum!(OnchainEvent, (0, Claim) => { - (0, claim_request), - }, {}, {}, + (0, claim_request, required), + }, (1, ContentiousOutpoint) => { - (0, package), - }, {}, {}, + (0, package, required), + }, ;); impl Readable for Option>> { @@ -236,7 +236,7 @@ impl OnchainTxHandler { entry.write(writer)?; } - write_tlv_fields!(writer, {}, {}); + write_tlv_fields!(writer, {}); Ok(()) } } @@ -298,7 +298,7 @@ impl<'a, K: KeysInterface> ReadableArgs<&'a K> for OnchainTxHandler { onchain_events_awaiting_threshold_conf.push(Readable::read(reader)?); } - read_tlv_fields!(reader, {}, {}); + read_tlv_fields!(reader, {}); let mut secp_ctx = Secp256k1::new(); secp_ctx.seeded_randomize(&keys_manager.get_secure_random_bytes()); diff --git a/lightning/src/chain/package.rs b/lightning/src/chain/package.rs index bd983ecd..b1e7d60a 100644 --- a/lightning/src/chain/package.rs +++ b/lightning/src/chain/package.rs @@ -87,14 +87,14 @@ impl RevokedOutput { } impl_writeable_tlv_based!(RevokedOutput, { - (0, per_commitment_point), - (2, counterparty_delayed_payment_base_key), - (4, counterparty_htlc_base_key), - (6, per_commitment_key), - (8, weight), - (10, amount), - (12, on_counterparty_tx_csv), -}, {}, {}); + (0, per_commitment_point, required), + (2, counterparty_delayed_payment_base_key, required), + (4, counterparty_htlc_base_key, required), + (6, per_commitment_key, required), + (8, weight, required), + (10, amount, required), + (12, on_counterparty_tx_csv, required), +}); /// A struct to describe a revoked offered output and corresponding information to generate a /// solving witness. @@ -131,14 +131,14 @@ impl RevokedHTLCOutput { } impl_writeable_tlv_based!(RevokedHTLCOutput, { - (0, per_commitment_point), - (2, counterparty_delayed_payment_base_key), - (4, counterparty_htlc_base_key), - (6, per_commitment_key), - (8, weight), - (10, amount), - (12, htlc), -}, {}, {}); + (0, per_commitment_point, required), + (2, counterparty_delayed_payment_base_key, required), + (4, counterparty_htlc_base_key, required), + (6, per_commitment_key, required), + (8, weight, required), + (10, amount, required), + (12, htlc, required), +}); /// A struct to describe a HTLC output on a counterparty commitment transaction. /// @@ -168,12 +168,12 @@ impl CounterpartyOfferedHTLCOutput { } impl_writeable_tlv_based!(CounterpartyOfferedHTLCOutput, { - (0, per_commitment_point), - (2, counterparty_delayed_payment_base_key), - (4, counterparty_htlc_base_key), - (6, preimage), - (8, htlc), -}, {}, {}); + (0, per_commitment_point, required), + (2, counterparty_delayed_payment_base_key, required), + (4, counterparty_htlc_base_key, required), + (6, preimage, required), + (8, htlc, required), +}); /// A struct to describe a HTLC output on a counterparty commitment transaction. /// @@ -199,11 +199,11 @@ impl CounterpartyReceivedHTLCOutput { } impl_writeable_tlv_based!(CounterpartyReceivedHTLCOutput, { - (0, per_commitment_point), - (2, counterparty_delayed_payment_base_key), - (4, counterparty_htlc_base_key), - (6, htlc), -}, {}, {}); + (0, per_commitment_point, required), + (2, counterparty_delayed_payment_base_key, required), + (4, counterparty_htlc_base_key, required), + (6, htlc, required), +}); /// A struct to describe a HTLC output on holder commitment transaction. /// @@ -236,11 +236,10 @@ impl HolderHTLCOutput { } impl_writeable_tlv_based!(HolderHTLCOutput, { - (0, amount), - (2, cltv_expiry), -}, { - (4, preimage), -}, {}); + (0, amount, required), + (2, cltv_expiry, required), + (4, preimage, option) +}); /// A struct to describe the channel output on the funding transaction. /// @@ -259,8 +258,8 @@ impl HolderFundingOutput { } impl_writeable_tlv_based!(HolderFundingOutput, { - (0, funding_redeemscript), -}, {}, {}); + (0, funding_redeemscript, required), +}); /// A wrapper encapsulating all in-protocol differing outputs types. /// @@ -690,10 +689,11 @@ impl Writeable for PackageTemplate { rev_outp.write(writer)?; } write_tlv_fields!(writer, { - (0, self.soonest_conf_deadline), - (2, self.feerate_previous), - (4, self.height_original), - }, { (6, self.height_timer) }); + (0, self.soonest_conf_deadline, required), + (2, self.feerate_previous, required), + (4, self.height_original, required), + (6, self.height_timer, option) + }); Ok(()) } } @@ -722,10 +722,11 @@ impl Readable for PackageTemplate { let mut height_timer = None; let mut height_original = 0; read_tlv_fields!(reader, { - (0, soonest_conf_deadline), - (2, feerate_previous), - (4, height_original) - }, { (6, height_timer) }); + (0, soonest_conf_deadline, required), + (2, feerate_previous, required), + (4, height_original, required), + (6, height_timer, option), + }); Ok(PackageTemplate { inputs, malleability, diff --git a/lightning/src/ln/chan_utils.rs b/lightning/src/ln/chan_utils.rs index 9f98cd03..6e0e5085 100644 --- a/lightning/src/ln/chan_utils.rs +++ b/lightning/src/ln/chan_utils.rs @@ -172,7 +172,7 @@ impl Writeable for CounterpartyCommitmentSecrets { writer.write_all(secret)?; writer.write_all(&byte_utils::be64_to_array(*idx))?; } - write_tlv_fields!(writer, {}, {}); + write_tlv_fields!(writer, {}); Ok(()) } } @@ -183,7 +183,7 @@ impl Readable for CounterpartyCommitmentSecrets { *secret = Readable::read(reader)?; *idx = Readable::read(reader)?; } - read_tlv_fields!(reader, {}, {}); + read_tlv_fields!(reader, {}); Ok(Self { old_secrets }) } } @@ -318,12 +318,12 @@ pub struct TxCreationKeys { } impl_writeable_tlv_based!(TxCreationKeys, { - (0, per_commitment_point), - (2, revocation_key), - (4, broadcaster_htlc_key), - (6, countersignatory_htlc_key), - (8, broadcaster_delayed_payment_key), -}, {}, {}); + (0, per_commitment_point, required), + (2, revocation_key, required), + (4, broadcaster_htlc_key, required), + (6, countersignatory_htlc_key, required), + (8, broadcaster_delayed_payment_key, required), +}); /// One counterparty's public keys which do not change over the life of a channel. #[derive(Clone, PartialEq)] @@ -350,12 +350,12 @@ pub struct ChannelPublicKeys { } impl_writeable_tlv_based!(ChannelPublicKeys, { - (0, funding_pubkey), - (2, revocation_basepoint), - (4, payment_point), - (6, delayed_payment_basepoint), - (8, htlc_basepoint), -}, {}, {}); + (0, funding_pubkey, required), + (2, revocation_basepoint, required), + (4, payment_point, required), + (6, delayed_payment_basepoint, required), + (8, htlc_basepoint, required), +}); impl TxCreationKeys { /// Create per-state keys from channel base points and the per-commitment point. @@ -429,13 +429,12 @@ pub struct HTLCOutputInCommitment { } impl_writeable_tlv_based!(HTLCOutputInCommitment, { - (0, offered), - (2, amount_msat), - (4, cltv_expiry), - (6, payment_hash), -}, { - (8, transaction_output_index) -}, {}); + (0, offered, required), + (2, amount_msat, required), + (4, cltv_expiry, required), + (6, payment_hash, required), + (8, transaction_output_index, option), +}); #[inline] pub(crate) fn get_htlc_redeemscript_with_explicit_keys(htlc: &HTLCOutputInCommitment, broadcaster_htlc_key: &PublicKey, countersignatory_htlc_key: &PublicKey, revocation_key: &PublicKey) -> Script { @@ -626,18 +625,17 @@ impl ChannelTransactionParameters { } impl_writeable_tlv_based!(CounterpartyChannelTransactionParameters, { - (0, pubkeys), - (2, selected_contest_delay), -}, {}, {}); + (0, pubkeys, required), + (2, selected_contest_delay, required), +}); impl_writeable_tlv_based!(ChannelTransactionParameters, { - (0, holder_pubkeys), - (2, holder_selected_contest_delay), - (4, is_outbound_from_holder), -}, { - (6, counterparty_parameters), - (8, funding_outpoint), -}, {}); + (0, holder_pubkeys, required), + (2, holder_selected_contest_delay, required), + (4, is_outbound_from_holder, required), + (6, counterparty_parameters, option), + (8, funding_outpoint, option), +}); /// Static channel fields used to build transactions given per-commitment fields, organized by /// broadcaster/countersignatory. @@ -720,11 +718,10 @@ impl PartialEq for HolderCommitmentTransaction { } impl_writeable_tlv_based!(HolderCommitmentTransaction, { - (0, inner), - (2, counterparty_sig), - (4, holder_sig_first), -}, {}, { - (6, counterparty_htlc_sigs), + (0, inner, required), + (2, counterparty_sig, required), + (4, holder_sig_first, required), + (6, counterparty_htlc_sigs, vec_type), }); impl HolderCommitmentTransaction { @@ -809,9 +806,9 @@ pub struct BuiltCommitmentTransaction { } impl_writeable_tlv_based!(BuiltCommitmentTransaction, { - (0, transaction), - (2, txid) -}, {}, {}); + (0, transaction, required), + (2, txid, required), +}); impl BuiltCommitmentTransaction { /// Get the SIGHASH_ALL sighash value of the transaction. @@ -866,14 +863,13 @@ impl PartialEq for CommitmentTransaction { } impl_writeable_tlv_based!(CommitmentTransaction, { - (0, commitment_number), - (2, to_broadcaster_value_sat), - (4, to_countersignatory_value_sat), - (6, feerate_per_kw), - (8, keys), - (10, built), -}, {}, { - (12, htlcs), + (0, commitment_number, required), + (2, to_broadcaster_value_sat, required), + (4, to_countersignatory_value_sat, required), + (6, feerate_per_kw, required), + (8, keys, required), + (10, built, required), + (12, htlcs, vec_type), }); impl CommitmentTransaction { diff --git a/lightning/src/ln/channel.rs b/lightning/src/ln/channel.rs index f8c8c15b..80240d0a 100644 --- a/lightning/src/ln/channel.rs +++ b/lightning/src/ln/channel.rs @@ -4606,7 +4606,7 @@ impl Writeable for Channel { self.channel_update_status.write(writer)?; - write_tlv_fields!(writer, {}, {(0, self.announcement_sigs)}); + write_tlv_fields!(writer, {(0, self.announcement_sigs, option)}); Ok(()) } @@ -4779,7 +4779,7 @@ impl<'a, Signer: Sign, K: Deref> ReadableArgs<&'a K> for Channel let channel_update_status = Readable::read(reader)?; let mut announcement_sigs = None; - read_tlv_fields!(reader, {}, {(0, announcement_sigs)}); + read_tlv_fields!(reader, {(0, announcement_sigs, option)}); let mut secp_ctx = Secp256k1::new(); secp_ctx.seeded_randomize(&keys_source.get_secure_random_bytes()); diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index ee266849..e7b1ff48 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -4320,22 +4320,22 @@ const MIN_SERIALIZATION_VERSION: u8 = 1; impl_writeable_tlv_based_enum!(PendingHTLCRouting, (0, Forward) => { - (0, onion_packet), - (2, short_channel_id), - }, {}, {}, + (0, onion_packet, required), + (2, short_channel_id, required), + }, (1, Receive) => { - (0, payment_data), - (2, incoming_cltv_expiry), - }, {}, {} + (0, payment_data, required), + (2, incoming_cltv_expiry, required), + } ;); impl_writeable_tlv_based!(PendingHTLCInfo, { - (0, routing), - (2, incoming_shared_secret), - (4, payment_hash), - (6, amt_to_forward), - (8, outgoing_cltv_value) -}, {}, {}); + (0, routing, required), + (2, incoming_shared_secret, required), + (4, payment_hash, required), + (6, amt_to_forward, required), + (8, outgoing_cltv_value, required) +}); impl_writeable_tlv_based_enum!(HTLCFailureMsg, ; (0, Relay), @@ -4347,60 +4347,58 @@ impl_writeable_tlv_based_enum!(PendingHTLCStatus, ; ); impl_writeable_tlv_based!(HTLCPreviousHopData, { - (0, short_channel_id), - (2, outpoint), - (4, htlc_id), - (6, incoming_packet_shared_secret) -}, {}, {}); + (0, short_channel_id, required), + (2, outpoint, required), + (4, htlc_id, required), + (6, incoming_packet_shared_secret, required) +}); impl_writeable_tlv_based!(ClaimableHTLC, { - (0, prev_hop), - (2, value), - (4, payment_data), - (6, cltv_expiry), -}, {}, {}); + (0, prev_hop, required), + (2, value, required), + (4, payment_data, required), + (6, cltv_expiry, required), +}); impl_writeable_tlv_based_enum!(HTLCSource, (0, OutboundRoute) => { - (0, session_priv), - (2, first_hop_htlc_msat), - }, {}, { - (4, path), - }; + (0, session_priv, required), + (2, first_hop_htlc_msat, required), + (4, path, vec_type), + }, ; (1, PreviousHopData) ); impl_writeable_tlv_based_enum!(HTLCFailReason, (0, LightningError) => { - (0, err), - }, {}, {}, + (0, err, required), + }, (1, Reason) => { - (0, failure_code), - }, {}, { - (2, data), + (0, failure_code, required), + (2, data, vec_type), }, ;); impl_writeable_tlv_based_enum!(HTLCForwardInfo, (0, AddHTLC) => { - (0, forward_info), - (2, prev_short_channel_id), - (4, prev_htlc_id), - (6, prev_funding_outpoint), - }, {}, {}, + (0, forward_info, required), + (2, prev_short_channel_id, required), + (4, prev_htlc_id, required), + (6, prev_funding_outpoint, required), + }, (1, FailHTLC) => { - (0, htlc_id), - (2, err_packet), - }, {}, {}, + (0, htlc_id, required), + (2, err_packet, required), + }, ;); impl_writeable_tlv_based!(PendingInboundPayment, { - (0, payment_secret), - (2, expiry_time), - (4, user_payment_id), - (6, payment_preimage), - (8, min_value_msat), -}, {}, {}); + (0, payment_secret, required), + (2, expiry_time, required), + (4, user_payment_id, required), + (6, payment_preimage, required), + (8, min_value_msat, required), +}); impl Writeable for ChannelManager where M::Target: chain::Watch, @@ -4495,7 +4493,7 @@ impl Writeable f session_priv.write(writer)?; } - write_tlv_fields!(writer, {}, {}); + write_tlv_fields!(writer, {}); Ok(()) } @@ -4740,7 +4738,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> } } - read_tlv_fields!(reader, {}, {}); + read_tlv_fields!(reader, {}); let mut secp_ctx = Secp256k1::new(); secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes()); diff --git a/lightning/src/ln/msgs.rs b/lightning/src/ln/msgs.rs index f02a4783..a64bbc01 100644 --- a/lightning/src/ln/msgs.rs +++ b/lightning/src/ln/msgs.rs @@ -1295,20 +1295,19 @@ impl Writeable for OnionHopData { }, OnionHopDataFormat::NonFinalNode { short_channel_id } => { encode_varint_length_prefixed_tlv!(w, { - (2, HighZeroBytesDroppedVarInt(self.amt_to_forward)), - (4, HighZeroBytesDroppedVarInt(self.outgoing_cltv_value)), - (6, short_channel_id) - }, { }); + (2, HighZeroBytesDroppedVarInt(self.amt_to_forward), required), + (4, HighZeroBytesDroppedVarInt(self.outgoing_cltv_value), required), + (6, short_channel_id, required) + }); }, OnionHopDataFormat::FinalNode { ref payment_data } => { if let Some(final_data) = payment_data { if final_data.total_msat > MAX_VALUE_MSAT { panic!("We should never be sending infinite/overflow onion payments"); } } encode_varint_length_prefixed_tlv!(w, { - (2, HighZeroBytesDroppedVarInt(self.amt_to_forward)), - (4, HighZeroBytesDroppedVarInt(self.outgoing_cltv_value)) - }, { - (8, payment_data) + (2, HighZeroBytesDroppedVarInt(self.amt_to_forward), required), + (4, HighZeroBytesDroppedVarInt(self.outgoing_cltv_value), required), + (8, payment_data, option) }); }, } @@ -1331,12 +1330,11 @@ impl Readable for OnionHopData { let mut cltv_value = HighZeroBytesDroppedVarInt(0u32); let mut short_id: Option = None; let mut payment_data: Option = None; - decode_tlv!(&mut rd, { - (2, amt), - (4, cltv_value) - }, { - (6, short_id), - (8, payment_data) + decode_tlv_stream!(&mut rd, { + (2, amt, required), + (4, cltv_value, required), + (6, short_id, option), + (8, payment_data, option), }); rd.eat_remaining().map_err(|_| DecodeError::ShortRead)?; let format = if let Some(short_channel_id) = short_id { diff --git a/lightning/src/routing/network_graph.rs b/lightning/src/routing/network_graph.rs index b4448f72..7f086603 100644 --- a/lightning/src/routing/network_graph.rs +++ b/lightning/src/routing/network_graph.rs @@ -460,14 +460,14 @@ impl fmt::Display for DirectionalChannelInfo { } impl_writeable_tlv_based!(DirectionalChannelInfo, { - (0, last_update), - (2, enabled), - (4, cltv_expiry_delta), - (6, htlc_minimum_msat), - (8, htlc_maximum_msat), - (10, fees), - (12, last_update_message), -}, {}, {}); + (0, last_update, required), + (2, enabled, required), + (4, cltv_expiry_delta, required), + (6, htlc_minimum_msat, required), + (8, htlc_maximum_msat, required), + (10, fees, required), + (12, last_update_message, required), +}); #[derive(Clone, Debug, PartialEq)] /// Details about a channel (both directions). @@ -501,14 +501,14 @@ impl fmt::Display for ChannelInfo { } impl_writeable_tlv_based!(ChannelInfo, { - (0, features), - (2, node_one), - (4, one_to_two), - (6, node_two), - (8, two_to_one), - (10, capacity_sats), - (12, announcement_message), -}, {}, {}); + (0, features, required), + (2, node_one, required), + (4, one_to_two, required), + (6, node_two, required), + (8, two_to_one, required), + (10, capacity_sats, required), + (12, announcement_message, required), +}); /// Fees for routing via a given channel or a node @@ -521,7 +521,10 @@ pub struct RoutingFees { pub proportional_millionths: u32, } -impl_writeable_tlv_based!(RoutingFees, {(0, base_msat), (2, proportional_millionths)}, {}, {}); +impl_writeable_tlv_based!(RoutingFees, { + (0, base_msat, required), + (2, proportional_millionths, required) +}); #[derive(Clone, Debug, PartialEq)] /// Information received in the latest node_announcement from this node. @@ -547,14 +550,12 @@ pub struct NodeAnnouncementInfo { } impl_writeable_tlv_based!(NodeAnnouncementInfo, { - (0, features), - (2, last_update), - (4, rgb), - (6, alias), -}, { - (8, announcement_message), -}, { - (10, addresses), + (0, features, required), + (2, last_update, required), + (4, rgb, required), + (6, alias, required), + (8, announcement_message, option), + (10, addresses, vec_type), }); #[derive(Clone, Debug, PartialEq)] @@ -580,11 +581,10 @@ impl fmt::Display for NodeInfo { } } -impl_writeable_tlv_based!(NodeInfo, {}, { - (0, lowest_inbound_channel_fees), - (2, announcement_info), -}, { - (4, channels), +impl_writeable_tlv_based!(NodeInfo, { + (0, lowest_inbound_channel_fees, option), + (2, announcement_info, option), + (4, channels, vec_type), }); const SERIALIZATION_VERSION: u8 = 1; @@ -606,7 +606,7 @@ impl Writeable for NetworkGraph { node_info.write(writer)?; } - write_tlv_fields!(writer, {}, {}); + write_tlv_fields!(writer, {}); Ok(()) } } @@ -630,7 +630,7 @@ impl Readable for NetworkGraph { let node_info = Readable::read(reader)?; nodes.insert(node_id, node_info); } - read_tlv_fields!(reader, {}, {}); + read_tlv_fields!(reader, {}); Ok(NetworkGraph { genesis_hash, diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 0edeef0c..12374bad 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -49,13 +49,13 @@ pub struct RouteHop { } impl_writeable_tlv_based!(RouteHop, { - (0, pubkey), - (2, node_features), - (4, short_channel_id), - (6, channel_features), - (8, fee_msat), - (10, cltv_expiry_delta), -}, {}, {}); + (0, pubkey, required), + (2, node_features, required), + (4, short_channel_id, required), + (6, channel_features, required), + (8, fee_msat, required), + (10, cltv_expiry_delta, required), +}); /// A route directs a payment from the sender (us) to the recipient. If the recipient supports MPP, /// it can take multiple paths. Each path is composed of one or more hops through the network. @@ -83,7 +83,7 @@ impl Writeable for Route { hop.write(writer)?; } } - write_tlv_fields!(writer, {}, {}); + write_tlv_fields!(writer, {}); Ok(()) } } @@ -101,7 +101,7 @@ impl Readable for Route { } paths.push(hops); } - read_tlv_fields!(reader, {}, {}); + read_tlv_fields!(reader, {}); Ok(Route { paths }) } } diff --git a/lightning/src/util/events.rs b/lightning/src/util/events.rs index c8c7ad49..1bdbff03 100644 --- a/lightning/src/util/events.rs +++ b/lightning/src/util/events.rs @@ -148,19 +148,18 @@ impl Writeable for Event { &Event::PaymentReceived { ref payment_hash, ref payment_preimage, ref payment_secret, ref amt, ref user_payment_id } => { 1u8.write(writer)?; write_tlv_fields!(writer, { - (0, payment_hash), - (2, payment_secret), - (4, amt), - (6, user_payment_id), - }, { - (8, payment_preimage), + (0, payment_hash, required), + (2, payment_secret, required), + (4, amt, required), + (6, user_payment_id, required), + (8, payment_preimage, option), }); }, &Event::PaymentSent { ref payment_preimage } => { 2u8.write(writer)?; write_tlv_fields!(writer, { - (0, payment_preimage), - }, {}); + (0, payment_preimage, required), + }); payment_preimage.write(writer)?; }, &Event::PaymentFailed { ref payment_hash, ref rejected_by_dest, @@ -175,21 +174,21 @@ impl Writeable for Event { #[cfg(test)] error_data.write(writer)?; write_tlv_fields!(writer, { - (0, payment_hash), - (2, rejected_by_dest), - }, {}); + (0, payment_hash, required), + (2, rejected_by_dest, required), + }); }, &Event::PendingHTLCsForwardable { time_forwardable: _ } => { 4u8.write(writer)?; - write_tlv_fields!(writer, {}, {}); + write_tlv_fields!(writer, {}); // We don't write the time_fordwardable out at all, as we presume when the user // deserializes us at least that much time has elapsed. }, &Event::SpendableOutputs { ref outputs } => { 5u8.write(writer)?; write_tlv_fields!(writer, { - (0, VecWriteWrapper(outputs)), - }, {}); + (0, VecWriteWrapper(outputs), required), + }); }, } Ok(()) @@ -207,12 +206,11 @@ impl MaybeReadable for Event { let mut amt = 0; let mut user_payment_id = 0; read_tlv_fields!(reader, { - (0, payment_hash), - (2, payment_secret), - (4, amt), - (6, user_payment_id), - }, { - (8, payment_preimage), + (0, payment_hash, required), + (2, payment_secret, required), + (4, amt, required), + (6, user_payment_id, required), + (8, payment_preimage, option), }); Ok(Some(Event::PaymentReceived { payment_hash, @@ -228,8 +226,8 @@ impl MaybeReadable for Event { let f = || { let mut payment_preimage = PaymentPreimage([0; 32]); read_tlv_fields!(reader, { - (0, payment_preimage), - }, {}); + (0, payment_preimage, required), + }); Ok(Some(Event::PaymentSent { payment_preimage, })) @@ -245,9 +243,9 @@ impl MaybeReadable for Event { let mut payment_hash = PaymentHash([0; 32]); let mut rejected_by_dest = false; read_tlv_fields!(reader, { - (0, payment_hash), - (2, rejected_by_dest), - }, {}); + (0, payment_hash, required), + (2, rejected_by_dest, required), + }); Ok(Some(Event::PaymentFailed { payment_hash, rejected_by_dest, @@ -261,7 +259,7 @@ impl MaybeReadable for Event { }, 4u8 => { let f = || { - read_tlv_fields!(reader, {}, {}); + read_tlv_fields!(reader, {}); Ok(Some(Event::PendingHTLCsForwardable { time_forwardable: Duration::from_secs(0) })) @@ -272,8 +270,8 @@ impl MaybeReadable for Event { let f = || { let mut outputs = VecReadWrapper(Vec::new()); read_tlv_fields!(reader, { - (0, outputs), - }, {}); + (0, outputs, required), + }); Ok(Some(Event::SpendableOutputs { outputs: outputs.0 })) }; f() diff --git a/lightning/src/util/ser_macros.rs b/lightning/src/util/ser_macros.rs index cf780ef0..b93115dc 100644 --- a/lightning/src/util/ser_macros.rs +++ b/lightning/src/util/ser_macros.rs @@ -8,81 +8,132 @@ // licenses. macro_rules! encode_tlv { - ($stream: expr, {$(($type: expr, $field: expr)),*}, {$(($optional_type: expr, $optional_field: expr)),*}) => { { + ($stream: expr, $type: expr, $field: expr, required) => { + BigSize($type).write($stream)?; + BigSize($field.serialized_length() as u64).write($stream)?; + $field.write($stream)?; + }; + ($stream: expr, $type: expr, $field: expr, vec_type) => { + encode_tlv!($stream, $type, ::util::ser::VecWriteWrapper(&$field), required); + }; + ($stream: expr, $optional_type: expr, $optional_field: expr, option) => { + if let Some(ref field) = $optional_field { + BigSize($optional_type).write($stream)?; + BigSize(field.serialized_length() as u64).write($stream)?; + field.write($stream)?; + } + }; +} + +macro_rules! encode_tlv_stream { + ($stream: expr, {$(($type: expr, $field: expr, $fieldty: ident)),*}) => { { #[allow(unused_imports)] - use util::ser::BigSize; - // Fields must be serialized in order, so we have to potentially switch between optional - // fields and normal fields while serializing. Thus, we end up having to loop over the type - // counts. - // Sadly, while LLVM does appear smart enough to make `max_field` a constant, it appears to - // refuse to unroll the loop. If we have enough entries that this is slow we can revisit - // this design in the future. - #[allow(unused_mut)] - let mut max_field: u64 = 0; - $( - if $type >= max_field { max_field = $type + 1; } - )* + use { + ln::msgs::DecodeError, + util::ser, + util::ser::BigSize, + }; + $( - if $optional_type >= max_field { max_field = $optional_type + 1; } + encode_tlv!($stream, $type, $field, $fieldty); )* - #[allow(unused_variables)] - for i in 0..max_field { - $( - if i == $type { - BigSize($type).write($stream)?; - BigSize($field.serialized_length() as u64).write($stream)?; - $field.write($stream)?; - } - )* + + #[allow(unused_mut, unused_variables, unused_assignments)] + #[cfg(debug_assertions)] + { + let mut last_seen: Option = None; $( - if i == $optional_type { - if let Some(ref field) = $optional_field { - BigSize($optional_type).write($stream)?; - BigSize(field.serialized_length() as u64).write($stream)?; - field.write($stream)?; - } + if let Some(t) = last_seen { + debug_assert!(t <= $type); } + last_seen = Some($type); )* } } } } macro_rules! get_varint_length_prefixed_tlv_length { - ({$(($type: expr, $field: expr)),*}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}) => { { - use util::ser::LengthCalculatingWriter; - #[allow(unused_mut)] - let mut len = LengthCalculatingWriter(0); - { - $( - BigSize($type).write(&mut len).expect("No in-memory data may fail to serialize"); - let field_len = $field.serialized_length(); - BigSize(field_len as u64).write(&mut len).expect("No in-memory data may fail to serialize"); - len.0 += field_len; - )* - $( - if let Some(ref field) = $optional_field { - BigSize($optional_type).write(&mut len).expect("No in-memory data may fail to serialize"); - let field_len = field.serialized_length(); - BigSize(field_len as u64).write(&mut len).expect("No in-memory data may fail to serialize"); - len.0 += field_len; - } - )* + ($len: expr, $type: expr, $field: expr, required) => { + BigSize($type).write(&mut $len).expect("No in-memory data may fail to serialize"); + let field_len = $field.serialized_length(); + BigSize(field_len as u64).write(&mut $len).expect("No in-memory data may fail to serialize"); + $len.0 += field_len; + }; + ($len: expr, $type: expr, $field: expr, vec_type) => { + get_varint_length_prefixed_tlv_length!($len, $type, ::util::ser::VecWriteWrapper(&$field), required); + }; + ($len: expr, $optional_type: expr, $optional_field: expr, option) => { + if let Some(ref field) = $optional_field { + BigSize($optional_type).write(&mut $len).expect("No in-memory data may fail to serialize"); + let field_len = field.serialized_length(); + BigSize(field_len as u64).write(&mut $len).expect("No in-memory data may fail to serialize"); + $len.0 += field_len; } - len.0 - } } + }; } macro_rules! encode_varint_length_prefixed_tlv { - ($stream: expr, {$(($type: expr, $field: expr)),*}, {$(($optional_type: expr, $optional_field: expr)),*}) => { { + ($stream: expr, {$(($type: expr, $field: expr, $fieldty: ident)),*}) => { { use util::ser::BigSize; - let len = get_varint_length_prefixed_tlv_length!({ $(($type, $field)),* }, { $(($optional_type, $optional_field)),* }); + let len = { + #[allow(unused_mut)] + let mut len = ::util::ser::LengthCalculatingWriter(0); + $( + get_varint_length_prefixed_tlv_length!(len, $type, $field, $fieldty); + )* + len.0 + }; BigSize(len as u64).write($stream)?; - encode_tlv!($stream, { $(($type, $field)),* }, { $(($optional_type, $optional_field)),* }); + encode_tlv_stream!($stream, { $(($type, $field, $fieldty)),* }); } } } +macro_rules! check_tlv_order { + ($last_seen_type: expr, $typ: expr, $type: expr, required) => {{ + #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true + let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type; + if invalid_order { + Err(DecodeError::InvalidValue)? + } + }}; + ($last_seen_type: expr, $typ: expr, $type: expr, option) => {{ + // no-op + }}; + ($last_seen_type: expr, $typ: expr, $type: expr, vec_type) => {{ + // no-op + }}; +} + +macro_rules! check_missing_tlv { + ($last_seen_type: expr, $type: expr, required) => {{ + #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true + let missing_req_type = $last_seen_type.is_none() || $last_seen_type.unwrap() < $type; + if missing_req_type { + Err(DecodeError::InvalidValue)? + } + }}; + ($last_seen_type: expr, $type: expr, vec_type) => {{ + // no-op + }}; + ($last_seen_type: expr, $type: expr, option) => {{ + // no-op + }}; +} + macro_rules! decode_tlv { - ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),*}, {$(($type: expr, $field: ident)),*}) => { { + ($reader: expr, $field: ident, required) => {{ + $field = ser::Readable::read(&mut $reader)?; + }}; + ($reader: expr, $field: ident, vec_type) => {{ + $field = Some(ser::Readable::read(&mut $reader)?); + }}; + ($reader: expr, $field: ident, option) => {{ + $field = Some(ser::Readable::read(&mut $reader)?); + }}; +} + +macro_rules! decode_tlv_stream { + ($stream: expr, {$(($type: expr, $field: ident, $fieldty: ident)),* $(,)*}) => { { use ln::msgs::DecodeError; let mut last_seen_type: Option = None; 'tlv_read: loop { @@ -117,11 +168,7 @@ macro_rules! decode_tlv { } // As we read types, make sure we hit every required type: $({ - #[allow(unused_comparisons)] // Note that $reqtype may be 0 making the second comparison always true - let invalid_order = (last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype) && typ.0 > $reqtype; - if invalid_order { - Err(DecodeError::InvalidValue)? - } + check_tlv_order!(last_seen_type, typ, $type, $fieldty); })* last_seen_type = Some(typ.0); @@ -129,15 +176,8 @@ macro_rules! decode_tlv { let length: ser::BigSize = Readable::read($stream)?; let mut s = ser::FixedLengthReader::new($stream, length.0); match typ.0 { - $($reqtype => { - $reqfield = ser::Readable::read(&mut s)?; - if s.bytes_remain() { - s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes - Err(DecodeError::InvalidValue)? - } - },)* $($type => { - $field = Some(ser::Readable::read(&mut s)?); + decode_tlv!(s, $field, $fieldty); if s.bytes_remain() { s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes Err(DecodeError::InvalidValue)? @@ -152,11 +192,7 @@ macro_rules! decode_tlv { } // Make sure we got to each required type after we've read every TLV: $({ - #[allow(unused_comparisons)] // Note that $reqtype may be 0 making the second comparison always true - let missing_req_type = last_seen_type.is_none() || last_seen_type.unwrap() < $reqtype; - if missing_req_type { - Err(DecodeError::InvalidValue)? - } + check_missing_tlv!(last_seen_type, $type, $fieldty); })* } } } @@ -172,8 +208,7 @@ macro_rules! impl_writeable { { // In tests, assert that the hard-coded length matches the actual one if $len != 0 { - use util::ser::LengthCalculatingWriter; - let mut len_calc = LengthCalculatingWriter(0); + let mut len_calc = ::util::ser::LengthCalculatingWriter(0); $( self.$field.write(&mut len_calc).expect("No in-memory data may fail to serialize"); )* assert_eq!(len_calc.0, $len); assert_eq!(self.serialized_length(), $len); @@ -219,8 +254,7 @@ macro_rules! impl_writeable_len_match { #[cfg(any(test, feature = "fuzztarget"))] { // In tests, assert that the hard-coded length matches the actual one - use util::ser::LengthCalculatingWriter; - let mut len_calc = LengthCalculatingWriter(0); + let mut len_calc = ::util::ser::LengthCalculatingWriter(0); $( self.$field.write(&mut len_calc).expect("No in-memory data may fail to serialize"); )* assert!(len_calc.0 $cmp len); assert_eq!(len_calc.0, self.serialized_length()); @@ -292,8 +326,8 @@ macro_rules! write_ver_prefix { /// This is the preferred method of adding new fields that old nodes can ignore and still function /// correctly. macro_rules! write_tlv_fields { - ($stream: expr, {$(($type: expr, $field: expr)),* $(,)*}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}) => { - encode_varint_length_prefixed_tlv!($stream, {$(($type, $field)),*} , {$(($optional_type, $optional_field)),*}); + ($stream: expr, {$(($type: expr, $field: expr, $fieldty: ident)),* $(,)*}) => { + encode_varint_length_prefixed_tlv!($stream, {$(($type, $field, $fieldty)),*}); } } @@ -313,103 +347,65 @@ macro_rules! read_ver_prefix { /// Reads a suffix added by write_tlv_fields. macro_rules! read_tlv_fields { - ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),* $(,)*}, {$(($type: expr, $field: ident)),* $(,)*}) => { { + ($stream: expr, {$(($type: expr, $field: ident, $fieldty: ident)),* $(,)*}) => { { let tlv_len = ::util::ser::BigSize::read($stream)?; let mut rd = ::util::ser::FixedLengthReader::new($stream, tlv_len.0); - decode_tlv!(&mut rd, {$(($reqtype, $reqfield)),*}, {$(($type, $field)),*}); + decode_tlv_stream!(&mut rd, {$(($type, $field, $fieldty)),*}); rd.eat_remaining().map_err(|_| ::ln::msgs::DecodeError::ShortRead)?; } } } -// If we naively create a struct in impl_writeable_tlv_based below, we may end up returning -// `Self { ,,vecfield: vecfield }` which is obviously incorrect. Instead, we have to match here to -// detect at least one empty field set and skip the potentially-extra comma. -macro_rules! _init_tlv_based_struct { - ($($type: ident)::*, {}, {$($field: ident),*}, {$($vecfield: ident),*}) => { - Ok($($type)::* { - $($field),*, - $($vecfield: $vecfield.unwrap().0),* - }) +macro_rules! init_tlv_based_struct_field { + ($field: ident, option) => { + $field }; - ($($type: ident)::*, {$($reqfield: ident),*}, {}, {$($vecfield: ident),*}) => { - Ok($($type)::* { - $($reqfield: $reqfield.0.unwrap()),*, - $($vecfield: $vecfield.unwrap().0),* - }) + ($field: ident, required) => { + $field.0.unwrap() }; - ($($type: ident)::*, {$($reqfield: ident),*}, {$($field: ident),*}, {}) => { - Ok($($type)::* { - $($reqfield: $reqfield.0.unwrap()),*, - $($field),* - }) + ($field: ident, vec_type) => { + $field.unwrap().0 }; - ($($type: ident)::*, {$($reqfield: ident),*}, {$($field: ident),*}, {$($vecfield: ident),*}) => { - Ok($($type)::* { - $($reqfield: $reqfield.0.unwrap()),*, - $($field),*, - $($vecfield: $vecfield.unwrap().0),* - }) - } } -// If we don't have any optional types below, but do have some vec types, we end up calling -// `write_tlv_field!($stream, {..}, {, (vec_ty, vec_val)})`, which is obviously broken. -// Instead, for write and read we match the missing values and skip the extra comma. -macro_rules! _write_tlv_fields { - ($stream: expr, {$(($type: expr, $field: expr)),* $(,)*}, {}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}) => { - write_tlv_fields!($stream, {$(($type, $field)),*} , {$(($optional_type, $optional_field)),*}); +macro_rules! init_tlv_field_var { + ($field: ident, required) => { + let mut $field = ::util::ser::OptionDeserWrapper(None); }; - ($stream: expr, {$(($type: expr, $field: expr)),* $(,)*}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}, {$(($optional_type_2: expr, $optional_field_2: expr)),* $(,)*}) => { - write_tlv_fields!($stream, {$(($type, $field)),*} , {$(($optional_type, $optional_field)),*, $(($optional_type_2, $optional_field_2)),*}); - } -} -macro_rules! _get_tlv_len { - ({$(($type: expr, $field: expr)),* $(,)*}, {}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}) => { - get_varint_length_prefixed_tlv_length!({$(($type, $field)),*} , {$(($optional_type, $optional_field)),*}) + ($field: ident, vec_type) => { + let mut $field = Some(::util::ser::VecReadWrapper(Vec::new())); }; - ({$(($type: expr, $field: expr)),* $(,)*}, {$(($optional_type: expr, $optional_field: expr)),* $(,)*}, {$(($optional_type_2: expr, $optional_field_2: expr)),* $(,)*}) => { - get_varint_length_prefixed_tlv_length!({$(($type, $field)),*} , {$(($optional_type, $optional_field)),*, $(($optional_type_2, $optional_field_2)),*}) - } -} -macro_rules! _read_tlv_fields { - ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),* $(,)*}, {}, {$(($type: expr, $field: ident)),* $(,)*}) => { - read_tlv_fields!($stream, {$(($reqtype, $reqfield)),*}, {$(($type, $field)),*}); - }; - ($stream: expr, {$(($reqtype: expr, $reqfield: ident)),* $(,)*}, {$(($type: expr, $field: ident)),* $(,)*}, {$(($type_2: expr, $field_2: ident)),* $(,)*}) => { - read_tlv_fields!($stream, {$(($reqtype, $reqfield)),*}, {$(($type, $field)),*, $(($type_2, $field_2)),*}); + ($field: ident, option) => { + let mut $field = None; } } /// Implements Readable/Writeable for a struct storing it as a set of TLVs -/// First block includes all the required fields including a dummy value which is used during -/// deserialization but which will never be exposed to other code. -/// The second block includes optional fields. -/// The third block includes any Vecs which need to have their individual elements serialized. +/// If $fieldty is `required`, then $field is a required field that is not an Option nor a Vec. +/// If $fieldty is `option`, then $field is optional field. +/// if $fieldty is `vec_type`, then $field is a Vec, which needs to have its individual elements +/// serialized. macro_rules! impl_writeable_tlv_based { - ($st: ident, {$(($reqtype: expr, $reqfield: ident)),* $(,)*}, {$(($type: expr, $field: ident)),* $(,)*}, {$(($vectype: expr, $vecfield: ident)),* $(,)*}) => { + ($st: ident, {$(($type: expr, $field: ident, $fieldty: ident)),* $(,)*}) => { impl ::util::ser::Writeable for $st { fn write(&self, writer: &mut W) -> Result<(), ::std::io::Error> { - _write_tlv_fields!(writer, { - $(($reqtype, self.$reqfield)),* - }, { - $(($type, self.$field)),* - }, { - $(($vectype, Some(::util::ser::VecWriteWrapper(&self.$vecfield)))),* + write_tlv_fields!(writer, { + $(($type, self.$field, $fieldty)),* }); Ok(()) } #[inline] fn serialized_length(&self) -> usize { - let len = _get_tlv_len!({ - $(($reqtype, self.$reqfield)),* - }, { - $(($type, self.$field)),* - }, { - $(($vectype, Some(::util::ser::VecWriteWrapper(&self.$vecfield)))),* - }); - use util::ser::{BigSize, LengthCalculatingWriter}; - let mut len_calc = LengthCalculatingWriter(0); + use util::ser::BigSize; + let len = { + #[allow(unused_mut)] + let mut len = ::util::ser::LengthCalculatingWriter(0); + $( + get_varint_length_prefixed_tlv_length!(len, $type, self.$field, $fieldty); + )* + len.0 + }; + let mut len_calc = ::util::ser::LengthCalculatingWriter(0); BigSize(len as u64).write(&mut len_calc).expect("No in-memory data may fail to serialize"); len + len_calc.0 } @@ -418,22 +414,16 @@ macro_rules! impl_writeable_tlv_based { impl ::util::ser::Readable for $st { fn read(reader: &mut R) -> Result { $( - let mut $reqfield = ::util::ser::OptionDeserWrapper(None); - )* - $( - let mut $field = None; + init_tlv_field_var!($field, $fieldty); )* - $( - let mut $vecfield = Some(::util::ser::VecReadWrapper(Vec::new())); - )* - _read_tlv_fields!(reader, { - $(($reqtype, $reqfield)),* - }, { - $(($type, $field)),* - }, { - $(($vectype, $vecfield)),* + read_tlv_fields!(reader, { + $(($type, $field, $fieldty)),* }); - _init_tlv_based_struct!($st, {$($reqfield),*}, {$($field),*}, {$($vecfield),*}) + Ok(Self { + $( + $field: init_tlv_based_struct_field!($field, $fieldty) + ),* + }) } } } @@ -443,31 +433,25 @@ macro_rules! impl_writeable_tlv_based { /// variants stored directly. /// The format is, for example /// impl_writeable_tlv_based_enum!(EnumName, -/// (0, StructVariantA) => {(0, variant_field)}, {(1, variant_optional_field)}, {}, -/// (1, StructVariantB) => {(0, variant_field_a), (1, variant_field_b)}, {}, {(2, variant_vec_field)}; +/// (0, StructVariantA) => {(0, required_variant_field, required), (1, optional_variant_field, option)}, +/// (1, StructVariantB) => {(0, variant_field_a, required), (1, variant_field_b, required), (2, variant_vec_field, vec_type)}; /// (2, TupleVariantA), (3, TupleVariantB), /// ); /// The type is written as a single byte, followed by any variant data. /// Attempts to read an unknown type byte result in DecodeError::UnknownRequiredFeature. macro_rules! impl_writeable_tlv_based_enum { ($st: ident, $(($variant_id: expr, $variant_name: ident) => - {$(($reqtype: expr, $reqfield: ident)),* $(,)*}, - {$(($type: expr, $field: ident)),* $(,)*}, - {$(($vectype: expr, $vecfield: ident)),* $(,)*} + {$(($type: expr, $field: ident, $fieldty: ident)),* $(,)*} ),* $(,)*; $(($tuple_variant_id: expr, $tuple_variant_name: ident)),* $(,)*) => { impl ::util::ser::Writeable for $st { fn write(&self, writer: &mut W) -> Result<(), ::std::io::Error> { match self { - $($st::$variant_name { $(ref $reqfield),* $(ref $field),*, $(ref $vecfield),* } => { + $($st::$variant_name { $(ref $field),* } => { let id: u8 = $variant_id; id.write(writer)?; - _write_tlv_fields!(writer, { - $(($reqtype, $reqfield)),* - }, { - $(($type, $field)),* - }, { - $(($vectype, Some(::util::ser::VecWriteWrapper(&$vecfield)))),* + write_tlv_fields!(writer, { + $(($type, $field, $fieldty)),* }); }),* $($st::$tuple_variant_name (ref field) => { @@ -489,22 +473,16 @@ macro_rules! impl_writeable_tlv_based_enum { // in the same function body. Instead, we define a closure and call it. let f = || { $( - let mut $reqfield = ::util::ser::OptionDeserWrapper(None); - )* - $( - let mut $field = None; - )* - $( - let mut $vecfield = Some(::util::ser::VecReadWrapper(Vec::new())); + init_tlv_field_var!($field, $fieldty); )* - _read_tlv_fields!(reader, { - $(($reqtype, $reqfield)),* - }, { - $(($type, $field)),* - }, { - $(($vectype, $vecfield)),* + read_tlv_fields!(reader, { + $(($type, $field, $fieldty)),* }); - _init_tlv_based_struct!($st::$variant_name, {$($reqfield),*}, {$($field),*}, {$($vecfield),*}) + Ok($st::$variant_name { + $( + $field: init_tlv_based_struct_field!($field, $fieldty) + ),* + }) }; f() }),* @@ -536,7 +514,7 @@ mod tests { let mut a: u64 = 0; let mut b: u32 = 0; let mut c: Option = None; - decode_tlv!(&mut s, {(2, a), (3, b)}, {(4, c)}); + decode_tlv_stream!(&mut s, {(2, a, required), (3, b, required), (4, c, option)}); Ok((a, b, c)) } @@ -600,7 +578,7 @@ mod tests { let mut tlv2: Option = None; let mut tlv3: Option<(PublicKey, u64, u64)> = None; let mut tlv4: Option = None; - decode_tlv!(&mut s, {}, {(1, tlv1), (2, tlv2), (3, tlv3), (254, tlv4)}); + decode_tlv_stream!(&mut s, {(1, tlv1, option), (2, tlv2, option), (3, tlv3, option), (254, tlv4, option)}); Ok((tlv1, tlv2, tlv3, tlv4)) } @@ -711,27 +689,27 @@ mod tests { let mut stream = VecWriter(Vec::new()); stream.0.clear(); - encode_varint_length_prefixed_tlv!(&mut stream, { (1, 1u8) }, { (42, None::) }); + encode_varint_length_prefixed_tlv!(&mut stream, {(1, 1u8, required), (42, None::, option)}); assert_eq!(stream.0, ::hex::decode("03010101").unwrap()); stream.0.clear(); - encode_varint_length_prefixed_tlv!(&mut stream, { }, { (1, Some(1u8)) }); + encode_varint_length_prefixed_tlv!(&mut stream, {(1, Some(1u8), option)}); assert_eq!(stream.0, ::hex::decode("03010101").unwrap()); stream.0.clear(); - encode_varint_length_prefixed_tlv!(&mut stream, { (4, 0xabcdu16) }, { (42, None::) }); + encode_varint_length_prefixed_tlv!(&mut stream, {(4, 0xabcdu16, required), (42, None::, option)}); assert_eq!(stream.0, ::hex::decode("040402abcd").unwrap()); stream.0.clear(); - encode_varint_length_prefixed_tlv!(&mut stream, { (0xff, 0xabcdu16) }, { (42, None::) }); + encode_varint_length_prefixed_tlv!(&mut stream, {(42, None::, option), (0xff, 0xabcdu16, required)}); assert_eq!(stream.0, ::hex::decode("06fd00ff02abcd").unwrap()); stream.0.clear(); - encode_varint_length_prefixed_tlv!(&mut stream, { (0, 1u64), (0xff, HighZeroBytesDroppedVarInt(0u64)) }, { (42, None::) }); + encode_varint_length_prefixed_tlv!(&mut stream, {(0, 1u64, required), (42, None::, option), (0xff, HighZeroBytesDroppedVarInt(0u64), required)}); assert_eq!(stream.0, ::hex::decode("0e00080000000000000001fd00ff00").unwrap()); stream.0.clear(); - encode_varint_length_prefixed_tlv!(&mut stream, { (0xff, HighZeroBytesDroppedVarInt(0u64)) }, { (0, Some(1u64)) }); + encode_varint_length_prefixed_tlv!(&mut stream, {(0, Some(1u64), option), (0xff, HighZeroBytesDroppedVarInt(0u64), required)}); assert_eq!(stream.0, ::hex::decode("0e00080000000000000001fd00ff00").unwrap()); Ok(()) -- 2.30.2