efd34f74ee5b3cb219566f1dcad5a8d4d9d57315
[rust-lightning] / lightning-block-sync / src / test_utils.rs
1 use crate::{AsyncBlockSourceResult, BlockHeaderData, BlockSource, BlockSourceError};
2 use crate::poll::{Validate, ValidatedBlockHeader};
3
4 use bitcoin::blockdata::block::{Block, BlockHeader};
5 use bitcoin::blockdata::constants::genesis_block;
6 use bitcoin::hash_types::BlockHash;
7 use bitcoin::network::constants::Network;
8 use bitcoin::util::uint::Uint256;
9
10 #[derive(Default)]
11 pub struct Blockchain {
12         pub blocks: Vec<Block>,
13         without_headers: bool,
14         malformed_headers: bool,
15 }
16
17 impl Blockchain {
18         pub fn default() -> Self {
19                 Blockchain::with_network(Network::Bitcoin)
20         }
21
22         pub fn with_network(network: Network) -> Self {
23                 let blocks = vec![genesis_block(network)];
24                 Self { blocks, ..Default::default() }
25         }
26
27         pub fn with_height(mut self, height: usize) -> Self {
28                 self.blocks.reserve_exact(height);
29                 let bits = BlockHeader::compact_target_from_u256(&Uint256::from_be_bytes([0xff; 32]));
30                 for i in 1..=height {
31                         let prev_block = &self.blocks[i - 1];
32                         let prev_blockhash = prev_block.block_hash();
33                         let time = prev_block.header.time + height as u32;
34                         self.blocks.push(Block {
35                                 header: BlockHeader {
36                                         version: 0,
37                                         prev_blockhash,
38                                         merkle_root: Default::default(),
39                                         time,
40                                         bits,
41                                         nonce: 0,
42                                 },
43                                 txdata: vec![],
44                         });
45                 }
46                 self
47         }
48
49         pub fn without_headers(self) -> Self {
50                 Self { without_headers: true, ..self }
51         }
52
53         pub fn malformed_headers(self) -> Self {
54                 Self { malformed_headers: true, ..self }
55         }
56
57         pub fn at_height(&self, height: usize) -> ValidatedBlockHeader {
58                 let block_header = self.at_height_unvalidated(height);
59                 let block_hash = self.blocks[height].block_hash();
60                 block_header.validate(block_hash).unwrap()
61         }
62
63         fn at_height_unvalidated(&self, height: usize) -> BlockHeaderData {
64                 assert!(!self.blocks.is_empty());
65                 assert!(height < self.blocks.len());
66                 BlockHeaderData {
67                         chainwork: self.blocks[0].header.work() + Uint256::from_u64(height as u64).unwrap(),
68                         height: height as u32,
69                         header: self.blocks[height].header.clone(),
70                 }
71         }
72
73         pub fn tip(&self) -> ValidatedBlockHeader {
74                 assert!(!self.blocks.is_empty());
75                 self.at_height(self.blocks.len() - 1)
76         }
77
78         pub fn disconnect_tip(&mut self) -> Option<Block> {
79                 self.blocks.pop()
80         }
81 }
82
83 impl BlockSource for Blockchain {
84         fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, _height_hint: Option<u32>) -> AsyncBlockSourceResult<'a, BlockHeaderData> {
85                 Box::pin(async move {
86                         if self.without_headers {
87                                 return Err(BlockSourceError::persistent("header not found"));
88                         }
89
90                         for (height, block) in self.blocks.iter().enumerate() {
91                                 if block.header.block_hash() == *header_hash {
92                                         let mut header_data = self.at_height_unvalidated(height);
93                                         if self.malformed_headers {
94                                                 header_data.header.time += 1;
95                                         }
96
97                                         return Ok(header_data);
98                                 }
99                         }
100                         Err(BlockSourceError::transient("header not found"))
101                 })
102         }
103
104         fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, Block> {
105                 Box::pin(async move {
106                         for block in self.blocks.iter() {
107                                 if block.header.block_hash() == *header_hash {
108                                         return Ok(block.clone());
109                                 }
110                         }
111                         Err(BlockSourceError::transient("block not found"))
112                 })
113         }
114
115         fn get_best_block<'a>(&'a mut self) -> AsyncBlockSourceResult<'a, (BlockHash, Option<u32>)> {
116                 Box::pin(async move {
117                         match self.blocks.last() {
118                                 None => Err(BlockSourceError::transient("empty chain")),
119                                 Some(block) => {
120                                         let height = (self.blocks.len() - 1) as u32;
121                                         Ok((block.block_hash(), Some(height)))
122                                 },
123                         }
124                 })
125         }
126 }