f04769560246f8537e1e022efa22c8b7a815eab4
[rust-lightning] / lightning-block-sync / src / rpc.rs
1 //! Simple RPC client implementation which implements [`BlockSource`] against a Bitcoin Core RPC
2 //! endpoint.
3
4 use crate::{BlockData, BlockHeaderData, BlockSource, AsyncBlockSourceResult};
5 use crate::http::{HttpClient, HttpEndpoint, HttpError, JsonResponse};
6
7 use bitcoin::hash_types::BlockHash;
8 use bitcoin::hashes::hex::ToHex;
9
10 use futures_util::lock::Mutex;
11
12 use serde_json;
13
14 use std::convert::TryFrom;
15 use std::convert::TryInto;
16 use std::sync::atomic::{AtomicUsize, Ordering};
17
18 /// A simple RPC client for calling methods using HTTP `POST`.
19 pub struct RpcClient {
20         basic_auth: String,
21         endpoint: HttpEndpoint,
22         client: Mutex<HttpClient>,
23         id: AtomicUsize,
24 }
25
26 impl RpcClient {
27         /// Creates a new RPC client connected to the given endpoint with the provided credentials. The
28         /// credentials should be a base64 encoding of a user name and password joined by a colon, as is
29         /// required for HTTP basic access authentication.
30         pub fn new(credentials: &str, endpoint: HttpEndpoint) -> std::io::Result<Self> {
31                 let client = Mutex::new(HttpClient::connect(&endpoint)?);
32                 Ok(Self {
33                         basic_auth: "Basic ".to_string() + credentials,
34                         endpoint,
35                         client,
36                         id: AtomicUsize::new(0),
37                 })
38         }
39
40         /// Calls a method with the response encoded in JSON format and interpreted as type `T`.
41         pub async fn call_method<T>(&self, method: &str, params: &[serde_json::Value]) -> std::io::Result<T>
42         where JsonResponse: TryFrom<Vec<u8>, Error = std::io::Error> + TryInto<T, Error = std::io::Error> {
43                 let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port());
44                 let uri = self.endpoint.path();
45                 let content = serde_json::json!({
46                         "method": method,
47                         "params": params,
48                         "id": &self.id.fetch_add(1, Ordering::AcqRel).to_string()
49                 });
50
51                 let mut response = match self.client.lock().await.post::<JsonResponse>(&uri, &host, &self.basic_auth, content).await {
52                         Ok(JsonResponse(response)) => response,
53                         Err(e) if e.kind() == std::io::ErrorKind::Other => {
54                                 match e.get_ref().unwrap().downcast_ref::<HttpError>() {
55                                         Some(http_error) => match JsonResponse::try_from(http_error.contents.clone()) {
56                                                 Ok(JsonResponse(response)) => response,
57                                                 Err(_) => Err(e)?,
58                                         },
59                                         None => Err(e)?,
60                                 }
61                         },
62                         Err(e) => Err(e)?,
63                 };
64
65                 if !response.is_object() {
66                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object"));
67                 }
68
69                 let error = &response["error"];
70                 if !error.is_null() {
71                         // TODO: Examine error code for a more precise std::io::ErrorKind.
72                         let message = error["message"].as_str().unwrap_or("unknown error");
73                         return Err(std::io::Error::new(std::io::ErrorKind::Other, message));
74                 }
75
76                 let result = &mut response["result"];
77                 if result.is_null() {
78                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON result"));
79                 }
80
81                 JsonResponse(result.take()).try_into()
82         }
83 }
84
85 impl BlockSource for RpcClient {
86         fn get_header<'a>(&'a self, header_hash: &'a BlockHash, _height: Option<u32>) -> AsyncBlockSourceResult<'a, BlockHeaderData> {
87                 Box::pin(async move {
88                         let header_hash = serde_json::json!(header_hash.to_hex());
89                         Ok(self.call_method("getblockheader", &[header_hash]).await?)
90                 })
91         }
92
93         fn get_block<'a>(&'a self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, BlockData> {
94                 Box::pin(async move {
95                         let header_hash = serde_json::json!(header_hash.to_hex());
96                         let verbosity = serde_json::json!(0);
97                         Ok(BlockData::FullBlock(self.call_method("getblock", &[header_hash, verbosity]).await?))
98                 })
99         }
100
101         fn get_best_block<'a>(&'a self) -> AsyncBlockSourceResult<'a, (BlockHash, Option<u32>)> {
102                 Box::pin(async move {
103                         Ok(self.call_method("getblockchaininfo", &[]).await?)
104                 })
105         }
106 }
107
108 #[cfg(test)]
109 mod tests {
110         use super::*;
111         use crate::http::client_tests::{HttpServer, MessageBody};
112
113         /// Credentials encoded in base64.
114         const CREDENTIALS: &'static str = "dXNlcjpwYXNzd29yZA==";
115
116         /// Converts a JSON value into `u64`.
117         impl TryInto<u64> for JsonResponse {
118                 type Error = std::io::Error;
119
120                 fn try_into(self) -> std::io::Result<u64> {
121                         match self.0.as_u64() {
122                                 None => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "not a number")),
123                                 Some(n) => Ok(n),
124                         }
125                 }
126         }
127
128         #[tokio::test]
129         async fn call_method_returning_unknown_response() {
130                 let server = HttpServer::responding_with_not_found();
131                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
132
133                 match client.call_method::<u64>("getblockcount", &[]).await {
134                         Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::Other),
135                         Ok(_) => panic!("Expected error"),
136                 }
137         }
138
139         #[tokio::test]
140         async fn call_method_returning_malfomred_response() {
141                 let response = serde_json::json!("foo");
142                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
143                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
144
145                 match client.call_method::<u64>("getblockcount", &[]).await {
146                         Err(e) => {
147                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
148                                 assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON object");
149                         },
150                         Ok(_) => panic!("Expected error"),
151                 }
152         }
153
154         #[tokio::test]
155         async fn call_method_returning_error() {
156                 let response = serde_json::json!({
157                         "error": { "code": -8, "message": "invalid parameter" },
158                 });
159                 let server = HttpServer::responding_with_server_error(response);
160                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
161
162                 let invalid_block_hash = serde_json::json!("foo");
163                 match client.call_method::<u64>("getblock", &[invalid_block_hash]).await {
164                         Err(e) => {
165                                 assert_eq!(e.kind(), std::io::ErrorKind::Other);
166                                 assert_eq!(e.get_ref().unwrap().to_string(), "invalid parameter");
167                         },
168                         Ok(_) => panic!("Expected error"),
169                 }
170         }
171
172         #[tokio::test]
173         async fn call_method_returning_missing_result() {
174                 let response = serde_json::json!({ "result": null });
175                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
176                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
177
178                 match client.call_method::<u64>("getblockcount", &[]).await {
179                         Err(e) => {
180                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
181                                 assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON result");
182                         },
183                         Ok(_) => panic!("Expected error"),
184                 }
185         }
186
187         #[tokio::test]
188         async fn call_method_returning_malformed_result() {
189                 let response = serde_json::json!({ "result": "foo" });
190                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
191                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
192
193                 match client.call_method::<u64>("getblockcount", &[]).await {
194                         Err(e) => {
195                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
196                                 assert_eq!(e.get_ref().unwrap().to_string(), "not a number");
197                         },
198                         Ok(_) => panic!("Expected error"),
199                 }
200         }
201
202         #[tokio::test]
203         async fn call_method_returning_valid_result() {
204                 let response = serde_json::json!({ "result": 654470 });
205                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
206                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
207
208                 match client.call_method::<u64>("getblockcount", &[]).await {
209                         Err(e) => panic!("Unexpected error: {:?}", e),
210                         Ok(count) => assert_eq!(count, 654470),
211                 }
212         }
213 }