Check IO errors in test using `raw_os_error()` instead of `kind()`
[rust-lightning] / lightning-block-sync / src / http.rs
index 2cfb8e50593d4aef4dc283f47655b3d6d8b5dc6d..0721babfde3d1b626051ba1ccccb7240d1f5a5a7 100644 (file)
@@ -5,9 +5,10 @@ use chunked_transfer;
 use serde_json;
 
 use std::convert::TryFrom;
+use std::fmt;
 #[cfg(not(feature = "tokio"))]
 use std::io::Write;
-use std::net::ToSocketAddrs;
+use std::net::{SocketAddr, ToSocketAddrs};
 use std::time::Duration;
 
 #[cfg(feature = "tokio")]
@@ -96,6 +97,7 @@ impl<'a> std::net::ToSocketAddrs for &'a HttpEndpoint {
 
 /// Client for making HTTP requests.
 pub(crate) struct HttpClient {
+       address: SocketAddr,
        stream: TcpStream,
 }
 
@@ -118,7 +120,7 @@ impl HttpClient {
                        TcpStream::from_std(stream)?
                };
 
-               Ok(Self { stream })
+               Ok(Self { address, stream })
        }
 
        /// Sends a `GET` request for a resource identified by `uri` at the `host`.
@@ -161,7 +163,6 @@ impl HttpClient {
        /// Sends an HTTP request message and reads the response, returning its body. Attempts to
        /// reconnect and retry if the connection has been closed.
        async fn send_request_with_retry(&mut self, request: &str) -> std::io::Result<Vec<u8>> {
-               let endpoint = self.stream.peer_addr().unwrap();
                match self.send_request(request).await {
                        Ok(bytes) => Ok(bytes),
                        Err(_) => {
@@ -175,7 +176,7 @@ impl HttpClient {
                                tokio::time::sleep(Duration::from_millis(100)).await;
                                #[cfg(not(feature = "tokio"))]
                                std::thread::sleep(Duration::from_millis(100));
-                               *self = Self::connect(endpoint)?;
+                               *self = Self::connect(self.address)?;
                                self.send_request(request).await
                        },
                }
@@ -279,32 +280,27 @@ impl HttpClient {
                        }
                }
 
-               if !status.is_ok() {
-                       // TODO: Handle 3xx redirection responses.
-                       return Err(std::io::Error::new(std::io::ErrorKind::NotFound, "not found"));
-               }
-
                // Read message body
                let read_limit = MAX_HTTP_MESSAGE_BODY_SIZE - reader.buffer().len();
                reader.get_mut().set_limit(read_limit as u64);
-               match message_length {
-                       HttpMessageLength::Empty => { Ok(Vec::new()) },
+               let contents = match message_length {
+                       HttpMessageLength::Empty => { Vec::new() },
                        HttpMessageLength::ContentLength(length) => {
                                if length == 0 || length > MAX_HTTP_MESSAGE_BODY_SIZE {
-                                       Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "out of range"))
+                                       return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "out of range"))
                                } else {
                                        let mut content = vec![0; length];
                                        #[cfg(feature = "tokio")]
                                        reader.read_exact(&mut content[..]).await?;
                                        #[cfg(not(feature = "tokio"))]
                                        reader.read_exact(&mut content[..])?;
-                                       Ok(content)
+                                       content
                                }
                        },
                        HttpMessageLength::TransferEncoding(coding) => {
                                if !coding.eq_ignore_ascii_case("chunked") {
-                                       Err(std::io::Error::new(
-                                                       std::io::ErrorKind::InvalidInput, "unsupported transfer coding"))
+                                       return Err(std::io::Error::new(
+                                               std::io::ErrorKind::InvalidInput, "unsupported transfer coding"))
                                } else {
                                        let mut content = Vec::new();
                                        #[cfg(feature = "tokio")]
@@ -339,17 +335,44 @@ impl HttpClient {
                                                        reader.read_exact(&mut content[chunk_offset..]).await?;
                                                        content.resize(chunk_offset + chunk_size, 0);
                                                }
-                                               Ok(content)
+                                               content
                                        }
                                        #[cfg(not(feature = "tokio"))]
                                        {
                                                let mut decoder = chunked_transfer::Decoder::new(reader);
                                                decoder.read_to_end(&mut content)?;
-                                               Ok(content)
+                                               content
                                        }
                                }
                        },
+               };
+
+               if !status.is_ok() {
+                       // TODO: Handle 3xx redirection responses.
+                       let error = HttpError {
+                               status_code: status.code.to_string(),
+                               contents,
+                       };
+                       return Err(std::io::Error::new(std::io::ErrorKind::Other, error));
                }
+
+               Ok(contents)
+       }
+}
+
+/// HTTP error consisting of a status code and body contents.
+#[derive(Debug)]
+pub(crate) struct HttpError {
+       pub(crate) status_code: String,
+       pub(crate) contents: Vec<u8>,
+}
+
+impl std::error::Error for HttpError {}
+
+impl fmt::Display for HttpError {
+       fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+               let contents = String::from_utf8_lossy(&self.contents);
+               write!(f, "status_code: {}, contents: {}", self.status_code, contents)
        }
 }
 
@@ -528,16 +551,16 @@ pub(crate) mod client_tests {
        }
 
        impl HttpServer {
-               pub fn responding_with_ok<T: ToString>(body: MessageBody<T>) -> Self {
+               fn responding_with_body<T: ToString>(status: &str, body: MessageBody<T>) -> Self {
                        let response = match body {
-                               MessageBody::Empty => "HTTP/1.1 200 OK\r\n\r\n".to_string(),
+                               MessageBody::Empty => format!("{}\r\n\r\n", status),
                                MessageBody::Content(body) => {
                                        let body = body.to_string();
                                        format!(
-                                               "HTTP/1.1 200 OK\r\n\
+                                               "{}\r\n\
                                                 Content-Length: {}\r\n\
                                                 \r\n\
-                                                {}", body.len(), body)
+                                                {}", status, body.len(), body)
                                },
                                MessageBody::ChunkedContent(body) => {
                                        let mut chuncked_body = Vec::new();
@@ -547,18 +570,26 @@ pub(crate) mod client_tests {
                                                encoder.write_all(body.to_string().as_bytes()).unwrap();
                                        }
                                        format!(
-                                               "HTTP/1.1 200 OK\r\n\
+                                               "{}\r\n\
                                                 Transfer-Encoding: chunked\r\n\
                                                 \r\n\
-                                                {}", String::from_utf8(chuncked_body).unwrap())
+                                                {}", status, String::from_utf8(chuncked_body).unwrap())
                                },
                        };
                        HttpServer::responding_with(response)
                }
 
+               pub fn responding_with_ok<T: ToString>(body: MessageBody<T>) -> Self {
+                       HttpServer::responding_with_body("HTTP/1.1 200 OK", body)
+               }
+
                pub fn responding_with_not_found() -> Self {
-                       let response = "HTTP/1.1 404 Not Found\r\n\r\n".to_string();
-                       HttpServer::responding_with(response)
+                       HttpServer::responding_with_body::<String>("HTTP/1.1 404 Not Found", MessageBody::Empty)
+               }
+
+               pub fn responding_with_server_error<T: ToString>(content: T) -> Self {
+                       let body = MessageBody::Content(content);
+                       HttpServer::responding_with_body("HTTP/1.1 500 Internal Server Error", body)
                }
 
                fn responding_with(response: String) -> Self {
@@ -605,7 +636,10 @@ pub(crate) mod client_tests {
        #[test]
        fn connect_to_unresolvable_host() {
                match HttpClient::connect(("example.invalid", 80)) {
-                       Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::Other),
+                       Err(e) => {
+                               assert!(e.to_string().contains("failed to lookup address information") ||
+                                       e.to_string().contains("No such host"), "{:?}", e);
+                       },
                        Ok(_) => panic!("Expected error"),
                }
        }
@@ -720,6 +754,22 @@ pub(crate) mod client_tests {
                }
        }
 
+       #[tokio::test]
+       async fn read_error() {
+               let server = HttpServer::responding_with_server_error("foo");
+
+               let mut client = HttpClient::connect(&server.endpoint()).unwrap();
+               match client.get::<JsonResponse>("/foo", "foo.com").await {
+                       Err(e) => {
+                               assert_eq!(e.kind(), std::io::ErrorKind::Other);
+                               let http_error = e.into_inner().unwrap().downcast::<HttpError>().unwrap();
+                               assert_eq!(http_error.status_code, "500");
+                               assert_eq!(http_error.contents, "foo".as_bytes());
+                       },
+                       Ok(_) => panic!("Expected error"),
+               }
+       }
+
        #[tokio::test]
        async fn read_empty_message_body() {
                let server = HttpServer::responding_with_ok::<String>(MessageBody::Empty);