Merge pull request #502 from rloomba/rloomba/add_unregister_listener
[rust-lightning] / lightning / src / ln / channelmonitor.rs
index 0a70eb0e676bb34d84e521d92524664530a6243e..68674ac3ae098c3db7d70a173755e44232c77c44 100644 (file)
@@ -117,9 +117,16 @@ pub struct HTLCUpdate {
 pub trait ManyChannelMonitor<ChanSigner: ChannelKeys>: Send + Sync {
        /// Adds or updates a monitor for the given `funding_txo`.
        ///
-       /// Implementor must also ensure that the funding_txo outpoint is registered with any relevant
-       /// ChainWatchInterfaces such that the provided monitor receives block_connected callbacks with
-       /// any spends of it.
+       /// Implementer must also ensure that the funding_txo txid *and* outpoint are registered with
+       /// any relevant ChainWatchInterfaces such that the provided monitor receives block_connected
+       /// callbacks with the funding transaction, or any spends of it.
+       ///
+       /// Further, the implementer must also ensure that each output returned in
+       /// monitor.get_outputs_to_watch() is registered to ensure that the provided monitor learns about
+       /// any spends of any of the outputs.
+       ///
+       /// Any spends of outputs which should have been registered which aren't passed to
+       /// ChannelMonitors via block_connected may result in funds loss.
        fn add_update_monitor(&self, funding_txo: OutPoint, monitor: ChannelMonitor<ChanSigner>) -> Result<(), ChannelMonitorUpdateErr>;
 
        /// Used by ChannelManager to get list of HTLC resolved onchain and which needed to be updated
@@ -259,6 +266,11 @@ impl<Key : Send + cmp::Eq + hash::Hash + 'static, ChanSigner: ChannelKeys> Simpl
                                self.chain_monitor.watch_all_txn();
                        }
                }
+               for (txid, outputs) in monitor.get_outputs_to_watch().iter() {
+                       for (idx, script) in outputs.iter().enumerate() {
+                               self.chain_monitor.install_watch_outpoint((*txid, idx as u32), script);
+                       }
+               }
                monitors.insert(key, monitor);
                Ok(())
        }
@@ -666,6 +678,12 @@ pub struct ChannelMonitor<ChanSigner: ChannelKeys> {
        // actions when we receive a block with given height. Actions depend on OnchainEvent type.
        onchain_events_waiting_threshold_conf: HashMap<u32, Vec<OnchainEvent>>,
 
+       // If we get serialized out and re-read, we need to make sure that the chain monitoring
+       // interface knows about the TXOs that we want to be notified of spends of. We could probably
+       // be smart and derive them from the above storage fields, but its much simpler and more
+       // Obviously Correct (tm) if we just keep track of them explicitly.
+       outputs_to_watch: HashMap<Sha256dHash, Vec<Script>>,
+
        // We simply modify last_block_hash in Channel's block_connected so that serialization is
        // consistent but hopefully the users' copy handles block_connected in a consistent way.
        // (we do *not*, however, update them in insert_combine to ensure any local user copies keep
@@ -736,7 +754,8 @@ impl<ChanSigner: ChannelKeys> PartialEq for ChannelMonitor<ChanSigner> {
                        self.to_remote_rescue != other.to_remote_rescue ||
                        self.pending_claim_requests != other.pending_claim_requests ||
                        self.claimable_outpoints != other.claimable_outpoints ||
-                       self.onchain_events_waiting_threshold_conf != other.onchain_events_waiting_threshold_conf
+                       self.onchain_events_waiting_threshold_conf != other.onchain_events_waiting_threshold_conf ||
+                       self.outputs_to_watch != other.outputs_to_watch
                {
                        false
                } else {
@@ -966,6 +985,15 @@ impl<ChanSigner: ChannelKeys + Writeable> ChannelMonitor<ChanSigner> {
                        }
                }
 
+               (self.outputs_to_watch.len() as u64).write(writer)?;
+               for (txid, output_scripts) in self.outputs_to_watch.iter() {
+                       txid.write(writer)?;
+                       (output_scripts.len() as u64).write(writer)?;
+                       for script in output_scripts.iter() {
+                               script.write(writer)?;
+                       }
+               }
+
                Ok(())
        }
 
@@ -1036,6 +1064,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                        claimable_outpoints: HashMap::new(),
 
                        onchain_events_waiting_threshold_conf: HashMap::new(),
+                       outputs_to_watch: HashMap::new(),
 
                        last_block_hash: Default::default(),
                        secp_ctx: Secp256k1::new(),
@@ -1370,6 +1399,12 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                }
        }
 
+       /// Gets a list of txids, with their output scripts (in the order they appear in the
+       /// transaction), which we must learn about spends of via block_connected().
+       pub fn get_outputs_to_watch(&self) -> &HashMap<Sha256dHash, Vec<Script>> {
+               &self.outputs_to_watch
+       }
+
        /// Gets the sets of all outpoints which this ChannelMonitor expects to hear about spends of.
        /// Generally useful when deserializing as during normal operation the return values of
        /// block_connected are sufficient to ensure all relevant outpoints are being monitored (note
@@ -1799,11 +1834,11 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                                        let mut inputs_info = Vec::new();
 
                                        macro_rules! sign_input {
-                                               ($sighash_parts: expr, $input: expr, $amount: expr, $preimage: expr) => {
+                                               ($sighash_parts: expr, $input: expr, $amount: expr, $preimage: expr, $idx: expr) => {
                                                        {
                                                                let (sig, redeemscript, htlc_key) = match self.key_storage {
                                                                        Storage::Local { ref htlc_base_key, .. } => {
-                                                                               let htlc = &per_commitment_option.unwrap()[$input.sequence as usize].0;
+                                                                               let htlc = &per_commitment_option.unwrap()[$idx as usize].0;
                                                                                let redeemscript = chan_utils::get_htlc_redeemscript_with_explicit_keys(htlc, &a_htlc_key, &b_htlc_key, &revocation_pubkey);
                                                                                let sighash = hash_to_message!(&$sighash_parts.sighash_all(&$input, &redeemscript, $amount)[..]);
                                                                                let htlc_key = ignore_error!(chan_utils::derive_private_key(&self.secp_ctx, revocation_point, &htlc_base_key));
@@ -1838,13 +1873,13 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                                                                                        vout: transaction_output_index,
                                                                                },
                                                                                script_sig: Script::new(),
-                                                                               sequence: idx as u32, // reset to 0xfffffffd in sign_input
+                                                                               sequence: 0xff_ff_ff_fd,
                                                                                witness: Vec::new(),
                                                                        };
                                                                        if htlc.cltv_expiry > height + CLTV_SHARED_CLAIM_BUFFER {
                                                                                inputs.push(input);
                                                                                inputs_desc.push(if htlc.offered { InputDescriptors::OfferedHTLC } else { InputDescriptors::ReceivedHTLC });
-                                                                               inputs_info.push((payment_preimage, tx.output[transaction_output_index as usize].value, htlc.cltv_expiry));
+                                                                               inputs_info.push((payment_preimage, tx.output[transaction_output_index as usize].value, htlc.cltv_expiry, idx));
                                                                                total_value += tx.output[transaction_output_index as usize].value;
                                                                        } else {
                                                                                let mut single_htlc_tx = Transaction {
@@ -1861,7 +1896,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                                                                                let mut used_feerate;
                                                                                if subtract_high_prio_fee!(self, fee_estimator, single_htlc_tx.output[0].value, predicted_weight, used_feerate) {
                                                                                        let sighash_parts = bip143::SighashComponents::new(&single_htlc_tx);
-                                                                                       let (redeemscript, htlc_key) = sign_input!(sighash_parts, single_htlc_tx.input[0], htlc.amount_msat / 1000, payment_preimage.0.to_vec());
+                                                                                       let (redeemscript, htlc_key) = sign_input!(sighash_parts, single_htlc_tx.input[0], htlc.amount_msat / 1000, payment_preimage.0.to_vec(), idx);
                                                                                        assert!(predicted_weight >= single_htlc_tx.get_weight());
                                                                                        spendable_outputs.push(SpendableOutputDescriptor::StaticOutput {
                                                                                                outpoint: BitcoinOutPoint { txid: single_htlc_tx.txid(), vout: 0 },
@@ -1892,7 +1927,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                                                                                vout: transaction_output_index,
                                                                        },
                                                                        script_sig: Script::new(),
-                                                                       sequence: idx as u32,
+                                                                       sequence: 0xff_ff_ff_fd,
                                                                        witness: Vec::new(),
                                                                };
                                                                let mut timeout_tx = Transaction {
@@ -1909,7 +1944,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                                                                let mut used_feerate;
                                                                if subtract_high_prio_fee!(self, fee_estimator, timeout_tx.output[0].value, predicted_weight, used_feerate) {
                                                                        let sighash_parts = bip143::SighashComponents::new(&timeout_tx);
-                                                                       let (redeemscript, htlc_key) = sign_input!(sighash_parts, timeout_tx.input[0], htlc.amount_msat / 1000, vec![0]);
+                                                                       let (redeemscript, htlc_key) = sign_input!(sighash_parts, timeout_tx.input[0], htlc.amount_msat / 1000, vec![0], idx);
                                                                        assert!(predicted_weight >= timeout_tx.get_weight());
                                                                        //TODO: track SpendableOutputDescriptor
                                                                        log_trace!(self, "Outpoint {}:{} is being being claimed, if it doesn't succeed, a bumped claiming txn is going to be broadcast at height {}", timeout_tx.input[0].previous_output.txid, timeout_tx.input[0].previous_output.vout, height_timer);
@@ -1961,7 +1996,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                                        let height_timer = Self::get_height_timer(height, soonest_timelock);
                                        let spend_txid = spend_tx.txid();
                                        for (input, info) in spend_tx.input.iter_mut().zip(inputs_info.iter()) {
-                                               let (redeemscript, htlc_key) = sign_input!(sighash_parts, input, info.1, (info.0).0.to_vec());
+                                               let (redeemscript, htlc_key) = sign_input!(sighash_parts, input, info.1, (info.0).0.to_vec(), info.3);
                                                log_trace!(self, "Outpoint {}:{} is being being claimed, if it doesn't succeed, a bumped claiming txn is going to be broadcast at height {}", input.previous_output.txid, input.previous_output.vout, height_timer);
                                                per_input_material.insert(input.previous_output, InputMaterial::RemoteHTLC { script: redeemscript, key: htlc_key, preimage: Some(*(info.0)), amount: info.1, locktime: 0});
                                                match self.claimable_outpoints.entry(input.previous_output) {
@@ -2362,7 +2397,21 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                }
        }
 
+       /// Called by SimpleManyChannelMonitor::block_connected, which implements
+       /// ChainListener::block_connected.
+       /// Eventually this should be pub and, roughly, implement ChainListener, however this requires
+       /// &mut self, as well as returns new spendable outputs and outpoints to watch for spending of
+       /// on-chain.
        fn block_connected(&mut self, txn_matched: &[&Transaction], height: u32, block_hash: &Sha256dHash, broadcaster: &BroadcasterInterface, fee_estimator: &FeeEstimator)-> (Vec<(Sha256dHash, Vec<TxOut>)>, Vec<SpendableOutputDescriptor>, Vec<(HTLCSource, Option<PaymentPreimage>, PaymentHash)>) {
+               for tx in txn_matched {
+                       let mut output_val = 0;
+                       for out in tx.output.iter() {
+                               if out.value > 21_000_000_0000_0000 { panic!("Value-overflowing transaction provided to block connected"); }
+                               output_val += out.value;
+                               if output_val > 21_000_000_0000_0000 { panic!("Value-overflowing transaction provided to block connected"); }
+                       }
+               }
+
                log_trace!(self, "Block {} at height {} connected with {} txn matched", block_hash, height, txn_matched.len());
                let mut watch_outputs = Vec::new();
                let mut spendable_outputs = Vec::new();
@@ -2580,6 +2629,9 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                        }
                }
                self.last_block_hash = block_hash.clone();
+               for &(ref txid, ref output_scripts) in watch_outputs.iter() {
+                       self.outputs_to_watch.insert(txid.clone(), output_scripts.iter().map(|o| o.script_pubkey.clone()).collect());
+               }
                (watch_outputs, spendable_outputs, htlc_updated)
        }
 
@@ -2899,7 +2951,6 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                for per_outp_material in cached_claim_datas.per_input_material.values() {
                        match per_outp_material {
                                &InputMaterial::Revoked { ref script, ref is_htlc, ref amount, .. } => {
-                                       log_trace!(self, "Is HLTC ? {}", is_htlc);
                                        inputs_witnesses_weight += Self::get_witnesses_weight(if !is_htlc { &[InputDescriptors::RevokedOutput] } else if HTLCType::scriptlen_to_htlctype(script.len()) == Some(HTLCType::OfferedHTLC) { &[InputDescriptors::RevokedOfferedHTLC] } else if HTLCType::scriptlen_to_htlctype(script.len()) == Some(HTLCType::AcceptedHTLC) { &[InputDescriptors::RevokedReceivedHTLC] } else { unreachable!() });
                                        amt += *amount;
                                },
@@ -3233,6 +3284,20 @@ impl<R: ::std::io::Read, ChanSigner: ChannelKeys + Readable<R>> ReadableArgs<R,
                        onchain_events_waiting_threshold_conf.insert(height_target, events);
                }
 
+               let outputs_to_watch_len: u64 = Readable::read(reader)?;
+               let mut outputs_to_watch = HashMap::with_capacity(cmp::min(outputs_to_watch_len as usize, MAX_ALLOC_SIZE / (mem::size_of::<Sha256dHash>() + mem::size_of::<Vec<Script>>())));
+               for _ in 0..outputs_to_watch_len {
+                       let txid = Readable::read(reader)?;
+                       let outputs_len: u64 = Readable::read(reader)?;
+                       let mut outputs = Vec::with_capacity(cmp::min(outputs_len as usize, MAX_ALLOC_SIZE / mem::size_of::<Script>()));
+                       for _ in 0..outputs_len {
+                               outputs.push(Readable::read(reader)?);
+                       }
+                       if let Some(_) = outputs_to_watch.insert(txid, outputs) {
+                               return Err(DecodeError::InvalidValue);
+                       }
+               }
+
                Ok((last_block_hash.clone(), ChannelMonitor {
                        commitment_transaction_number_obscure_factor,
 
@@ -3265,6 +3330,7 @@ impl<R: ::std::io::Read, ChanSigner: ChannelKeys + Readable<R>> ReadableArgs<R,
                        claimable_outpoints,
 
                        onchain_events_waiting_threshold_conf,
+                       outputs_to_watch,
 
                        last_block_hash,
                        secp_ctx,