a3935edf5cb6e9880217196112d8f4ead5bacd57
[rust-lightning] / lightning-block-sync / src / http.rs
1 //! Simple HTTP implementation which supports both async and traditional execution environments
2 //! with minimal dependencies. This is used as the basis for REST and RPC clients.
3
4 use chunked_transfer;
5 use serde_json;
6
7 use std::convert::TryFrom;
8 use std::fmt;
9 #[cfg(not(feature = "tokio"))]
10 use std::io::Write;
11 use std::net::{SocketAddr, ToSocketAddrs};
12 use std::time::Duration;
13
14 #[cfg(feature = "tokio")]
15 use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
16 #[cfg(feature = "tokio")]
17 use tokio::net::TcpStream;
18
19 #[cfg(not(feature = "tokio"))]
20 use std::io::BufRead;
21 use std::io::Read;
22 #[cfg(not(feature = "tokio"))]
23 use std::net::TcpStream;
24
25 /// Timeout for operations on TCP streams.
26 const TCP_STREAM_TIMEOUT: Duration = Duration::from_secs(5);
27
28 /// Timeout for reading the first byte of a response. This is separate from the general read
29 /// timeout as it is not uncommon for Bitcoin Core to be blocked waiting on UTXO cache flushes for
30 /// upwards of 10 minutes on slow devices (e.g. RPis with SSDs over USB). Note that we always retry
31 /// once when we time out, so the maximum time we allow Bitcoin Core to block for is twice this
32 /// value.
33 const TCP_STREAM_RESPONSE_TIMEOUT: Duration = Duration::from_secs(300);
34
35 /// Maximum HTTP message header size in bytes.
36 const MAX_HTTP_MESSAGE_HEADER_SIZE: usize = 8192;
37
38 /// Maximum HTTP message body size in bytes. Enough for a hex-encoded block in JSON format and any
39 /// overhead for HTTP chunked transfer encoding.
40 const MAX_HTTP_MESSAGE_BODY_SIZE: usize = 2 * 4_000_000 + 32_000;
41
42 /// Endpoint for interacting with an HTTP-based API.
43 #[derive(Debug)]
44 pub struct HttpEndpoint {
45         host: String,
46         port: Option<u16>,
47         path: String,
48 }
49
50 impl HttpEndpoint {
51         /// Creates an endpoint for the given host and default HTTP port.
52         pub fn for_host(host: String) -> Self {
53                 Self {
54                         host,
55                         port: None,
56                         path: String::from("/"),
57                 }
58         }
59
60         /// Specifies a port to use with the endpoint.
61         pub fn with_port(mut self, port: u16) -> Self {
62                 self.port = Some(port);
63                 self
64         }
65
66         /// Specifies a path to use with the endpoint.
67         pub fn with_path(mut self, path: String) -> Self {
68                 self.path = path;
69                 self
70         }
71
72         /// Returns the endpoint host.
73         pub fn host(&self) -> &str {
74                 &self.host
75         }
76
77         /// Returns the endpoint port.
78         pub fn port(&self) -> u16 {
79                 match self.port {
80                         None => 80,
81                         Some(port) => port,
82                 }
83         }
84
85         /// Returns the endpoint path.
86         pub fn path(&self) -> &str {
87                 &self.path
88         }
89 }
90
91 impl<'a> std::net::ToSocketAddrs for &'a HttpEndpoint {
92         type Iter = <(&'a str, u16) as std::net::ToSocketAddrs>::Iter;
93
94         fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
95                 (self.host(), self.port()).to_socket_addrs()
96         }
97 }
98
99 /// Client for making HTTP requests.
100 pub(crate) struct HttpClient {
101         address: SocketAddr,
102         stream: TcpStream,
103 }
104
105 impl HttpClient {
106         /// Opens a connection to an HTTP endpoint.
107         pub fn connect<E: ToSocketAddrs>(endpoint: E) -> std::io::Result<Self> {
108                 let address = match endpoint.to_socket_addrs()?.next() {
109                         None => {
110                                 return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "could not resolve to any addresses"));
111                         },
112                         Some(address) => address,
113                 };
114                 let stream = std::net::TcpStream::connect_timeout(&address, TCP_STREAM_TIMEOUT)?;
115                 stream.set_read_timeout(Some(TCP_STREAM_TIMEOUT))?;
116                 stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT))?;
117
118                 #[cfg(feature = "tokio")]
119                 let stream = {
120                         stream.set_nonblocking(true)?;
121                         TcpStream::from_std(stream)?
122                 };
123
124                 Ok(Self { address, stream })
125         }
126
127         /// Sends a `GET` request for a resource identified by `uri` at the `host`.
128         ///
129         /// Returns the response body in `F` format.
130         #[allow(dead_code)]
131         pub async fn get<F>(&mut self, uri: &str, host: &str) -> std::io::Result<F>
132         where F: TryFrom<Vec<u8>, Error = std::io::Error> {
133                 let request = format!(
134                         "GET {} HTTP/1.1\r\n\
135                          Host: {}\r\n\
136                          Connection: keep-alive\r\n\
137                          \r\n", uri, host);
138                 let response_body = self.send_request_with_retry(&request).await?;
139                 F::try_from(response_body)
140         }
141
142         /// Sends a `POST` request for a resource identified by `uri` at the `host` using the given HTTP
143         /// authentication credentials.
144         ///
145         /// The request body consists of the provided JSON `content`. Returns the response body in `F`
146         /// format.
147         #[allow(dead_code)]
148         pub async fn post<F>(&mut self, uri: &str, host: &str, auth: &str, content: serde_json::Value) -> std::io::Result<F>
149         where F: TryFrom<Vec<u8>, Error = std::io::Error> {
150                 let content = content.to_string();
151                 let request = format!(
152                         "POST {} HTTP/1.1\r\n\
153                          Host: {}\r\n\
154                          Authorization: {}\r\n\
155                          Connection: keep-alive\r\n\
156                          Content-Type: application/json\r\n\
157                          Content-Length: {}\r\n\
158                          \r\n\
159                          {}", uri, host, auth, content.len(), content);
160                 let response_body = self.send_request_with_retry(&request).await?;
161                 F::try_from(response_body)
162         }
163
164         /// Sends an HTTP request message and reads the response, returning its body. Attempts to
165         /// reconnect and retry if the connection has been closed.
166         async fn send_request_with_retry(&mut self, request: &str) -> std::io::Result<Vec<u8>> {
167                 match self.send_request(request).await {
168                         Ok(bytes) => Ok(bytes),
169                         Err(_) => {
170                                 // Reconnect and retry on fail. This can happen if the connection was closed after
171                                 // the keep-alive limits are reached, or generally if the request timed out due to
172                                 // Bitcoin Core being stuck on a long-running operation or its RPC queue being
173                                 // full.
174                                 // Block 100ms before retrying the request as in many cases the source of the error
175                                 // may be persistent for some time.
176                                 #[cfg(feature = "tokio")]
177                                 tokio::time::sleep(Duration::from_millis(100)).await;
178                                 #[cfg(not(feature = "tokio"))]
179                                 std::thread::sleep(Duration::from_millis(100));
180                                 *self = Self::connect(self.address)?;
181                                 self.send_request(request).await
182                         },
183                 }
184         }
185
186         /// Sends an HTTP request message and reads the response, returning its body.
187         async fn send_request(&mut self, request: &str) -> std::io::Result<Vec<u8>> {
188                 self.write_request(request).await?;
189                 self.read_response().await
190         }
191
192         /// Writes an HTTP request message.
193         async fn write_request(&mut self, request: &str) -> std::io::Result<()> {
194                 #[cfg(feature = "tokio")]
195                 {
196                         self.stream.write_all(request.as_bytes()).await?;
197                         self.stream.flush().await
198                 }
199                 #[cfg(not(feature = "tokio"))]
200                 {
201                         self.stream.write_all(request.as_bytes())?;
202                         self.stream.flush()
203                 }
204         }
205
206         /// Reads an HTTP response message.
207         async fn read_response(&mut self) -> std::io::Result<Vec<u8>> {
208                 #[cfg(feature = "tokio")]
209                 let stream = self.stream.split().0;
210                 #[cfg(not(feature = "tokio"))]
211                 let stream = std::io::Read::by_ref(&mut self.stream);
212
213                 let limited_stream = stream.take(MAX_HTTP_MESSAGE_HEADER_SIZE as u64);
214
215                 #[cfg(feature = "tokio")]
216                 let mut reader = tokio::io::BufReader::new(limited_stream);
217                 #[cfg(not(feature = "tokio"))]
218                 let mut reader = std::io::BufReader::new(limited_stream);
219
220                 macro_rules! read_line {
221                         () => { read_line!(0) };
222                         ($retry_count: expr) => { {
223                                 let mut line = String::new();
224                                 let mut timeout_count: u64 = 0;
225                                 let bytes_read = loop {
226                                         #[cfg(feature = "tokio")]
227                                         let read_res = reader.read_line(&mut line).await;
228                                         #[cfg(not(feature = "tokio"))]
229                                         let read_res = reader.read_line(&mut line);
230                                         match read_res {
231                                                 Ok(bytes_read) => break bytes_read,
232                                                 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
233                                                         timeout_count += 1;
234                                                         if timeout_count > $retry_count {
235                                                                 return Err(e);
236                                                         } else {
237                                                                 continue;
238                                                         }
239                                                 }
240                                                 Err(e) => return Err(e),
241                                         }
242                                 };
243
244                                 match bytes_read {
245                                         0 => None,
246                                         _ => {
247                                                 // Remove trailing CRLF
248                                                 if line.ends_with('\n') { line.pop(); if line.ends_with('\r') { line.pop(); } }
249                                                 Some(line)
250                                         },
251                                 }
252                         } }
253                 }
254
255                 // Read and parse status line
256                 // Note that we allow retrying a few times to reach TCP_STREAM_RESPONSE_TIMEOUT.
257                 let status_line = read_line!(TCP_STREAM_RESPONSE_TIMEOUT.as_secs() / TCP_STREAM_TIMEOUT.as_secs())
258                         .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no status line"))?;
259                 let status = HttpStatus::parse(&status_line)?;
260
261                 // Read and parse relevant headers
262                 let mut message_length = HttpMessageLength::Empty;
263                 loop {
264                         let line = read_line!()
265                                 .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no headers"))?;
266                         if line.is_empty() { break; }
267
268                         let header = HttpHeader::parse(&line)?;
269                         if header.has_name("Content-Length") {
270                                 let length = header.value.parse()
271                                         .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
272                                 if let HttpMessageLength::Empty = message_length {
273                                         message_length = HttpMessageLength::ContentLength(length);
274                                 }
275                                 continue;
276                         }
277
278                         if header.has_name("Transfer-Encoding") {
279                                 message_length = HttpMessageLength::TransferEncoding(header.value.into());
280                                 continue;
281                         }
282                 }
283
284                 // Read message body
285                 let read_limit = MAX_HTTP_MESSAGE_BODY_SIZE - reader.buffer().len();
286                 reader.get_mut().set_limit(read_limit as u64);
287                 let contents = match message_length {
288                         HttpMessageLength::Empty => { Vec::new() },
289                         HttpMessageLength::ContentLength(length) => {
290                                 if length == 0 || length > MAX_HTTP_MESSAGE_BODY_SIZE {
291                                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "out of range"))
292                                 } else {
293                                         let mut content = vec![0; length];
294                                         #[cfg(feature = "tokio")]
295                                         reader.read_exact(&mut content[..]).await?;
296                                         #[cfg(not(feature = "tokio"))]
297                                         reader.read_exact(&mut content[..])?;
298                                         content
299                                 }
300                         },
301                         HttpMessageLength::TransferEncoding(coding) => {
302                                 if !coding.eq_ignore_ascii_case("chunked") {
303                                         return Err(std::io::Error::new(
304                                                 std::io::ErrorKind::InvalidInput, "unsupported transfer coding"))
305                                 } else {
306                                         let mut content = Vec::new();
307                                         #[cfg(feature = "tokio")]
308                                         {
309                                                 // Since chunked_transfer doesn't have an async interface, only use it to
310                                                 // determine the size of each chunk to read.
311                                                 //
312                                                 // TODO: Replace with an async interface when available.
313                                                 // https://github.com/frewsxcv/rust-chunked-transfer/issues/7
314                                                 loop {
315                                                         // Read the chunk header which contains the chunk size.
316                                                         let mut chunk_header = String::new();
317                                                         reader.read_line(&mut chunk_header).await?;
318                                                         if chunk_header == "0\r\n" {
319                                                                 // Read the terminator chunk since the decoder consumes the CRLF
320                                                                 // immediately when this chunk is encountered.
321                                                                 reader.read_line(&mut chunk_header).await?;
322                                                         }
323
324                                                         // Decode the chunk header to obtain the chunk size.
325                                                         let mut buffer = Vec::new();
326                                                         let mut decoder = chunked_transfer::Decoder::new(chunk_header.as_bytes());
327                                                         decoder.read_to_end(&mut buffer)?;
328
329                                                         // Read the chunk body.
330                                                         let chunk_size = match decoder.remaining_chunks_size() {
331                                                                 None => break,
332                                                                 Some(chunk_size) => chunk_size,
333                                                         };
334                                                         let chunk_offset = content.len();
335                                                         content.resize(chunk_offset + chunk_size + "\r\n".len(), 0);
336                                                         reader.read_exact(&mut content[chunk_offset..]).await?;
337                                                         content.resize(chunk_offset + chunk_size, 0);
338                                                 }
339                                                 content
340                                         }
341                                         #[cfg(not(feature = "tokio"))]
342                                         {
343                                                 let mut decoder = chunked_transfer::Decoder::new(reader);
344                                                 decoder.read_to_end(&mut content)?;
345                                                 content
346                                         }
347                                 }
348                         },
349                 };
350
351                 if !status.is_ok() {
352                         // TODO: Handle 3xx redirection responses.
353                         let error = HttpError {
354                                 status_code: status.code.to_string(),
355                                 contents,
356                         };
357                         return Err(std::io::Error::new(std::io::ErrorKind::Other, error));
358                 }
359
360                 Ok(contents)
361         }
362 }
363
364 /// HTTP error consisting of a status code and body contents.
365 #[derive(Debug)]
366 pub(crate) struct HttpError {
367         pub(crate) status_code: String,
368         pub(crate) contents: Vec<u8>,
369 }
370
371 impl std::error::Error for HttpError {}
372
373 impl fmt::Display for HttpError {
374         fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
375                 let contents = String::from_utf8_lossy(&self.contents);
376                 write!(f, "status_code: {}, contents: {}", self.status_code, contents)
377         }
378 }
379
380 /// HTTP response status code as defined by [RFC 7231].
381 ///
382 /// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-6
383 struct HttpStatus<'a> {
384         code: &'a str,
385 }
386
387 impl<'a> HttpStatus<'a> {
388         /// Parses an HTTP status line as defined by [RFC 7230].
389         ///
390         /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.1.2
391         fn parse(line: &'a String) -> std::io::Result<HttpStatus<'a>> {
392                 let mut tokens = line.splitn(3, ' ');
393
394                 let http_version = tokens.next()
395                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no HTTP-Version"))?;
396                 if !http_version.eq_ignore_ascii_case("HTTP/1.1") &&
397                         !http_version.eq_ignore_ascii_case("HTTP/1.0") {
398                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid HTTP-Version"));
399                 }
400
401                 let code = tokens.next()
402                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Status-Code"))?;
403                 if code.len() != 3 || !code.chars().all(|c| c.is_ascii_digit()) {
404                         return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid Status-Code"));
405                 }
406
407                 let _reason = tokens.next()
408                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Reason-Phrase"))?;
409
410                 Ok(Self { code })
411         }
412
413         /// Returns whether the status is successful (i.e., 2xx status class).
414         fn is_ok(&self) -> bool {
415                 self.code.starts_with('2')
416         }
417 }
418
419 /// HTTP response header as defined by [RFC 7231].
420 ///
421 /// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-7
422 struct HttpHeader<'a> {
423         name: &'a str,
424         value: &'a str,
425 }
426
427 impl<'a> HttpHeader<'a> {
428         /// Parses an HTTP header field as defined by [RFC 7230].
429         ///
430         /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.2
431         fn parse(line: &'a String) -> std::io::Result<HttpHeader<'a>> {
432                 let mut tokens = line.splitn(2, ':');
433                 let name = tokens.next()
434                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header name"))?;
435                 let value = tokens.next()
436                         .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header value"))?
437                         .trim_start();
438                 Ok(Self { name, value })
439         }
440
441         /// Returns whether the header field has the given name.
442         fn has_name(&self, name: &str) -> bool {
443                 self.name.eq_ignore_ascii_case(name)
444         }
445 }
446
447 /// HTTP message body length as defined by [RFC 7230].
448 ///
449 /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.3.3
450 enum HttpMessageLength {
451         Empty,
452         ContentLength(usize),
453         TransferEncoding(String),
454 }
455
456 /// An HTTP response body in binary format.
457 pub struct BinaryResponse(pub Vec<u8>);
458
459 /// An HTTP response body in JSON format.
460 pub struct JsonResponse(pub serde_json::Value);
461
462 /// Interprets bytes from an HTTP response body as binary data.
463 impl TryFrom<Vec<u8>> for BinaryResponse {
464         type Error = std::io::Error;
465
466         fn try_from(bytes: Vec<u8>) -> std::io::Result<Self> {
467                 Ok(BinaryResponse(bytes))
468         }
469 }
470
471 /// Interprets bytes from an HTTP response body as a JSON value.
472 impl TryFrom<Vec<u8>> for JsonResponse {
473         type Error = std::io::Error;
474
475         fn try_from(bytes: Vec<u8>) -> std::io::Result<Self> {
476                 Ok(JsonResponse(serde_json::from_slice(&bytes)?))
477         }
478 }
479
480 #[cfg(test)]
481 mod endpoint_tests {
482         use super::HttpEndpoint;
483
484         #[test]
485         fn with_default_port() {
486                 let endpoint = HttpEndpoint::for_host("foo.com".into());
487                 assert_eq!(endpoint.host(), "foo.com");
488                 assert_eq!(endpoint.port(), 80);
489         }
490
491         #[test]
492         fn with_custom_port() {
493                 let endpoint = HttpEndpoint::for_host("foo.com".into()).with_port(8080);
494                 assert_eq!(endpoint.host(), "foo.com");
495                 assert_eq!(endpoint.port(), 8080);
496         }
497
498         #[test]
499         fn with_uri_path() {
500                 let endpoint = HttpEndpoint::for_host("foo.com".into()).with_path("/path".into());
501                 assert_eq!(endpoint.host(), "foo.com");
502                 assert_eq!(endpoint.path(), "/path");
503         }
504
505         #[test]
506         fn without_uri_path() {
507                 let endpoint = HttpEndpoint::for_host("foo.com".into());
508                 assert_eq!(endpoint.host(), "foo.com");
509                 assert_eq!(endpoint.path(), "/");
510         }
511
512         #[test]
513         fn convert_to_socket_addrs() {
514                 let endpoint = HttpEndpoint::for_host("foo.com".into());
515                 let host = endpoint.host();
516                 let port = endpoint.port();
517
518                 use std::net::ToSocketAddrs;
519                 match (&endpoint).to_socket_addrs() {
520                         Err(e) => panic!("Unexpected error: {:?}", e),
521                         Ok(mut socket_addrs) => {
522                                 match socket_addrs.next() {
523                                         None => panic!("Expected socket address"),
524                                         Some(addr) => {
525                                                 assert_eq!(addr, (host, port).to_socket_addrs().unwrap().next().unwrap());
526                                                 assert!(socket_addrs.next().is_none());
527                                         }
528                                 }
529                         }
530                 }
531         }
532 }
533
534 #[cfg(test)]
535 pub(crate) mod client_tests {
536         use super::*;
537         use std::io::BufRead;
538         use std::io::Write;
539
540         /// Server for handling HTTP client requests with a stock response.
541         pub struct HttpServer {
542                 address: std::net::SocketAddr,
543                 handler: std::thread::JoinHandle<()>,
544                 shutdown: std::sync::Arc<std::sync::atomic::AtomicBool>,
545         }
546
547         /// Body of HTTP response messages.
548         pub enum MessageBody<T: ToString> {
549                 Empty,
550                 Content(T),
551                 ChunkedContent(T),
552         }
553
554         impl HttpServer {
555                 fn responding_with_body<T: ToString>(status: &str, body: MessageBody<T>) -> Self {
556                         let response = match body {
557                                 MessageBody::Empty => format!("{}\r\n\r\n", status),
558                                 MessageBody::Content(body) => {
559                                         let body = body.to_string();
560                                         format!(
561                                                 "{}\r\n\
562                                                  Content-Length: {}\r\n\
563                                                  \r\n\
564                                                  {}", status, body.len(), body)
565                                 },
566                                 MessageBody::ChunkedContent(body) => {
567                                         let mut chuncked_body = Vec::new();
568                                         {
569                                                 use chunked_transfer::Encoder;
570                                                 let mut encoder = Encoder::with_chunks_size(&mut chuncked_body, 8);
571                                                 encoder.write_all(body.to_string().as_bytes()).unwrap();
572                                         }
573                                         format!(
574                                                 "{}\r\n\
575                                                  Transfer-Encoding: chunked\r\n\
576                                                  \r\n\
577                                                  {}", status, String::from_utf8(chuncked_body).unwrap())
578                                 },
579                         };
580                         HttpServer::responding_with(response)
581                 }
582
583                 pub fn responding_with_ok<T: ToString>(body: MessageBody<T>) -> Self {
584                         HttpServer::responding_with_body("HTTP/1.1 200 OK", body)
585                 }
586
587                 pub fn responding_with_not_found() -> Self {
588                         HttpServer::responding_with_body::<String>("HTTP/1.1 404 Not Found", MessageBody::Empty)
589                 }
590
591                 pub fn responding_with_server_error<T: ToString>(content: T) -> Self {
592                         let body = MessageBody::Content(content);
593                         HttpServer::responding_with_body("HTTP/1.1 500 Internal Server Error", body)
594                 }
595
596                 fn responding_with(response: String) -> Self {
597                         let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
598                         let address = listener.local_addr().unwrap();
599
600                         let shutdown = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
601                         let shutdown_signaled = std::sync::Arc::clone(&shutdown);
602                         let handler = std::thread::spawn(move || {
603                                 for stream in listener.incoming() {
604                                         let mut stream = stream.unwrap();
605                                         stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT)).unwrap();
606
607                                         let lines_read = std::io::BufReader::new(&stream)
608                                                 .lines()
609                                                 .take_while(|line| !line.as_ref().unwrap().is_empty())
610                                                 .count();
611                                         if lines_read == 0 { continue; }
612
613                                         for chunk in response.as_bytes().chunks(16) {
614                                                 if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) {
615                                                         return;
616                                                 } else {
617                                                         if let Err(_) = stream.write(chunk) { break; }
618                                                         if let Err(_) = stream.flush() { break; }
619                                                 }
620                                         }
621                                 }
622                         });
623
624                         Self { address, handler, shutdown }
625                 }
626
627                 fn shutdown(self) {
628                         self.shutdown.store(true, std::sync::atomic::Ordering::SeqCst);
629                         self.handler.join().unwrap();
630                 }
631
632                 pub fn endpoint(&self) -> HttpEndpoint {
633                         HttpEndpoint::for_host(self.address.ip().to_string()).with_port(self.address.port())
634                 }
635         }
636
637         #[test]
638         fn connect_to_unresolvable_host() {
639                 match HttpClient::connect(("example.invalid", 80)) {
640                         Err(e) => {
641                                 assert!(e.to_string().contains("failed to lookup address information") ||
642                                         e.to_string().contains("No such host"), "{:?}", e);
643                         },
644                         Ok(_) => panic!("Expected error"),
645                 }
646         }
647
648         #[test]
649         fn connect_with_no_socket_address() {
650                 match HttpClient::connect(&vec![][..]) {
651                         Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput),
652                         Ok(_) => panic!("Expected error"),
653                 }
654         }
655
656         #[test]
657         fn connect_with_unknown_server() {
658                 match HttpClient::connect(("::", 80)) {
659                         #[cfg(target_os = "windows")]
660                         Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::AddrNotAvailable),
661                         #[cfg(not(target_os = "windows"))]
662                         Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::ConnectionRefused),
663                         Ok(_) => panic!("Expected error"),
664                 }
665         }
666
667         #[tokio::test]
668         async fn connect_with_valid_endpoint() {
669                 let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
670
671                 match HttpClient::connect(&server.endpoint()) {
672                         Err(e) => panic!("Unexpected error: {:?}", e),
673                         Ok(_) => {},
674                 }
675         }
676
677         #[tokio::test]
678         async fn read_empty_message() {
679                 let server = HttpServer::responding_with("".to_string());
680
681                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
682                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
683                         Err(e) => {
684                                 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
685                                 assert_eq!(e.get_ref().unwrap().to_string(), "no status line");
686                         },
687                         Ok(_) => panic!("Expected error"),
688                 }
689         }
690
691         #[tokio::test]
692         async fn read_incomplete_message() {
693                 let server = HttpServer::responding_with("HTTP/1.1 200 OK".to_string());
694
695                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
696                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
697                         Err(e) => {
698                                 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
699                                 assert_eq!(e.get_ref().unwrap().to_string(), "no headers");
700                         },
701                         Ok(_) => panic!("Expected error"),
702                 }
703         }
704
705         #[tokio::test]
706         async fn read_too_large_message_headers() {
707                 let response = format!(
708                         "HTTP/1.1 302 Found\r\n\
709                          Location: {}\r\n\
710                          \r\n", "Z".repeat(MAX_HTTP_MESSAGE_HEADER_SIZE));
711                 let server = HttpServer::responding_with(response);
712
713                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
714                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
715                         Err(e) => {
716                                 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
717                                 assert_eq!(e.get_ref().unwrap().to_string(), "no headers");
718                         },
719                         Ok(_) => panic!("Expected error"),
720                 }
721         }
722
723         #[tokio::test]
724         async fn read_too_large_message_body() {
725                 let body = "Z".repeat(MAX_HTTP_MESSAGE_BODY_SIZE + 1);
726                 let server = HttpServer::responding_with_ok::<String>(MessageBody::Content(body));
727
728                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
729                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
730                         Err(e) => {
731                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
732                                 assert_eq!(e.get_ref().unwrap().to_string(), "out of range");
733                         },
734                         Ok(_) => panic!("Expected error"),
735                 }
736                 server.shutdown();
737         }
738
739         #[tokio::test]
740         async fn read_message_with_unsupported_transfer_coding() {
741                 let response = String::from(
742                         "HTTP/1.1 200 OK\r\n\
743                          Transfer-Encoding: gzip\r\n\
744                          \r\n\
745                          foobar");
746                 let server = HttpServer::responding_with(response);
747
748                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
749                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
750                         Err(e) => {
751                                 assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput);
752                                 assert_eq!(e.get_ref().unwrap().to_string(), "unsupported transfer coding");
753                         },
754                         Ok(_) => panic!("Expected error"),
755                 }
756         }
757
758         #[tokio::test]
759         async fn read_error() {
760                 let server = HttpServer::responding_with_server_error("foo");
761
762                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
763                 match client.get::<JsonResponse>("/foo", "foo.com").await {
764                         Err(e) => {
765                                 assert_eq!(e.kind(), std::io::ErrorKind::Other);
766                                 let http_error = e.into_inner().unwrap().downcast::<HttpError>().unwrap();
767                                 assert_eq!(http_error.status_code, "500");
768                                 assert_eq!(http_error.contents, "foo".as_bytes());
769                         },
770                         Ok(_) => panic!("Expected error"),
771                 }
772         }
773
774         #[tokio::test]
775         async fn read_empty_message_body() {
776                 let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
777
778                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
779                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
780                         Err(e) => panic!("Unexpected error: {:?}", e),
781                         Ok(bytes) => assert_eq!(bytes.0, Vec::<u8>::new()),
782                 }
783         }
784
785         #[tokio::test]
786         async fn read_message_body_with_length() {
787                 let body = "foo bar baz qux".repeat(32);
788                 let content = MessageBody::Content(body.clone());
789                 let server = HttpServer::responding_with_ok::<String>(content);
790
791                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
792                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
793                         Err(e) => panic!("Unexpected error: {:?}", e),
794                         Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()),
795                 }
796         }
797
798         #[tokio::test]
799         async fn read_chunked_message_body() {
800                 let body = "foo bar baz qux".repeat(32);
801                 let chunked_content = MessageBody::ChunkedContent(body.clone());
802                 let server = HttpServer::responding_with_ok::<String>(chunked_content);
803
804                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
805                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
806                         Err(e) => panic!("Unexpected error: {:?}", e),
807                         Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()),
808                 }
809         }
810
811         #[tokio::test]
812         async fn reconnect_closed_connection() {
813                 let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
814
815                 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
816                 assert!(client.get::<BinaryResponse>("/foo", "foo.com").await.is_ok());
817                 match client.get::<BinaryResponse>("/foo", "foo.com").await {
818                         Err(e) => panic!("Unexpected error: {:?}", e),
819                         Ok(bytes) => assert_eq!(bytes.0, Vec::<u8>::new()),
820                 }
821         }
822
823         #[test]
824         fn from_bytes_into_binary_response() {
825                 let bytes = b"foo";
826                 match BinaryResponse::try_from(bytes.to_vec()) {
827                         Err(e) => panic!("Unexpected error: {:?}", e),
828                         Ok(response) => assert_eq!(&response.0, bytes),
829                 }
830         }
831
832         #[test]
833         fn from_invalid_bytes_into_json_response() {
834                 let json = serde_json::json!({ "result": 42 });
835                 match JsonResponse::try_from(json.to_string().as_bytes()[..5].to_vec()) {
836                         Err(_) => {},
837                         Ok(_) => panic!("Expected error"),
838                 }
839         }
840
841         #[test]
842         fn from_valid_bytes_into_json_response() {
843                 let json = serde_json::json!({ "result": 42 });
844                 match JsonResponse::try_from(json.to_string().as_bytes().to_vec()) {
845                         Err(e) => panic!("Unexpected error: {:?}", e),
846                         Ok(response) => assert_eq!(response.0, json),
847                 }
848         }
849 }