Batch-sign local HTLC txn with a well-doc'd API, returning sigs
[rust-lightning] / lightning / src / ln / onchaintx.rs
index 77ce1bdc4be192f60af35759b80fe1e7373c8967..6b5da67c1a526de38c7905694f4dde9063278b34 100644 (file)
@@ -10,7 +10,7 @@ use bitcoin::util::bip143;
 
 use bitcoin_hashes::sha256d::Hash as Sha256dHash;
 
-use secp256k1::Secp256k1;
+use secp256k1::{Secp256k1, Signature};
 use secp256k1;
 
 use ln::msgs::DecodeError;
@@ -137,13 +137,63 @@ macro_rules! subtract_high_prio_fee {
        }
 }
 
+impl Readable for Option<Vec<Option<(usize, Signature)>>> {
+       fn read<R: ::std::io::Read>(reader: &mut R) -> Result<Self, DecodeError> {
+               match Readable::read(reader)? {
+                       0u8 => Ok(None),
+                       1u8 => {
+                               let vlen: u64 = Readable::read(reader)?;
+                               let mut ret = Vec::with_capacity(cmp::min(vlen as usize, MAX_ALLOC_SIZE / ::std::mem::size_of::<Option<(usize, Signature)>>()));
+                               for _ in 0..vlen {
+                                       ret.push(match Readable::read(reader)? {
+                                               0u8 => None,
+                                               1u8 => Some((<u64 as Readable>::read(reader)? as usize, Readable::read(reader)?)),
+                                               _ => return Err(DecodeError::InvalidValue)
+                                       });
+                               }
+                               Ok(Some(ret))
+                       },
+                       _ => Err(DecodeError::InvalidValue),
+               }
+       }
+}
+
+impl Writeable for Option<Vec<Option<(usize, Signature)>>> {
+       fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
+               match self {
+                       &Some(ref vec) => {
+                               1u8.write(writer)?;
+                               (vec.len() as u64).write(writer)?;
+                               for opt in vec.iter() {
+                                       match opt {
+                                               &Some((ref idx, ref sig)) => {
+                                                       1u8.write(writer)?;
+                                                       (*idx as u64).write(writer)?;
+                                                       sig.write(writer)?;
+                                               },
+                                               &None => 0u8.write(writer)?,
+                                       }
+                               }
+                       },
+                       &None => 0u8.write(writer)?,
+               }
+               Ok(())
+       }
+}
+
 
 /// OnchainTxHandler receives claiming requests, aggregates them if it's sound, broadcast and
 /// do RBF bumping if possible.
 pub struct OnchainTxHandler<ChanSigner: ChannelKeys> {
        destination_script: Script,
        local_commitment: Option<LocalCommitmentTransaction>,
+       // local_htlc_sigs and prev_local_htlc_sigs are in the order as they appear in the commitment
+       // transaction outputs (hence the Option<>s inside the Vec). The first usize is the index in
+       // the set of HTLCs in the LocalCommitmentTransaction (including those which do not appear in
+       // the commitment transaction).
+       local_htlc_sigs: Option<Vec<Option<(usize, Signature)>>>,
        prev_local_commitment: Option<LocalCommitmentTransaction>,
+       prev_local_htlc_sigs: Option<Vec<Option<(usize, Signature)>>>,
        local_csv: u16,
 
        key_storage: ChanSigner,
@@ -185,7 +235,9 @@ impl<ChanSigner: ChannelKeys + Writeable> OnchainTxHandler<ChanSigner> {
        pub(crate) fn write<W: Writer>(&self, writer: &mut W) -> Result<(), ::std::io::Error> {
                self.destination_script.write(writer)?;
                self.local_commitment.write(writer)?;
+               self.local_htlc_sigs.write(writer)?;
                self.prev_local_commitment.write(writer)?;
+               self.prev_local_htlc_sigs.write(writer)?;
 
                self.local_csv.write(writer)?;
 
@@ -231,7 +283,9 @@ impl<ChanSigner: ChannelKeys + Readable> ReadableArgs<Arc<Logger>> for OnchainTx
                let destination_script = Readable::read(reader)?;
 
                let local_commitment = Readable::read(reader)?;
+               let local_htlc_sigs = Readable::read(reader)?;
                let prev_local_commitment = Readable::read(reader)?;
+               let prev_local_htlc_sigs = Readable::read(reader)?;
 
                let local_csv = Readable::read(reader)?;
 
@@ -283,7 +337,9 @@ impl<ChanSigner: ChannelKeys + Readable> ReadableArgs<Arc<Logger>> for OnchainTx
                Ok(OnchainTxHandler {
                        destination_script,
                        local_commitment,
+                       local_htlc_sigs,
                        prev_local_commitment,
+                       prev_local_htlc_sigs,
                        local_csv,
                        key_storage,
                        claimable_outpoints,
@@ -303,7 +359,9 @@ impl<ChanSigner: ChannelKeys> OnchainTxHandler<ChanSigner> {
                OnchainTxHandler {
                        destination_script,
                        local_commitment: None,
+                       local_htlc_sigs: None,
                        prev_local_commitment: None,
+                       prev_local_htlc_sigs: None,
                        local_csv,
                        key_storage,
                        pending_claim_requests: HashMap::new(),
@@ -510,19 +568,7 @@ impl<ChanSigner: ChannelKeys> OnchainTxHandler<ChanSigner> {
                        for (_, (outp, per_outp_material)) in cached_claim_datas.per_input_material.iter().enumerate() {
                                match per_outp_material {
                                        &InputMaterial::LocalHTLC { ref preimage, ref amount } => {
-                                               let mut htlc_tx = None;
-                                               if let Some(ref mut local_commitment) = self.local_commitment {
-                                                       if local_commitment.txid() == outp.txid {
-                                                               self.key_storage.sign_htlc_transaction(local_commitment, outp.vout, *preimage, self.local_csv, &self.secp_ctx);
-                                                               htlc_tx = local_commitment.htlc_with_valid_witness(outp.vout).clone();
-                                                       }
-                                               }
-                                               if let Some(ref mut prev_local_commitment) = self.prev_local_commitment {
-                                                       if prev_local_commitment.txid() == outp.txid {
-                                                               self.key_storage.sign_htlc_transaction(prev_local_commitment, outp.vout, *preimage, self.local_csv, &self.secp_ctx);
-                                                               htlc_tx = prev_local_commitment.htlc_with_valid_witness(outp.vout).clone();
-                                                       }
-                                               }
+                                               let htlc_tx = self.get_fully_signed_htlc_tx(outp, preimage);
                                                if let Some(htlc_tx) = htlc_tx {
                                                        let feerate = (amount - htlc_tx.output[0].value) * 1000 / htlc_tx.get_weight() as u64;
                                                        // Timer set to $NEVER given we can't bump tx without anchor outputs
@@ -771,11 +817,47 @@ impl<ChanSigner: ChannelKeys> OnchainTxHandler<ChanSigner> {
                if let Some(ref local_commitment) = self.local_commitment {
                        if local_commitment.has_local_sig() { return Err(()) }
                }
+               if self.local_htlc_sigs.is_some() || self.prev_local_htlc_sigs.is_some() {
+                       return Err(());
+               }
                self.prev_local_commitment = self.local_commitment.take();
                self.local_commitment = Some(tx);
                Ok(())
        }
 
+       fn sign_latest_local_htlcs(&mut self) {
+               if let Some(ref local_commitment) = self.local_commitment {
+                       if let Ok(sigs) = self.key_storage.sign_local_commitment_htlc_transactions(local_commitment, self.local_csv, &self.secp_ctx) {
+                               self.local_htlc_sigs = Some(Vec::new());
+                               let ret = self.local_htlc_sigs.as_mut().unwrap();
+                               for (htlc_idx, (local_sig, &(ref htlc, _))) in sigs.iter().zip(local_commitment.per_htlc.iter()).enumerate() {
+                                       if let Some(tx_idx) = htlc.transaction_output_index {
+                                               if ret.len() <= tx_idx as usize { ret.resize(tx_idx as usize + 1, None); }
+                                               ret[tx_idx as usize] = Some((htlc_idx, local_sig.expect("Did not receive a signature for a non-dust HTLC")));
+                                       } else {
+                                               assert!(local_sig.is_none(), "Received a signature for a dust HTLC");
+                                       }
+                               }
+                       }
+               }
+       }
+       fn sign_prev_local_htlcs(&mut self) {
+               if let Some(ref local_commitment) = self.prev_local_commitment {
+                       if let Ok(sigs) = self.key_storage.sign_local_commitment_htlc_transactions(local_commitment, self.local_csv, &self.secp_ctx) {
+                               self.prev_local_htlc_sigs = Some(Vec::new());
+                               let ret = self.prev_local_htlc_sigs.as_mut().unwrap();
+                               for (htlc_idx, (local_sig, &(ref htlc, _))) in sigs.iter().zip(local_commitment.per_htlc.iter()).enumerate() {
+                                       if let Some(tx_idx) = htlc.transaction_output_index {
+                                               if ret.len() <= tx_idx as usize { ret.resize(tx_idx as usize + 1, None); }
+                                               ret[tx_idx as usize] = Some((htlc_idx, local_sig.expect("Did not receive a signature for a non-dust HTLC")));
+                                       } else {
+                                               assert!(local_sig.is_none(), "Received a signature for a dust HTLC");
+                                       }
+                               }
+                       }
+               }
+       }
+
        //TODO: getting lastest local transactions should be infaillible and result in us "force-closing the channel", but we may
        // have empty local commitment transaction if a ChannelMonitor is asked to force-close just after Channel::get_outbound_funding_created,
        // before providing a initial commitment transaction. For outbound channel, init ChannelMonitor at Channel::funding_signed, there is nothing
@@ -804,14 +886,44 @@ impl<ChanSigner: ChannelKeys> OnchainTxHandler<ChanSigner> {
                None
        }
 
-       pub(super) fn get_fully_signed_htlc_tx(&mut self, txid: Sha256dHash, htlc_index: u32, preimage: Option<PaymentPreimage>) -> Option<Transaction> {
-               //TODO: store preimage in OnchainTxHandler
-               if let Some(ref mut local_commitment) = self.local_commitment {
-                       if local_commitment.txid() == txid {
-                               self.key_storage.sign_htlc_transaction(local_commitment, htlc_index, preimage, self.local_csv, &self.secp_ctx);
-                               return local_commitment.htlc_with_valid_witness(htlc_index).clone();
+       pub(super) fn get_fully_signed_htlc_tx(&mut self, outp: &::bitcoin::OutPoint, preimage: &Option<PaymentPreimage>) -> Option<Transaction> {
+               let mut htlc_tx = None;
+               if self.local_commitment.is_some() {
+                       let commitment_txid = self.local_commitment.as_ref().unwrap().txid();
+                       if commitment_txid == outp.txid {
+                               self.sign_latest_local_htlcs();
+                               if let &Some(ref htlc_sigs) = &self.local_htlc_sigs {
+                                       let &(ref htlc_idx, ref htlc_sig) = htlc_sigs[outp.vout as usize].as_ref().unwrap();
+                                       htlc_tx = Some(self.local_commitment.as_ref().unwrap()
+                                               .get_signed_htlc_tx(*htlc_idx, htlc_sig, preimage, self.local_csv));
+                               }
                        }
                }
-               None
+               if self.prev_local_commitment.is_some() {
+                       let commitment_txid = self.prev_local_commitment.as_ref().unwrap().txid();
+                       if commitment_txid == outp.txid {
+                               self.sign_prev_local_htlcs();
+                               if let &Some(ref htlc_sigs) = &self.prev_local_htlc_sigs {
+                                       let &(ref htlc_idx, ref htlc_sig) = htlc_sigs[outp.vout as usize].as_ref().unwrap();
+                                       htlc_tx = Some(self.prev_local_commitment.as_ref().unwrap()
+                                               .get_signed_htlc_tx(*htlc_idx, htlc_sig, preimage, self.local_csv));
+                               }
+                       }
+               }
+               htlc_tx
+       }
+
+       #[cfg(test)]
+       pub(super) fn unsafe_get_fully_signed_htlc_tx(&mut self, outp: &::bitcoin::OutPoint, preimage: &Option<PaymentPreimage>) -> Option<Transaction> {
+               let latest_had_sigs = self.local_htlc_sigs.is_some();
+               let prev_had_sigs = self.prev_local_htlc_sigs.is_some();
+               let ret = self.get_fully_signed_htlc_tx(outp, preimage);
+               if !latest_had_sigs {
+                       self.local_htlc_sigs = None;
+               }
+               if !prev_had_sigs {
+                       self.prev_local_htlc_sigs = None;
+               }
+               ret
        }
 }