Redo ChannelMonitor deserialization to avoid read_to_end()
[rust-lightning] / src / ln / channelmonitor.rs
index fb965434514230a6ce094f4d4e75d778bff775f1..3ba6ebc106351a5530a06275c2a70172cc68119e 100644 (file)
@@ -16,6 +16,7 @@ use bitcoin::blockdata::transaction::{TxIn,TxOut,SigHashType,Transaction};
 use bitcoin::blockdata::transaction::OutPoint as BitcoinOutPoint;
 use bitcoin::blockdata::script::Script;
 use bitcoin::network::serialize;
+use bitcoin::network::encodable::{ConsensusDecodable, ConsensusEncodable};
 use bitcoin::util::hash::Sha256dHash;
 use bitcoin::util::bip143;
 
@@ -32,7 +33,7 @@ use chain::chaininterface::{ChainListener, ChainWatchInterface, BroadcasterInter
 use chain::transaction::OutPoint;
 use chain::keysinterface::SpendableOutputDescriptor;
 use util::logger::Logger;
-use util::ser::{ReadableArgs, Writer};
+use util::ser::{ReadableArgs, Readable, Writer, Writeable, WriterWriteAdaptor, U48};
 use util::sha2::Sha256;
 use util::{byte_utils, events};
 
@@ -613,8 +614,7 @@ impl ChannelMonitor {
                        &Some((ref outpoint, ref script)) => {
                                writer.write_all(&outpoint.txid[..])?;
                                writer.write_all(&byte_utils::be16_to_array(outpoint.index))?;
-                               writer.write_all(&byte_utils::be64_to_array(script.len() as u64))?;
-                               writer.write_all(&script[..])?;
+                               script.write(writer)?;
                        },
                        &None => {
                                // We haven't even been initialized...not sure why anyone is serializing us, but
@@ -624,7 +624,7 @@ impl ChannelMonitor {
                }
 
                // Set in initial Channel-object creation, so should always be set by now:
-               writer.write_all(&byte_utils::be48_to_array(self.commitment_transaction_number_obscure_factor))?;
+               U48(self.commitment_transaction_number_obscure_factor).write(writer)?;
 
                match self.key_storage {
                        KeyStorage::PrivMode { ref revocation_base_key, ref htlc_base_key, ref delayed_payment_base_key, ref prev_latest_per_commitment_point, ref latest_per_commitment_point } => {
@@ -718,9 +718,12 @@ impl ChannelMonitor {
 
                macro_rules! serialize_local_tx {
                        ($local_tx: expr) => {
-                               let tx_ser = serialize::serialize(&$local_tx.tx).unwrap();
-                               writer.write_all(&byte_utils::be64_to_array(tx_ser.len() as u64))?;
-                               writer.write_all(&tx_ser)?;
+                               if let Err(e) = $local_tx.tx.consensus_encode(&mut serialize::RawEncoder::new(WriterWriteAdaptor(writer))) {
+                                       match e {
+                                               serialize::Error::Io(e) => return Err(e),
+                                               _ => panic!("local tx must have been well-formed!"),
+                                       }
+                               }
 
                                writer.write_all(&$local_tx.revocation_key.serialize())?;
                                writer.write_all(&$local_tx.a_htlc_key.serialize())?;
@@ -756,8 +759,7 @@ impl ChannelMonitor {
                        writer.write_all(payment_preimage)?;
                }
 
-               writer.write_all(&byte_utils::be64_to_array(self.destination_script.len() as u64))?;
-               writer.write_all(&self.destination_script[..])?;
+               self.destination_script.write(writer)?;
 
                Ok(())
        }
@@ -1378,27 +1380,10 @@ impl ChannelMonitor {
        }
 }
 
+const MAX_ALLOC_SIZE: usize = 64*1024;
+
 impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
        fn read(reader: &mut R, logger: Arc<Logger>) -> Result<Self, DecodeError> {
-               // TODO: read_to_end and then deserializing from that vector is really dumb, we should
-               // actually use the fancy serialization framework we have instead of hacking around it.
-               let mut datavec = Vec::new();
-               reader.read_to_end(&mut datavec)?;
-               let data = &datavec;
-
-               let mut read_pos = 0;
-               macro_rules! read_bytes {
-                       ($byte_count: expr) => {
-                               {
-                                       if ($byte_count as usize) > data.len() - read_pos {
-                                               return Err(DecodeError::ShortRead);
-                                       }
-                                       read_pos += $byte_count as usize;
-                                       &data[read_pos - $byte_count as usize..read_pos]
-                               }
-                       }
-               }
-
                let secp_ctx = Secp256k1::new();
                macro_rules! unwrap_obj {
                        ($key: expr) => {
@@ -1409,8 +1394,8 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                        }
                }
 
-               let _ver = read_bytes!(1)[0];
-               let min_ver = read_bytes!(1)[0];
+               let _ver: u8 = Readable::read(reader)?;
+               let min_ver: u8 = Readable::read(reader)?;
                if min_ver > SERIALIZATION_VERSION {
                        return Err(DecodeError::UnknownVersion);
                }
@@ -1418,31 +1403,26 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                // Technically this can fail and serialize fail a round-trip, but only for serialization of
                // barely-init'd ChannelMonitors that we can't do anything with.
                let outpoint = OutPoint {
-                       txid: Sha256dHash::from(read_bytes!(32)),
-                       index: byte_utils::slice_to_be16(read_bytes!(2)),
+                       txid: Readable::read(reader)?,
+                       index: Readable::read(reader)?,
                };
-               let script_len = byte_utils::slice_to_be64(read_bytes!(8));
-               let funding_txo = Some((outpoint, Script::from(read_bytes!(script_len).to_vec())));
-               let commitment_transaction_number_obscure_factor = byte_utils::slice_to_be48(read_bytes!(6));
+               let funding_txo = Some((outpoint, Readable::read(reader)?));
+               let commitment_transaction_number_obscure_factor = <U48 as Readable<R>>::read(reader)?.0;
 
-               let key_storage = match read_bytes!(1)[0] {
+               let key_storage = match <u8 as Readable<R>>::read(reader)? {
                        0 => {
-                               let revocation_base_key = unwrap_obj!(SecretKey::from_slice(&secp_ctx, read_bytes!(32)));
-                               let htlc_base_key = unwrap_obj!(SecretKey::from_slice(&secp_ctx, read_bytes!(32)));
-                               let delayed_payment_base_key = unwrap_obj!(SecretKey::from_slice(&secp_ctx, read_bytes!(32)));
-                               let prev_latest_per_commitment_point = match read_bytes!(1)[0] {
-                                               0 => None,
-                                               1 => {
-                                                       Some(unwrap_obj!(PublicKey::from_slice(&secp_ctx, read_bytes!(33))))
-                                               },
-                                               _ => return Err(DecodeError::InvalidValue),
+                               let revocation_base_key = Readable::read(reader)?;
+                               let htlc_base_key = Readable::read(reader)?;
+                               let delayed_payment_base_key = Readable::read(reader)?;
+                               let prev_latest_per_commitment_point = match <u8 as Readable<R>>::read(reader)? {
+                                       0 => None,
+                                       1 => Some(Readable::read(reader)?),
+                                       _ => return Err(DecodeError::InvalidValue),
                                };
-                               let latest_per_commitment_point = match read_bytes!(1)[0] {
-                                               0 => None,
-                                               1 => {
-                                                       Some(unwrap_obj!(PublicKey::from_slice(&secp_ctx, read_bytes!(33))))
-                                               },
-                                               _ => return Err(DecodeError::InvalidValue),
+                               let latest_per_commitment_point = match <u8 as Readable<R>>::read(reader)? {
+                                       0 => None,
+                                       1 => Some(Readable::read(reader)?),
+                                       _ => return Err(DecodeError::InvalidValue),
                                };
                                KeyStorage::PrivMode {
                                        revocation_base_key,
@@ -1455,45 +1435,41 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                        _ => return Err(DecodeError::InvalidValue),
                };
 
-               let their_htlc_base_key = Some(unwrap_obj!(PublicKey::from_slice(&secp_ctx, read_bytes!(33))));
-               let their_delayed_payment_base_key = Some(unwrap_obj!(PublicKey::from_slice(&secp_ctx, read_bytes!(33))));
+               let their_htlc_base_key = Some(Readable::read(reader)?);
+               let their_delayed_payment_base_key = Some(Readable::read(reader)?);
 
                let their_cur_revocation_points = {
-                       let first_idx = byte_utils::slice_to_be48(read_bytes!(6));
+                       let first_idx = <U48 as Readable<R>>::read(reader)?.0;
                        if first_idx == 0 {
                                None
                        } else {
-                               let first_point = unwrap_obj!(PublicKey::from_slice(&secp_ctx, read_bytes!(33)));
-                               let second_point_slice = read_bytes!(33);
+                               let first_point = Readable::read(reader)?;
+                               let second_point_slice: [u8; 33] = Readable::read(reader)?;
                                if second_point_slice[0..32] == [0; 32] && second_point_slice[32] == 0 {
                                        Some((first_idx, first_point, None))
                                } else {
-                                       Some((first_idx, first_point, Some(unwrap_obj!(PublicKey::from_slice(&secp_ctx, second_point_slice)))))
+                                       Some((first_idx, first_point, Some(unwrap_obj!(PublicKey::from_slice(&secp_ctx, &second_point_slice)))))
                                }
                        }
                };
 
-               let our_to_self_delay = byte_utils::slice_to_be16(read_bytes!(2));
-               let their_to_self_delay = Some(byte_utils::slice_to_be16(read_bytes!(2)));
+               let our_to_self_delay: u16 = Readable::read(reader)?;
+               let their_to_self_delay: Option<u16> = Some(Readable::read(reader)?);
 
                let mut old_secrets = [([0; 32], 1 << 48); 49];
                for &mut (ref mut secret, ref mut idx) in old_secrets.iter_mut() {
-                       secret.copy_from_slice(read_bytes!(32));
-                       *idx = byte_utils::slice_to_be64(read_bytes!(8));
+                       *secret = Readable::read(reader)?;
+                       *idx = Readable::read(reader)?;
                }
 
                macro_rules! read_htlc_in_commitment {
                        () => {
                                {
-                                       let offered = match read_bytes!(1)[0] {
-                                               0 => false, 1 => true,
-                                               _ => return Err(DecodeError::InvalidValue),
-                                       };
-                                       let amount_msat = byte_utils::slice_to_be64(read_bytes!(8));
-                                       let cltv_expiry = byte_utils::slice_to_be32(read_bytes!(4));
-                                       let mut payment_hash = [0; 32];
-                                       payment_hash[..].copy_from_slice(read_bytes!(32));
-                                       let transaction_output_index = byte_utils::slice_to_be32(read_bytes!(4));
+                                       let offered: bool = Readable::read(reader)?;
+                                       let amount_msat: u64 = Readable::read(reader)?;
+                                       let cltv_expiry: u32 = Readable::read(reader)?;
+                                       let payment_hash: [u8; 32] = Readable::read(reader)?;
+                                       let transaction_output_index: u32 = Readable::read(reader)?;
 
                                        HTLCOutputInCommitment {
                                                offered, amount_msat, cltv_expiry, payment_hash, transaction_output_index
@@ -1502,14 +1478,12 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                        }
                }
 
-               let remote_claimable_outpoints_len = byte_utils::slice_to_be64(read_bytes!(8));
-               if remote_claimable_outpoints_len > data.len() as u64 / 64 { return Err(DecodeError::BadLengthDescriptor); }
-               let mut remote_claimable_outpoints = HashMap::with_capacity(remote_claimable_outpoints_len as usize);
+               let remote_claimable_outpoints_len: u64 = Readable::read(reader)?;
+               let mut remote_claimable_outpoints = HashMap::with_capacity(cmp::min(remote_claimable_outpoints_len as usize, MAX_ALLOC_SIZE / 64));
                for _ in 0..remote_claimable_outpoints_len {
-                       let txid = Sha256dHash::from(read_bytes!(32));
-                       let outputs_count = byte_utils::slice_to_be64(read_bytes!(8));
-                       if outputs_count > data.len() as u64 / 32 { return Err(DecodeError::BadLengthDescriptor); }
-                       let mut outputs = Vec::with_capacity(outputs_count as usize);
+                       let txid: Sha256dHash = Readable::read(reader)?;
+                       let outputs_count: u64 = Readable::read(reader)?;
+                       let mut outputs = Vec::with_capacity(cmp::min(outputs_count as usize, MAX_ALLOC_SIZE / 32));
                        for _ in 0..outputs_count {
                                outputs.push(read_htlc_in_commitment!());
                        }
@@ -1518,24 +1492,21 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                        }
                }
 
-               let remote_commitment_txn_on_chain_len = byte_utils::slice_to_be64(read_bytes!(8));
-               if remote_commitment_txn_on_chain_len > data.len() as u64 / 32 { return Err(DecodeError::BadLengthDescriptor); }
-               let mut remote_commitment_txn_on_chain = HashMap::with_capacity(remote_commitment_txn_on_chain_len as usize);
+               let remote_commitment_txn_on_chain_len: u64 = Readable::read(reader)?;
+               let mut remote_commitment_txn_on_chain = HashMap::with_capacity(cmp::min(remote_commitment_txn_on_chain_len as usize, MAX_ALLOC_SIZE / 32));
                for _ in 0..remote_commitment_txn_on_chain_len {
-                       let txid = Sha256dHash::from(read_bytes!(32));
-                       let commitment_number = byte_utils::slice_to_be48(read_bytes!(6));
+                       let txid: Sha256dHash = Readable::read(reader)?;
+                       let commitment_number = <U48 as Readable<R>>::read(reader)?.0;
                        if let Some(_) = remote_commitment_txn_on_chain.insert(txid, commitment_number) {
                                return Err(DecodeError::InvalidValue);
                        }
                }
 
-               let remote_hash_commitment_number_len = byte_utils::slice_to_be64(read_bytes!(8));
-               if remote_hash_commitment_number_len > data.len() as u64 / 32 { return Err(DecodeError::BadLengthDescriptor); }
-               let mut remote_hash_commitment_number = HashMap::with_capacity(remote_hash_commitment_number_len as usize);
+               let remote_hash_commitment_number_len: u64 = Readable::read(reader)?;
+               let mut remote_hash_commitment_number = HashMap::with_capacity(cmp::min(remote_hash_commitment_number_len as usize, MAX_ALLOC_SIZE / 32));
                for _ in 0..remote_hash_commitment_number_len {
-                       let mut txid = [0; 32];
-                       txid[..].copy_from_slice(read_bytes!(32));
-                       let commitment_number = byte_utils::slice_to_be48(read_bytes!(6));
+                       let txid: [u8; 32] = Readable::read(reader)?;
+                       let commitment_number = <U48 as Readable<R>>::read(reader)?.0;
                        if let Some(_) = remote_hash_commitment_number.insert(txid, commitment_number) {
                                return Err(DecodeError::InvalidValue);
                        }
@@ -1544,29 +1515,29 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                macro_rules! read_local_tx {
                        () => {
                                {
-                                       let tx_len = byte_utils::slice_to_be64(read_bytes!(8));
-                                       let tx_ser = read_bytes!(tx_len);
-                                       let tx: Transaction = unwrap_obj!(serialize::deserialize(tx_ser));
-                                       if serialize::serialize(&tx).unwrap() != tx_ser {
-                                               // We check that the tx re-serializes to the same form to ensure there is
-                                               // no extra data, and as rust-bitcoin doesn't handle the 0-input ambiguity
-                                               // all that well.
+                                       let tx = match Transaction::consensus_decode(&mut serialize::RawDecoder::new(reader.by_ref())) {
+                                               Ok(tx) => tx,
+                                               Err(e) => match e {
+                                                       serialize::Error::Io(ioe) => return Err(DecodeError::Io(ioe)),
+                                                       _ => return Err(DecodeError::InvalidValue),
+                                               },
+                                       };
+
+                                       if tx.input.is_empty() {
+                                               // Ensure tx didn't hit the 0-input ambiguity case.
                                                return Err(DecodeError::InvalidValue);
                                        }
 
-                                       let revocation_key = unwrap_obj!(PublicKey::from_slice(&secp_ctx, read_bytes!(33)));
-                                       let a_htlc_key = unwrap_obj!(PublicKey::from_slice(&secp_ctx, read_bytes!(33)));
-                                       let b_htlc_key = unwrap_obj!(PublicKey::from_slice(&secp_ctx, read_bytes!(33)));
-                                       let delayed_payment_key = unwrap_obj!(PublicKey::from_slice(&secp_ctx, read_bytes!(33)));
-                                       let feerate_per_kw = byte_utils::slice_to_be64(read_bytes!(8));
+                                       let revocation_key = Readable::read(reader)?;
+                                       let a_htlc_key = Readable::read(reader)?;
+                                       let b_htlc_key = Readable::read(reader)?;
+                                       let delayed_payment_key = Readable::read(reader)?;
+                                       let feerate_per_kw: u64 = Readable::read(reader)?;
 
-                                       let htlc_outputs_len = byte_utils::slice_to_be64(read_bytes!(8));
-                                       if htlc_outputs_len > data.len() as u64 / 128 { return Err(DecodeError::BadLengthDescriptor); }
-                                       let mut htlc_outputs = Vec::with_capacity(htlc_outputs_len as usize);
+                                       let htlc_outputs_len: u64 = Readable::read(reader)?;
+                                       let mut htlc_outputs = Vec::with_capacity(cmp::min(htlc_outputs_len as usize, MAX_ALLOC_SIZE / 128));
                                        for _ in 0..htlc_outputs_len {
-                                               htlc_outputs.push((read_htlc_in_commitment!(),
-                                                               unwrap_obj!(Signature::from_compact(&secp_ctx, read_bytes!(64))),
-                                                               unwrap_obj!(Signature::from_compact(&secp_ctx, read_bytes!(64)))));
+                                               htlc_outputs.push((read_htlc_in_commitment!(), Readable::read(reader)?, Readable::read(reader)?));
                                        }
 
                                        LocalSignedTx {
@@ -1577,7 +1548,7 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                        }
                }
 
-               let prev_local_signed_commitment_tx = match read_bytes!(1)[0] {
+               let prev_local_signed_commitment_tx = match <u8 as Readable<R>>::read(reader)? {
                        0 => None,
                        1 => {
                                Some(read_local_tx!())
@@ -1585,7 +1556,7 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                        _ => return Err(DecodeError::InvalidValue),
                };
 
-               let current_local_signed_commitment_tx = match read_bytes!(1)[0] {
+               let current_local_signed_commitment_tx = match <u8 as Readable<R>>::read(reader)? {
                        0 => None,
                        1 => {
                                Some(read_local_tx!())
@@ -1593,13 +1564,11 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                        _ => return Err(DecodeError::InvalidValue),
                };
 
-               let payment_preimages_len = byte_utils::slice_to_be64(read_bytes!(8));
-               if payment_preimages_len > data.len() as u64 / 32 { return Err(DecodeError::InvalidValue); }
-               let mut payment_preimages = HashMap::with_capacity(payment_preimages_len as usize);
+               let payment_preimages_len: u64 = Readable::read(reader)?;
+               let mut payment_preimages = HashMap::with_capacity(cmp::min(payment_preimages_len as usize, MAX_ALLOC_SIZE / 32));
                let mut sha = Sha256::new();
                for _ in 0..payment_preimages_len {
-                       let mut preimage = [0; 32];
-                       preimage[..].copy_from_slice(read_bytes!(32));
+                       let preimage: [u8; 32] = Readable::read(reader)?;
                        sha.reset();
                        sha.input(&preimage);
                        let mut hash = [0; 32];
@@ -1609,8 +1578,7 @@ impl<R: ::std::io::Read> ReadableArgs<R, Arc<Logger>> for ChannelMonitor {
                        }
                }
 
-               let destination_script_len = byte_utils::slice_to_be64(read_bytes!(8));
-               let destination_script = Script::from(read_bytes!(destination_script_len).to_vec());
+               let destination_script = Readable::read(reader)?;
 
                Ok(ChannelMonitor {
                        funding_txo,