807a33a69dec3fb305fb56ad9b9dc6953a26b138
[rust-lightning] / lightning-block-sync / src / test_utils.rs
1 use crate::{AsyncBlockSourceResult, BlockHeaderData, BlockSource, BlockSourceError, ChainListener, UnboundedCache};
2 use crate::poll::{Validate, ValidatedBlockHeader};
3
4 use bitcoin::blockdata::block::{Block, BlockHeader};
5 use bitcoin::blockdata::constants::genesis_block;
6 use bitcoin::hash_types::BlockHash;
7 use bitcoin::network::constants::Network;
8 use bitcoin::util::uint::Uint256;
9
10 use std::collections::VecDeque;
11
12 #[derive(Default)]
13 pub struct Blockchain {
14         pub blocks: Vec<Block>,
15         without_blocks: Option<std::ops::RangeFrom<usize>>,
16         without_headers: bool,
17         malformed_headers: bool,
18 }
19
20 impl Blockchain {
21         pub fn default() -> Self {
22                 Blockchain::with_network(Network::Bitcoin)
23         }
24
25         pub fn with_network(network: Network) -> Self {
26                 let blocks = vec![genesis_block(network)];
27                 Self { blocks, ..Default::default() }
28         }
29
30         pub fn with_height(mut self, height: usize) -> Self {
31                 self.blocks.reserve_exact(height);
32                 let bits = BlockHeader::compact_target_from_u256(&Uint256::from_be_bytes([0xff; 32]));
33                 for i in 1..=height {
34                         let prev_block = &self.blocks[i - 1];
35                         let prev_blockhash = prev_block.block_hash();
36                         let time = prev_block.header.time + height as u32;
37                         self.blocks.push(Block {
38                                 header: BlockHeader {
39                                         version: 0,
40                                         prev_blockhash,
41                                         merkle_root: Default::default(),
42                                         time,
43                                         bits,
44                                         nonce: 0,
45                                 },
46                                 txdata: vec![],
47                         });
48                 }
49                 self
50         }
51
52         pub fn without_blocks(self, range: std::ops::RangeFrom<usize>) -> Self {
53                 Self { without_blocks: Some(range), ..self }
54         }
55
56         pub fn without_headers(self) -> Self {
57                 Self { without_headers: true, ..self }
58         }
59
60         pub fn malformed_headers(self) -> Self {
61                 Self { malformed_headers: true, ..self }
62         }
63
64         pub fn fork_at_height(&self, height: usize) -> Self {
65                 assert!(height + 1 < self.blocks.len());
66                 let mut blocks = self.blocks.clone();
67                 let mut prev_blockhash = blocks[height].block_hash();
68                 for block in blocks.iter_mut().skip(height + 1) {
69                         block.header.prev_blockhash = prev_blockhash;
70                         block.header.nonce += 1;
71                         prev_blockhash = block.block_hash();
72                 }
73                 Self { blocks, without_blocks: None, ..*self }
74         }
75
76         pub fn at_height(&self, height: usize) -> ValidatedBlockHeader {
77                 let block_header = self.at_height_unvalidated(height);
78                 let block_hash = self.blocks[height].block_hash();
79                 block_header.validate(block_hash).unwrap()
80         }
81
82         fn at_height_unvalidated(&self, height: usize) -> BlockHeaderData {
83                 assert!(!self.blocks.is_empty());
84                 assert!(height < self.blocks.len());
85                 BlockHeaderData {
86                         chainwork: self.blocks[0].header.work() + Uint256::from_u64(height as u64).unwrap(),
87                         height: height as u32,
88                         header: self.blocks[height].header.clone(),
89                 }
90         }
91
92         pub fn tip(&self) -> ValidatedBlockHeader {
93                 assert!(!self.blocks.is_empty());
94                 self.at_height(self.blocks.len() - 1)
95         }
96
97         pub fn disconnect_tip(&mut self) -> Option<Block> {
98                 self.blocks.pop()
99         }
100
101         pub fn header_cache(&self, heights: std::ops::RangeInclusive<usize>) -> UnboundedCache {
102                 let mut cache = UnboundedCache::new();
103                 for i in heights {
104                         let value = self.at_height(i);
105                         let key = value.header.block_hash();
106                         assert!(cache.insert(key, value).is_none());
107                 }
108                 cache
109         }
110 }
111
112 impl BlockSource for Blockchain {
113         fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, _height_hint: Option<u32>) -> AsyncBlockSourceResult<'a, BlockHeaderData> {
114                 Box::pin(async move {
115                         if self.without_headers {
116                                 return Err(BlockSourceError::persistent("header not found"));
117                         }
118
119                         for (height, block) in self.blocks.iter().enumerate() {
120                                 if block.header.block_hash() == *header_hash {
121                                         let mut header_data = self.at_height_unvalidated(height);
122                                         if self.malformed_headers {
123                                                 header_data.header.time += 1;
124                                         }
125
126                                         return Ok(header_data);
127                                 }
128                         }
129                         Err(BlockSourceError::transient("header not found"))
130                 })
131         }
132
133         fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block> {
134                 Box::pin(async move {
135                         for (height, block) in self.blocks.iter().enumerate() {
136                                 if block.header.block_hash() == *header_hash {
137                                         if let Some(without_blocks) = &self.without_blocks {
138                                                 if without_blocks.contains(&height) {
139                                                         return Err(BlockSourceError::persistent("block not found"));
140                                                 }
141                                         }
142
143                                         return Ok(block.clone());
144                                 }
145                         }
146                         Err(BlockSourceError::transient("block not found"))
147                 })
148         }
149
150         fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<'a, (BlockHash, Option<u32>)> {
151                 Box::pin(async move {
152                         match self.blocks.last() {
153                                 None => Err(BlockSourceError::transient("empty chain")),
154                                 Some(block) => {
155                                         let height = (self.blocks.len() - 1) as u32;
156                                         Ok((block.block_hash(), Some(height)))
157                                 },
158                         }
159                 })
160         }
161 }
162
163 pub struct NullChainListener;
164
165 impl ChainListener for NullChainListener {
166         fn block_connected(&mut self, _block: &Block, _height: u32) {}
167         fn block_disconnected(&mut self, _header: &BlockHeader, _height: u32) {}
168 }
169
170 pub struct MockChainListener {
171         expected_blocks_connected: VecDeque<BlockHeaderData>,
172         expected_blocks_disconnected: VecDeque<BlockHeaderData>,
173 }
174
175 impl MockChainListener {
176         pub fn new() -> Self {
177                 Self {
178                         expected_blocks_connected: VecDeque::new(),
179                         expected_blocks_disconnected: VecDeque::new(),
180                 }
181         }
182
183         pub fn expect_block_connected(mut self, block: BlockHeaderData) -> Self {
184                 self.expected_blocks_connected.push_back(block);
185                 self
186         }
187
188         pub fn expect_block_disconnected(mut self, block: BlockHeaderData) -> Self {
189                 self.expected_blocks_disconnected.push_back(block);
190                 self
191         }
192 }
193
194 impl ChainListener for MockChainListener {
195         fn block_connected(&mut self, block: &Block, height: u32) {
196                 match self.expected_blocks_connected.pop_front() {
197                         None => {
198                                 panic!("Unexpected block connected: {:?}", block.block_hash());
199                         },
200                         Some(expected_block) => {
201                                 assert_eq!(block.block_hash(), expected_block.header.block_hash());
202                                 assert_eq!(height, expected_block.height);
203                         },
204                 }
205         }
206
207         fn block_disconnected(&mut self, header: &BlockHeader, height: u32) {
208                 match self.expected_blocks_disconnected.pop_front() {
209                         None => {
210                                 panic!("Unexpected block disconnected: {:?}", header.block_hash());
211                         },
212                         Some(expected_block) => {
213                                 assert_eq!(header.block_hash(), expected_block.header.block_hash());
214                                 assert_eq!(height, expected_block.height);
215                         },
216                 }
217         }
218 }
219
220 impl Drop for MockChainListener {
221         fn drop(&mut self) {
222                 if std::thread::panicking() {
223                         return;
224                 }
225                 if !self.expected_blocks_connected.is_empty() {
226                         panic!("Expected blocks connected: {:?}", self.expected_blocks_connected);
227                 }
228                 if !self.expected_blocks_disconnected.is_empty() {
229                         panic!("Expected blocks disconnected: {:?}", self.expected_blocks_disconnected);
230                 }
231         }
232 }