Add subcrate which impls a simple SPV client from Bitcoin Core RPC
[rust-lightning] / lightning-block-sync / src / http_clients.rs
1 use serde_json;
2
3 use serde_derive::Deserialize;
4
5 use crate::utils::hex_to_uint256;
6 use crate::{BlockHeaderData, BlockSource, BlockSourceRespErr};
7
8 use bitcoin::hashes::hex::{ToHex, FromHex};
9 use bitcoin::hash_types::{BlockHash, TxMerkleNode};
10
11 use bitcoin::blockdata::block::{Block, BlockHeader};
12 use bitcoin::consensus::encode;
13
14 use std::convert::TryInto;
15 use std::cmp;
16 use std::future::Future;
17 use std::pin::Pin;
18 use std::net::ToSocketAddrs;
19 use std::io::Write;
20 use std::time::Duration;
21
22 #[cfg(feature = "rpc-client")]
23 use crate::utils::hex_to_vec;
24 #[cfg(feature = "rpc-client")]
25 use std::sync::atomic::{AtomicUsize, Ordering};
26 #[cfg(feature = "rpc-client")]
27 use base64;
28
29 #[cfg(feature = "tokio")]
30 use tokio::net::TcpStream;
31 #[cfg(feature = "tokio")]
32 use tokio::io::AsyncReadExt;
33
34 #[cfg(not(feature = "tokio"))]
35 use std::net::TcpStream;
36 #[cfg(not(feature = "tokio"))]
37 use std::io::Read;
38
39 /// Splits an HTTP URI into its component parts - (is_ssl, hostname, port number, and HTTP path)
40 fn split_uri<'a>(uri: &'a str) -> Option<(bool, &'a str, u16, &'a str)> {
41         let mut uri_iter = uri.splitn(2, ":");
42         let ssl = match uri_iter.next() {
43                 Some("http") => false,
44                 Some("https") => true,
45                 _ => return None,
46         };
47         let mut host_path = match uri_iter.next() {
48                 Some(r) => r,
49                 None => return None,
50         };
51         host_path = host_path.trim_start_matches("/");
52         let mut host_path_iter = host_path.splitn(2, "/");
53         let (host_port_len, host, port) = match host_path_iter.next() {
54                 Some(r) if !r.is_empty() => {
55                         let is_v6_explicit = r.starts_with("[");
56                         let mut iter = if is_v6_explicit {
57                                 r[1..].splitn(2, "]")
58                         } else {
59                                 r.splitn(2, ":")
60                         };
61                         (r.len(), match iter.next() {
62                                 Some(host) => host,
63                                 None => return None,
64                         }, match iter.next() {
65                                 Some(port) if !is_v6_explicit || !port.is_empty() => match if is_v6_explicit {
66                                         if port.as_bytes()[0] != ':' as u8 { return None; }
67                                         &port[1..]
68                                 } else { port }
69                                 .parse::<u16>() {
70                                         Ok(p) => p,
71                                         Err(_) => return None,
72                                 },
73                                 _ => if ssl { 443 } else { 80 },
74                         })
75                 },
76                 _ => return None,
77         };
78         let path = &host_path[host_port_len..];
79
80         Some((ssl, host, port, path))
81 }
82
83 async fn read_http_resp(mut socket: TcpStream, max_resp: usize) -> Option<Vec<u8>> {
84         let mut resp = Vec::new();
85         let mut bytes_read = 0;
86         macro_rules! read_socket { () => { {
87                 #[cfg(feature = "tokio")]
88                 let res = socket.read(&mut resp[bytes_read..]).await;
89                 #[cfg(not(feature = "tokio"))]
90                 let res = socket.read(&mut resp[bytes_read..]);
91                 match res {
92                         Ok(0) => return None,
93                         Ok(b) => b,
94                         Err(_) => return None,
95                 }
96         } } }
97
98         let mut actual_len = 0;
99         let mut ok_found = false;
100         let mut chunked = false;
101         // We expect the HTTP headers to fit in 8KB, and use resp as a temporary buffer for headers
102         // until we know our real length.
103         resp.extend_from_slice(&[0; 8192]);
104         'read_headers: loop {
105                 if bytes_read >= 8192 { return None; }
106                 bytes_read += read_socket!();
107                 for line in resp[..bytes_read].split(|c| *c == '\n' as u8 || *c == '\r' as u8) {
108                         let content_header = b"Content-Length: ";
109                         if line.len() > content_header.len() && line[..content_header.len()].eq_ignore_ascii_case(content_header) {
110                                 actual_len = match match std::str::from_utf8(&line[content_header.len()..]){
111                                         Ok(s) => s, Err(_) => return None,
112                                 }.parse() {
113                                         Ok(len) => len, Err(_) => return None,
114                                 };
115                         }
116                         let http_resp_1 = b"HTTP/1.1 200 ";
117                         let http_resp_0 = b"HTTP/1.0 200 ";
118                         if line.len() > http_resp_1.len() && (line[..http_resp_1.len()].eq_ignore_ascii_case(http_resp_1) ||
119                                                                   line[..http_resp_0.len()].eq_ignore_ascii_case(http_resp_0)) {
120                                 ok_found = true;
121                         }
122                         let transfer_encoding = b"Transfer-Encoding: ";
123                         if line.len() > transfer_encoding.len() && line[..transfer_encoding.len()].eq_ignore_ascii_case(transfer_encoding) {
124                                 match &*String::from_utf8_lossy(&line[transfer_encoding.len()..]).to_ascii_lowercase() {
125                                         "chunked" => chunked = true,
126                                         _ => return None, // Unsupported
127                                 }
128                         }
129                 }
130                 for (idx, window) in resp[..bytes_read].windows(4).enumerate() {
131                         if window[0..2] == *b"\n\n" || window[0..2] == *b"\r\r" {
132                                 resp = resp.split_off(idx + 2);
133                                 resp.resize(bytes_read - idx - 2, 0);
134                                 break 'read_headers;
135                         } else if window[0..4] == *b"\r\n\r\n" {
136                                 resp = resp.split_off(idx + 4);
137                                 resp.resize(bytes_read - idx - 4, 0);
138                                 break 'read_headers;
139                         }
140                 }
141         }
142         if !ok_found || (!chunked && (actual_len == 0 || actual_len > max_resp)) { return None; } // Sorry, not implemented
143         bytes_read = resp.len();
144         if !chunked {
145                 resp.resize(actual_len, 0);
146                 while bytes_read < actual_len {
147                         bytes_read += read_socket!();
148                 }
149                 Some(resp)
150         } else {
151                 actual_len = 0;
152                 let mut chunk_remaining = 0;
153                 'read_bytes: loop {
154                         if chunk_remaining == 0 {
155                                 let mut bytes_skipped = 0;
156                                 let mut finished_read = false;
157                                 let mut lineiter = resp[actual_len..bytes_read].split(|c| *c == '\n' as u8 || *c == '\r' as u8).peekable();
158                                 loop {
159                                         let line = match lineiter.next() { Some(line) => line, None => break };
160                                         if lineiter.peek().is_none() { // We haven't yet read to the end of this line
161                                                 if line.len() > 8 {
162                                                         // No reason to ever have a chunk length line longer than 4 chars
163                                                         return None;
164                                                 }
165                                                 break;
166                                         }
167                                         bytes_skipped += line.len() + 1;
168                                         if line.len() == 0 { continue; } // Probably between the \r and \n
169                                         match usize::from_str_radix(&match std::str::from_utf8(line) {
170                                                 Ok(s) => s, Err(_) => return None,
171                                         }, 16) {
172                                                 Ok(chunklen) => {
173                                                         if chunklen == 0 { finished_read = true; }
174                                                         chunk_remaining = chunklen;
175                                                         match lineiter.next() {
176                                                                 Some(l) if l.is_empty() => {
177                                                                         // Drop \r after \n
178                                                                         bytes_skipped += 1;
179                                                                         if actual_len + bytes_skipped > bytes_read {
180                                                                                 // Go back and get more bytes so we can skip trailing \n
181                                                                                 chunk_remaining = 0;
182                                                                         }
183                                                                 },
184                                                                 Some(_) => {},
185                                                                 None => {
186                                                                         // Go back and get more bytes so we can skip trailing \n
187                                                                         chunk_remaining = 0;
188                                                                 },
189                                                         }
190                                                         break;
191                                                 },
192                                                 Err(_) => return None,
193                                         }
194                                 }
195                                 if chunk_remaining != 0 {
196                                         bytes_read -= bytes_skipped;
197                                         resp.drain(actual_len..actual_len + bytes_skipped);
198                                         if actual_len + chunk_remaining > max_resp { return None; }
199                                         let already_in_chunk = cmp::min(bytes_read - actual_len, chunk_remaining);
200                                         actual_len += already_in_chunk;
201                                         chunk_remaining -= already_in_chunk;
202                                         continue 'read_bytes;
203                                 } else {
204                                         if finished_read {
205                                                 // Note that we may leave some extra \r\ns to be read, but that's OK,
206                                                 // we'll ignore then when parsing headers for the next request.
207                                                 resp.resize(actual_len, 0);
208                                                 return Some(resp);
209                                         } else {
210                                                 // Need to read more bytes to figure out chunk length
211                                         }
212                                 }
213                         }
214                         resp.resize(bytes_read + cmp::max(10, chunk_remaining), 0);
215                         let avail = read_socket!();
216                         bytes_read += avail;
217                         if chunk_remaining != 0 {
218                                 let chunk_read = cmp::min(chunk_remaining, avail);
219                                 chunk_remaining -= chunk_read;
220                                 actual_len += chunk_read;
221                         }
222                 }
223         }
224 }
225
226 #[cfg(feature = "rest-client")]
227 pub struct RESTClient {
228         uri: String,
229 }
230
231 #[cfg(feature = "rest-client")]
232 impl RESTClient {
233         pub fn new(uri: String) -> Option<Self> {
234                 match split_uri(&uri) {
235                         Some((ssl, _host, _port, _path)) if !ssl => Some(Self { uri }),
236                         _ => None,
237                 }
238         }
239
240         async fn make_raw_rest_call(&self, req_path: &str) -> Result<Vec<u8>, ()> {
241                 let (ssl, host, port, path) = split_uri(&self.uri).unwrap();
242                 if ssl { unreachable!(); }
243
244                 let mut stream = match std::net::TcpStream::connect_timeout(&match (host, port).to_socket_addrs() {
245                         Ok(mut sockaddrs) => match sockaddrs.next() { Some(sockaddr) => sockaddr, None => return Err(()) },
246                         Err(_) => return Err(()),
247                 }, Duration::from_secs(1)) {
248                         Ok(stream) => stream,
249                         Err(_) => return Err(()),
250                 };
251                 stream.set_write_timeout(Some(Duration::from_secs(1))).expect("Host kernel is uselessly old?");
252                 stream.set_read_timeout(Some(Duration::from_secs(2))).expect("Host kernel is uselessly old?");
253
254                 let req = format!("GET {}/{} HTTP/1.1\nHost: {}\nConnection: keep-alive\n\n", path, req_path, host);
255                 match stream.write(req.as_bytes()) {
256                         Ok(len) if len == req.len() => {},
257                         _ => return Err(()),
258                 }
259                 #[cfg(feature = "tokio")]
260                 let stream = TcpStream::from_std(stream).unwrap();
261                 match read_http_resp(stream, 4_000_000).await {
262                         Some(r) => Ok(r),
263                         None => return Err(()),
264                 }
265         }
266
267         async fn make_rest_call(&self, req_path: &str) -> Result<serde_json::Value, ()> {
268                 let resp = self.make_raw_rest_call(req_path).await?;
269                 let v: serde_json::Value = match serde_json::from_slice(&resp[..]) {
270                         Ok(v) => v,
271                         Err(_) => return Err(()),
272                 };
273                 if !v.is_object() {
274                         return Err(());
275                 }
276                 Ok(v)
277         }
278 }
279
280 #[cfg(feature = "rpc-client")]
281 pub struct RPCClient {
282         basic_auth: String,
283         uri: String,
284         id: AtomicUsize,
285 }
286
287 #[cfg(feature = "rpc-client")]
288 impl RPCClient {
289         pub fn new(user_auth: &str, uri: String) -> Option<Self> {
290                 match split_uri(&uri) {
291                         Some((ssl, _host, _port, _path)) if !ssl => {
292                                 Some(Self {
293                                         basic_auth: "Basic ".to_string() + &base64::encode(user_auth),
294                                         uri,
295                                         id: AtomicUsize::new(0),
296                                 })
297                         },
298                         _ => None,
299                 }
300         }
301
302         /// params entries must be pre-quoted if appropriate
303         async fn make_rpc_call(&self, method: &str, params: &[&str]) -> Result<serde_json::Value, ()> {
304                 let (ssl, host, port, path) = split_uri(&self.uri).unwrap();
305                 if ssl { unreachable!(); }
306
307                 let mut stream = match std::net::TcpStream::connect_timeout(&match (host, port).to_socket_addrs() {
308                         Ok(mut sockaddrs) => match sockaddrs.next() { Some(sockaddr) => sockaddr, None => return Err(()) },
309                         Err(_) => return Err(()),
310                 }, Duration::from_secs(1)) {
311                         Ok(stream) => stream,
312                         Err(_) => return Err(()),
313                 };
314                 stream.set_write_timeout(Some(Duration::from_secs(1))).expect("Host kernel is uselessly old?");
315                 stream.set_read_timeout(Some(Duration::from_secs(2))).expect("Host kernel is uselessly old?");
316
317                 let mut param_str = String::new();
318                 for (idx, param) in params.iter().enumerate() {
319                         param_str += param;
320                         if idx != params.len() - 1 {
321                                 param_str += ",";
322                         }
323                 }
324                 let req = "{\"method\":\"".to_string() + method + "\",\"params\":[" + &param_str + "],\"id\":" + &self.id.fetch_add(1, Ordering::AcqRel).to_string() + "}";
325
326                 let req = format!("POST {} HTTP/1.1\r\nHost: {}\r\nAuthorization: {}\r\nConnection: keep-alive\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", path, host, &self.basic_auth, req.len(), req);
327                 match stream.write(req.as_bytes()) {
328                         Ok(len) if len == req.len() => {},
329                         _ => return Err(()),
330                 }
331                 #[cfg(feature = "tokio")]
332                 let stream = TcpStream::from_std(stream).unwrap();
333                 let resp = match read_http_resp(stream, 4_000_000).await {
334                         Some(r) => r,
335                         None => return Err(()),
336                 };
337
338                 let v: serde_json::Value = match serde_json::from_slice(&resp[..]) {
339                         Ok(v) => v,
340                         Err(_) => return Err(()),
341                 };
342                 if !v.is_object() {
343                         return Err(());
344                 }
345                 let v_obj = v.as_object().unwrap();
346                 if v_obj.get("error") != Some(&serde_json::Value::Null) {
347                         return Err(());
348                 }
349                 if let Some(res) = v_obj.get("result") {
350                         Ok((*res).clone())
351                 } else {
352                         Err(())
353                 }
354         }
355 }
356
357 #[derive(Deserialize)]
358 struct GetHeaderResponse {
359         pub chainwork: String,
360         pub height: u32,
361
362         pub version: u32,
363         pub merkleroot: String,
364         pub time: u32,
365         pub nonce: u32,
366         pub bits: String,
367         pub previousblockhash: String,
368 }
369
370 impl GetHeaderResponse {
371         /// Always returns BogusData if we return an Err
372         pub fn to_block_header(self) -> Result<BlockHeaderData, BlockSourceRespErr> {
373                 let header = BlockHeader {
374                         version: self.version,
375                         prev_blockhash: BlockHash::from_hex(&self.previousblockhash).map_err(|_| BlockSourceRespErr::BogusData)?,
376                         merkle_root: TxMerkleNode::from_hex(&self.merkleroot).map_err(|_| BlockSourceRespErr::BogusData)?,
377                         time: self.time,
378                         bits: u32::from_str_radix(&self.bits, 16).map_err(|_| BlockSourceRespErr::BogusData)?,
379                         nonce: self.nonce,
380                 };
381
382                 Ok(BlockHeaderData {
383                         chainwork: hex_to_uint256(&self.chainwork).ok_or(BlockSourceRespErr::BogusData)?,
384                         height: self.height,
385                         header,
386                 })
387         }
388 }
389
390 #[cfg(feature = "rpc-client")]
391 impl BlockSource for RPCClient {
392         fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, _height: Option<u32>) -> Pin<Box<dyn Future<Output = Result<BlockHeaderData, BlockSourceRespErr>> + 'a + Send>> {
393                 let param = "\"".to_string() + &header_hash.to_hex() + "\"";
394                 Box::pin(async move {
395                         let res = self.make_rpc_call("getblockheader", &[&param]).await;
396                         if let Ok(mut v) = res {
397                                 if v.is_object() {
398                                         if let None = v.get("previousblockhash") {
399                                                 // Got a request for genesis block, add a dummy previousblockhash
400                                                 v.as_object_mut().unwrap().insert("previousblockhash".to_string(), serde_json::Value::String("".to_string()));
401                                         }
402                                 }
403                                 let deser_res: Result<GetHeaderResponse, _> = serde_json::from_value(v);
404                                 match deser_res {
405                                         Ok(resp) => resp.to_block_header(),
406                                         Err(_) => Err(BlockSourceRespErr::NoResponse),
407                                 }
408                         } else { Err(BlockSourceRespErr::NoResponse) }
409                 })
410         }
411
412         fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> Pin<Box<dyn Future<Output = Result<Block, BlockSourceRespErr>> + 'a + Send>> {
413                 let param = "\"".to_string() + &header_hash.to_hex() + "\"";
414                 Box::pin(async move {
415                         let blockhex = self.make_rpc_call("getblock", &[&param, "0"]).await.map_err(|_| BlockSourceRespErr::NoResponse)?;
416                         let blockdata = hex_to_vec(blockhex.as_str().ok_or(BlockSourceRespErr::NoResponse)?).ok_or(BlockSourceRespErr::NoResponse)?;
417                         let block: Block = encode::deserialize(&blockdata).map_err(|_| BlockSourceRespErr::NoResponse)?;
418                         Ok(block)
419                 })
420         }
421
422         fn get_best_block<'a>(&'a mut self) -> Pin<Box<dyn Future<Output = Result<(BlockHash, Option<u32>), BlockSourceRespErr>> + 'a + Send>> {
423                 Box::pin(async move {
424                         if let Ok(v) = self.make_rpc_call("getblockchaininfo", &[]).await {
425                                 let height = v["blocks"].as_u64().ok_or(BlockSourceRespErr::NoResponse)?
426                                         .try_into().map_err(|_| BlockSourceRespErr::NoResponse)?;
427                                 let blockstr = v["bestblockhash"].as_str().ok_or(BlockSourceRespErr::NoResponse)?;
428                                 Ok((BlockHash::from_hex(blockstr).map_err(|_| BlockSourceRespErr::NoResponse)?, Some(height)))
429                         } else { Err(BlockSourceRespErr::NoResponse) }
430                 })
431         }
432 }
433
434 #[cfg(feature = "rest-client")]
435 impl BlockSource for RESTClient {
436         fn get_header<'a>(&'a mut self, header_hash: &'a BlockHash, _height: Option<u32>) -> Pin<Box<dyn Future<Output = Result<BlockHeaderData, BlockSourceRespErr>> + 'a + Send>> {
437                 Box::pin(async move {
438                         let reqpath = format!("headers/1/{}.json", header_hash.to_hex());
439                         match self.make_rest_call(&reqpath).await {
440                                 Ok(serde_json::Value::Array(mut v)) if !v.is_empty() => {
441                                         let mut header = v.drain(..).next().unwrap();
442                                         if !header.is_object() { return Err(BlockSourceRespErr::NoResponse); }
443                                         if let None = header.get("previousblockhash") {
444                                                 // Got a request for genesis block, add a dummy previousblockhash
445                                                 header.as_object_mut().unwrap().insert("previousblockhash".to_string(), serde_json::Value::String("".to_string()));
446                                         }
447                                         let deser_res: Result<GetHeaderResponse, _> = serde_json::from_value(header);
448                                         match deser_res {
449                                                 Ok(resp) => resp.to_block_header(),
450                                                 Err(_) => Err(BlockSourceRespErr::NoResponse),
451                                         }
452                                 },
453                                 _ => Err(BlockSourceRespErr::NoResponse)
454                         }
455                 })
456         }
457
458         fn get_block<'a>(&'a mut self, header_hash: &'a BlockHash) -> Pin<Box<dyn Future<Output = Result<Block, BlockSourceRespErr>> + 'a + Send>> {
459                 Box::pin(async move {
460                         let reqpath = format!("block/{}.bin", header_hash.to_hex());
461                         let blockdata = self.make_raw_rest_call(&reqpath).await.map_err(|_| BlockSourceRespErr::NoResponse)?;
462                         let block: Block = encode::deserialize(&blockdata).map_err(|_| BlockSourceRespErr::NoResponse)?;
463                         Ok(block)
464                 })
465         }
466
467         fn get_best_block<'a>(&'a mut self) -> Pin<Box<dyn Future<Output = Result<(BlockHash, Option<u32>), BlockSourceRespErr>> + 'a + Send>> {
468                 Box::pin(async move {
469                         let v = self.make_rest_call("chaininfo.json").await.map_err(|_| BlockSourceRespErr::NoResponse)?;
470                         let height = v["blocks"].as_u64().ok_or(BlockSourceRespErr::NoResponse)?
471                                 .try_into().map_err(|_| BlockSourceRespErr::NoResponse)?;
472                         let blockstr = v["bestblockhash"].as_str().ok_or(BlockSourceRespErr::NoResponse)?;
473                         Ok((BlockHash::from_hex(blockstr).map_err(|_| BlockSourceRespErr::NoResponse)?, Some(height)))
474                 })
475         }
476 }
477
478 #[cfg(test)]
479 #[test]
480 fn test_split_uri() {
481         assert_eq!(split_uri("http://example.com:8080/path"), Some((false, "example.com", 8080, "/path")));
482         assert_eq!(split_uri("http:example.com:8080/path/b"), Some((false, "example.com", 8080, "/path/b")));
483         assert_eq!(split_uri("https://0.0.0.0/"), Some((true, "0.0.0.0", 443, "/")));
484         assert_eq!(split_uri("http:[0:bad::43]:80/"), Some((false, "0:bad::43", 80, "/")));
485         assert_eq!(split_uri("http:[::]"), Some((false, "::", 80, "")));
486         assert_eq!(split_uri("http://"), None);
487         assert_eq!(split_uri("http://example.com:70000/"), None);
488         assert_eq!(split_uri("ftp://example.com:80/"), None);
489         assert_eq!(split_uri("http://example.com"), Some((false, "example.com", 80, "")));
490 }