Merge pull request #455 from TheBlueMatt/2020-01-monitor-reload-watch
authorMatt Corallo <649246+TheBlueMatt@users.noreply.github.com>
Wed, 19 Feb 2020 00:06:26 +0000 (00:06 +0000)
committerGitHub <noreply@github.com>
Wed, 19 Feb 2020 00:06:26 +0000 (00:06 +0000)
Track the full list of outpoints a chanmon wants monitoring for

lightning/src/chain/chaininterface.rs
lightning/src/ln/channelmonitor.rs
lightning/src/ln/functional_test_utils.rs

index 0845eb5fc6965392cf6340a5a682bffe70ebe7db..7a077e895a615fce737c4aa4db1cd13c0667e092 100644 (file)
@@ -126,6 +126,7 @@ pub trait FeeEstimator: Sync + Send {
 pub const MIN_RELAY_FEE_SAT_PER_1000_WEIGHT: u64 = 4000;
 
 /// Utility for tracking registered txn/outpoints and checking for matches
+#[cfg_attr(test, derive(PartialEq))]
 pub struct ChainWatchedUtil {
        watch_all: bool,
 
@@ -305,6 +306,17 @@ pub struct ChainWatchInterfaceUtil {
        logger: Arc<Logger>,
 }
 
+// We only expose PartialEq in test since its somewhat unclear exactly what it should do and we're
+// only comparing a subset of fields (essentially just checking that the set of things we're
+// watching is the same).
+#[cfg(test)]
+impl PartialEq for ChainWatchInterfaceUtil {
+       fn eq(&self, o: &Self) -> bool {
+               self.network == o.network &&
+               *self.watched.lock().unwrap() == *o.watched.lock().unwrap()
+       }
+}
+
 /// Register listener
 impl ChainWatchInterface for ChainWatchInterfaceUtil {
        fn install_watch_tx(&self, txid: &Sha256dHash, script_pub_key: &Script) {
index a1577712fdd8e2196529535fa912c3432c46aebd..68674ac3ae098c3db7d70a173755e44232c77c44 100644 (file)
@@ -117,9 +117,16 @@ pub struct HTLCUpdate {
 pub trait ManyChannelMonitor<ChanSigner: ChannelKeys>: Send + Sync {
        /// Adds or updates a monitor for the given `funding_txo`.
        ///
-       /// Implementor must also ensure that the funding_txo outpoint is registered with any relevant
-       /// ChainWatchInterfaces such that the provided monitor receives block_connected callbacks with
-       /// any spends of it.
+       /// Implementer must also ensure that the funding_txo txid *and* outpoint are registered with
+       /// any relevant ChainWatchInterfaces such that the provided monitor receives block_connected
+       /// callbacks with the funding transaction, or any spends of it.
+       ///
+       /// Further, the implementer must also ensure that each output returned in
+       /// monitor.get_outputs_to_watch() is registered to ensure that the provided monitor learns about
+       /// any spends of any of the outputs.
+       ///
+       /// Any spends of outputs which should have been registered which aren't passed to
+       /// ChannelMonitors via block_connected may result in funds loss.
        fn add_update_monitor(&self, funding_txo: OutPoint, monitor: ChannelMonitor<ChanSigner>) -> Result<(), ChannelMonitorUpdateErr>;
 
        /// Used by ChannelManager to get list of HTLC resolved onchain and which needed to be updated
@@ -259,6 +266,11 @@ impl<Key : Send + cmp::Eq + hash::Hash + 'static, ChanSigner: ChannelKeys> Simpl
                                self.chain_monitor.watch_all_txn();
                        }
                }
+               for (txid, outputs) in monitor.get_outputs_to_watch().iter() {
+                       for (idx, script) in outputs.iter().enumerate() {
+                               self.chain_monitor.install_watch_outpoint((*txid, idx as u32), script);
+                       }
+               }
                monitors.insert(key, monitor);
                Ok(())
        }
@@ -666,6 +678,12 @@ pub struct ChannelMonitor<ChanSigner: ChannelKeys> {
        // actions when we receive a block with given height. Actions depend on OnchainEvent type.
        onchain_events_waiting_threshold_conf: HashMap<u32, Vec<OnchainEvent>>,
 
+       // If we get serialized out and re-read, we need to make sure that the chain monitoring
+       // interface knows about the TXOs that we want to be notified of spends of. We could probably
+       // be smart and derive them from the above storage fields, but its much simpler and more
+       // Obviously Correct (tm) if we just keep track of them explicitly.
+       outputs_to_watch: HashMap<Sha256dHash, Vec<Script>>,
+
        // We simply modify last_block_hash in Channel's block_connected so that serialization is
        // consistent but hopefully the users' copy handles block_connected in a consistent way.
        // (we do *not*, however, update them in insert_combine to ensure any local user copies keep
@@ -736,7 +754,8 @@ impl<ChanSigner: ChannelKeys> PartialEq for ChannelMonitor<ChanSigner> {
                        self.to_remote_rescue != other.to_remote_rescue ||
                        self.pending_claim_requests != other.pending_claim_requests ||
                        self.claimable_outpoints != other.claimable_outpoints ||
-                       self.onchain_events_waiting_threshold_conf != other.onchain_events_waiting_threshold_conf
+                       self.onchain_events_waiting_threshold_conf != other.onchain_events_waiting_threshold_conf ||
+                       self.outputs_to_watch != other.outputs_to_watch
                {
                        false
                } else {
@@ -966,6 +985,15 @@ impl<ChanSigner: ChannelKeys + Writeable> ChannelMonitor<ChanSigner> {
                        }
                }
 
+               (self.outputs_to_watch.len() as u64).write(writer)?;
+               for (txid, output_scripts) in self.outputs_to_watch.iter() {
+                       txid.write(writer)?;
+                       (output_scripts.len() as u64).write(writer)?;
+                       for script in output_scripts.iter() {
+                               script.write(writer)?;
+                       }
+               }
+
                Ok(())
        }
 
@@ -1036,6 +1064,7 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                        claimable_outpoints: HashMap::new(),
 
                        onchain_events_waiting_threshold_conf: HashMap::new(),
+                       outputs_to_watch: HashMap::new(),
 
                        last_block_hash: Default::default(),
                        secp_ctx: Secp256k1::new(),
@@ -1370,6 +1399,12 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                }
        }
 
+       /// Gets a list of txids, with their output scripts (in the order they appear in the
+       /// transaction), which we must learn about spends of via block_connected().
+       pub fn get_outputs_to_watch(&self) -> &HashMap<Sha256dHash, Vec<Script>> {
+               &self.outputs_to_watch
+       }
+
        /// Gets the sets of all outpoints which this ChannelMonitor expects to hear about spends of.
        /// Generally useful when deserializing as during normal operation the return values of
        /// block_connected are sufficient to ensure all relevant outpoints are being monitored (note
@@ -2362,6 +2397,11 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                }
        }
 
+       /// Called by SimpleManyChannelMonitor::block_connected, which implements
+       /// ChainListener::block_connected.
+       /// Eventually this should be pub and, roughly, implement ChainListener, however this requires
+       /// &mut self, as well as returns new spendable outputs and outpoints to watch for spending of
+       /// on-chain.
        fn block_connected(&mut self, txn_matched: &[&Transaction], height: u32, block_hash: &Sha256dHash, broadcaster: &BroadcasterInterface, fee_estimator: &FeeEstimator)-> (Vec<(Sha256dHash, Vec<TxOut>)>, Vec<SpendableOutputDescriptor>, Vec<(HTLCSource, Option<PaymentPreimage>, PaymentHash)>) {
                for tx in txn_matched {
                        let mut output_val = 0;
@@ -2589,6 +2629,9 @@ impl<ChanSigner: ChannelKeys> ChannelMonitor<ChanSigner> {
                        }
                }
                self.last_block_hash = block_hash.clone();
+               for &(ref txid, ref output_scripts) in watch_outputs.iter() {
+                       self.outputs_to_watch.insert(txid.clone(), output_scripts.iter().map(|o| o.script_pubkey.clone()).collect());
+               }
                (watch_outputs, spendable_outputs, htlc_updated)
        }
 
@@ -3241,6 +3284,20 @@ impl<R: ::std::io::Read, ChanSigner: ChannelKeys + Readable<R>> ReadableArgs<R,
                        onchain_events_waiting_threshold_conf.insert(height_target, events);
                }
 
+               let outputs_to_watch_len: u64 = Readable::read(reader)?;
+               let mut outputs_to_watch = HashMap::with_capacity(cmp::min(outputs_to_watch_len as usize, MAX_ALLOC_SIZE / (mem::size_of::<Sha256dHash>() + mem::size_of::<Vec<Script>>())));
+               for _ in 0..outputs_to_watch_len {
+                       let txid = Readable::read(reader)?;
+                       let outputs_len: u64 = Readable::read(reader)?;
+                       let mut outputs = Vec::with_capacity(cmp::min(outputs_len as usize, MAX_ALLOC_SIZE / mem::size_of::<Script>()));
+                       for _ in 0..outputs_len {
+                               outputs.push(Readable::read(reader)?);
+                       }
+                       if let Some(_) = outputs_to_watch.insert(txid, outputs) {
+                               return Err(DecodeError::InvalidValue);
+                       }
+               }
+
                Ok((last_block_hash.clone(), ChannelMonitor {
                        commitment_transaction_number_obscure_factor,
 
@@ -3273,6 +3330,7 @@ impl<R: ::std::io::Read, ChanSigner: ChannelKeys + Readable<R>> ReadableArgs<R,
                        claimable_outpoints,
 
                        onchain_events_waiting_threshold_conf,
+                       outputs_to_watch,
 
                        last_block_hash,
                        secp_ctx,
index 1ae8ca1e0058cbb59a006017fc30f9aa4bb75849..674e7d0ff322df1b17f7db5926601789c4ddbe50 100644 (file)
@@ -5,6 +5,7 @@ use chain::chaininterface;
 use chain::transaction::OutPoint;
 use chain::keysinterface::KeysInterface;
 use ln::channelmanager::{ChannelManager,RAACommitmentOrder, PaymentPreimage, PaymentHash};
+use ln::channelmonitor::{ChannelMonitor, ManyChannelMonitor};
 use ln::router::{Route, Router};
 use ln::features::InitFeatures;
 use ln::msgs;
@@ -16,6 +17,7 @@ use util::events::{Event, EventsProvider, MessageSendEvent, MessageSendEventsPro
 use util::errors::APIError;
 use util::logger::Logger;
 use util::config::UserConfig;
+use util::ser::ReadableArgs;
 
 use bitcoin::util::hash::BitcoinHash;
 use bitcoin::blockdata::block::BlockHeader;
@@ -89,6 +91,27 @@ impl<'a, 'b> Drop for Node<'a, 'b> {
                        assert!(self.node.get_and_clear_pending_msg_events().is_empty());
                        assert!(self.node.get_and_clear_pending_events().is_empty());
                        assert!(self.chan_monitor.added_monitors.lock().unwrap().is_empty());
+
+                       // Check that if we serialize and then deserialize all our channel monitors we get the
+                       // same set of outputs to watch for on chain as we have now. Note that if we write
+                       // tests that fully close channels and remove the monitors at some point this may break.
+                       let chain_watch = Arc::new(chaininterface::ChainWatchInterfaceUtil::new(Network::Testnet, Arc::clone(&self.logger) as Arc<Logger>));
+                       let feeest = Arc::new(test_utils::TestFeeEstimator { sat_per_kw: 253 });
+                       let channel_monitor = test_utils::TestChannelMonitor::new(chain_watch.clone(), self.tx_broadcaster.clone(), self.logger.clone(), feeest);
+                       let old_monitors = self.chan_monitor.simple_monitor.monitors.lock().unwrap();
+                       for (_, old_monitor) in old_monitors.iter() {
+                               let mut w = test_utils::TestVecWriter(Vec::new());
+                               old_monitor.write_for_disk(&mut w).unwrap();
+                               let (_, deserialized_monitor) = <(Sha256d, ChannelMonitor<EnforcingChannelKeys>)>::read(
+                                       &mut ::std::io::Cursor::new(&w.0), Arc::clone(&self.logger) as Arc<Logger>).unwrap();
+                               if let Err(_) = channel_monitor.add_update_monitor(deserialized_monitor.get_funding_txo().unwrap(), deserialized_monitor) {
+                                       panic!();
+                               }
+                       }
+
+                       if *chain_watch != *self.chain_monitor {
+                               panic!();
+                       }
                }
        }
 }