diff --git a/Cargo.toml b/Cargo.toml index 277a7a9..f3ac84b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ serde = { version = "1", features = ["derive"] } rayon = "1.3.0" rayon-core = "1.7.0" chrono = "0.4.11" -env_logger = "0.7.1" +env_logger = "0.8.1" [[example]] name = "smoke-test" diff --git a/src/response.rs b/src/response.rs index 53c8ed6..cc91933 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,7 +1,6 @@ use std::fmt; use std::io::{self, Cursor, ErrorKind, Read}; use std::str::FromStr; -use std::time::Instant; use chunked_transfer::Decoder as ChunkDecoder; @@ -47,7 +46,6 @@ pub struct Response { headers: Vec
, unit: Option, stream: Option, - deadline: Option, } /// index into status_line where we split: HTTP/1.1 200 OK @@ -267,12 +265,17 @@ impl Response { let stream = self.stream.expect("No reader in response?!"); let unit = self.unit; + if let Some(unit) = &unit { + let result = stream.set_read_timeout(unit.req.agent.config.timeout_read); + if let Err(e) = result { + return Box::new(ErrorReader(e)) as Box; + } + } let deadline = unit.as_ref().and_then(|u| u.deadline); let stream = DeadlineStream::new(stream, deadline); match (use_chunked, limit_bytes) { - (true, _) => Box::new(PoolReturnRead::new(unit, ChunkDecoder::new(stream))) - as Box, + (true, _) => Box::new(PoolReturnRead::new(unit, ChunkDecoder::new(stream))), (false, Some(len)) => { Box::new(PoolReturnRead::new(unit, LimitedRead::new(stream, len))) } @@ -438,7 +441,6 @@ impl Response { headers, unit: None, stream: None, - deadline: None, }) } @@ -507,9 +509,6 @@ impl FromStr for Response { /// *Internal API* pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option, stream: Stream) { resp.url = Some(url); - if let Some(unit) = &unit { - resp.deadline = unit.deadline; - } resp.unit = unit; resp.stream = Some(stream); } @@ -730,3 +729,14 @@ mod tests { assert!(matches!(err, Error::BadStatus)); } } + +// ErrorReader returns an error for every read. +// The error is as close to a clone of the underlying +// io::Error as we can get. +struct ErrorReader(io::Error); + +impl Read for ErrorReader { + fn read(&mut self, _buf: &mut [u8]) -> io::Result { + Err(io::Error::new(self.0.kind(), self.0.to_string())) + } +} diff --git a/src/stream.rs b/src/stream.rs index dc775f6..3f10be5 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -162,6 +162,14 @@ impl Stream { } } + pub(crate) fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { + if let Some(socket) = self.socket() { + socket.set_read_timeout(timeout) + } else { + Ok(()) + } + } + #[cfg(test)] pub fn to_write_vec(&self) -> Vec { match self { @@ -324,7 +332,7 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result Result { - let deadline: Option = + let connect_deadline: Option = if let Some(timeout_connect) = unit.req.agent.config.timeout_connect { Instant::now().checked_add(timeout_connect) } else { @@ -357,7 +365,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result Some(time_until_deadline(deadline)?), None => None, }; @@ -368,7 +376,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result Result io::Result<() stream.write_all(&contents[i..i + 1])?; stream.write_all(&[b'\n'; 1])?; stream.flush()?; - thread::sleep(Duration::from_millis(10)); + thread::sleep(Duration::from_millis(100)); } Ok(()) } @@ -47,17 +47,38 @@ fn overall_timeout_during_body() { get_and_expect_timeout(url); } +#[test] +fn read_timeout_during_body() { + let server = TestServer::new(|stream| dribble_body_respond(stream, &[b'a'; 300])); + let url = format!("http://localhost:{}/", server.port); + let agent = builder().timeout_read(Duration::from_millis(70)).build(); + let resp = match agent.get(&url).call() { + Ok(r) => r, + Err(e) => panic!("got error during headers, not body: {:?}", e), + }; + match resp.into_string() { + Err(io_error) => match io_error.kind() { + io::ErrorKind::TimedOut => Ok(()), + _ => Err(format!("{:?}", io_error)), + }, + Ok(_) => Err("successful response".to_string()), + } + .expect("expected timeout but got something else"); +} + // Send HTTP headers on the TcpStream at a rate of one header every 100 // milliseconds, for a total of 30 headers. -//fn dribble_headers_respond(mut stream: TcpStream) -> io::Result<()> { -// stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n")?; -// for _ in 0..30 { -// stream.write_all(b"a: b\n")?; -// stream.flush()?; -// thread::sleep(Duration::from_millis(100)); -// } -// Ok(()) -//} +fn dribble_headers_respond(mut stream: TcpStream) -> io::Result<()> { + stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n")?; + for _ in 0..30 { + stream.write_all(b"a: b\r\n")?; + stream.flush()?; + thread::sleep(Duration::from_millis(100)); + } + stream.write_all(b"\r\n")?; + + Ok(()) +} #[test] // TODO: Our current behavior is actually incorrect (we'll return BadHeader if a timeout occurs during headers). @@ -70,6 +91,35 @@ fn overall_timeout_during_body() { // let url = format!("http://localhost:{}/", server.port); // get_and_expect_timeout(url); //} +#[test] +fn read_timeout_during_headers() { + let server = TestServer::new(dribble_headers_respond); + let url = format!("http://localhost:{}/", server.port); + let agent = builder().timeout_read(Duration::from_millis(10)).build(); + let resp = agent.get(&url).call(); + match resp { + Ok(_) => Err("successful response".to_string()), + Err(Error::Io(e)) if e.kind() == io::ErrorKind::TimedOut => Ok(()), + Err(e) => Err(format!("Unexpected error type: {:?}", e)), + } + .expect("expected timeout but got something else"); +} + +#[test] +fn overall_timeout_during_headers() { + // Start a test server on an available port, that dribbles out a response at 1 write per 10ms. + let server = TestServer::new(dribble_headers_respond); + let url = format!("http://localhost:{}/", server.port); + let agent = builder().timeout(Duration::from_millis(500)).build(); + let resp = agent.get(&url).call(); + match resp { + Ok(_) => Err("successful response".to_string()), + Err(Error::Io(e)) if e.kind() == io::ErrorKind::TimedOut => Ok(()), + Err(e) => Err(format!("Unexpected error type: {:?}", e)), + } + .expect("expected timeout but got something else"); +} + #[test] #[cfg(feature = "json")] fn overall_timeout_reading_json() {