Rename ChannelMonitor::write_for_disk --> serialize_for_disk
[rust-lightning] / lightning / src / util / test_utils.rs
1 // This file is Copyright its original authors, visible in version control
2 // history.
3 //
4 // This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5 // or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7 // You may not use this file except in accordance with one or both of these
8 // licenses.
9
10 use chain;
11 use chain::chaininterface;
12 use chain::chaininterface::ConfirmationTarget;
13 use chain::chainmonitor;
14 use chain::channelmonitor;
15 use chain::channelmonitor::MonitorEvent;
16 use chain::transaction::OutPoint;
17 use chain::keysinterface;
18 use ln::features::{ChannelFeatures, InitFeatures};
19 use ln::msgs;
20 use ln::msgs::OptionalField;
21 use util::enforcing_trait_impls::EnforcingChannelKeys;
22 use util::events;
23 use util::logger::{Logger, Level, Record};
24 use util::ser::{Readable, Writer, Writeable};
25
26 use bitcoin::blockdata::constants::genesis_block;
27 use bitcoin::blockdata::transaction::{Transaction, TxOut};
28 use bitcoin::blockdata::script::{Builder, Script};
29 use bitcoin::blockdata::opcodes;
30 use bitcoin::network::constants::Network;
31 use bitcoin::hash_types::{BlockHash, Txid};
32
33 use bitcoin::secp256k1::{SecretKey, PublicKey, Secp256k1, Signature};
34
35 use regex;
36
37 use std::time::Duration;
38 use std::sync::Mutex;
39 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
40 use std::{cmp, mem};
41 use std::collections::{HashMap, HashSet};
42
43 pub struct TestVecWriter(pub Vec<u8>);
44 impl Writer for TestVecWriter {
45         fn write_all(&mut self, buf: &[u8]) -> Result<(), ::std::io::Error> {
46                 self.0.extend_from_slice(buf);
47                 Ok(())
48         }
49         fn size_hint(&mut self, size: usize) {
50                 self.0.reserve_exact(size);
51         }
52 }
53
54 pub struct TestFeeEstimator {
55         pub sat_per_kw: u32,
56 }
57 impl chaininterface::FeeEstimator for TestFeeEstimator {
58         fn get_est_sat_per_1000_weight(&self, _confirmation_target: ConfirmationTarget) -> u32 {
59                 self.sat_per_kw
60         }
61 }
62
63 pub struct TestChainMonitor<'a> {
64         pub added_monitors: Mutex<Vec<(OutPoint, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>>,
65         pub latest_monitor_update_id: Mutex<HashMap<[u8; 32], (OutPoint, u64)>>,
66         pub chain_monitor: chainmonitor::ChainMonitor<EnforcingChannelKeys, &'a TestChainSource, &'a chaininterface::BroadcasterInterface, &'a TestFeeEstimator, &'a TestLogger, &'a channelmonitor::Persist<EnforcingChannelKeys>>,
67         pub update_ret: Mutex<Option<Result<(), channelmonitor::ChannelMonitorUpdateErr>>>,
68         // If this is set to Some(), after the next return, we'll always return this until update_ret
69         // is changed:
70         pub next_update_ret: Mutex<Option<Result<(), channelmonitor::ChannelMonitorUpdateErr>>>,
71 }
72 impl<'a> TestChainMonitor<'a> {
73         pub fn new(chain_source: Option<&'a TestChainSource>, broadcaster: &'a chaininterface::BroadcasterInterface, logger: &'a TestLogger, fee_estimator: &'a TestFeeEstimator, persister: &'a channelmonitor::Persist<EnforcingChannelKeys>) -> Self {
74                 Self {
75                         added_monitors: Mutex::new(Vec::new()),
76                         latest_monitor_update_id: Mutex::new(HashMap::new()),
77                         chain_monitor: chainmonitor::ChainMonitor::new(chain_source, broadcaster, logger, fee_estimator, persister),
78                         update_ret: Mutex::new(None),
79                         next_update_ret: Mutex::new(None),
80                 }
81         }
82 }
83 impl<'a> chain::Watch for TestChainMonitor<'a> {
84         type Keys = EnforcingChannelKeys;
85
86         fn watch_channel(&self, funding_txo: OutPoint, monitor: channelmonitor::ChannelMonitor<EnforcingChannelKeys>) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
87                 // At every point where we get a monitor update, we should be able to send a useful monitor
88                 // to a watchtower and disk...
89                 let mut w = TestVecWriter(Vec::new());
90                 monitor.serialize_for_disk(&mut w).unwrap();
91                 let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::read(
92                         &mut ::std::io::Cursor::new(&w.0)).unwrap().1;
93                 assert!(new_monitor == monitor);
94                 self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, monitor.get_latest_update_id()));
95                 self.added_monitors.lock().unwrap().push((funding_txo, monitor));
96                 let watch_res = self.chain_monitor.watch_channel(funding_txo, new_monitor);
97
98                 let ret = self.update_ret.lock().unwrap().clone();
99                 if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
100                         *self.update_ret.lock().unwrap() = Some(next_ret);
101                 }
102                 if ret.is_some() {
103                         assert!(watch_res.is_ok());
104                         return ret.unwrap();
105                 }
106                 watch_res
107         }
108
109         fn update_channel(&self, funding_txo: OutPoint, update: channelmonitor::ChannelMonitorUpdate) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
110                 // Every monitor update should survive roundtrip
111                 let mut w = TestVecWriter(Vec::new());
112                 update.write(&mut w).unwrap();
113                 assert!(channelmonitor::ChannelMonitorUpdate::read(
114                                 &mut ::std::io::Cursor::new(&w.0)).unwrap() == update);
115
116                 self.latest_monitor_update_id.lock().unwrap().insert(funding_txo.to_channel_id(), (funding_txo, update.update_id));
117                 let update_res = self.chain_monitor.update_channel(funding_txo, update);
118                 // At every point where we get a monitor update, we should be able to send a useful monitor
119                 // to a watchtower and disk...
120                 let monitors = self.chain_monitor.monitors.lock().unwrap();
121                 let monitor = monitors.get(&funding_txo).unwrap();
122                 w.0.clear();
123                 monitor.serialize_for_disk(&mut w).unwrap();
124                 let new_monitor = <(BlockHash, channelmonitor::ChannelMonitor<EnforcingChannelKeys>)>::read(
125                         &mut ::std::io::Cursor::new(&w.0)).unwrap().1;
126                 assert!(new_monitor == *monitor);
127                 self.added_monitors.lock().unwrap().push((funding_txo, new_monitor));
128
129                 let ret = self.update_ret.lock().unwrap().clone();
130                 if let Some(next_ret) = self.next_update_ret.lock().unwrap().take() {
131                         *self.update_ret.lock().unwrap() = Some(next_ret);
132                 }
133                 if ret.is_some() {
134                         assert!(update_res.is_ok());
135                         return ret.unwrap();
136                 }
137                 update_res
138         }
139
140         fn release_pending_monitor_events(&self) -> Vec<MonitorEvent> {
141                 return self.chain_monitor.release_pending_monitor_events();
142         }
143 }
144
145 pub struct TestPersister {
146         pub update_ret: Mutex<Result<(), channelmonitor::ChannelMonitorUpdateErr>>
147 }
148 impl TestPersister {
149         pub fn new() -> Self {
150                 Self {
151                         update_ret: Mutex::new(Ok(()))
152                 }
153         }
154
155         pub fn set_update_ret(&self, ret: Result<(), channelmonitor::ChannelMonitorUpdateErr>) {
156                 *self.update_ret.lock().unwrap() = ret;
157         }
158 }
159 impl channelmonitor::Persist<EnforcingChannelKeys> for TestPersister {
160         fn persist_new_channel(&self, _funding_txo: OutPoint, _data: &channelmonitor::ChannelMonitor<EnforcingChannelKeys>) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
161                 self.update_ret.lock().unwrap().clone()
162         }
163
164         fn update_persisted_channel(&self, _funding_txo: OutPoint, _update: &channelmonitor::ChannelMonitorUpdate, _data: &channelmonitor::ChannelMonitor<EnforcingChannelKeys>) -> Result<(), channelmonitor::ChannelMonitorUpdateErr> {
165                 self.update_ret.lock().unwrap().clone()
166         }
167 }
168
169 pub struct TestBroadcaster {
170         pub txn_broadcasted: Mutex<Vec<Transaction>>,
171 }
172 impl chaininterface::BroadcasterInterface for TestBroadcaster {
173         fn broadcast_transaction(&self, tx: &Transaction) {
174                 self.txn_broadcasted.lock().unwrap().push(tx.clone());
175         }
176 }
177
178 pub struct TestChannelMessageHandler {
179         pub pending_events: Mutex<Vec<events::MessageSendEvent>>,
180 }
181
182 impl TestChannelMessageHandler {
183         pub fn new() -> Self {
184                 TestChannelMessageHandler {
185                         pending_events: Mutex::new(Vec::new()),
186                 }
187         }
188 }
189
190 impl msgs::ChannelMessageHandler for TestChannelMessageHandler {
191         fn handle_open_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, _msg: &msgs::OpenChannel) {}
192         fn handle_accept_channel(&self, _their_node_id: &PublicKey, _their_features: InitFeatures, _msg: &msgs::AcceptChannel) {}
193         fn handle_funding_created(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingCreated) {}
194         fn handle_funding_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingSigned) {}
195         fn handle_funding_locked(&self, _their_node_id: &PublicKey, _msg: &msgs::FundingLocked) {}
196         fn handle_shutdown(&self, _their_node_id: &PublicKey, _msg: &msgs::Shutdown) {}
197         fn handle_closing_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::ClosingSigned) {}
198         fn handle_update_add_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateAddHTLC) {}
199         fn handle_update_fulfill_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFulfillHTLC) {}
200         fn handle_update_fail_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFailHTLC) {}
201         fn handle_update_fail_malformed_htlc(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFailMalformedHTLC) {}
202         fn handle_commitment_signed(&self, _their_node_id: &PublicKey, _msg: &msgs::CommitmentSigned) {}
203         fn handle_revoke_and_ack(&self, _their_node_id: &PublicKey, _msg: &msgs::RevokeAndACK) {}
204         fn handle_update_fee(&self, _their_node_id: &PublicKey, _msg: &msgs::UpdateFee) {}
205         fn handle_announcement_signatures(&self, _their_node_id: &PublicKey, _msg: &msgs::AnnouncementSignatures) {}
206         fn handle_channel_reestablish(&self, _their_node_id: &PublicKey, _msg: &msgs::ChannelReestablish) {}
207         fn peer_disconnected(&self, _their_node_id: &PublicKey, _no_connection_possible: bool) {}
208         fn peer_connected(&self, _their_node_id: &PublicKey, _msg: &msgs::Init) {}
209         fn handle_error(&self, _their_node_id: &PublicKey, _msg: &msgs::ErrorMessage) {}
210 }
211
212 impl events::MessageSendEventsProvider for TestChannelMessageHandler {
213         fn get_and_clear_pending_msg_events(&self) -> Vec<events::MessageSendEvent> {
214                 let mut pending_events = self.pending_events.lock().unwrap();
215                 let mut ret = Vec::new();
216                 mem::swap(&mut ret, &mut *pending_events);
217                 ret
218         }
219 }
220
221 fn get_dummy_channel_announcement(short_chan_id: u64) -> msgs::ChannelAnnouncement {
222         use bitcoin::secp256k1::ffi::Signature as FFISignature;
223         let secp_ctx = Secp256k1::new();
224         let network = Network::Testnet;
225         let node_1_privkey = SecretKey::from_slice(&[42; 32]).unwrap();
226         let node_2_privkey = SecretKey::from_slice(&[41; 32]).unwrap();
227         let node_1_btckey = SecretKey::from_slice(&[40; 32]).unwrap();
228         let node_2_btckey = SecretKey::from_slice(&[39; 32]).unwrap();
229         let unsigned_ann = msgs::UnsignedChannelAnnouncement {
230                 features: ChannelFeatures::known(),
231                 chain_hash: genesis_block(network).header.block_hash(),
232                 short_channel_id: short_chan_id,
233                 node_id_1: PublicKey::from_secret_key(&secp_ctx, &node_1_privkey),
234                 node_id_2: PublicKey::from_secret_key(&secp_ctx, &node_2_privkey),
235                 bitcoin_key_1: PublicKey::from_secret_key(&secp_ctx, &node_1_btckey),
236                 bitcoin_key_2: PublicKey::from_secret_key(&secp_ctx, &node_2_btckey),
237                 excess_data: Vec::new(),
238         };
239
240         msgs::ChannelAnnouncement {
241                 node_signature_1: Signature::from(FFISignature::new()),
242                 node_signature_2: Signature::from(FFISignature::new()),
243                 bitcoin_signature_1: Signature::from(FFISignature::new()),
244                 bitcoin_signature_2: Signature::from(FFISignature::new()),
245                 contents: unsigned_ann,
246         }
247 }
248
249 fn get_dummy_channel_update(short_chan_id: u64) -> msgs::ChannelUpdate {
250         use bitcoin::secp256k1::ffi::Signature as FFISignature;
251         let network = Network::Testnet;
252         msgs::ChannelUpdate {
253                 signature: Signature::from(FFISignature::new()),
254                 contents: msgs::UnsignedChannelUpdate {
255                         chain_hash: genesis_block(network).header.block_hash(),
256                         short_channel_id: short_chan_id,
257                         timestamp: 0,
258                         flags: 0,
259                         cltv_expiry_delta: 0,
260                         htlc_minimum_msat: 0,
261                         htlc_maximum_msat: OptionalField::Absent,
262                         fee_base_msat: 0,
263                         fee_proportional_millionths: 0,
264                         excess_data: vec![],
265                 }
266         }
267 }
268
269 pub struct TestRoutingMessageHandler {
270         pub chan_upds_recvd: AtomicUsize,
271         pub chan_anns_recvd: AtomicUsize,
272         pub chan_anns_sent: AtomicUsize,
273         pub request_full_sync: AtomicBool,
274 }
275
276 impl TestRoutingMessageHandler {
277         pub fn new() -> Self {
278                 TestRoutingMessageHandler {
279                         chan_upds_recvd: AtomicUsize::new(0),
280                         chan_anns_recvd: AtomicUsize::new(0),
281                         chan_anns_sent: AtomicUsize::new(0),
282                         request_full_sync: AtomicBool::new(false),
283                 }
284         }
285 }
286 impl msgs::RoutingMessageHandler for TestRoutingMessageHandler {
287         fn handle_node_announcement(&self, _msg: &msgs::NodeAnnouncement) -> Result<bool, msgs::LightningError> {
288                 Err(msgs::LightningError { err: "".to_owned(), action: msgs::ErrorAction::IgnoreError })
289         }
290         fn handle_channel_announcement(&self, _msg: &msgs::ChannelAnnouncement) -> Result<bool, msgs::LightningError> {
291                 self.chan_anns_recvd.fetch_add(1, Ordering::AcqRel);
292                 Err(msgs::LightningError { err: "".to_owned(), action: msgs::ErrorAction::IgnoreError })
293         }
294         fn handle_channel_update(&self, _msg: &msgs::ChannelUpdate) -> Result<bool, msgs::LightningError> {
295                 self.chan_upds_recvd.fetch_add(1, Ordering::AcqRel);
296                 Err(msgs::LightningError { err: "".to_owned(), action: msgs::ErrorAction::IgnoreError })
297         }
298         fn handle_htlc_fail_channel_update(&self, _update: &msgs::HTLCFailChannelUpdate) {}
299         fn get_next_channel_announcements(&self, starting_point: u64, batch_amount: u8) -> Vec<(msgs::ChannelAnnouncement, Option<msgs::ChannelUpdate>, Option<msgs::ChannelUpdate>)> {
300                 let mut chan_anns = Vec::new();
301                 const TOTAL_UPDS: u64 = 100;
302                 let end: u64 = cmp::min(starting_point + batch_amount as u64, TOTAL_UPDS - self.chan_anns_sent.load(Ordering::Acquire) as u64);
303                 for i in starting_point..end {
304                         let chan_upd_1 = get_dummy_channel_update(i);
305                         let chan_upd_2 = get_dummy_channel_update(i);
306                         let chan_ann = get_dummy_channel_announcement(i);
307
308                         chan_anns.push((chan_ann, Some(chan_upd_1), Some(chan_upd_2)));
309                 }
310
311                 self.chan_anns_sent.fetch_add(chan_anns.len(), Ordering::AcqRel);
312                 chan_anns
313         }
314
315         fn get_next_node_announcements(&self, _starting_point: Option<&PublicKey>, _batch_amount: u8) -> Vec<msgs::NodeAnnouncement> {
316                 Vec::new()
317         }
318
319         fn should_request_full_sync(&self, _node_id: &PublicKey) -> bool {
320                 self.request_full_sync.load(Ordering::Acquire)
321         }
322 }
323
324 pub struct TestLogger {
325         level: Level,
326         id: String,
327         pub lines: Mutex<HashMap<(String, String), usize>>,
328 }
329
330 impl TestLogger {
331         pub fn new() -> TestLogger {
332                 Self::with_id("".to_owned())
333         }
334         pub fn with_id(id: String) -> TestLogger {
335                 TestLogger {
336                         level: Level::Trace,
337                         id,
338                         lines: Mutex::new(HashMap::new())
339                 }
340         }
341         pub fn enable(&mut self, level: Level) {
342                 self.level = level;
343         }
344         pub fn assert_log(&self, module: String, line: String, count: usize) {
345                 let log_entries = self.lines.lock().unwrap();
346                 assert_eq!(log_entries.get(&(module, line)), Some(&count));
347         }
348
349         /// Search for the number of occurrence of the logged lines which
350         /// 1. belongs to the specified module and
351         /// 2. contains `line` in it.
352         /// And asserts if the number of occurrences is the same with the given `count`
353         pub fn assert_log_contains(&self, module: String, line: String, count: usize) {
354                 let log_entries = self.lines.lock().unwrap();
355                 let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| {
356                         m == &module && l.contains(line.as_str())
357                 }).map(|(_, c) | { c }).sum();
358                 assert_eq!(l, count)
359         }
360
361     /// Search for the number of occurrences of logged lines which
362     /// 1. belong to the specified module and
363     /// 2. match the given regex pattern.
364     /// Assert that the number of occurrences equals the given `count`
365         pub fn assert_log_regex(&self, module: String, pattern: regex::Regex, count: usize) {
366                 let log_entries = self.lines.lock().unwrap();
367                 let l: usize = log_entries.iter().filter(|&(&(ref m, ref l), _c)| {
368                         m == &module && pattern.is_match(&l)
369                 }).map(|(_, c) | { c }).sum();
370                 assert_eq!(l, count)
371         }
372 }
373
374 impl Logger for TestLogger {
375         fn log(&self, record: &Record) {
376                 *self.lines.lock().unwrap().entry((record.module_path.to_string(), format!("{}", record.args))).or_insert(0) += 1;
377                 if self.level >= record.level {
378                         println!("{:<5} {} [{} : {}, {}] {}", record.level.to_string(), self.id, record.module_path, record.file, record.line, record.args);
379                 }
380         }
381 }
382
383 pub struct TestKeysInterface {
384         backing: keysinterface::KeysManager,
385         pub override_session_priv: Mutex<Option<[u8; 32]>>,
386         pub override_channel_id_priv: Mutex<Option<[u8; 32]>>,
387 }
388
389 impl keysinterface::KeysInterface for TestKeysInterface {
390         type ChanKeySigner = EnforcingChannelKeys;
391
392         fn get_node_secret(&self) -> SecretKey { self.backing.get_node_secret() }
393         fn get_destination_script(&self) -> Script { self.backing.get_destination_script() }
394         fn get_shutdown_pubkey(&self) -> PublicKey { self.backing.get_shutdown_pubkey() }
395         fn get_channel_keys(&self, inbound: bool, channel_value_satoshis: u64) -> EnforcingChannelKeys {
396                 EnforcingChannelKeys::new(self.backing.get_channel_keys(inbound, channel_value_satoshis))
397         }
398
399         fn get_secure_random_bytes(&self) -> [u8; 32] {
400                 let override_channel_id = self.override_channel_id_priv.lock().unwrap();
401                 let override_session_key = self.override_session_priv.lock().unwrap();
402                 if override_channel_id.is_some() && override_session_key.is_some() {
403                         panic!("We don't know which override key to use!");
404                 }
405                 if let Some(key) = &*override_channel_id {
406                         return *key;
407                 }
408                 if let Some(key) = &*override_session_key {
409                         return *key;
410                 }
411                 self.backing.get_secure_random_bytes()
412         }
413 }
414
415 impl TestKeysInterface {
416         pub fn new(seed: &[u8; 32], network: Network) -> Self {
417                 let now = Duration::from_secs(genesis_block(network).header.time as u64);
418                 Self {
419                         backing: keysinterface::KeysManager::new(seed, network, now.as_secs(), now.subsec_nanos()),
420                         override_session_priv: Mutex::new(None),
421                         override_channel_id_priv: Mutex::new(None),
422                 }
423         }
424         pub fn derive_channel_keys(&self, channel_value_satoshis: u64, user_id_1: u64, user_id_2: u64) -> EnforcingChannelKeys {
425                 EnforcingChannelKeys::new(self.backing.derive_channel_keys(channel_value_satoshis, user_id_1, user_id_2))
426         }
427 }
428
429 pub struct TestChainSource {
430         pub genesis_hash: BlockHash,
431         pub utxo_ret: Mutex<Result<TxOut, chain::AccessError>>,
432         pub watched_txn: Mutex<HashSet<(Txid, Script)>>,
433         pub watched_outputs: Mutex<HashSet<(OutPoint, Script)>>,
434 }
435
436 impl TestChainSource {
437         pub fn new(network: Network) -> Self {
438                 let script_pubkey = Builder::new().push_opcode(opcodes::OP_TRUE).into_script();
439                 Self {
440                         genesis_hash: genesis_block(network).block_hash(),
441                         utxo_ret: Mutex::new(Ok(TxOut { value: u64::max_value(), script_pubkey })),
442                         watched_txn: Mutex::new(HashSet::new()),
443                         watched_outputs: Mutex::new(HashSet::new()),
444                 }
445         }
446 }
447
448 impl chain::Access for TestChainSource {
449         fn get_utxo(&self, genesis_hash: &BlockHash, _short_channel_id: u64) -> Result<TxOut, chain::AccessError> {
450                 if self.genesis_hash != *genesis_hash {
451                         return Err(chain::AccessError::UnknownChain);
452                 }
453
454                 self.utxo_ret.lock().unwrap().clone()
455         }
456 }
457
458 impl chain::Filter for TestChainSource {
459         fn register_tx(&self, txid: &Txid, script_pubkey: &Script) {
460                 self.watched_txn.lock().unwrap().insert((*txid, script_pubkey.clone()));
461         }
462
463         fn register_output(&self, outpoint: &OutPoint, script_pubkey: &Script) {
464                 self.watched_outputs.lock().unwrap().insert((*outpoint, script_pubkey.clone()));
465         }
466 }