From 50cb5cecd103f9fcaaf0c5146b913047dd207760 Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Sat, 28 Nov 2020 17:47:17 -0800 Subject: [PATCH] Fix buffered DeadlineStream --- src/response.rs | 53 ++++++++++++++++++++++---------------------- src/stream.rs | 59 ++++++++++++++++++++++++++++++++++--------------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/src/response.rs b/src/response.rs index 803c82b..8e032d9 100644 --- a/src/response.rs +++ b/src/response.rs @@ -524,34 +524,35 @@ pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option, s resp.stream = Some(stream); } -fn read_next_line(reader: &mut R) -> io::Result { - let mut buf = Vec::new(); - let mut prev_byte_was_cr = false; - let mut one = [0_u8]; - - loop { - let amt = reader.read(&mut one[..])?; - - if amt == 0 { - return Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Unexpected EOF", - )); +fn read_next_line(reader: &mut impl BufRead) -> io::Result { + let mut s = String::new(); + let result = reader.read_line(&mut s).map_err(|e| { + // On unix-y platforms set_read_timeout and set_write_timeout + // causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut. + // Since the socket most definitely not set_nonblocking(true), + // we can safely normalize WouldBlock to TimedOut + if e.kind() == io::ErrorKind::WouldBlock { + io::Error::new(io::ErrorKind::TimedOut, "timed out reading headers") + } else { + e } - - let byte = one[0]; - - if byte == b'\n' && prev_byte_was_cr { - buf.pop(); // removing the '\r' - return String::from_utf8(buf).map_err(|_| { - io::Error::new(io::ErrorKind::InvalidInput, "Header is not in ASCII") - }); - } - - prev_byte_was_cr = byte == b'\r'; - - buf.push(byte); + }); + if result? == 0 { + return Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "Unexpected EOF", + )); } + + if !s.ends_with("\r\n") { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Header field didn't end with \\r: {}", s), + )); + } + s.pop(); + s.pop(); + Ok(s) } /// Limits a `Read` to a content size (as set by a "Content-Length" header). diff --git a/src/stream.rs b/src/stream.rs index a885b5a..717b590 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -54,6 +54,32 @@ impl From for Stream { } } +impl BufRead for DeadlineStream { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + if let Some(deadline) = self.deadline { + let timeout = time_until_deadline(deadline)?; + if let Some(socket) = self.stream.socket() { + socket.set_read_timeout(Some(timeout))?; + socket.set_write_timeout(Some(timeout))?; + } + } + self.stream.fill_buf().map_err(|e| { + // On unix-y platforms set_read_timeout and set_write_timeout + // causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut. + // Since the socket most definitely not set_nonblocking(true), + // we can safely normalize WouldBlock to TimedOut + if e.kind() == io::ErrorKind::WouldBlock { + return io_err_timeout("timed out reading response".to_string()); + } + e + }) + } + + fn consume(&mut self, amt: usize) { + self.stream.consume(amt) + } +} + impl Read for DeadlineStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { if let Some(deadline) = self.deadline { @@ -109,6 +135,19 @@ impl Stream { inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])), } } + + fn from_tcp_stream(t: TcpStream) -> Stream { + Stream { + inner: BufReader::with_capacity(1000, Inner::Http(t)), + } + } + + fn from_tls_stream(t: StreamOwned) -> Stream { + Stream { + inner: BufReader::with_capacity(1000, Inner::Https(t)), + } + } + // Check if the server has closed a stream by performing a one-byte // non-blocking read. If this returns EOF, the server has closed the // connection: return true. If this returns WouldBlock (aka EAGAIN), @@ -198,16 +237,6 @@ impl Read for Inner { } } -impl BufRead for DeadlineStream { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - self.stream.fill_buf() - } - - fn consume(&mut self, amt: usize) { - self.stream.consume(amt) - } -} - impl BufRead for Stream { fn fill_buf(&mut self) -> io::Result<&[u8]> { self.inner.fill_buf() @@ -279,11 +308,7 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result // let port = unit.url.port().unwrap_or(80); - connect_host(unit, hostname, port) - .map(Inner::Http) - .map(|h| Stream { - inner: BufReader::new(h), - }) + connect_host(unit, hostname, port).map(Stream::from_tcp_stream) } #[cfg(all(feature = "tls", feature = "native-certs"))] @@ -327,9 +352,7 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result Result {