Check IO errors in test using `raw_os_error()` instead of `kind()`
[rust-lightning] / lightning-block-sync / src / http.rs
1 //! Simple HTTP implementation which supports both async and traditional execution environments
2 //! with minimal dependencies. This is used as the basis for REST and RPC clients.
3
4 use chunked_transfer;
5 use serde_json;
6
7 use std::convert::TryFrom;
8 use std::fmt;
9 #[cfg(not(feature = "tokio"))]
10 use std::io::Write;
11 use std::net::{SocketAddr, ToSocketAddrs};
12 use std::time::Duration;
13
14 #[cfg(feature = "tokio")]
15 use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
16 #[cfg(feature = "tokio")]
17 use tokio::net::TcpStream;
18
19 #[cfg(not(feature = "tokio"))]
20 use std::io::BufRead;
21 use std::io::Read;
22 #[cfg(not(feature = "tokio"))]
23 use std::net::TcpStream;
24
25 /// Timeout for operations on TCP streams.
26 const TCP_STREAM_TIMEOUT: Duration = Duration::from_secs(5);
27
28 /// Timeout for reading the first byte of a response. This is separate from the general read
29 /// timeout as it is not uncommon for Bitcoin Core to be blocked waiting on UTXO cache flushes for
30 /// upwards of a minute or more. Note that we always retry once when we time out, so the maximum
31 /// time we allow Bitcoin Core to block for is twice this value.
32 const TCP_STREAM_RESPONSE_TIMEOUT: Duration = Duration::from_secs(120);
33
34 /// Maximum HTTP message header size in bytes.
35 const MAX_HTTP_MESSAGE_HEADER_SIZE: usize = 8192;
36
37 /// Maximum HTTP message body size in bytes. Enough for a hex-encoded block in JSON format and any
38 /// overhead for HTTP chunked transfer encoding.
39 const MAX_HTTP_MESSAGE_BODY_SIZE: usize = 2 * 4_000_000 + 32_000;
40
41 /// Endpoint for interacting with an HTTP-based API.
42 #[derive(Debug)]
43 pub struct HttpEndpoint {
44         host: String,
45         port: Option<u16>,
46         path: String,
47 }
48
49 impl HttpEndpoint {
50         /// Creates an endpoint for the given host and default HTTP port.
51         pub fn for_host(host: String) -> Self {
52                 Self {
53                         host,
54                         port: None,
55                         path: String::from("/"),
56                 }
57         }
58
59         /// Specifies a port to use with the endpoint.
60         pub fn with_port(mut self, port: u16) -> Self {
61                 self.port = Some(port);
62                 self
63         }
64
65         /// Specifies a path to use with the endpoint.
66         pub fn with_path(mut self, path: String) -> Self {
67                 self.path = path;
68                 self
69         }
70
71         /// Returns the endpoint host.
72         pub fn host(&self) -> &str {
73                 &self.host
74         }
75
76         /// Returns the endpoint port.
77         pub fn port(&self) -> u16 {
78                 match self.port {
79                         None => 80,
80                         Some(port) => port,
81                 }
82         }
83
84         /// Returns the endpoint path.
85         pub fn path(&self) -> &str {
86                 &self.path
87         }
88 }
89
90 impl<'a> std::net::ToSocketAddrs for &'a HttpEndpoint {
91         type Iter = <(&'a str, u16) as std::net::ToSocketAddrs>::Iter;
92
93         fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
94                 (self.host(), self.port()).to_socket_addrs()
95         }
96 }
97
98 /// Client for making HTTP requests.
99 pub(crate) struct HttpClient {
100         address: SocketAddr,
101         stream: TcpStream,
102 }
103
104 impl HttpClient {
105         /// Opens a connection to an HTTP endpoint.
106         pub fn connect<E: ToSocketAddrs>(endpoint: E) -> std::io::Result<Self> {
107                 let address = match endpoint.to_socket_addrs()?.next() {
108                         None => {
109                                 return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "could not resolve to any addresses"));
110                         },
111                         Some(address) => address,
112                 };
113                 let stream = std::net::TcpStream::connect_timeout(&address, TCP_STREAM_TIMEOUT)?;
114                 stream.set_read_timeout(Some(TCP_STREAM_TIMEOUT))?;
115                 stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT))?;
116
117                 #[cfg(feature = "tokio")]
118                 let stream = {
119                         stream.set_nonblocking(true)?;
120                         TcpStream::from_std(stream)?
121                 };
122
123                 Ok(Self { address, stream })
124         }
125
126         /// Sends a `GET` request for a resource identified by `uri` at the `host`.
127         ///
128         /// Returns the response body in `F` format.
129         #[allow(dead_code)]
130         pub async fn get<F>(&mut self, uri: &str, host: &str) -> std::io::Result<F>
131         where F: TryFrom<Vec<u8>, Error = std::io::Error> {
132                 let request = format!(
133                         "GET {} HTTP/1.1\r\n\
134                          Host: {}\r\n\
135                          Connection: keep-alive\r\n\
136                          \r\n", uri, host);
137                 let response_body = self.send_request_with_retry(&request).await?;
138                 F::try_from(response_body)
139         }
140
141         /// Sends a `POST` request for a resource identified by `uri` at the `host` using the given HTTP
142         /// authentication credentials.
143         ///
144         /// The request body consists of the provided JSON `content`. Returns the response body in `F`
145         /// format.
146         #[allow(dead_code)]
147         pub async fn post<F>(&mut self, uri: &str, host: &str, auth: &str, content: serde_json::Value) -> std::io::Result<F>
148         where F: TryFrom<Vec<u8>, Error = std::io::Error> {
149                 let content = content.to_string();
150                 let request = format!(
151                         "POST {} HTTP/1.1\r\n\
152                          Host: {}\r\n\
153                          Authorization: {}\r\n\
154                          Connection: keep-alive\r\n\
155                          Content-Type: application/json\r\n\
156                          Content-Length: {}\r\n\
157                          \r\n\
158                          {}", uri, host, auth, content.len(), content);
159                 let response_body = self.send_request_with_retry(&request).await?;
160                 F::try_from(response_body)
161         }
162
163         /// Sends an HTTP request message and reads the response, returning its body. Attempts to
164         /// reconnect and retry if the connection has been closed.
165         async fn send_request_with_retry(&mut self, request: &str) -> std::io::Result<Vec<u8>> {
166                 match self.send_request(request).await {
167                         Ok(bytes) => Ok(bytes),
168                         Err(_) => {
169                                 // Reconnect and retry on fail. This can happen if the connection was closed after
170                                 // the keep-alive limits are reached, or generally if the request timed out due to
171                                 // Bitcoin Core being stuck on a long-running operation or its RPC queue being
172                                 // full.
173                                 // Block 100ms before retrying the request as in many cases the source of the error
174                                 // may be persistent for some time.
175                                 #[cfg(feature = "tokio")]
176                                 tokio::time::sleep(Duration::from_millis(100)).await;
177                                 #[cfg(not(feature = "tokio"))]
178                                 std::thread::sleep(Duration::from_millis(100));
179                                 *self = Self::connect(self.address)?;
180                                 self.send_request(request).await
181                         },
182                 }
183         }
184
185         /// Sends an HTTP request message and reads the response, returning its body.
186         async fn send_request(&mut self, request: &str) -> std::io::Result<Vec<u8>> {
187                 self.write_request(request).await?;
188                 self.read_response().await
189         }
190
191         /// Writes an HTTP request message.
192         async fn write_request(&mut self, request: &str) -> std::io::Result<()> {
193                 #[cfg(feature = "tokio")]
194                 {
195                         self.stream.write_all(request.as_bytes()).await?;
196                         self.stream.flush().await
197                 }
198                 #[cfg(not(feature = "tokio"))]
199                 {
200                         self.stream.write_all(request.as_bytes())?;
201                         self.stream.flush()
202                 }
203         }
204
205         /// Reads an HTTP response message.
206         async fn read_response(&mut self) -> std::io::Result<Vec<u8>> {
207                 #[cfg(feature = "tokio")]
208                 let stream = self.stream.split().0;
209                 #[cfg(not(feature = "tokio"))]
210                 let stream = std::io::Read::by_ref(&mut self.stream);
211
212                 let limited_stream = stream.take(MAX_HTTP_MESSAGE_HEADER_SIZE as u64);
213
214                 #[cfg(feature = "tokio")]
215                 let mut reader = tokio::io::BufReader::new(limited_stream);
216                 #[cfg(not(feature = "tokio"))]
217                 let mut reader = std::io::BufReader::new(limited_stream);
218
219                 macro_rules! read_line {
220                         () => { read_line!(0) };
221                         ($retry_count: expr) => { {
222                                 let mut line = String::new();
223                                 let mut timeout_count: u64 = 0;
224                                 let bytes_read = loop {
225                                         #[cfg(feature = "tokio")]
226                                         let read_res = reader.read_line(&mut line).await;
227                                         #[cfg(not(feature = "tokio"))]
228                                         let read_res = reader.read_line(&mut line);
229                                         match read_res {
230                                                 Ok(bytes_read) => break bytes_read,
231                                                 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
232                                                         timeout_count += 1;
233                                                         if timeout_count > $retry_count {
234                                                                 return Err(e);
235                                                         } else {
236                                                                 continue;
237                                                         }
238                                                 }
239                                                 Err(e) => return Err(e),
240                                         }
241                                 };
242
243                                 match bytes_read {
244                                         0 => None,
245                                         _ => {
246                                                 // Remove trailing CRLF
247                                                 if line.ends_with('\n') { line.pop(); if line.ends_with('\r') { line.pop(); } }
248                                                 Some(line)
249                                         },
250                                 }
251                         } }
252                 }
253
254                 // Read and parse status line
255                 // Note that we allow retrying a few times to reach TCP_STREAM_RESPONSE_TIMEOUT.
256                 let status_line = read_line!(TCP_STREAM_RESPONSE_TIMEOUT.as_secs() / TCP_STREAM_TIMEOUT.as_secs())
257                         .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no status line"))?;
258                 let status = HttpStatus::parse(&status_line)?;
259
260                 // Read and parse relevant headers
261                 let mut message_length = HttpMessageLength::Empty;
262                 loop {
263                         let line = read_line!()
264                                 .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no headers"))?;
265                         if line.is_empty() { break; }
266
267                         let header = HttpHeader::parse(&line)?;
268                         if header.has_name("Content-Length") {
269                                 let length = header.value.parse()
270                                         .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
271                                 if let HttpMessageLength::Empty = message_length {
272                                         message_length = HttpMessageLength::ContentLength(length);
273                                 }
274                                 continue;
275                         }
276
277                         if header.has_name("Transfer-Encoding") {
278                                 message_length = HttpMessageLength::TransferEncoding(header.value.into());
279                                 continue;
280                         }
281                 }
282
283                 // Read message body
284                 let read_limit = MAX_HTTP_MESSAGE_BODY_SIZE - reader.buffer().len();
285                 reader.get_mut().set_limit(read_limit as u64);
286                 let contents = match message_length {
287                         HttpMessageLength::Empty => { Vec::new() },
288                         HttpMessageLength::ContentLength(length) => {
289                                 if length == 0 || length > MAX_HTTP_MESSAGE_BODY_SIZE {
290                                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "out of range"))
291                                 } else {
292                                         let mut content = vec![0; length];
293                                         #[cfg(feature = "tokio")]
294                                         reader.read_exact(&mut content[..]).await?;
295                                         #[cfg(not(feature = "tokio"))]
296                                         reader.read_exact(&mut content[..])?;
297                                         content
298                                 }
299                         },
300                         HttpMessageLength::TransferEncoding(coding) => {
301                                 if !coding.eq_ignore_ascii_case("chunked") {
302                                         return Err(std::io::Error::new(
303                                                 std::io::ErrorKind::InvalidInput, "unsupported transfer coding"))
304                                 } else {
305                                         let mut content = Vec::new();
306                                         #[cfg(feature = "tokio")]
307                                         {
308                                                 // Since chunked_transfer doesn't have an async interface, only use it to
309                                                 // determine the size of each chunk to read.
310                                                 //
311                                                 // TODO: Replace with an async interface when available.
312                                                 // https://github.com/frewsxcv/rust-chunked-transfer/issues/7
313                                                 loop {
314                                                         // Read the chunk header which contains the chunk size.
315                                                         let mut chunk_header = String::new();
316                                                         reader.read_line(&mut chunk_header).await?;
317                                                         if chunk_header == "0\r\n" {
318                                                                 // Read the terminator chunk since the decoder consumes the CRLF
319                                                                 // immediately when this chunk is encountered.
320                                                                 reader.read_line(&mut chunk_header).await?;
321                                                         }
322
323                                                         // Decode the chunk header to obtain the chunk size.
324                                                         let mut buffer = Vec::new();
325                                                         let mut decoder = chunked_transfer::Decoder::new(chunk_header.as_bytes());
326                                                         decoder.read_to_end(&mut buffer)?;
327
328                                                         // Read the chunk body.
329                                                         let chunk_size = match decoder.remaining_chunks_size() {
330                                                                 None => break,
331                                                                 Some(chunk_size) => chunk_size,
332                                                         };
333                                                         let chunk_offset = content.len();
334                                                         content.resize(chunk_offset + chunk_size + "\r\n".len(), 0);
335                                                         reader.read_exact(&mut content[chunk_offset..]).await?;
336                                                         content.resize(chunk_offset + chunk_size, 0);
337                                                 }
338                                                 content
339                                         }
340                                         #[cfg(not(feature = "tokio"))]
341                                         {
342                                                 let mut decoder = chunked_transfer::Decoder::new(reader);
343                                                 decoder.read_to_end(&mut content)?;
344                                                 content
345                                         }
346                                 }
347                         },
348                 };
349
350                 if !status.is_ok() {
351                         // TODO: Handle 3xx redirection responses.
352                         let error = HttpError {
353                                 status_code: status.code.to_string(),
354                                 contents,
355                         };
356                         return Err(std::io::Error::new(std::io::ErrorKind::Other, error));
357                 }
358
359                 Ok(contents)
360         }
361 }
362
363 /// HTTP error consisting of a status code and body contents.
364 #[derive(Debug)]
365 pub(crate) struct HttpError {
366         pub(crate) status_code: String,
367         pub(crate) contents: Vec<u8>,
368 }
369
370 impl std::error::Error for HttpError {}
371
372 impl fmt::Display for HttpError {
373         fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
374                 let contents = String::from_utf8_lossy(&self.contents);
375                 write!(f, "status_code: {}, contents: {}", self.status_code, contents)
376         }
377 }
378
379 /// HTTP response status code as defined by [RFC 7231].
380 ///
381 /// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-6
382 struct HttpStatus<'a> {
383         code: &'a str,
384 }
385
386 impl<'a> HttpStatus<'a> {
387         /// Parses an HTTP status line as defined by [RFC 7230].
388         ///
389         /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.1.2
390         fn parse(line: &'a String) -> std::io::Result<HttpStatus<'a>> {
391                 let mut tokens = line.splitn(3, ' ');
392
393                 let http_version = tokens.next()
394                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no HTTP-Version"))?;
395                 if !http_version.eq_ignore_ascii_case("HTTP/1.1") &&
396                         !http_version.eq_ignore_ascii_case("HTTP/1.0") {
397                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid HTTP-Version"));
398                 }
399
400                 let code = tokens.next()
401                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Status-Code"))?;
402                 if code.len() != 3 || !code.chars().all(|c| c.is_ascii_digit()) {
403                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid Status-Code"));
404                 }
405
406                 let _reason = tokens.next()
407                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Reason-Phrase"))?;
408
409                 Ok(Self { code })
410         }
411
412         /// Returns whether the status is successful (i.e., 2xx status class).
413         fn is_ok(&self) -> bool {
414                 self.code.starts_with('2')
415         }
416 }
417
418 /// HTTP response header as defined by [RFC 7231].
419 ///
420 /// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-7
421 struct HttpHeader<'a> {
422         name: &'a str,
423         value: &'a str,
424 }
425
426 impl<'a> HttpHeader<'a> {
427         /// Parses an HTTP header field as defined by [RFC 7230].
428         ///
429         /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.2
430         fn parse(line: &'a String) -> std::io::Result<HttpHeader<'a>> {
431                 let mut tokens = line.splitn(2, ':');
432                 let name = tokens.next()
433                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header name"))?;
434                 let value = tokens.next()
435                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header value"))?
436                         .trim_start();
437                 Ok(Self { name, value })
438         }
439
440         /// Returns whether the header field has the given name.
441         fn has_name(&self, name: &str) -> bool {
442                 self.name.eq_ignore_ascii_case(name)
443         }
444 }
445
446 /// HTTP message body length as defined by [RFC 7230].
447 ///
448 /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.3.3
449 enum HttpMessageLength {
450         Empty,
451         ContentLength(usize),
452         TransferEncoding(String),
453 }
454
455 /// An HTTP response body in binary format.
456 pub struct BinaryResponse(pub Vec<u8>);
457
458 /// An HTTP response body in JSON format.
459 pub struct JsonResponse(pub serde_json::Value);
460
461 /// Interprets bytes from an HTTP response body as binary data.
462 impl TryFrom<Vec<u8>> for BinaryResponse {
463         type Error = std::io::Error;
464
465         fn try_from(bytes: Vec<u8>) -> std::io::Result<Self> {
466                 Ok(BinaryResponse(bytes))
467         }
468 }
469
470 /// Interprets bytes from an HTTP response body as a JSON value.
471 impl TryFrom<Vec<u8>> for JsonResponse {
472         type Error = std::io::Error;
473
474         fn try_from(bytes: Vec<u8>) -> std::io::Result<Self> {
475                 Ok(JsonResponse(serde_json::from_slice(&bytes)?))
476         }
477 }
478
479 #[cfg(test)]
480 mod endpoint_tests {
481         use super::HttpEndpoint;
482
483         #[test]
484         fn with_default_port() {
485                 let endpoint = HttpEndpoint::for_host("foo.com".into());
486                 assert_eq!(endpoint.host(), "foo.com");
487                 assert_eq!(endpoint.port(), 80);
488         }
489
490         #[test]
491         fn with_custom_port() {
492                 let endpoint = HttpEndpoint::for_host("foo.com".into()).with_port(8080);
493                 assert_eq!(endpoint.host(), "foo.com");
494                 assert_eq!(endpoint.port(), 8080);
495         }
496
497         #[test]
498         fn with_uri_path() {
499                 let endpoint = HttpEndpoint::for_host("foo.com".into()).with_path("/path".into());
500                 assert_eq!(endpoint.host(), "foo.com");
501                 assert_eq!(endpoint.path(), "/path");
502         }
503
504         #[test]
505         fn without_uri_path() {
506                 let endpoint = HttpEndpoint::for_host("foo.com".into());
507                 assert_eq!(endpoint.host(), "foo.com");
508                 assert_eq!(endpoint.path(), "/");
509         }
510
511         #[test]
512         fn convert_to_socket_addrs() {
513                 let endpoint = HttpEndpoint::for_host("foo.com".into());
514                 let host = endpoint.host();
515                 let port = endpoint.port();
516
517                 use std::net::ToSocketAddrs;
518                 match (&endpoint).to_socket_addrs() {
519                         Err(e) => panic!("Unexpected error: {:?}", e),
520                         Ok(mut socket_addrs) => {
521                                 match socket_addrs.next() {
522                                         None => panic!("Expected socket address"),
523                                         Some(addr) => {
524                                                 assert_eq!(addr, (host, port).to_socket_addrs().unwrap().next().unwrap());
525                                                 assert!(socket_addrs.next().is_none());
526                                         }
527                                 }
528                         }
529                 }
530         }
531 }
532
533 #[cfg(test)]
534 pub(crate) mod client_tests {
535         use super::*;
536         use std::io::BufRead;
537         use std::io::Write;
538
539         /// Server for handling HTTP client requests with a stock response.
540         pub struct HttpServer {
541                 address: std::net::SocketAddr,
542                 handler: std::thread::JoinHandle<()>,
543                 shutdown: std::sync::Arc<std::sync::atomic::AtomicBool>,
544         }
545
546         /// Body of HTTP response messages.
547         pub enum MessageBody<T: ToString> {
548                 Empty,
549                 Content(T),
550                 ChunkedContent(T),
551         }
552
553         impl HttpServer {
554                 fn responding_with_body<T: ToString>(status: &str, body: MessageBody<T>) -> Self {
555                         let response = match body {
556                                 MessageBody::Empty => format!("{}\r\n\r\n", status),
557                                 MessageBody::Content(body) => {
558                                         let body = body.to_string();
559                                         format!(
560                                                 "{}\r\n\
561                                                  Content-Length: {}\r\n\
562                                                  \r\n\
563                                                  {}", status, body.len(), body)
564                                 },
565                                 MessageBody::ChunkedContent(body) => {
566                                         let mut chuncked_body = Vec::new();
567                                         {
568                                                 use chunked_transfer::Encoder;
569                                                 let mut encoder = Encoder::with_chunks_size(&mut chuncked_body, 8);
570                                                 encoder.write_all(body.to_string().as_bytes()).unwrap();
571                                         }
572                                         format!(
573                                                 "{}\r\n\
574                                                  Transfer-Encoding: chunked\r\n\
575                                                  \r\n\
576                                                  {}", status, String::from_utf8(chuncked_body).unwrap())
577                                 },
578                         };
579                         HttpServer::responding_with(response)
580                 }
581
582                 pub fn responding_with_ok<T: ToString>(body: MessageBody<T>) -> Self {
583                         HttpServer::responding_with_body("HTTP/1.1 200 OK", body)
584                 }
585
586                 pub fn responding_with_not_found() -> Self {
587                         HttpServer::responding_with_body::<String>("HTTP/1.1 404 Not Found", MessageBody::Empty)
588                 }
589
590                 pub fn responding_with_server_error<T: ToString>(content: T) -> Self {
591                         let body = MessageBody::Content(content);
592                         HttpServer::responding_with_body("HTTP/1.1 500 Internal Server Error", body)
593                 }
594
595                 fn responding_with(response: String) -> Self {
596                         let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
597                         let address = listener.local_addr().unwrap();
598
599                         let shutdown = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
600                         let shutdown_signaled = std::sync::Arc::clone(&shutdown);
601                         let handler = std::thread::spawn(move || {
602                                 for stream in listener.incoming() {
603                                         let mut stream = stream.unwrap();
604                                         stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT)).unwrap();
605
606                                         let lines_read = std::io::BufReader::new(&stream)
607                                                 .lines()
608                                                 .take_while(|line| !line.as_ref().unwrap().is_empty())
609                                                 .count();
610                                         if lines_read == 0 { continue; }
611
612                                         for chunk in response.as_bytes().chunks(16) {
613                                                 if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) {
614                                                         return;
615                                                 } else {
616                                                         if let Err(_) = stream.write(chunk) { break; }
617                                                         if let Err(_) = stream.flush() { break; }
618                                                 }
619                                         }
620                                 }
621                         });
622
623                         Self { address, handler, shutdown }
624                 }
625
626                 fn shutdown(self) {
627                         self.shutdown.store(true, std::sync::atomic::Ordering::SeqCst);
628                         self.handler.join().unwrap();
629                 }
630
631                 pub fn endpoint(&self) -> HttpEndpoint {
632                         HttpEndpoint::for_host(self.address.ip().to_string()).with_port(self.address.port())
633                 }
634         }
635
636         #[test]
637         fn connect_to_unresolvable_host() {
638                 match HttpClient::connect(("example.invalid", 80)) {
639                         Err(e) => {
640                                 assert!(e.to_string().contains("failed to lookup address information") ||
641                                         e.to_string().contains("No such host"), "{:?}", e);
642                         },
643                         Ok(_) => panic!("Expected error"),
644                 }
645         }
646
647         #[test]
648         fn connect_with_no_socket_address() {
649                 match HttpClient::connect(&vec![][..]) {
650                         Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput),
651                         Ok(_) => panic!("Expected error"),
652                 }
653         }
654
655         #[test]
656         fn connect_with_unknown_server() {
657                 match HttpClient::connect(("::", 80)) {
658                         #[cfg(target_os = "windows")]
659                         Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::AddrNotAvailable),
660                         #[cfg(not(target_os = "windows"))]
661                         Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::ConnectionRefused),
662                         Ok(_) => panic!("Expected error"),
663                 }
664         }
665
666         #[tokio::test]
667         async fn connect_with_valid_endpoint() {
668                 let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
669
670                 match HttpClient::connect(&server.endpoint()) {
671                         Err(e) => panic!("Unexpected error: {:?}", e),
672                         Ok(_) => {},
673                 }
674         }
675
676         #[tokio::test]
677         async fn read_empty_message() {
678                 let server = HttpServer::responding_with("".to_string());
679
680                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
681                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
682                         Err(e) => {
683                                 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
684                                 assert_eq!(e.get_ref().unwrap().to_string(), "no status line");
685                         },
686                         Ok(_) => panic!("Expected error"),
687                 }
688         }
689
690         #[tokio::test]
691         async fn read_incomplete_message() {
692                 let server = HttpServer::responding_with("HTTP/1.1 200 OK".to_string());
693
694                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
695                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
696                         Err(e) => {
697                                 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
698                                 assert_eq!(e.get_ref().unwrap().to_string(), "no headers");
699                         },
700                         Ok(_) => panic!("Expected error"),
701                 }
702         }
703
704         #[tokio::test]
705         async fn read_too_large_message_headers() {
706                 let response = format!(
707                         "HTTP/1.1 302 Found\r\n\
708                          Location: {}\r\n\
709                          \r\n", "Z".repeat(MAX_HTTP_MESSAGE_HEADER_SIZE));
710                 let server = HttpServer::responding_with(response);
711
712                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
713                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
714                         Err(e) => {
715                                 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
716                                 assert_eq!(e.get_ref().unwrap().to_string(), "no headers");
717                         },
718                         Ok(_) => panic!("Expected error"),
719                 }
720         }
721
722         #[tokio::test]
723         async fn read_too_large_message_body() {
724                 let body = "Z".repeat(MAX_HTTP_MESSAGE_BODY_SIZE + 1);
725                 let server = HttpServer::responding_with_ok::<String>(MessageBody::Content(body));
726
727                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
728                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
729                         Err(e) => {
730                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
731                                 assert_eq!(e.get_ref().unwrap().to_string(), "out of range");
732                         },
733                         Ok(_) => panic!("Expected error"),
734                 }
735                 server.shutdown();
736         }
737
738         #[tokio::test]
739         async fn read_message_with_unsupported_transfer_coding() {
740                 let response = String::from(
741                         "HTTP/1.1 200 OK\r\n\
742                          Transfer-Encoding: gzip\r\n\
743                          \r\n\
744                          foobar");
745                 let server = HttpServer::responding_with(response);
746
747                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
748                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
749                         Err(e) => {
750                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput);
751                                 assert_eq!(e.get_ref().unwrap().to_string(), "unsupported transfer coding");
752                         },
753                         Ok(_) => panic!("Expected error"),
754                 }
755         }
756
757         #[tokio::test]
758         async fn read_error() {
759                 let server = HttpServer::responding_with_server_error("foo");
760
761                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
762                 match client.get::<JsonResponse>("/foo", "foo.com").await {
763                         Err(e) => {
764                                 assert_eq!(e.kind(), std::io::ErrorKind::Other);
765                                 let http_error = e.into_inner().unwrap().downcast::<HttpError>().unwrap();
766                                 assert_eq!(http_error.status_code, "500");
767                                 assert_eq!(http_error.contents, "foo".as_bytes());
768                         },
769                         Ok(_) => panic!("Expected error"),
770                 }
771         }
772
773         #[tokio::test]
774         async fn read_empty_message_body() {
775                 let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
776
777                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
778                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
779                         Err(e) => panic!("Unexpected error: {:?}", e),
780                         Ok(bytes) => assert_eq!(bytes.0, Vec::<u8>::new()),
781                 }
782         }
783
784         #[tokio::test]
785         async fn read_message_body_with_length() {
786                 let body = "foo bar baz qux".repeat(32);
787                 let content = MessageBody::Content(body.clone());
788                 let server = HttpServer::responding_with_ok::<String>(content);
789
790                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
791                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
792                         Err(e) => panic!("Unexpected error: {:?}", e),
793                         Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()),
794                 }
795         }
796
797         #[tokio::test]
798         async fn read_chunked_message_body() {
799                 let body = "foo bar baz qux".repeat(32);
800                 let chunked_content = MessageBody::ChunkedContent(body.clone());
801                 let server = HttpServer::responding_with_ok::<String>(chunked_content);
802
803                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
804                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
805                         Err(e) => panic!("Unexpected error: {:?}", e),
806                         Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()),
807                 }
808         }
809
810         #[tokio::test]
811         async fn reconnect_closed_connection() {
812                 let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
813
814                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
815                 assert!(client.get::<BinaryResponse>("/foo", "foo.com").await.is_ok());
816                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
817                         Err(e) => panic!("Unexpected error: {:?}", e),
818                         Ok(bytes) => assert_eq!(bytes.0, Vec::<u8>::new()),
819                 }
820         }
821
822         #[test]
823         fn from_bytes_into_binary_response() {
824                 let bytes = b"foo";
825                 match BinaryResponse::try_from(bytes.to_vec()) {
826                         Err(e) => panic!("Unexpected error: {:?}", e),
827                         Ok(response) => assert_eq!(&response.0, bytes),
828                 }
829         }
830
831         #[test]
832         fn from_invalid_bytes_into_json_response() {
833                 let json = serde_json::json!({ "result": 42 });
834                 match JsonResponse::try_from(json.to_string().as_bytes()[..5].to_vec()) {
835                         Err(_) => {},
836                         Ok(_) => panic!("Expected error"),
837                 }
838         }
839
840         #[test]
841         fn from_valid_bytes_into_json_response() {
842                 let json = serde_json::json!({ "result": 42 });
843                 match JsonResponse::try_from(json.to_string().as_bytes().to_vec()) {
844                         Err(e) => panic!("Unexpected error: {:?}", e),
845                         Ok(response) => assert_eq!(response.0, json),
846                 }
847         }
848 }