Avoid connection-per-RPC-call again by caching connections
[rust-lightning] / lightning-block-sync / src / rpc.rs
1 //! Simple RPC client implementation which implements [`BlockSource`] against a Bitcoin Core RPC
2 //! endpoint.
3
4 use crate::{BlockData, BlockHeaderData, BlockSource, AsyncBlockSourceResult};
5 use crate::http::{HttpClient, HttpEndpoint, HttpError, JsonResponse};
6
7 use bitcoin::hash_types::BlockHash;
8 use bitcoin::hashes::hex::ToHex;
9
10 use std::sync::Mutex;
11
12 use serde_json;
13
14 use std::convert::TryFrom;
15 use std::convert::TryInto;
16 use std::error::Error;
17 use std::fmt;
18 use std::sync::atomic::{AtomicUsize, Ordering};
19
20 /// An error returned by the RPC server.
21 #[derive(Debug)]
22 pub struct RpcError {
23         /// The error code.
24         pub code: i64,
25         /// The error message.
26         pub message: String,
27 }
28
29 impl fmt::Display for RpcError {
30     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31         write!(f, "RPC error {}: {}", self.code, self.message)
32     }
33 }
34
35 impl Error for RpcError {}
36
37 /// A simple RPC client for calling methods using HTTP `POST`.
38 ///
39 /// Implements [`BlockSource`] and may return an `Err` containing [`RpcError`]. See
40 /// [`RpcClient::call_method`] for details.
41 pub struct RpcClient {
42         basic_auth: String,
43         endpoint: HttpEndpoint,
44         client: Mutex<Option<HttpClient>>,
45         id: AtomicUsize,
46 }
47
48 impl RpcClient {
49         /// Creates a new RPC client connected to the given endpoint with the provided credentials. The
50         /// credentials should be a base64 encoding of a user name and password joined by a colon, as is
51         /// required for HTTP basic access authentication.
52         pub fn new(credentials: &str, endpoint: HttpEndpoint) -> std::io::Result<Self> {
53                 Ok(Self {
54                         basic_auth: "Basic ".to_string() + credentials,
55                         endpoint,
56                         client: Mutex::new(None),
57                         id: AtomicUsize::new(0),
58                 })
59         }
60
61         /// Calls a method with the response encoded in JSON format and interpreted as type `T`.
62         ///
63         /// When an `Err` is returned, [`std::io::Error::into_inner`] may contain an [`RpcError`] if
64         /// [`std::io::Error::kind`] is [`std::io::ErrorKind::Other`].
65         pub async fn call_method<T>(&self, method: &str, params: &[serde_json::Value]) -> std::io::Result<T>
66         where JsonResponse: TryFrom<Vec<u8>, Error = std::io::Error> + TryInto<T, Error = std::io::Error> {
67                 let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port());
68                 let uri = self.endpoint.path();
69                 let content = serde_json::json!({
70                         "method": method,
71                         "params": params,
72                         "id": &self.id.fetch_add(1, Ordering::AcqRel).to_string()
73                 });
74
75                 let mut client = if let Some(client) = self.client.lock().unwrap().take() { client }
76                         else { HttpClient::connect(&self.endpoint)? };
77                 let http_response = client.post::<JsonResponse>(&uri, &host, &self.basic_auth, content).await;
78                 *self.client.lock().unwrap() = Some(client);
79
80                 let mut response = match http_response {
81                         Ok(JsonResponse(response)) => response,
82                         Err(e) if e.kind() == std::io::ErrorKind::Other => {
83                                 match e.get_ref().unwrap().downcast_ref::<HttpError>() {
84                                         Some(http_error) => match JsonResponse::try_from(http_error.contents.clone()) {
85                                                 Ok(JsonResponse(response)) => response,
86                                                 Err(_) => Err(e)?,
87                                         },
88                                         None => Err(e)?,
89                                 }
90                         },
91                         Err(e) => Err(e)?,
92                 };
93
94                 if !response.is_object() {
95                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object"));
96                 }
97
98                 let error = &response["error"];
99                 if !error.is_null() {
100                         // TODO: Examine error code for a more precise std::io::ErrorKind.
101                         let rpc_error = RpcError { 
102                                 code: error["code"].as_i64().unwrap_or(-1), 
103                                 message: error["message"].as_str().unwrap_or("unknown error").to_string() 
104                         };
105                         return Err(std::io::Error::new(std::io::ErrorKind::Other, rpc_error));
106                 }
107
108                 let result = &mut response["result"];
109                 if result.is_null() {
110                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON result"));
111                 }
112
113                 JsonResponse(result.take()).try_into()
114         }
115 }
116
117 impl BlockSource for RpcClient {
118         fn get_header<'a>(&'a self, header_hash: &'a BlockHash, _height: Option<u32>) -> AsyncBlockSourceResult<'a, BlockHeaderData> {
119                 Box::pin(async move {
120                         let header_hash = serde_json::json!(header_hash.to_hex());
121                         Ok(self.call_method("getblockheader", &[header_hash]).await?)
122                 })
123         }
124
125         fn get_block<'a>(&'a self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, BlockData> {
126                 Box::pin(async move {
127                         let header_hash = serde_json::json!(header_hash.to_hex());
128                         let verbosity = serde_json::json!(0);
129                         Ok(BlockData::FullBlock(self.call_method("getblock", &[header_hash, verbosity]).await?))
130                 })
131         }
132
133         fn get_best_block<'a>(&'a self) -> AsyncBlockSourceResult<'a, (BlockHash, Option<u32>)> {
134                 Box::pin(async move {
135                         Ok(self.call_method("getblockchaininfo", &[]).await?)
136                 })
137         }
138 }
139
140 #[cfg(test)]
141 mod tests {
142         use super::*;
143         use crate::http::client_tests::{HttpServer, MessageBody};
144
145         /// Credentials encoded in base64.
146         const CREDENTIALS: &'static str = "dXNlcjpwYXNzd29yZA==";
147
148         /// Converts a JSON value into `u64`.
149         impl TryInto<u64> for JsonResponse {
150                 type Error = std::io::Error;
151
152                 fn try_into(self) -> std::io::Result<u64> {
153                         match self.0.as_u64() {
154                                 None => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "not a number")),
155                                 Some(n) => Ok(n),
156                         }
157                 }
158         }
159
160         #[tokio::test]
161         async fn call_method_returning_unknown_response() {
162                 let server = HttpServer::responding_with_not_found();
163                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
164
165                 match client.call_method::<u64>("getblockcount", &[]).await {
166                         Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::Other),
167                         Ok(_) => panic!("Expected error"),
168                 }
169         }
170
171         #[tokio::test]
172         async fn call_method_returning_malfomred_response() {
173                 let response = serde_json::json!("foo");
174                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
175                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
176
177                 match client.call_method::<u64>("getblockcount", &[]).await {
178                         Err(e) => {
179                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
180                                 assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON object");
181                         },
182                         Ok(_) => panic!("Expected error"),
183                 }
184         }
185
186         #[tokio::test]
187         async fn call_method_returning_error() {
188                 let response = serde_json::json!({
189                         "error": { "code": -8, "message": "invalid parameter" },
190                 });
191                 let server = HttpServer::responding_with_server_error(response);
192                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
193
194                 let invalid_block_hash = serde_json::json!("foo");
195                 match client.call_method::<u64>("getblock", &[invalid_block_hash]).await {
196                         Err(e) => {
197                                 assert_eq!(e.kind(), std::io::ErrorKind::Other);
198                                 let rpc_error: Box<RpcError> = e.into_inner().unwrap().downcast().unwrap();
199                                 assert_eq!(rpc_error.code, -8);
200                                 assert_eq!(rpc_error.message, "invalid parameter");
201                         },
202                         Ok(_) => panic!("Expected error"),
203                 }
204         }
205
206         #[tokio::test]
207         async fn call_method_returning_missing_result() {
208                 let response = serde_json::json!({ "result": null });
209                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
210                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
211
212                 match client.call_method::<u64>("getblockcount", &[]).await {
213                         Err(e) => {
214                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
215                                 assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON result");
216                         },
217                         Ok(_) => panic!("Expected error"),
218                 }
219         }
220
221         #[tokio::test]
222         async fn call_method_returning_malformed_result() {
223                 let response = serde_json::json!({ "result": "foo" });
224                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
225                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
226
227                 match client.call_method::<u64>("getblockcount", &[]).await {
228                         Err(e) => {
229                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
230                                 assert_eq!(e.get_ref().unwrap().to_string(), "not a number");
231                         },
232                         Ok(_) => panic!("Expected error"),
233                 }
234         }
235
236         #[tokio::test]
237         async fn call_method_returning_valid_result() {
238                 let response = serde_json::json!({ "result": 654470 });
239                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
240                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
241
242                 match client.call_method::<u64>("getblockcount", &[]).await {
243                         Err(e) => panic!("Unexpected error: {:?}", e),
244                         Ok(count) => assert_eq!(count, 654470),
245                 }
246         }
247 }