521d94f3d1441a42e07498513c6aadde1a8a1e1c
[rust-lightning] / lightning-block-sync / src / rpc.rs
1 //! Simple RPC client implementation which implements [`BlockSource`] against a Bitcoin Core RPC
2 //! endpoint.
3
4 use crate::{BlockData, BlockHeaderData, BlockSource, AsyncBlockSourceResult};
5 use crate::http::{HttpClient, HttpEndpoint, HttpError, JsonResponse};
6
7 use bitcoin::hash_types::BlockHash;
8 use bitcoin::hashes::hex::ToHex;
9
10 use serde_json;
11
12 use std::convert::TryFrom;
13 use std::convert::TryInto;
14 use std::error::Error;
15 use std::fmt;
16 use std::sync::atomic::{AtomicUsize, Ordering};
17
18 /// An error returned by the RPC server.
19 #[derive(Debug)]
20 pub struct RpcError {
21         /// The error code.
22         pub code: i64,
23         /// The error message.
24         pub message: String,
25 }
26
27 impl fmt::Display for RpcError {
28     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29         write!(f, "RPC error {}: {}", self.code, self.message)
30     }
31 }
32
33 impl Error for RpcError {}
34
35 /// A simple RPC client for calling methods using HTTP `POST`.
36 ///
37 /// Implements [`BlockSource`] and may return an `Err` containing [`RpcError`]. See
38 /// [`RpcClient::call_method`] for details.
39 pub struct RpcClient {
40         basic_auth: String,
41         endpoint: HttpEndpoint,
42         id: AtomicUsize,
43 }
44
45 impl RpcClient {
46         /// Creates a new RPC client connected to the given endpoint with the provided credentials. The
47         /// credentials should be a base64 encoding of a user name and password joined by a colon, as is
48         /// required for HTTP basic access authentication.
49         pub fn new(credentials: &str, endpoint: HttpEndpoint) -> std::io::Result<Self> {
50                 Ok(Self {
51                         basic_auth: "Basic ".to_string() + credentials,
52                         endpoint,
53                         id: AtomicUsize::new(0),
54                 })
55         }
56
57         /// Calls a method with the response encoded in JSON format and interpreted as type `T`.
58         ///
59         /// When an `Err` is returned, [`std::io::Error::into_inner`] may contain an [`RpcError`] if
60         /// [`std::io::Error::kind`] is [`std::io::ErrorKind::Other`].
61         pub async fn call_method<T>(&self, method: &str, params: &[serde_json::Value]) -> std::io::Result<T>
62         where JsonResponse: TryFrom<Vec<u8>, Error = std::io::Error> + TryInto<T, Error = std::io::Error> {
63                 let host = format!("{}:{}", self.endpoint.host(), self.endpoint.port());
64                 let uri = self.endpoint.path();
65                 let content = serde_json::json!({
66                         "method": method,
67                         "params": params,
68                         "id": &self.id.fetch_add(1, Ordering::AcqRel).to_string()
69                 });
70
71                 let mut client = HttpClient::connect(&self.endpoint)?;
72                 let mut response = match client.post::<JsonResponse>(&uri, &host, &self.basic_auth, content).await {
73                         Ok(JsonResponse(response)) => response,
74                         Err(e) if e.kind() == std::io::ErrorKind::Other => {
75                                 match e.get_ref().unwrap().downcast_ref::<HttpError>() {
76                                         Some(http_error) => match JsonResponse::try_from(http_error.contents.clone()) {
77                                                 Ok(JsonResponse(response)) => response,
78                                                 Err(_) => Err(e)?,
79                                         },
80                                         None => Err(e)?,
81                                 }
82                         },
83                         Err(e) => Err(e)?,
84                 };
85
86                 if !response.is_object() {
87                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON object"));
88                 }
89
90                 let error = &response["error"];
91                 if !error.is_null() {
92                         // TODO: Examine error code for a more precise std::io::ErrorKind.
93                         let rpc_error = RpcError { 
94                                 code: error["code"].as_i64().unwrap_or(-1), 
95                                 message: error["message"].as_str().unwrap_or("unknown error").to_string() 
96                         };
97                         return Err(std::io::Error::new(std::io::ErrorKind::Other, rpc_error));
98                 }
99
100                 let result = &mut response["result"];
101                 if result.is_null() {
102                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected JSON result"));
103                 }
104
105                 JsonResponse(result.take()).try_into()
106         }
107 }
108
109 impl BlockSource for RpcClient {
110         fn get_header<'a>(&'a self, header_hash: &'a BlockHash, _height: Option<u32>) -> AsyncBlockSourceResult<'a, BlockHeaderData> {
111                 Box::pin(async move {
112                         let header_hash = serde_json::json!(header_hash.to_hex());
113                         Ok(self.call_method("getblockheader", &[header_hash]).await?)
114                 })
115         }
116
117         fn get_block<'a>(&'a self, header_hash: &'a BlockHash) -> AsyncBlockSourceResult<'a, BlockData> {
118                 Box::pin(async move {
119                         let header_hash = serde_json::json!(header_hash.to_hex());
120                         let verbosity = serde_json::json!(0);
121                         Ok(BlockData::FullBlock(self.call_method("getblock", &[header_hash, verbosity]).await?))
122                 })
123         }
124
125         fn get_best_block<'a>(&'a self) -> AsyncBlockSourceResult<'a, (BlockHash, Option<u32>)> {
126                 Box::pin(async move {
127                         Ok(self.call_method("getblockchaininfo", &[]).await?)
128                 })
129         }
130 }
131
132 #[cfg(test)]
133 mod tests {
134         use super::*;
135         use crate::http::client_tests::{HttpServer, MessageBody};
136
137         /// Credentials encoded in base64.
138         const CREDENTIALS: &'static str = "dXNlcjpwYXNzd29yZA==";
139
140         /// Converts a JSON value into `u64`.
141         impl TryInto<u64> for JsonResponse {
142                 type Error = std::io::Error;
143
144                 fn try_into(self) -> std::io::Result<u64> {
145                         match self.0.as_u64() {
146                                 None => Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "not a number")),
147                                 Some(n) => Ok(n),
148                         }
149                 }
150         }
151
152         #[tokio::test]
153         async fn call_method_returning_unknown_response() {
154                 let server = HttpServer::responding_with_not_found();
155                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
156
157                 match client.call_method::<u64>("getblockcount", &[]).await {
158                         Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::Other),
159                         Ok(_) => panic!("Expected error"),
160                 }
161         }
162
163         #[tokio::test]
164         async fn call_method_returning_malfomred_response() {
165                 let response = serde_json::json!("foo");
166                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
167                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
168
169                 match client.call_method::<u64>("getblockcount", &[]).await {
170                         Err(e) => {
171                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
172                                 assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON object");
173                         },
174                         Ok(_) => panic!("Expected error"),
175                 }
176         }
177
178         #[tokio::test]
179         async fn call_method_returning_error() {
180                 let response = serde_json::json!({
181                         "error": { "code": -8, "message": "invalid parameter" },
182                 });
183                 let server = HttpServer::responding_with_server_error(response);
184                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
185
186                 let invalid_block_hash = serde_json::json!("foo");
187                 match client.call_method::<u64>("getblock", &[invalid_block_hash]).await {
188                         Err(e) => {
189                                 assert_eq!(e.kind(), std::io::ErrorKind::Other);
190                                 let rpc_error: Box<RpcError> = e.into_inner().unwrap().downcast().unwrap();
191                                 assert_eq!(rpc_error.code, -8);
192                                 assert_eq!(rpc_error.message, "invalid parameter");
193                         },
194                         Ok(_) => panic!("Expected error"),
195                 }
196         }
197
198         #[tokio::test]
199         async fn call_method_returning_missing_result() {
200                 let response = serde_json::json!({ "result": null });
201                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
202                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
203
204                 match client.call_method::<u64>("getblockcount", &[]).await {
205                         Err(e) => {
206                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
207                                 assert_eq!(e.get_ref().unwrap().to_string(), "expected JSON result");
208                         },
209                         Ok(_) => panic!("Expected error"),
210                 }
211         }
212
213         #[tokio::test]
214         async fn call_method_returning_malformed_result() {
215                 let response = serde_json::json!({ "result": "foo" });
216                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
217                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
218
219                 match client.call_method::<u64>("getblockcount", &[]).await {
220                         Err(e) => {
221                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
222                                 assert_eq!(e.get_ref().unwrap().to_string(), "not a number");
223                         },
224                         Ok(_) => panic!("Expected error"),
225                 }
226         }
227
228         #[tokio::test]
229         async fn call_method_returning_valid_result() {
230                 let response = serde_json::json!({ "result": 654470 });
231                 let server = HttpServer::responding_with_ok(MessageBody::Content(response));
232                 let client = RpcClient::new(CREDENTIALS, server.endpoint()).unwrap();
233
234                 match client.call_method::<u64>("getblockcount", &[]).await {
235                         Err(e) => panic!("Unexpected error: {:?}", e),
236                         Ok(count) => assert_eq!(count, 654470),
237                 }
238         }
239 }