WIP: Common ChainListener implementations and example
[rust-lightning] / lightning-block-sync / src / test_utils.rs
1 use crate::{AsyncBlockSourceResult, BlockHeaderData, BlockSource, BlockSourceError, ChainListener, UnboundedCache};
2 use crate::poll::{Validate, ValidatedBlockHeader};
3
4 use bitcoin::blockdata::block::{Block, BlockHeader};
5 use bitcoin::blockdata::constants::genesis_block;
6 use bitcoin::hash_types::BlockHash;
7 use bitcoin::network::constants::Network;
8 use bitcoin::util::uint::Uint256;
9
10 use std::cell::RefCell;
11 use std::collections::VecDeque;
12
13 #[derive(Default)]
14 pub struct Blockchain {
15         pub blocks: Vec<Block>,
16         without_blocks: Option<std::ops::RangeFrom<usize>>,
17         without_headers: bool,
18         malformed_headers: bool,
19 }
20
21 impl Blockchain {
22         pub fn default() -> Self {
23                 Blockchain::with_network(Network::Bitcoin)
24         }
25
26         pub fn with_network(network: Network) -> Self {
27                 let blocks = vec![genesis_block(network)];
28                 Self { blocks, ..Default::default() }
29         }
30
31         pub fn with_height(mut self, height: usize) -> Self {
32                 self.blocks.reserve_exact(height);
33                 let bits = BlockHeader::compact_target_from_u256(&Uint256::from_be_bytes([0xff; 32]));
34                 for i in 1..=height {
35                         let prev_block = &self.blocks[i - 1];
36                         let prev_blockhash = prev_block.block_hash();
37                         let time = prev_block.header.time + height as u32;
38                         self.blocks.push(Block {
39                                 header: BlockHeader {
40                                         version: 0,
41                                         prev_blockhash,
42                                         merkle_root: Default::default(),
43                                         time,
44                                         bits,
45                                         nonce: 0,
46                                 },
47                                 txdata: vec![],
48                         });
49                 }
50                 self
51         }
52
53         pub fn without_blocks(self, range: std::ops::RangeFrom<usize>) -> Self {
54                 Self { without_blocks: Some(range), ..self }
55         }
56
57         pub fn without_headers(self) -> Self {
58                 Self { without_headers: true, ..self }
59         }
60
61         pub fn malformed_headers(self) -> Self {
62                 Self { malformed_headers: true, ..self }
63         }
64
65         pub fn fork_at_height(&self, height: usize) -> Self {
66                 assert!(height + 1 < self.blocks.len());
67                 let mut blocks = self.blocks.clone();
68                 let mut prev_blockhash = blocks[height].block_hash();
69                 for block in blocks.iter_mut().skip(height + 1) {
70                         block.header.prev_blockhash = prev_blockhash;
71                         block.header.nonce += 1;
72                         prev_blockhash = block.block_hash();
73                 }
74                 Self { blocks, without_blocks: None, ..*self }
75         }
76
77         pub fn at_height(&self, height: usize) -> ValidatedBlockHeader {
78                 let block_header = self.at_height_unvalidated(height);
79                 let block_hash = self.blocks[height].block_hash();
80                 block_header.validate(block_hash).unwrap()
81         }
82
83         fn at_height_unvalidated(&self, height: usize) -> BlockHeaderData {
84                 assert!(!self.blocks.is_empty());
85                 assert!(height < self.blocks.len());
86                 BlockHeaderData {
87                         chainwork: self.blocks[0].header.work() + Uint256::from_u64(height as u64).unwrap(),
88                         height: height as u32,
89                         header: self.blocks[height].header.clone(),
90                 }
91         }
92
93         pub fn tip(&self) -> ValidatedBlockHeader {
94                 assert!(!self.blocks.is_empty());
95                 self.at_height(self.blocks.len() - 1)
96         }
97
98         pub fn disconnect_tip(&mut self) -> Option<Block> {
99                 self.blocks.pop()
100         }
101
102         pub fn header_cache(&self, heights: std::ops::RangeInclusive<usize>) -> UnboundedCache {
103                 let mut cache = UnboundedCache::new();
104                 for i in heights {
105                         let value = self.at_height(i);
106                         let key = value.header.block_hash();
107                         assert!(cache.insert(key, value).is_none());
108                 }
109                 cache
110         }
111 }
112
113 impl BlockSource for Blockchain {
114         fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, _height_hint: Option<u32>) -> AsyncBlockSourceResult<'a, BlockHeaderData> {
115                 Box::pin(async move {
116                         if self.without_headers {
117                                 return Err(BlockSourceError::persistent("header not found"));
118                         }
119
120                         for (height, block) in self.blocks.iter().enumerate() {
121                                 if block.header.block_hash() == *header_hash {
122                                         let mut header_data = self.at_height_unvalidated(height);
123                                         if self.malformed_headers {
124                                                 header_data.header.time += 1;
125                                         }
126
127                                         return Ok(header_data);
128                                 }
129                         }
130                         Err(BlockSourceError::transient("header not found"))
131                 })
132         }
133
134         fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block> {
135                 Box::pin(async move {
136                         for (height, block) in self.blocks.iter().enumerate() {
137                                 if block.header.block_hash() == *header_hash {
138                                         if let Some(without_blocks) = &self.without_blocks {
139                                                 if without_blocks.contains(&height) {
140                                                         return Err(BlockSourceError::persistent("block not found"));
141                                                 }
142                                         }
143
144                                         return Ok(block.clone());
145                                 }
146                         }
147                         Err(BlockSourceError::transient("block not found"))
148                 })
149         }
150
151         fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<'a, (BlockHash, Option<u32>)> {
152                 Box::pin(async move {
153                         match self.blocks.last() {
154                                 None => Err(BlockSourceError::transient("empty chain")),
155                                 Some(block) => {
156                                         let height = (self.blocks.len() - 1) as u32;
157                                         Ok((block.block_hash(), Some(height)))
158                                 },
159                         }
160                 })
161         }
162 }
163
164 pub struct NullChainListener;
165
166 impl ChainListener for NullChainListener {
167         fn block_connected(&self, _block: &Block, _height: u32) {}
168         fn block_disconnected(&self, _header: &BlockHeader, _height: u32) {}
169 }
170
171 pub struct MockChainListener {
172         expected_blocks_connected: RefCell<VecDeque<BlockHeaderData>>,
173         expected_blocks_disconnected: RefCell<VecDeque<BlockHeaderData>>,
174 }
175
176 impl MockChainListener {
177         pub fn new() -> Self {
178                 Self {
179                         expected_blocks_connected: RefCell::new(VecDeque::new()),
180                         expected_blocks_disconnected: RefCell::new(VecDeque::new()),
181                 }
182         }
183
184         pub fn expect_block_connected(self, block: BlockHeaderData) -> Self {
185                 self.expected_blocks_connected.borrow_mut().push_back(block);
186                 self
187         }
188
189         pub fn expect_block_disconnected(self, block: BlockHeaderData) -> Self {
190                 self.expected_blocks_disconnected.borrow_mut().push_back(block);
191                 self
192         }
193 }
194
195 impl ChainListener for MockChainListener {
196         fn block_connected(&self, block: &Block, height: u32) {
197                 match self.expected_blocks_connected.borrow_mut().pop_front() {
198                         None => {
199                                 panic!("Unexpected block connected: {:?}", block.block_hash());
200                         },
201                         Some(expected_block) => {
202                                 assert_eq!(block.block_hash(), expected_block.header.block_hash());
203                                 assert_eq!(height, expected_block.height);
204                         },
205                 }
206         }
207
208         fn block_disconnected(&self, header: &BlockHeader, height: u32) {
209                 match self.expected_blocks_disconnected.borrow_mut().pop_front() {
210                         None => {
211                                 panic!("Unexpected block disconnected: {:?}", header.block_hash());
212                         },
213                         Some(expected_block) => {
214                                 assert_eq!(header.block_hash(), expected_block.header.block_hash());
215                                 assert_eq!(height, expected_block.height);
216                         },
217                 }
218         }
219 }
220
221 impl Drop for MockChainListener {
222         fn drop(&mut self) {
223                 if std::thread::panicking() {
224                         return;
225                 }
226
227                 let expected_blocks_connected = self.expected_blocks_connected.borrow();
228                 if !expected_blocks_connected.is_empty() {
229                         panic!("Expected blocks connected: {:?}", expected_blocks_connected);
230                 }
231
232                 let expected_blocks_disconnected = self.expected_blocks_disconnected.borrow();
233                 if !expected_blocks_disconnected.is_empty() {
234                         panic!("Expected blocks disconnected: {:?}", expected_blocks_disconnected);
235                 }
236         }
237 }