Utility for syncing a set of chain listeners
[rust-lightning] / lightning-block-sync / src / test_utils.rs
1 use crate::{AsyncBlockSourceResult, BlockHeaderData, BlockSource, BlockSourceError, UnboundedCache};
2 use crate::poll::{Validate, ValidatedBlockHeader};
3
4 use bitcoin::blockdata::block::{Block, BlockHeader};
5 use bitcoin::blockdata::constants::genesis_block;
6 use bitcoin::hash_types::BlockHash;
7 use bitcoin::network::constants::Network;
8 use bitcoin::util::uint::Uint256;
9
10 use lightning::chain;
11
12 use std::cell::RefCell;
13 use std::collections::VecDeque;
14
15 #[derive(Default)]
16 pub struct Blockchain {
17         pub blocks: Vec<Block>,
18         without_blocks: Option<std::ops::RangeFrom<usize>>,
19         without_headers: bool,
20         malformed_headers: bool,
21 }
22
23 impl Blockchain {
24         pub fn default() -> Self {
25                 Blockchain::with_network(Network::Bitcoin)
26         }
27
28         pub fn with_network(network: Network) -> Self {
29                 let blocks = vec![genesis_block(network)];
30                 Self { blocks, ..Default::default() }
31         }
32
33         pub fn with_height(mut self, height: usize) -> Self {
34                 self.blocks.reserve_exact(height);
35                 let bits = BlockHeader::compact_target_from_u256(&Uint256::from_be_bytes([0xff; 32]));
36                 for i in 1..=height {
37                         let prev_block = &self.blocks[i - 1];
38                         let prev_blockhash = prev_block.block_hash();
39                         let time = prev_block.header.time + height as u32;
40                         self.blocks.push(Block {
41                                 header: BlockHeader {
42                                         version: 0,
43                                         prev_blockhash,
44                                         merkle_root: Default::default(),
45                                         time,
46                                         bits,
47                                         nonce: 0,
48                                 },
49                                 txdata: vec![],
50                         });
51                 }
52                 self
53         }
54
55         pub fn without_blocks(self, range: std::ops::RangeFrom<usize>) -> Self {
56                 Self { without_blocks: Some(range), ..self }
57         }
58
59         pub fn without_headers(self) -> Self {
60                 Self { without_headers: true, ..self }
61         }
62
63         pub fn malformed_headers(self) -> Self {
64                 Self { malformed_headers: true, ..self }
65         }
66
67         pub fn fork_at_height(&self, height: usize) -> Self {
68                 assert!(height + 1 < self.blocks.len());
69                 let mut blocks = self.blocks.clone();
70                 let mut prev_blockhash = blocks[height].block_hash();
71                 for block in blocks.iter_mut().skip(height + 1) {
72                         block.header.prev_blockhash = prev_blockhash;
73                         block.header.nonce += 1;
74                         prev_blockhash = block.block_hash();
75                 }
76                 Self { blocks, without_blocks: None, ..*self }
77         }
78
79         pub fn at_height(&self, height: usize) -> ValidatedBlockHeader {
80                 let block_header = self.at_height_unvalidated(height);
81                 let block_hash = self.blocks[height].block_hash();
82                 block_header.validate(block_hash).unwrap()
83         }
84
85         fn at_height_unvalidated(&self, height: usize) -> BlockHeaderData {
86                 assert!(!self.blocks.is_empty());
87                 assert!(height < self.blocks.len());
88                 BlockHeaderData {
89                         chainwork: self.blocks[0].header.work() + Uint256::from_u64(height as u64).unwrap(),
90                         height: height as u32,
91                         header: self.blocks[height].header.clone(),
92                 }
93         }
94
95         pub fn tip(&self) -> ValidatedBlockHeader {
96                 assert!(!self.blocks.is_empty());
97                 self.at_height(self.blocks.len() - 1)
98         }
99
100         pub fn disconnect_tip(&mut self) -> Option<Block> {
101                 self.blocks.pop()
102         }
103
104         pub fn header_cache(&self, heights: std::ops::RangeInclusive<usize>) -> UnboundedCache {
105                 let mut cache = UnboundedCache::new();
106                 for i in heights {
107                         let value = self.at_height(i);
108                         let key = value.header.block_hash();
109                         assert!(cache.insert(key, value).is_none());
110                 }
111                 cache
112         }
113 }
114
115 impl BlockSource for Blockchain {
116         fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, _height_hint: Option<u32>) -> AsyncBlockSourceResult<'a, BlockHeaderData> {
117                 Box::pin(async move {
118                         if self.without_headers {
119                                 return Err(BlockSourceError::persistent("header not found"));
120                         }
121
122                         for (height, block) in self.blocks.iter().enumerate() {
123                                 if block.header.block_hash() == *header_hash {
124                                         let mut header_data = self.at_height_unvalidated(height);
125                                         if self.malformed_headers {
126                                                 header_data.header.time += 1;
127                                         }
128
129                                         return Ok(header_data);
130                                 }
131                         }
132                         Err(BlockSourceError::transient("header not found"))
133                 })
134         }
135
136         fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block> {
137                 Box::pin(async move {
138                         for (height, block) in self.blocks.iter().enumerate() {
139                                 if block.header.block_hash() == *header_hash {
140                                         if let Some(without_blocks) = &self.without_blocks {
141                                                 if without_blocks.contains(&height) {
142                                                         return Err(BlockSourceError::persistent("block not found"));
143                                                 }
144                                         }
145
146                                         return Ok(block.clone());
147                                 }
148                         }
149                         Err(BlockSourceError::transient("block not found"))
150                 })
151         }
152
153         fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<'a, (BlockHash, Option<u32>)> {
154                 Box::pin(async move {
155                         match self.blocks.last() {
156                                 None => Err(BlockSourceError::transient("empty chain")),
157                                 Some(block) => {
158                                         let height = (self.blocks.len() - 1) as u32;
159                                         Ok((block.block_hash(), Some(height)))
160                                 },
161                         }
162                 })
163         }
164 }
165
166 pub struct NullChainListener;
167
168 impl chain::Listen for NullChainListener {
169         fn block_connected(&self, _block: &Block, _height: u32) {}
170         fn block_disconnected(&self, _header: &BlockHeader, _height: u32) {}
171 }
172
173 pub struct MockChainListener {
174         expected_blocks_connected: RefCell<VecDeque<BlockHeaderData>>,
175         expected_blocks_disconnected: RefCell<VecDeque<BlockHeaderData>>,
176 }
177
178 impl MockChainListener {
179         pub fn new() -> Self {
180                 Self {
181                         expected_blocks_connected: RefCell::new(VecDeque::new()),
182                         expected_blocks_disconnected: RefCell::new(VecDeque::new()),
183                 }
184         }
185
186         pub fn expect_block_connected(self, block: BlockHeaderData) -> Self {
187                 self.expected_blocks_connected.borrow_mut().push_back(block);
188                 self
189         }
190
191         pub fn expect_block_disconnected(self, block: BlockHeaderData) -> Self {
192                 self.expected_blocks_disconnected.borrow_mut().push_back(block);
193                 self
194         }
195 }
196
197 impl chain::Listen for MockChainListener {
198         fn block_connected(&self, block: &Block, height: u32) {
199                 match self.expected_blocks_connected.borrow_mut().pop_front() {
200                         None => {
201                                 panic!("Unexpected block connected: {:?}", block.block_hash());
202                         },
203                         Some(expected_block) => {
204                                 assert_eq!(block.block_hash(), expected_block.header.block_hash());
205                                 assert_eq!(height, expected_block.height);
206                         },
207                 }
208         }
209
210         fn block_disconnected(&self, header: &BlockHeader, height: u32) {
211                 match self.expected_blocks_disconnected.borrow_mut().pop_front() {
212                         None => {
213                                 panic!("Unexpected block disconnected: {:?}", header.block_hash());
214                         },
215                         Some(expected_block) => {
216                                 assert_eq!(header.block_hash(), expected_block.header.block_hash());
217                                 assert_eq!(height, expected_block.height);
218                         },
219                 }
220         }
221 }
222
223 impl Drop for MockChainListener {
224         fn drop(&mut self) {
225                 if std::thread::panicking() {
226                         return;
227                 }
228
229                 let expected_blocks_connected = self.expected_blocks_connected.borrow();
230                 if !expected_blocks_connected.is_empty() {
231                         panic!("Expected blocks connected: {:?}", expected_blocks_connected);
232                 }
233
234                 let expected_blocks_disconnected = self.expected_blocks_disconnected.borrow();
235                 if !expected_blocks_disconnected.is_empty() {
236                         panic!("Expected blocks disconnected: {:?}", expected_blocks_disconnected);
237                 }
238         }
239 }