Check expected amount in claim_funds
[rust-lightning] / fuzz / fuzz_targets / full_stack_target.rs
index acd757882b9825a77e2685ff74dae19608fead07..41ab473fd61c4877412b1dcf36acfde09f5bb3d2 100644 (file)
@@ -49,7 +49,7 @@ use std::collections::{HashMap, hash_map};
 use std::cmp;
 use std::hash::Hash;
 use std::sync::Arc;
-use std::sync::atomic::{AtomicU8,AtomicUsize,Ordering};
+use std::sync::atomic::{AtomicU64,AtomicUsize,Ordering};
 
 #[inline]
 pub fn slice_to_be16(v: &[u8]) -> u16 {
@@ -124,9 +124,8 @@ struct Peer<'a> {
        peers_connected: &'a RefCell<[bool; 256]>,
 }
 impl<'a> SocketDescriptor for Peer<'a> {
-       fn send_data(&mut self, data: &Vec<u8>, write_offset: usize, _resume_read: bool) -> usize {
-               assert!(write_offset < data.len());
-               data.len() - write_offset
+       fn send_data(&mut self, data: &[u8], _resume_read: bool) -> usize {
+               data.len()
        }
        fn disconnect_socket(&mut self) {
                assert!(self.peers_connected.borrow()[self.id as usize]);
@@ -236,7 +235,7 @@ impl<'a> Drop for MoneyLossDetector<'a> {
 
 struct KeyProvider {
        node_secret: SecretKey,
-       counter: AtomicU8,
+       counter: AtomicU64,
 }
 impl KeysInterface for KeyProvider {
        fn get_node_secret(&self) -> SecretKey {
@@ -256,7 +255,7 @@ impl KeysInterface for KeyProvider {
        }
 
        fn get_channel_keys(&self, inbound: bool) -> ChannelKeys {
-               let ctr = self.counter.fetch_add(1, Ordering::Relaxed);
+               let ctr = self.counter.fetch_add(1, Ordering::Relaxed) as u8;
                if inbound {
                        ChannelKeys {
                                funding_key:               SecretKey::from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ctr]).unwrap(),
@@ -279,13 +278,14 @@ impl KeysInterface for KeyProvider {
        }
 
        fn get_session_key(&self) -> SecretKey {
-               let ctr = self.counter.fetch_add(1, Ordering::Relaxed);
+               let ctr = self.counter.fetch_add(1, Ordering::Relaxed) as u8;
                SecretKey::from_slice(&[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, ctr]).unwrap()
        }
 
        fn get_channel_id(&self) -> [u8; 32] {
                let ctr = self.counter.fetch_add(1, Ordering::Relaxed);
-               [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, ctr]
+               [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+               (ctr >> 8*7) as u8, (ctr >> 8*6) as u8, (ctr >> 8*5) as u8, (ctr >> 8*4) as u8, (ctr >> 8*3) as u8, (ctr >> 8*2) as u8, (ctr >> 8*1) as u8, 14, (ctr >> 8*0) as u8]
        }
 }
 
@@ -326,12 +326,12 @@ pub fn do_test(data: &[u8], logger: &Arc<Logger>) {
        let broadcast = Arc::new(TestBroadcaster{});
        let monitor = channelmonitor::SimpleManyChannelMonitor::new(watch.clone(), broadcast.clone(), Arc::clone(&logger), fee_est.clone());
 
-       let keys_manager = Arc::new(KeyProvider { node_secret: our_network_key.clone(), counter: AtomicU8::new(0) });
+       let keys_manager = Arc::new(KeyProvider { node_secret: our_network_key.clone(), counter: AtomicU64::new(0) });
        let mut config = UserConfig::new();
        config.channel_options.fee_proportional_millionths =  slice_to_be32(get_slice!(4));
        config.channel_options.announced_channel = get_slice!(1)[0] != 0;
        config.peer_channel_config_limits.min_dust_limit_satoshis = 0;
-       let channelmanager = ChannelManager::new(Network::Bitcoin, fee_est.clone(), monitor.clone(), watch.clone(), broadcast.clone(), Arc::clone(&logger), keys_manager.clone(), config).unwrap();
+       let channelmanager = ChannelManager::new(Network::Bitcoin, fee_est.clone(), monitor.clone(), watch.clone(), broadcast.clone(), Arc::clone(&logger), keys_manager.clone(), config, 0).unwrap();
        let router = Arc::new(Router::new(PublicKey::from_secret_key(&Secp256k1::signing_only(), &keys_manager.get_node_secret()), watch.clone(), Arc::clone(&logger)));
 
        let peers = RefCell::new([false; 256]);
@@ -341,7 +341,7 @@ pub fn do_test(data: &[u8], logger: &Arc<Logger>) {
        }, our_network_key, &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, 0], Arc::clone(&logger)));
 
        let mut should_forward = false;
-       let mut payments_received: Vec<PaymentHash> = Vec::new();
+       let mut payments_received: Vec<(PaymentHash, u64)> = Vec::new();
        let mut payments_sent = 0;
        let mut pending_funding_generation: Vec<([u8; 32], u64, Script)> = Vec::new();
        let mut pending_funding_signatures = HashMap::new();
@@ -426,7 +426,7 @@ pub fn do_test(data: &[u8], logger: &Arc<Logger>) {
                                }
                        },
                        8 => {
-                               for payment in payments_received.drain(..) {
+                               for (payment, amt) in payments_received.drain(..) {
                                        // SHA256 is defined as XOR of all input bytes placed in the first byte, and 0s
                                        // for the remaining bytes. Thus, if not all remaining bytes are 0s we cannot
                                        // fulfill this HTLC, but if they are, we can just take the first byte and
@@ -436,12 +436,12 @@ pub fn do_test(data: &[u8], logger: &Arc<Logger>) {
                                        } else {
                                                let mut payment_preimage = PaymentPreimage([0; 32]);
                                                payment_preimage.0[0] = payment.0[0];
-                                               channelmanager.claim_funds(payment_preimage);
+                                               channelmanager.claim_funds(payment_preimage, amt);
                                        }
                                }
                        },
                        9 => {
-                               for payment in payments_received.drain(..) {
+                               for (payment, _) in payments_received.drain(..) {
                                        channelmanager.fail_htlc_backwards(&payment);
                                }
                        },
@@ -516,8 +516,9 @@ pub fn do_test(data: &[u8], logger: &Arc<Logger>) {
                                Event::FundingBroadcastSafe { funding_txo, .. } => {
                                        pending_funding_relay.push(pending_funding_signatures.remove(&funding_txo).unwrap());
                                },
-                               Event::PaymentReceived { payment_hash, .. } => {
-                                       payments_received.push(payment_hash);
+                               Event::PaymentReceived { payment_hash, amt } => {
+                                       //TODO: enhance by fetching random amounts from fuzz input?
+                                       payments_received.push((payment_hash, amt));
                                },
                                Event::PaymentSent {..} => {},
                                Event::PaymentFailed {..} => {},