4 use std::convert::TryFrom;
5 #[cfg(not(feature = "tokio"))]
7 use std::net::ToSocketAddrs;
8 use std::time::Duration;
10 #[cfg(feature = "tokio")]
11 use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
12 #[cfg(feature = "tokio")]
13 use tokio::net::TcpStream;
15 #[cfg(not(feature = "tokio"))]
18 #[cfg(not(feature = "tokio"))]
19 use std::net::TcpStream;
21 /// Timeout for operations on TCP streams.
22 const TCP_STREAM_TIMEOUT: Duration = Duration::from_secs(5);
24 /// Maximum HTTP message header size in bytes.
25 const MAX_HTTP_MESSAGE_HEADER_SIZE: usize = 8192;
27 /// Maximum HTTP message body size in bytes. Enough for a hex-encoded block in JSON format and any
28 /// overhead for HTTP chunked transfer encoding.
29 const MAX_HTTP_MESSAGE_BODY_SIZE: usize = 2 * 4_000_000 + 32_000;
31 /// Endpoint for interacting with an HTTP-based API.
33 pub struct HttpEndpoint {
40 /// Creates an endpoint for the given host and default HTTP port.
41 pub fn for_host(host: String) -> Self {
45 path: String::from("/"),
49 /// Specifies a port to use with the endpoint.
50 pub fn with_port(mut self, port: u16) -> Self {
51 self.port = Some(port);
55 /// Specifies a path to use with the endpoint.
56 pub fn with_path(mut self, path: String) -> Self {
61 /// Returns the endpoint host.
62 pub fn host(&self) -> &str {
66 /// Returns the endpoint port.
67 pub fn port(&self) -> u16 {
74 /// Returns the endpoint path.
75 pub fn path(&self) -> &str {
80 impl<'a> std::net::ToSocketAddrs for &'a HttpEndpoint {
81 type Iter = <(&'a str, u16) as std::net::ToSocketAddrs>::Iter;
83 fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
84 (self.host(), self.port()).to_socket_addrs()
88 /// Client for making HTTP requests.
89 pub(crate) struct HttpClient {
94 /// Opens a connection to an HTTP endpoint.
95 pub fn connect<E: ToSocketAddrs>(endpoint: E) -> std::io::Result<Self> {
96 let address = match endpoint.to_socket_addrs()?.next() {
98 return Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, "could not resolve to any addresses"));
100 Some(address) => address,
102 let stream = std::net::TcpStream::connect_timeout(&address, TCP_STREAM_TIMEOUT)?;
103 stream.set_read_timeout(Some(TCP_STREAM_TIMEOUT))?;
104 stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT))?;
106 #[cfg(feature = "tokio")]
108 stream.set_nonblocking(true)?;
109 TcpStream::from_std(stream)?
115 /// Sends a `GET` request for a resource identified by `uri` at the `host`.
117 /// Returns the response body in `F` format.
119 pub async fn get<F>(&mut self, uri: &str, host: &str) -> std::io::Result<F>
120 where F: TryFrom<Vec<u8>, Error = std::io::Error> {
121 let request = format!(
122 "GET {} HTTP/1.1\r\n\
124 Connection: keep-alive\r\n\
126 let response_body = self.send_request_with_retry(&request).await?;
127 F::try_from(response_body)
130 /// Sends a `POST` request for a resource identified by `uri` at the `host` using the given HTTP
131 /// authentication credentials.
133 /// The request body consists of the provided JSON `content`. Returns the response body in `F`
136 pub async fn post<F>(&mut self, uri: &str, host: &str, auth: &str, content: serde_json::Value) -> std::io::Result<F>
137 where F: TryFrom<Vec<u8>, Error = std::io::Error> {
138 let content = content.to_string();
139 let request = format!(
140 "POST {} HTTP/1.1\r\n\
142 Authorization: {}\r\n\
143 Connection: keep-alive\r\n\
144 Content-Type: application/json\r\n\
145 Content-Length: {}\r\n\
147 {}", uri, host, auth, content.len(), content);
148 let response_body = self.send_request_with_retry(&request).await?;
149 F::try_from(response_body)
152 /// Sends an HTTP request message and reads the response, returning its body. Attempts to
153 /// reconnect and retry if the connection has been closed.
154 async fn send_request_with_retry(&mut self, request: &str) -> std::io::Result<Vec<u8>> {
155 let endpoint = self.stream.peer_addr().unwrap();
156 match self.send_request(request).await {
157 Ok(bytes) => Ok(bytes),
158 Err(e) => match e.kind() {
159 std::io::ErrorKind::ConnectionReset |
160 std::io::ErrorKind::ConnectionAborted |
161 std::io::ErrorKind::UnexpectedEof => {
162 // Reconnect if the connection was closed. This may happen if the server's
163 // keep-alive limits are reached.
164 *self = Self::connect(endpoint)?;
165 self.send_request(request).await
172 /// Sends an HTTP request message and reads the response, returning its body.
173 async fn send_request(&mut self, request: &str) -> std::io::Result<Vec<u8>> {
174 self.write_request(request).await?;
175 self.read_response().await
178 /// Writes an HTTP request message.
179 async fn write_request(&mut self, request: &str) -> std::io::Result<()> {
180 #[cfg(feature = "tokio")]
182 self.stream.write_all(request.as_bytes()).await?;
183 self.stream.flush().await
185 #[cfg(not(feature = "tokio"))]
187 self.stream.write_all(request.as_bytes())?;
192 /// Reads an HTTP response message.
193 async fn read_response(&mut self) -> std::io::Result<Vec<u8>> {
194 #[cfg(feature = "tokio")]
195 let stream = self.stream.split().0;
196 #[cfg(not(feature = "tokio"))]
197 let stream = std::io::Read::by_ref(&mut self.stream);
199 let limited_stream = stream.take(MAX_HTTP_MESSAGE_HEADER_SIZE as u64);
201 #[cfg(feature = "tokio")]
202 let mut reader = tokio::io::BufReader::new(limited_stream);
203 #[cfg(not(feature = "tokio"))]
204 let mut reader = std::io::BufReader::new(limited_stream);
206 macro_rules! read_line { () => { {
207 let mut line = String::new();
208 #[cfg(feature = "tokio")]
209 let bytes_read = reader.read_line(&mut line).await?;
210 #[cfg(not(feature = "tokio"))]
211 let bytes_read = reader.read_line(&mut line)?;
216 // Remove trailing CRLF
217 if line.ends_with('\n') { line.pop(); if line.ends_with('\r') { line.pop(); } }
223 // Read and parse status line
224 let status_line = read_line!()
225 .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no status line"))?;
226 let status = HttpStatus::parse(&status_line)?;
228 // Read and parse relevant headers
229 let mut message_length = HttpMessageLength::Empty;
231 let line = read_line!()
232 .ok_or(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "no headers"))?;
233 if line.is_empty() { break; }
235 let header = HttpHeader::parse(&line)?;
236 if header.has_name("Content-Length") {
237 let length = header.value.parse()
238 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
239 if let HttpMessageLength::Empty = message_length {
240 message_length = HttpMessageLength::ContentLength(length);
245 if header.has_name("Transfer-Encoding") {
246 message_length = HttpMessageLength::TransferEncoding(header.value.into());
252 // TODO: Handle 3xx redirection responses.
253 return Err(std::io::Error::new(std::io::ErrorKind::NotFound, "not found"));
257 let read_limit = MAX_HTTP_MESSAGE_BODY_SIZE - reader.buffer().len();
258 reader.get_mut().set_limit(read_limit as u64);
259 match message_length {
260 HttpMessageLength::Empty => { Ok(Vec::new()) },
261 HttpMessageLength::ContentLength(length) => {
262 if length == 0 || length > MAX_HTTP_MESSAGE_BODY_SIZE {
263 Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "out of range"))
265 let mut content = vec![0; length];
266 #[cfg(feature = "tokio")]
267 reader.read_exact(&mut content[..]).await?;
268 #[cfg(not(feature = "tokio"))]
269 reader.read_exact(&mut content[..])?;
273 HttpMessageLength::TransferEncoding(coding) => {
274 if !coding.eq_ignore_ascii_case("chunked") {
275 Err(std::io::Error::new(
276 std::io::ErrorKind::InvalidInput, "unsupported transfer coding"))
278 let mut content = Vec::new();
279 #[cfg(feature = "tokio")]
281 // Since chunked_transfer doesn't have an async interface, only use it to
282 // determine the size of each chunk to read.
284 // TODO: Replace with an async interface when available.
285 // https://github.com/frewsxcv/rust-chunked-transfer/issues/7
287 // Read the chunk header which contains the chunk size.
288 let mut chunk_header = String::new();
289 reader.read_line(&mut chunk_header).await?;
290 if chunk_header == "0\r\n" {
291 // Read the terminator chunk since the decoder consumes the CRLF
292 // immediately when this chunk is encountered.
293 reader.read_line(&mut chunk_header).await?;
296 // Decode the chunk header to obtain the chunk size.
297 let mut buffer = Vec::new();
298 let mut decoder = chunked_transfer::Decoder::new(chunk_header.as_bytes());
299 decoder.read_to_end(&mut buffer)?;
301 // Read the chunk body.
302 let chunk_size = match decoder.remaining_chunks_size() {
304 Some(chunk_size) => chunk_size,
306 let chunk_offset = content.len();
307 content.resize(chunk_offset + chunk_size + "\r\n".len(), 0);
308 reader.read_exact(&mut content[chunk_offset..]).await?;
309 content.resize(chunk_offset + chunk_size, 0);
313 #[cfg(not(feature = "tokio"))]
315 let mut decoder = chunked_transfer::Decoder::new(reader);
316 decoder.read_to_end(&mut content)?;
325 /// HTTP response status code as defined by [RFC 7231].
327 /// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-6
328 struct HttpStatus<'a> {
332 impl<'a> HttpStatus<'a> {
333 /// Parses an HTTP status line as defined by [RFC 7230].
335 /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.1.2
336 fn parse(line: &'a String) -> std::io::Result<HttpStatus<'a>> {
337 let mut tokens = line.splitn(3, ' ');
339 let http_version = tokens.next()
340 .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no HTTP-Version"))?;
341 if !http_version.eq_ignore_ascii_case("HTTP/1.1") &&
342 !http_version.eq_ignore_ascii_case("HTTP/1.0") {
343 return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid HTTP-Version"));
346 let code = tokens.next()
347 .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Status-Code"))?;
348 if code.len() != 3 || !code.chars().all(|c| c.is_ascii_digit()) {
349 return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid Status-Code"));
352 let _reason = tokens.next()
353 .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no Reason-Phrase"))?;
358 /// Returns whether the status is successful (i.e., 2xx status class).
359 fn is_ok(&self) -> bool {
360 self.code.starts_with('2')
364 /// HTTP response header as defined by [RFC 7231].
366 /// [RFC 7231]: https://tools.ietf.org/html/rfc7231#section-7
367 struct HttpHeader<'a> {
372 impl<'a> HttpHeader<'a> {
373 /// Parses an HTTP header field as defined by [RFC 7230].
375 /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.2
376 fn parse(line: &'a String) -> std::io::Result<HttpHeader<'a>> {
377 let mut tokens = line.splitn(2, ':');
378 let name = tokens.next()
379 .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header name"))?;
380 let value = tokens.next()
381 .ok_or(std::io::Error::new(std::io::ErrorKind::InvalidData, "no header value"))?
383 Ok(Self { name, value })
386 /// Returns whether the header field has the given name.
387 fn has_name(&self, name: &str) -> bool {
388 self.name.eq_ignore_ascii_case(name)
392 /// HTTP message body length as defined by [RFC 7230].
394 /// [RFC 7230]: https://tools.ietf.org/html/rfc7230#section-3.3.3
395 enum HttpMessageLength {
397 ContentLength(usize),
398 TransferEncoding(String),
401 /// An HTTP response body in binary format.
402 pub struct BinaryResponse(pub Vec<u8>);
404 /// An HTTP response body in JSON format.
405 pub struct JsonResponse(pub serde_json::Value);
407 /// Interprets bytes from an HTTP response body as binary data.
408 impl TryFrom<Vec<u8>> for BinaryResponse {
409 type Error = std::io::Error;
411 fn try_from(bytes: Vec<u8>) -> std::io::Result<Self> {
412 Ok(BinaryResponse(bytes))
416 /// Interprets bytes from an HTTP response body as a JSON value.
417 impl TryFrom<Vec<u8>> for JsonResponse {
418 type Error = std::io::Error;
420 fn try_from(bytes: Vec<u8>) -> std::io::Result<Self> {
421 Ok(JsonResponse(serde_json::from_slice(&bytes)?))
427 use super::HttpEndpoint;
430 fn with_default_port() {
431 let endpoint = HttpEndpoint::for_host("foo.com".into());
432 assert_eq!(endpoint.host(), "foo.com");
433 assert_eq!(endpoint.port(), 80);
437 fn with_custom_port() {
438 let endpoint = HttpEndpoint::for_host("foo.com".into()).with_port(8080);
439 assert_eq!(endpoint.host(), "foo.com");
440 assert_eq!(endpoint.port(), 8080);
445 let endpoint = HttpEndpoint::for_host("foo.com".into()).with_path("/path".into());
446 assert_eq!(endpoint.host(), "foo.com");
447 assert_eq!(endpoint.path(), "/path");
451 fn without_uri_path() {
452 let endpoint = HttpEndpoint::for_host("foo.com".into());
453 assert_eq!(endpoint.host(), "foo.com");
454 assert_eq!(endpoint.path(), "/");
458 fn convert_to_socket_addrs() {
459 let endpoint = HttpEndpoint::for_host("foo.com".into());
460 let host = endpoint.host();
461 let port = endpoint.port();
463 use std::net::ToSocketAddrs;
464 match (&endpoint).to_socket_addrs() {
465 Err(e) => panic!("Unexpected error: {:?}", e),
466 Ok(mut socket_addrs) => {
467 match socket_addrs.next() {
468 None => panic!("Expected socket address"),
470 assert_eq!(addr, (host, port).to_socket_addrs().unwrap().next().unwrap());
471 assert!(socket_addrs.next().is_none());
480 pub(crate) mod client_tests {
482 use std::io::BufRead;
485 /// Server for handling HTTP client requests with a stock response.
486 pub struct HttpServer {
487 address: std::net::SocketAddr,
488 handler: std::thread::JoinHandle<()>,
489 shutdown: std::sync::Arc<std::sync::atomic::AtomicBool>,
492 /// Body of HTTP response messages.
493 pub enum MessageBody<T: ToString> {
500 pub fn responding_with_ok<T: ToString>(body: MessageBody<T>) -> Self {
501 let response = match body {
502 MessageBody::Empty => "HTTP/1.1 200 OK\r\n\r\n".to_string(),
503 MessageBody::Content(body) => {
504 let body = body.to_string();
506 "HTTP/1.1 200 OK\r\n\
507 Content-Length: {}\r\n\
509 {}", body.len(), body)
511 MessageBody::ChunkedContent(body) => {
512 let mut chuncked_body = Vec::new();
514 use chunked_transfer::Encoder;
515 let mut encoder = Encoder::with_chunks_size(&mut chuncked_body, 8);
516 encoder.write_all(body.to_string().as_bytes()).unwrap();
519 "HTTP/1.1 200 OK\r\n\
520 Transfer-Encoding: chunked\r\n\
522 {}", String::from_utf8(chuncked_body).unwrap())
525 HttpServer::responding_with(response)
528 pub fn responding_with_not_found() -> Self {
529 let response = "HTTP/1.1 404 Not Found\r\n\r\n".to_string();
530 HttpServer::responding_with(response)
533 fn responding_with(response: String) -> Self {
534 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
535 let address = listener.local_addr().unwrap();
537 let shutdown = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
538 let shutdown_signaled = std::sync::Arc::clone(&shutdown);
539 let handler = std::thread::spawn(move || {
540 for stream in listener.incoming() {
541 let mut stream = stream.unwrap();
542 stream.set_write_timeout(Some(TCP_STREAM_TIMEOUT)).unwrap();
544 let lines_read = std::io::BufReader::new(&stream)
546 .take_while(|line| !line.as_ref().unwrap().is_empty())
548 if lines_read == 0 { continue; }
550 for chunk in response.as_bytes().chunks(16) {
551 if shutdown_signaled.load(std::sync::atomic::Ordering::SeqCst) {
554 if let Err(_) = stream.write(chunk) { break; }
555 if let Err(_) = stream.flush() { break; }
561 Self { address, handler, shutdown }
565 self.shutdown.store(true, std::sync::atomic::Ordering::SeqCst);
566 self.handler.join().unwrap();
569 pub fn endpoint(&self) -> HttpEndpoint {
570 HttpEndpoint::for_host(self.address.ip().to_string()).with_port(self.address.port())
575 fn connect_to_unresolvable_host() {
576 match HttpClient::connect(("example.invalid", 80)) {
577 Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::Other),
578 Ok(_) => panic!("Expected error"),
583 fn connect_with_no_socket_address() {
584 match HttpClient::connect(&vec![][..]) {
585 Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput),
586 Ok(_) => panic!("Expected error"),
591 fn connect_with_unknown_server() {
592 match HttpClient::connect(("::", 80)) {
593 #[cfg(target_os = "windows")]
594 Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::AddrNotAvailable),
595 #[cfg(not(target_os = "windows"))]
596 Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::ConnectionRefused),
597 Ok(_) => panic!("Expected error"),
602 async fn connect_with_valid_endpoint() {
603 let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
605 match HttpClient::connect(&server.endpoint()) {
606 Err(e) => panic!("Unexpected error: {:?}", e),
612 async fn read_empty_message() {
613 let server = HttpServer::responding_with("".to_string());
615 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
616 match client.get::<BinaryResponse>("/foo", "foo.com").await {
618 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
619 assert_eq!(e.get_ref().unwrap().to_string(), "no status line");
621 Ok(_) => panic!("Expected error"),
626 async fn read_incomplete_message() {
627 let server = HttpServer::responding_with("HTTP/1.1 200 OK".to_string());
629 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
630 match client.get::<BinaryResponse>("/foo", "foo.com").await {
632 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
633 assert_eq!(e.get_ref().unwrap().to_string(), "no headers");
635 Ok(_) => panic!("Expected error"),
640 async fn read_too_large_message_headers() {
641 let response = format!(
642 "HTTP/1.1 302 Found\r\n\
644 \r\n", "Z".repeat(MAX_HTTP_MESSAGE_HEADER_SIZE));
645 let server = HttpServer::responding_with(response);
647 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
648 match client.get::<BinaryResponse>("/foo", "foo.com").await {
650 assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
651 assert_eq!(e.get_ref().unwrap().to_string(), "no headers");
653 Ok(_) => panic!("Expected error"),
658 async fn read_too_large_message_body() {
659 let body = "Z".repeat(MAX_HTTP_MESSAGE_BODY_SIZE + 1);
660 let server = HttpServer::responding_with_ok::<String>(MessageBody::Content(body));
662 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
663 match client.get::<BinaryResponse>("/foo", "foo.com").await {
665 assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
666 assert_eq!(e.get_ref().unwrap().to_string(), "out of range");
668 Ok(_) => panic!("Expected error"),
674 async fn read_message_with_unsupported_transfer_coding() {
675 let response = String::from(
676 "HTTP/1.1 200 OK\r\n\
677 Transfer-Encoding: gzip\r\n\
680 let server = HttpServer::responding_with(response);
682 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
683 match client.get::<BinaryResponse>("/foo", "foo.com").await {
685 assert_eq!(e.kind(), std::io::ErrorKind::InvalidInput);
686 assert_eq!(e.get_ref().unwrap().to_string(), "unsupported transfer coding");
688 Ok(_) => panic!("Expected error"),
693 async fn read_empty_message_body() {
694 let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
696 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
697 match client.get::<BinaryResponse>("/foo", "foo.com").await {
698 Err(e) => panic!("Unexpected error: {:?}", e),
699 Ok(bytes) => assert_eq!(bytes.0, Vec::<u8>::new()),
704 async fn read_message_body_with_length() {
705 let body = "foo bar baz qux".repeat(32);
706 let content = MessageBody::Content(body.clone());
707 let server = HttpServer::responding_with_ok::<String>(content);
709 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
710 match client.get::<BinaryResponse>("/foo", "foo.com").await {
711 Err(e) => panic!("Unexpected error: {:?}", e),
712 Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()),
717 async fn read_chunked_message_body() {
718 let body = "foo bar baz qux".repeat(32);
719 let chunked_content = MessageBody::ChunkedContent(body.clone());
720 let server = HttpServer::responding_with_ok::<String>(chunked_content);
722 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
723 match client.get::<BinaryResponse>("/foo", "foo.com").await {
724 Err(e) => panic!("Unexpected error: {:?}", e),
725 Ok(bytes) => assert_eq!(bytes.0, body.as_bytes()),
730 async fn reconnect_closed_connection() {
731 let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);
733 let mut client = HttpClient::connect(&server.endpoint()).unwrap();
734 assert!(client.get::<BinaryResponse>("/foo", "foo.com").await.is_ok());
735 match client.get::<BinaryResponse>("/foo", "foo.com").await {
736 Err(e) => panic!("Unexpected error: {:?}", e),
737 Ok(bytes) => assert_eq!(bytes.0, Vec::<u8>::new()),
742 fn from_bytes_into_binary_response() {
744 match BinaryResponse::try_from(bytes.to_vec()) {
745 Err(e) => panic!("Unexpected error: {:?}", e),
746 Ok(response) => assert_eq!(&response.0, bytes),
751 fn from_invalid_bytes_into_json_response() {
752 let json = serde_json::json!({ "result": 42 });
753 match JsonResponse::try_from(json.to_string().as_bytes()[..5].to_vec()) {
755 Ok(_) => panic!("Expected error"),
760 fn from_valid_bytes_into_json_response() {
761 let json = serde_json::json!({ "result": 42 });
762 match JsonResponse::try_from(json.to_string().as_bytes().to_vec()) {
763 Err(e) => panic!("Unexpected error: {:?}", e),
764 Ok(response) => assert_eq!(response.0, json),