diff --git a/src/request.rs b/src/request.rs index daee345..7d469fb 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,8 +1,8 @@ -use std::fmt; use std::io::Read; use std::result::Result; use std::sync::{Arc, Mutex}; use std::time; +use std::{fmt, time::Duration}; use qstring::QString; use url::{form_urlencoded, Url}; @@ -40,8 +40,8 @@ pub struct Request { pub(crate) headers: Vec
, pub(crate) query: QString, pub(crate) timeout_connect: u64, - pub(crate) timeout_read: u64, - pub(crate) timeout_write: u64, + pub(crate) timeout_read: Option, + pub(crate) timeout_write: Option, pub(crate) timeout: Option, pub(crate) redirects: u32, pub(crate) proxy: Option, @@ -368,7 +368,10 @@ impl Request { /// println!("{:?}", r); /// ``` pub fn timeout_read(&mut self, millis: u64) -> &mut Request { - self.timeout_read = millis; + match millis { + 0 => self.timeout_read = None, + m => self.timeout_read = Some(Duration::from_millis(m)), + } self } @@ -385,7 +388,10 @@ impl Request { /// println!("{:?}", r); /// ``` pub fn timeout_write(&mut self, millis: u64) -> &mut Request { - self.timeout_write = millis; + match millis { + 0 => self.timeout_write = None, + m => self.timeout_write = Some(Duration::from_millis(m)), + } self } diff --git a/src/response.rs b/src/response.rs index a768535..75ba391 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,7 +1,6 @@ use std::fmt; use std::io::{self, Cursor, ErrorKind, Read}; use std::str::FromStr; -use std::time::Instant; use chunked_transfer::Decoder as ChunkDecoder; @@ -56,7 +55,6 @@ pub struct Response { headers: Vec
, unit: Option, stream: Option, - deadline: Option, } /// index into status_line where we split: HTTP/1.1 200 OK @@ -327,12 +325,17 @@ impl Response { let stream = self.stream.expect("No reader in response?!"); let unit = self.unit; + if let Some(unit) = &unit { + let result = stream.set_read_timeout(unit.req.timeout_read); + if let Err(e) = result { + return Box::new(ErrorReader(e)) as Box; + } + } let deadline = unit.as_ref().and_then(|u| u.deadline); let stream = DeadlineStream::new(stream, deadline); match (use_chunked, limit_bytes) { - (true, _) => Box::new(PoolReturnRead::new(unit, ChunkDecoder::new(stream))) - as Box, + (true, _) => Box::new(PoolReturnRead::new(unit, ChunkDecoder::new(stream))), (false, Some(len)) => { Box::new(PoolReturnRead::new(unit, LimitedRead::new(stream, len))) } @@ -505,7 +508,6 @@ impl Response { headers, unit: None, stream: None, - deadline: None, }) } @@ -585,9 +587,6 @@ impl Into for Error { /// *Internal API* pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option, stream: Stream) { resp.url = Some(url); - if let Some(unit) = &unit { - resp.deadline = unit.deadline; - } resp.unit = unit; resp.stream = Some(stream); } @@ -813,3 +812,14 @@ mod tests { assert_eq!(v, "Bad Status\n"); } } + +// ErrorReader returns an error for every read. +// The error is as close to a clone of the underlying +// io::Error as we can get. +struct ErrorReader(io::Error); + +impl Read for ErrorReader { + fn read(&mut self, _buf: &mut [u8]) -> io::Result { + Err(io::Error::new(self.0.kind(), self.0.to_string())) + } +} diff --git a/src/stream.rs b/src/stream.rs index f14a3ae..cf53c4d 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -172,6 +172,14 @@ impl Stream { } } + pub(crate) fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { + if let Some(socket) = self.socket() { + socket.set_read_timeout(timeout) + } else { + Ok(()) + } + } + #[cfg(test)] pub fn to_write_vec(&self) -> Vec { match self { @@ -453,24 +461,16 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result 0 { - stream - .set_read_timeout(Some(Duration::from_millis(unit.req.timeout_read as u64))) - .ok(); } else { - stream.set_read_timeout(None).ok(); + stream.set_read_timeout(unit.req.timeout_read)?; } if let Some(deadline) = deadline { stream .set_write_timeout(Some(time_until_deadline(deadline)?)) .ok(); - } else if unit.req.timeout_write > 0 { - stream - .set_write_timeout(Some(Duration::from_millis(unit.req.timeout_write as u64))) - .ok(); } else { - stream.set_write_timeout(None).ok(); + stream.set_read_timeout(unit.req.timeout_read)?; } if proto == Some(Proto::HTTPConnect) { diff --git a/src/test/timeout.rs b/src/test/timeout.rs index 7a296e4..c9a0821 100644 --- a/src/test/timeout.rs +++ b/src/test/timeout.rs @@ -19,7 +19,7 @@ fn dribble_body_respond(mut stream: TcpStream, contents: &[u8]) -> io::Result<() stream.write_all(&contents[i..i + 1])?; stream.write_all(&[b'\n'; 1])?; stream.flush()?; - thread::sleep(Duration::from_millis(10)); + thread::sleep(Duration::from_millis(100)); } Ok(()) } @@ -47,6 +47,22 @@ fn overall_timeout_during_body() { get_and_expect_timeout(url); } +#[test] +fn read_timeout_during_body() { + let server = TestServer::new(|stream| dribble_body_respond(stream, &[b'a'; 300])); + let url = format!("http://localhost:{}/", server.port); + let agent = Agent::default().build(); + let resp = agent.get(&url).timeout_read(5).call(); + match resp.into_string() { + Err(io_error) => match io_error.kind() { + io::ErrorKind::TimedOut => Ok(()), + _ => Err(format!("{:?}", io_error)), + }, + Ok(_) => Err("successful response".to_string()), + } + .expect("expected timeout but got something else"); +} + // Send HTTP headers on the TcpStream at a rate of one header every 100 // milliseconds, for a total of 30 headers. fn dribble_headers_respond(mut stream: TcpStream) -> io::Result<()> {