]> git.bitcoin.ninja Git - rust-lightning/commitdiff
Refactor/dont re-enter block_conencted on duplicate watch calls
authorMatt Corallo <git@bluematt.me>
Fri, 7 Sep 2018 15:56:41 +0000 (11:56 -0400)
committerMatt Corallo <git@bluematt.me>
Fri, 7 Sep 2018 17:34:58 +0000 (13:34 -0400)
Previously we'd hit an infinite loop if a block_connected call
always resulted in the same ChainWatchInterface registrations.
While we're at it, we also split ChainWatchUtil in two to make
things a bit more flexible for users, though not sure if that
actually matters, and make the matching more aggressive in testing,
even if we pick the more performant option at runtime.

src/chain/chaininterface.rs

index 3a0b2d735d9a877eb98f95dd23739a472ad6a54d..5ff720c32e5fed57d8c0ad23e20c3180c675674e 100644 (file)
@@ -5,9 +5,12 @@ use bitcoin::blockdata::constants::genesis_block;
 use bitcoin::util::hash::Sha256dHash;
 use bitcoin::network::constants::Network;
 use bitcoin::network::serialize::BitcoinHash;
+
 use util::logger::Logger;
+
 use std::sync::{Mutex,Weak,MutexGuard,Arc};
 use std::sync::atomic::{AtomicUsize, Ordering};
+use std::collections::HashSet;
 
 /// Used to give chain error details upstream
 pub enum ChainError {
@@ -57,6 +60,8 @@ pub trait ChainListener: Sync + Send {
        /// Note that if a new transaction/outpoint is watched during a block_connected call, the block
        /// *must* be re-scanned with the new transaction/outpoints and block_connected should be
        /// called again with the same header and (at least) the new transactions.
+       /// Note that if non-new transaction/outpoints may be registered during a call, a second call
+       /// *must not* happen.
        /// This also means those counting confirmations using block_connected callbacks should watch
        /// for duplicate headers and not count them towards confirmations!
        fn block_connected(&self, header: &BlockHeader, height: u32, txn_matched: &[&Transaction], indexes_of_txn_matched: &[u32]);
@@ -85,11 +90,98 @@ pub trait FeeEstimator: Sync + Send {
        fn get_est_sat_per_1000_weight(&self, confirmation_target: ConfirmationTarget) -> u64;
 }
 
+/// Utility for tracking registered txn/outpoints and checking for matches
+pub struct ChainWatchedUtil {
+       watch_all: bool,
+
+       // We are more conservative in matching during testing to ensure everything matches *exactly*,
+       // even though during normal runtime we take more optimized match approaches...
+       #[cfg(test)]
+       watched_txn: HashSet<(Sha256dHash, Script)>,
+       #[cfg(not(test))]
+       watched_txn: HashSet<Script>,
+
+       watched_outpoints: HashSet<(Sha256dHash, u32)>,
+}
+
+impl ChainWatchedUtil {
+       /// Constructs an empty (watches nothing) ChainWatchedUtil
+       pub fn new() -> Self {
+               Self {
+                       watch_all: false,
+                       watched_txn: HashSet::new(),
+                       watched_outpoints: HashSet::new(),
+               }
+       }
+
+       /// Registers a tx for monitoring, returning true if it was a new tx and false if we'd already
+       /// been watching for it.
+       pub fn register_tx(&mut self, txid: &Sha256dHash, script_pub_key: &Script) -> bool {
+               if self.watch_all { return false; }
+               #[cfg(test)]
+               {
+                       self.watched_txn.insert((txid.clone(), script_pub_key.clone()))
+               }
+               #[cfg(not(test))]
+               {
+                       let _tx_unused = txid; // Its used in cfg(test), though
+                       self.watched_txn.insert(script_pub_key.clone())
+               }
+       }
+
+       /// Registers an outpoint for monitoring, returning true if it was a new outpoint and false if
+       /// we'd already been watching for it
+       pub fn register_outpoint(&mut self, outpoint: (Sha256dHash, u32), _script_pub_key: &Script) -> bool {
+               if self.watch_all { return false; }
+               self.watched_outpoints.insert(outpoint)
+       }
+
+       /// Sets us to match all transactions, returning true if this is a new setting anf false if
+       /// we'd already been set to match everything.
+       pub fn watch_all(&mut self) -> bool {
+               if self.watch_all { return false; }
+               self.watch_all = true;
+               true
+       }
+
+       /// Checks if a given transaction matches the current filter.
+       pub fn does_match_tx(&self, tx: &Transaction) -> bool {
+               if self.watch_all {
+                       return true;
+               }
+               for out in tx.output.iter() {
+                       #[cfg(test)]
+                       for &(ref txid, ref script) in self.watched_txn.iter() {
+                               if *script == out.script_pubkey {
+                                       if tx.txid() == *txid {
+                                               return true;
+                                       }
+                               }
+                       }
+                       #[cfg(not(test))]
+                       for script in self.watched_txn.iter() {
+                               if *script == out.script_pubkey {
+                                       return true;
+                               }
+                       }
+               }
+               for input in tx.input.iter() {
+                       for outpoint in self.watched_outpoints.iter() {
+                               let &(outpoint_hash, outpoint_index) = outpoint;
+                               if outpoint_hash == input.previous_output.txid && outpoint_index == input.previous_output.vout {
+                                       return true;
+                               }
+                       }
+               }
+               false
+       }
+}
+
 /// Utility to capture some common parts of ChainWatchInterface implementors.
 /// Keeping a local copy of this in a ChainWatchInterface implementor is likely useful.
 pub struct ChainWatchInterfaceUtil {
        network: Network,
-       watched: Mutex<(Vec<Script>, Vec<(Sha256dHash, u32)>, bool)>, //TODO: Something clever to optimize this
+       watched: Mutex<ChainWatchedUtil>,
        listeners: Mutex<Vec<Weak<ChainListener>>>,
        reentered: AtomicUsize,
        logger: Arc<Logger>,
@@ -97,22 +189,25 @@ pub struct ChainWatchInterfaceUtil {
 
 /// Register listener
 impl ChainWatchInterface for ChainWatchInterfaceUtil {
-       fn install_watch_tx(&self, _txid: &Sha256dHash, script_pub_key: &Script) {
+       fn install_watch_tx(&self, txid: &Sha256dHash, script_pub_key: &Script) {
                let mut watched = self.watched.lock().unwrap();
-               watched.0.push(script_pub_key.clone());
-               self.reentered.fetch_add(1, Ordering::Relaxed);
+               if watched.register_tx(txid, script_pub_key) {
+                       self.reentered.fetch_add(1, Ordering::Relaxed);
+               }
        }
 
-       fn install_watch_outpoint(&self, outpoint: (Sha256dHash, u32), _out_script: &Script) {
+       fn install_watch_outpoint(&self, outpoint: (Sha256dHash, u32), out_script: &Script) {
                let mut watched = self.watched.lock().unwrap();
-               watched.1.push(outpoint);
-               self.reentered.fetch_add(1, Ordering::Relaxed);
+               if watched.register_outpoint(outpoint, out_script) {
+                       self.reentered.fetch_add(1, Ordering::Relaxed);
+               }
        }
 
        fn watch_all_txn(&self) {
                let mut watched = self.watched.lock().unwrap();
-               watched.2 = true;
-               self.reentered.fetch_add(1, Ordering::Relaxed);
+               if watched.watch_all() {
+                       self.reentered.fetch_add(1, Ordering::Relaxed);
+               }
        }
 
        fn register_listener(&self, listener: Weak<ChainListener>) {
@@ -132,7 +227,7 @@ impl ChainWatchInterfaceUtil {
        pub fn new(network: Network, logger: Arc<Logger>) -> ChainWatchInterfaceUtil {
                ChainWatchInterfaceUtil {
                        network: network,
-                       watched: Mutex::new((Vec::new(), Vec::new(), false)),
+                       watched: Mutex::new(ChainWatchedUtil::new()),
                        listeners: Mutex::new(Vec::new()),
                        reentered: AtomicUsize::new(1),
                        logger: logger,
@@ -195,25 +290,7 @@ impl ChainWatchInterfaceUtil {
                self.does_match_tx_unguarded (tx, &watched)
        }
 
-       fn does_match_tx_unguarded(&self, tx: &Transaction, watched: &MutexGuard<(Vec<Script>, Vec<(Sha256dHash, u32)>, bool)>) -> bool {
-               if watched.2 {
-                       return true;
-               }
-               for out in tx.output.iter() {
-                       for script in watched.0.iter() {
-                               if script[..] == out.script_pubkey[..] {
-                                       return true;
-                               }
-                       }
-               }
-               for input in tx.input.iter() {
-                       for outpoint in watched.1.iter() {
-                               let &(outpoint_hash, outpoint_index) = outpoint;
-                               if outpoint_hash == input.previous_output.txid && outpoint_index == input.previous_output.vout {
-                                       return true;
-                               }
-                       }
-               }
-               false
+       fn does_match_tx_unguarded(&self, tx: &Transaction, watched: &MutexGuard<ChainWatchedUtil>) -> bool {
+               watched.does_match_tx(tx)
        }
 }