TESTING
[rust-lightning] / fuzz / src / full_stack.rs
index e6496125ac656ded582209d631ae74db71ba728a..720b9620529741421c27a6450911b6a65755fa6b 100644 (file)
@@ -39,7 +39,7 @@ use std::cell::RefCell;
 use std::collections::{HashMap, hash_map};
 use std::cmp;
 use std::hash::Hash;
-use std::sync::Arc;
+use std::sync::{Arc,Mutex};
 use std::sync::atomic::{AtomicU64,AtomicUsize,Ordering};
 
 #[inline]
@@ -96,6 +96,7 @@ struct FuzzEstimator {
 }
 impl FeeEstimator for FuzzEstimator {
        fn get_est_sat_per_1000_weight(&self, _: ConfirmationTarget) -> u64 {
+println!("fee_get");
                //TODO: We should actually be testing at least much more than 64k...
                match self.input.get_slice(2) {
                        Some(slice) => cmp::max(slice_to_be16(slice) as u64, 253),
@@ -104,9 +105,13 @@ impl FeeEstimator for FuzzEstimator {
        }
 }
 
-struct TestBroadcaster {}
+pub struct TestBroadcaster {
+       pub txn_broadcasted: Mutex<Vec<Transaction>>,
+}
 impl BroadcasterInterface for TestBroadcaster {
-       fn broadcast_transaction(&self, _tx: &Transaction) {}
+       fn broadcast_transaction(&self, tx: &Transaction) {
+               self.txn_broadcasted.lock().unwrap().push(tx.clone());
+       }
 }
 
 #[derive(Clone)]
@@ -136,9 +141,10 @@ impl<'a> Hash for Peer<'a> {
 }
 
 struct MoneyLossDetector<'a> {
-       manager: Arc<ChannelManager<EnforcingChannelKeys>>,
+       manager: Arc<ChannelManager<EnforcingChannelKeys, Arc<channelmonitor::SimpleManyChannelMonitor<OutPoint>>>>,
        monitor: Arc<channelmonitor::SimpleManyChannelMonitor<OutPoint>>,
-       handler: PeerManager<Peer<'a>>,
+       broadcaster: Arc<TestBroadcaster>,
+       handler: PeerManager<Peer<'a>, Arc<ChannelManager<EnforcingChannelKeys, Arc<channelmonitor::SimpleManyChannelMonitor<OutPoint>>>>>,
 
        peers: &'a RefCell<[bool; 256]>,
        funding_txn: Vec<Transaction>,
@@ -149,10 +155,11 @@ struct MoneyLossDetector<'a> {
        blocks_connected: u32,
 }
 impl<'a> MoneyLossDetector<'a> {
-       pub fn new(peers: &'a RefCell<[bool; 256]>, manager: Arc<ChannelManager<EnforcingChannelKeys>>, monitor: Arc<channelmonitor::SimpleManyChannelMonitor<OutPoint>>, handler: PeerManager<Peer<'a>>) -> Self {
+       pub fn new(peers: &'a RefCell<[bool; 256]>, manager: Arc<ChannelManager<EnforcingChannelKeys, Arc<channelmonitor::SimpleManyChannelMonitor<OutPoint>>>>, monitor: Arc<channelmonitor::SimpleManyChannelMonitor<OutPoint>>, broadcaster: Arc<TestBroadcaster>, handler: PeerManager<Peer<'a>, Arc<ChannelManager<EnforcingChannelKeys, Arc<channelmonitor::SimpleManyChannelMonitor<OutPoint>>>>>) -> Self {
                MoneyLossDetector {
                        manager,
                        monitor,
+                       broadcaster,
                        handler,
 
                        peers,
@@ -220,7 +227,38 @@ impl<'a> Drop for MoneyLossDetector<'a> {
 
                        // Force all channels onto the chain (and time out claim txn)
                        self.manager.force_close_all_channels();
+                       for _ in 0..6*24*14 {
+                               self.connect_block(&[]);
+                       }
+               }
+
+               // Test that all broadcasted transactions either spend one of our funding transactions or
+               // some other broadcasted transaction:
+
+               let mut txn_map = HashMap::new();
+               let mut funding_txn_map = HashMap::new();
+               for tx in self.funding_txn.drain(..) {
+                       funding_txn_map.insert(tx.bitcoin_hash(), tx);
                }
+               let mut txn_broadcasted = self.broadcaster.txn_broadcasted.lock().unwrap();
+               for tx in txn_broadcasted.drain(..) {
+                       txn_map.insert(tx.bitcoin_hash(), tx);
+               }
+               /*for (_, tx) in txn_map.iter() {
+                       for inp in tx.input.iter() {
+                               let prev_tx = match funding_txn_map.get(&inp.prev_hash) {
+                                       Some(ptx) => ptx,
+                                       None => {
+                                               txn_map.get(&inp.prev_hash).unwrap()
+                                       }
+                               };
+                               assert!(prev_tx.output.len() > inp.prev_index as usize);
+                       }
+               }*/
+
+               //XXX: Find all non-conflicting sets of txn broadcasted and ensure that in each case we
+               //always get back at least the amount we expect minus tx fees (which we should be able to
+               //calculate now!
        }
 }
 
@@ -247,7 +285,7 @@ impl KeysInterface for KeyProvider {
                PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap())
        }
 
-       fn get_channel_keys(&self, inbound: bool) -> EnforcingChannelKeys {
+       fn get_channel_keys(&self, inbound: bool, channel_value_satoshis: u64) -> EnforcingChannelKeys {
                let ctr = self.counter.fetch_add(1, Ordering::Relaxed) as u8;
                EnforcingChannelKeys::new(if inbound {
                        InMemoryChannelKeys {
@@ -257,7 +295,8 @@ impl KeysInterface for KeyProvider {
                                delayed_payment_base_key:  SecretKey::from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, ctr]).unwrap(),
                                htlc_base_key:             SecretKey::from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, ctr]).unwrap(),
                                commitment_seed: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, ctr],
-                               remote_funding_pubkey: None,
+                               remote_channel_pubkeys: None,
+                               channel_value_satoshis: channel_value_satoshis,
                        }
                } else {
                        InMemoryChannelKeys {
@@ -267,7 +306,8 @@ impl KeysInterface for KeyProvider {
                                delayed_payment_base_key:  SecretKey::from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, ctr]).unwrap(),
                                htlc_base_key:             SecretKey::from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, ctr]).unwrap(),
                                commitment_seed: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, ctr],
-                               remote_funding_pubkey: None,
+                               remote_channel_pubkeys: None,
+                               channel_value_satoshis: channel_value_satoshis,
                        }
                })
        }
@@ -319,32 +359,34 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
        };
 
        let watch = Arc::new(ChainWatchInterfaceUtil::new(Network::Bitcoin, Arc::clone(&logger)));
-       let broadcast = Arc::new(TestBroadcaster{});
-       let monitor = channelmonitor::SimpleManyChannelMonitor::new(watch.clone(), broadcast.clone(), Arc::clone(&logger), fee_est.clone());
+       let broadcast = Arc::new(TestBroadcaster{ txn_broadcasted: Mutex::new(Vec::new()) });
+       let monitor = Arc::new(channelmonitor::SimpleManyChannelMonitor::new(watch.clone(), broadcast.clone(), Arc::clone(&logger), fee_est.clone()));
 
        let keys_manager = Arc::new(KeyProvider { node_secret: our_network_key.clone(), counter: AtomicU64::new(0) });
        let mut config = UserConfig::default();
        config.channel_options.fee_proportional_millionths =  slice_to_be32(get_slice!(4));
        config.channel_options.announced_channel = get_slice!(1)[0] != 0;
        config.peer_channel_config_limits.min_dust_limit_satoshis = 0;
-       let channelmanager = ChannelManager::new(Network::Bitcoin, fee_est.clone(), monitor.clone(), broadcast.clone(), Arc::clone(&logger), keys_manager.clone(), config, 0).unwrap();
+       let channelmanager = Arc::new(ChannelManager::new(Network::Bitcoin, fee_est.clone(), monitor.clone(), broadcast.clone(), Arc::clone(&logger), keys_manager.clone(), config, 0).unwrap());
        let router = Arc::new(Router::new(PublicKey::from_secret_key(&Secp256k1::signing_only(), &keys_manager.get_node_secret()), watch.clone(), Arc::clone(&logger)));
 
        let peers = RefCell::new([false; 256]);
-       let mut loss_detector = MoneyLossDetector::new(&peers, channelmanager.clone(), monitor.clone(), PeerManager::new(MessageHandler {
+       let mut loss_detector = MoneyLossDetector::new(&peers, channelmanager.clone(), monitor.clone(), broadcast.clone(), PeerManager::new(MessageHandler {
                chan_handler: channelmanager.clone(),
                route_handler: router.clone(),
        }, our_network_key, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0], Arc::clone(&logger)));
 
        let mut should_forward = false;
-       let mut payments_received: Vec<(PaymentHash, u64)> = Vec::new();
+       let mut payments_received: Vec<(PaymentHash, Option<[u8; 32]>, u64)> = Vec::new();
        let mut payments_sent = 0;
        let mut pending_funding_generation: Vec<([u8; 32], u64, Script)> = Vec::new();
        let mut pending_funding_signatures = HashMap::new();
        let mut pending_funding_relay = Vec::new();
 
        loop {
-               match get_slice!(1)[0] {
+let a = get_slice!(1)[0];
+println!("action: {}", a);
+               match a {
                        0 => {
                                let mut new_id = 0;
                                for i in 1..256 {
@@ -378,7 +420,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                        3 => {
                                let peer_id = get_slice!(1)[0];
                                if !peers.borrow()[peer_id as usize] { return; }
-                               match loss_detector.handler.read_event(&mut Peer{id: peer_id, peers_connected: &peers}, get_slice!(get_slice!(1)[0]).to_vec()) {
+                               match loss_detector.handler.read_event(&mut Peer{id: peer_id, peers_connected: &peers}, get_slice!(get_slice!(1)[0])) {
                                        Ok(res) => assert!(!res),
                                        Err(_) => { peers.borrow_mut()[peer_id as usize] = false; }
                                }
@@ -395,7 +437,7 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                sha.input(&payment_hash.0[..]);
                                payment_hash.0 = Sha256::from_engine(sha).into_inner();
                                payments_sent += 1;
-                               match channelmanager.send_payment(route, payment_hash) {
+                               match channelmanager.send_payment(route, payment_hash, None) {
                                        Ok(_) => {},
                                        Err(_) => return,
                                }
@@ -422,23 +464,23 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                }
                        },
                        8 => {
-                               for (payment, amt) in payments_received.drain(..) {
+                               for (payment, payment_secret, amt) in payments_received.drain(..) {
                                        // SHA256 is defined as XOR of all input bytes placed in the first byte, and 0s
                                        // for the remaining bytes. Thus, if not all remaining bytes are 0s we cannot
                                        // fulfill this HTLC, but if they are, we can just take the first byte and
                                        // place that anywhere in our preimage.
                                        if &payment.0[1..] != &[0; 31] {
-                                               channelmanager.fail_htlc_backwards(&payment);
+                                               channelmanager.fail_htlc_backwards(&payment, &payment_secret);
                                        } else {
                                                let mut payment_preimage = PaymentPreimage([0; 32]);
                                                payment_preimage.0[0] = payment.0[0];
-                                               channelmanager.claim_funds(payment_preimage, amt);
+                                               channelmanager.claim_funds(payment_preimage, &payment_secret, amt);
                                        }
                                }
                        },
                        9 => {
-                               for (payment, _) in payments_received.drain(..) {
-                                       channelmanager.fail_htlc_backwards(&payment);
+                               for (payment, payment_secret, _) in payments_received.drain(..) {
+                                       channelmanager.fail_htlc_backwards(&payment, &payment_secret);
                                }
                        },
                        10 => {
@@ -485,6 +527,12 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                                } else {
                                        let txres: Result<Transaction, _> = deserialize(get_slice!(txlen));
                                        if let Ok(tx) = txres {
+                                               let mut output_val = 0;
+                                               for out in tx.output.iter() {
+                                                       if out.value > 21_000_000_0000_0000 { return; }
+                                                       output_val += out.value;
+                                                       if output_val > 21_000_000_0000_0000 { return; }
+                                               }
                                                loss_detector.connect_block(&[tx]);
                                        } else {
                                                return;
@@ -503,22 +551,26 @@ pub fn do_test(data: &[u8], logger: &Arc<dyn Logger>) {
                        },
                        _ => return,
                }
+println!("PROCESSING EVENTS");
                loss_detector.handler.process_events();
                for event in loss_detector.manager.get_and_clear_pending_events() {
                        match event {
                                Event::FundingGenerationReady { temporary_channel_id, channel_value_satoshis, output_script, .. } => {
+println!("fgr");
                                        pending_funding_generation.push((temporary_channel_id, channel_value_satoshis, output_script));
                                },
                                Event::FundingBroadcastSafe { funding_txo, .. } => {
+println!("fbs");
                                        pending_funding_relay.push(pending_funding_signatures.remove(&funding_txo).unwrap());
                                },
-                               Event::PaymentReceived { payment_hash, amt } => {
+                               Event::PaymentReceived { payment_hash, payment_secret, amt } => {
                                        //TODO: enhance by fetching random amounts from fuzz input?
-                                       payments_received.push((payment_hash, amt));
+                                       payments_received.push((payment_hash, payment_secret, amt));
                                },
                                Event::PaymentSent {..} => {},
                                Event::PaymentFailed {..} => {},
                                Event::PendingHTLCsForwardable {..} => {
+println!("PENDING HTLCS FORWARDABLE");
                                        should_forward = true;
                                },
                                Event::SpendableOutputs {..} => {},
@@ -540,12 +592,6 @@ mod tests {
        use std::collections::HashMap;
        use std::sync::{Arc, Mutex};
 
-       #[test]
-       fn duplicate_crash() {
-               let logger: Arc<dyn Logger> = Arc::new(test_logger::TestLogger::new("".to_owned()));
-               super::do_test(&::hex::decode("00").unwrap(), &logger);
-       }
-
        struct TrackingLogger {
                /// (module, message) -> count
                pub lines: Mutex<HashMap<(String, String), usize>>,