From 131a0264d1b67a67e4f6425542bac753fddb1b6a Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Sat, 28 Nov 2020 12:04:28 -0800 Subject: [PATCH 1/6] Move BufReader up the stack in Stream. Stream now has an `Inner` enum, and wraps an instance of that enum in a BufReader. This allows Stream itself to implement BufRead trivially, and simplify some of the match dispatching. Having Stream implement BufRead means we can make use of `read_line` instead of our own `read_next_line` (not done in this PR yet). Also, removes the `Cursor` variant of the Inner enum in favor of using the `Test` variant everywhere, since it's strictly more powerful. --- src/pool.rs | 16 ++---- src/response.rs | 17 +++--- src/stream.rs | 138 ++++++++++++++++++++++++------------------------ src/test/mod.rs | 8 ++- 4 files changed, 86 insertions(+), 93 deletions(-) diff --git a/src/pool.rs b/src/pool.rs index 6f732b3..f6a154e 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -229,7 +229,7 @@ fn pool_connections_limit() { proxy: None, }); for key in poolkeys.clone() { - pool.add(key, Stream::Cursor(std::io::Cursor::new(vec![]))); + pool.add(key, Stream::from_vec(vec![])) } assert_eq!(pool.len(), pool.max_idle_connections); @@ -254,10 +254,7 @@ fn pool_per_host_connections_limit() { }; for _ in 0..pool.max_idle_connections_per_host * 2 { - pool.add( - poolkey.clone(), - Stream::Cursor(std::io::Cursor::new(vec![])), - ); + pool.add(poolkey.clone(), Stream::from_vec(vec![])) } assert_eq!(pool.len(), pool.max_idle_connections_per_host); @@ -275,15 +272,12 @@ fn pool_checks_proxy() { let pool = ConnectionPool::new_with_limits(10, 1); let url = Url::parse("zzz:///example.com").unwrap(); - pool.add( - PoolKey::new(&url, None), - Stream::Cursor(std::io::Cursor::new(vec![])), - ); + pool.add(PoolKey::new(&url, None), Stream::from_vec(vec![])); assert_eq!(pool.len(), 1); pool.add( PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())), - Stream::Cursor(std::io::Cursor::new(vec![])), + Stream::from_vec(vec![]), ); assert_eq!(pool.len(), 2); @@ -292,7 +286,7 @@ fn pool_checks_proxy() { &url, Some(Proxy::new("user:password@localhost:9999").unwrap()), ), - Stream::Cursor(std::io::Cursor::new(vec![])), + Stream::from_vec(vec![]), ); assert_eq!(pool.len(), 3); } diff --git a/src/response.rs b/src/response.rs index 18320e1..803c82b 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,6 +1,6 @@ -use std::fmt; -use std::io::{self, Cursor, Read}; +use std::io::{self, Read}; use std::str::FromStr; +use std::{fmt, io::BufRead}; use chunked_transfer::Decoder as ChunkDecoder; @@ -425,7 +425,7 @@ impl Response { /// let resp = ureq::Response::do_from_read(read); /// /// assert_eq!(resp.status(), 401); - pub(crate) fn do_from_read(mut reader: impl Read) -> Result { + pub(crate) fn do_from_read(mut reader: impl BufRead) -> Result { // // HTTP/1.1 200 OK\r\n let status_line = read_next_line(&mut reader)?; @@ -455,8 +455,8 @@ impl Response { } #[cfg(test)] - pub fn to_write_vec(&self) -> Vec { - self.stream.as_ref().unwrap().to_write_vec() + pub fn to_write_vec(self) -> Vec { + self.stream.unwrap().to_write_vec() } } @@ -508,10 +508,9 @@ impl FromStr for Response { /// assert_eq!(body, "Hello World!!!"); /// ``` fn from_str(s: &str) -> Result { - let bytes = s.as_bytes().to_owned(); - let mut cursor = Cursor::new(bytes); - let mut resp = Self::do_from_read(&mut cursor)?; - set_stream(&mut resp, "".into(), None, Stream::Cursor(cursor)); + let mut stream = Stream::from_vec(s.as_bytes().to_owned()); + let mut resp = Self::do_from_read(&mut stream)?; + set_stream(&mut resp, "".into(), None, stream); Ok(resp) } } diff --git a/src/stream.rs b/src/stream.rs index b551094..a885b5a 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,10 +1,10 @@ use log::debug; -use std::fmt; -use std::io::{self, BufRead, BufReader, Cursor, Read, Write}; +use std::io::{self, BufRead, BufReader, Read, Write}; use std::net::SocketAddr; use std::net::TcpStream; use std::time::Duration; use std::time::Instant; +use std::{fmt, io::Cursor}; use chunked_transfer::Decoder as ChunkDecoder; @@ -21,14 +21,15 @@ use crate::{error::Error, proxy::Proto}; use crate::error::ErrorKind; use crate::unit::Unit; -#[allow(clippy::large_enum_variant)] -pub enum Stream { - Http(BufReader), +pub(crate) struct Stream { + inner: BufReader, +} + +enum Inner { + Http(TcpStream), #[cfg(feature = "tls")] - Https(BufReader>), - Cursor(Cursor>), - #[cfg(test)] - Test(Box, Vec), + Https(rustls::StreamOwned), + Test(Box, Vec), } // DeadlineStream wraps a stream such that read() will return an error @@ -36,7 +37,7 @@ pub enum Stream { // TcpStream to ensure read() doesn't block beyond the deadline. // When the From trait is used to turn a DeadlineStream back into a // Stream (by PoolReturningRead), the timeouts are removed. -pub struct DeadlineStream { +pub(crate) struct DeadlineStream { stream: Stream, deadline: Option, } @@ -91,22 +92,23 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error { impl fmt::Debug for Stream { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "Stream[{}]", - match self { - Stream::Http(_) => "http", - #[cfg(feature = "tls")] - Stream::Https(_) => "https", - Stream::Cursor(_) => "cursor", - #[cfg(test)] - Stream::Test(_, _) => "test", - } - ) + let mut result = f.debug_struct("Stream"); + match self.inner.get_ref() { + Inner::Http(tcpstream) => result.field("tcp", tcpstream), + #[cfg(feature = "tls")] + Inner::Https(tlsstream) => result.field("tls", tlsstream.get_ref()), + Inner::Test(_, _) => result.field("test", &String::new()), + }; + result.finish() } } impl Stream { + pub(crate) fn from_vec(v: Vec) -> Stream { + Stream { + inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])), + } + } // Check if the server has closed a stream by performing a one-byte // non-blocking read. If this returns EOF, the server has closed the // connection: return true. If this returns WouldBlock (aka EAGAIN), @@ -134,10 +136,10 @@ impl Stream { } } pub fn is_poolable(&self) -> bool { - match self { - Stream::Http(_) => true, + match self.inner.get_ref() { + Inner::Http(_) => true, #[cfg(feature = "tls")] - Stream::Https(_) => true, + Inner::Https(_) => true, _ => false, } } @@ -154,10 +156,10 @@ impl Stream { } pub(crate) fn socket(&self) -> Option<&TcpStream> { - match self { - Stream::Http(b) => Some(b.get_ref()), + match self.inner.get_ref() { + Inner::Http(b) => Some(b), #[cfg(feature = "tls")] - Stream::Https(b) => Some(&b.get_ref().sock), + Inner::Https(b) => Some(&b.get_ref()), _ => None, } } @@ -171,48 +173,48 @@ impl Stream { } #[cfg(test)] - pub fn to_write_vec(&self) -> Vec { - match self { - Stream::Test(_, writer) => writer.clone(), + pub fn to_write_vec(self) -> Vec { + match self.inner.into_inner() { + Inner::Test(_, writer) => writer.clone(), _ => panic!("to_write_vec on non Test stream"), } } } impl Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.inner.read(buf) + } +} + +impl Read for Inner { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { - Stream::Http(sock) => sock.read(buf), + Inner::Http(sock) => sock.read(buf), #[cfg(feature = "tls")] - Stream::Https(stream) => read_https(stream, buf), - Stream::Cursor(read) => read.read(buf), - #[cfg(test)] - Stream::Test(reader, _) => reader.read(buf), + Inner::Https(stream) => read_https(stream, buf), + Inner::Test(reader, _) => reader.read(buf), } } } +impl BufRead for DeadlineStream { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + self.stream.fill_buf() + } + + fn consume(&mut self, amt: usize) { + self.stream.consume(amt) + } +} + impl BufRead for Stream { fn fill_buf(&mut self) -> io::Result<&[u8]> { - match self { - Stream::Http(r) => r.fill_buf(), - #[cfg(feature = "tls")] - Stream::Https(r) => r.fill_buf(), - Stream::Cursor(r) => r.fill_buf(), - #[cfg(test)] - Stream::Test(r, _) => r.fill_buf(), - } + self.inner.fill_buf() } fn consume(&mut self, amt: usize) { - match self { - Stream::Http(r) => r.consume(amt), - #[cfg(feature = "tls")] - Stream::Https(r) => r.consume(amt), - Stream::Cursor(r) => r.consume(amt), - #[cfg(test)] - Stream::Test(r, _) => r.consume(amt), - } + self.inner.consume(amt) } } @@ -228,7 +230,7 @@ where #[cfg(feature = "tls")] fn read_https( - stream: &mut BufReader>, + stream: &mut StreamOwned, buf: &mut [u8], ) -> io::Result { match stream.read(buf) { @@ -256,23 +258,19 @@ fn is_close_notify(e: &std::io::Error) -> bool { impl Write for Stream { fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Stream::Http(sock) => sock.get_mut().write(buf), + match self.inner.get_mut() { + Inner::Http(sock) => sock.write(buf), #[cfg(feature = "tls")] - Stream::Https(stream) => stream.get_mut().write(buf), - Stream::Cursor(_) => panic!("Write to read only stream"), - #[cfg(test)] - Stream::Test(_, writer) => writer.write(buf), + Inner::Https(stream) => stream.write(buf), + Inner::Test(_, writer) => writer.write(buf), } } fn flush(&mut self) -> io::Result<()> { - match self { - Stream::Http(sock) => sock.get_mut().flush(), + match self.inner.get_mut() { + Inner::Http(sock) => sock.flush(), #[cfg(feature = "tls")] - Stream::Https(stream) => stream.get_mut().flush(), - Stream::Cursor(_) => panic!("Flush read only stream"), - #[cfg(test)] - Stream::Test(_, writer) => writer.flush(), + Inner::Https(stream) => stream.flush(), + Inner::Test(_, writer) => writer.flush(), } } } @@ -282,8 +280,10 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result let port = unit.url.port().unwrap_or(80); connect_host(unit, hostname, port) - .map(BufReader::new) - .map(Stream::Http) + .map(Inner::Http) + .map(|h| Stream { + inner: BufReader::new(h), + }) } #[cfg(all(feature = "tls", feature = "native-certs"))] @@ -327,7 +327,9 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result Result { diff --git a/src/test/mod.rs b/src/test/mod.rs index 09747cf..15a918d 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -3,7 +3,7 @@ use crate::stream::Stream; use crate::unit::Unit; use once_cell::sync::Lazy; use std::collections::HashMap; -use std::io::{Cursor, Write}; +use std::io::Write; use std::sync::{Arc, Mutex}; mod agent_test; @@ -29,7 +29,7 @@ where } #[allow(clippy::write_with_newline)] -pub fn make_response( +pub(crate) fn make_response( status: u16, status_text: &str, headers: Vec<&str>, @@ -42,9 +42,7 @@ pub fn make_response( } write!(&mut buf, "\r\n").ok(); buf.append(&mut body); - let cursor = Cursor::new(buf); - let write: Vec = vec![]; - Ok(Stream::Test(Box::new(cursor), write)) + Ok(Stream::from_vec(buf)) } pub(crate) fn resolve_handler(unit: &Unit) -> Result { From 50cb5cecd103f9fcaaf0c5146b913047dd207760 Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Sat, 28 Nov 2020 17:47:17 -0800 Subject: [PATCH 2/6] Fix buffered DeadlineStream --- src/response.rs | 53 ++++++++++++++++++++++---------------------- src/stream.rs | 59 ++++++++++++++++++++++++++++++++++--------------- 2 files changed, 68 insertions(+), 44 deletions(-) diff --git a/src/response.rs b/src/response.rs index 803c82b..8e032d9 100644 --- a/src/response.rs +++ b/src/response.rs @@ -524,34 +524,35 @@ pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option, s resp.stream = Some(stream); } -fn read_next_line(reader: &mut R) -> io::Result { - let mut buf = Vec::new(); - let mut prev_byte_was_cr = false; - let mut one = [0_u8]; - - loop { - let amt = reader.read(&mut one[..])?; - - if amt == 0 { - return Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Unexpected EOF", - )); +fn read_next_line(reader: &mut impl BufRead) -> io::Result { + let mut s = String::new(); + let result = reader.read_line(&mut s).map_err(|e| { + // On unix-y platforms set_read_timeout and set_write_timeout + // causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut. + // Since the socket most definitely not set_nonblocking(true), + // we can safely normalize WouldBlock to TimedOut + if e.kind() == io::ErrorKind::WouldBlock { + io::Error::new(io::ErrorKind::TimedOut, "timed out reading headers") + } else { + e } - - let byte = one[0]; - - if byte == b'\n' && prev_byte_was_cr { - buf.pop(); // removing the '\r' - return String::from_utf8(buf).map_err(|_| { - io::Error::new(io::ErrorKind::InvalidInput, "Header is not in ASCII") - }); - } - - prev_byte_was_cr = byte == b'\r'; - - buf.push(byte); + }); + if result? == 0 { + return Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "Unexpected EOF", + )); } + + if !s.ends_with("\r\n") { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Header field didn't end with \\r: {}", s), + )); + } + s.pop(); + s.pop(); + Ok(s) } /// Limits a `Read` to a content size (as set by a "Content-Length" header). diff --git a/src/stream.rs b/src/stream.rs index a885b5a..717b590 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -54,6 +54,32 @@ impl From for Stream { } } +impl BufRead for DeadlineStream { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + if let Some(deadline) = self.deadline { + let timeout = time_until_deadline(deadline)?; + if let Some(socket) = self.stream.socket() { + socket.set_read_timeout(Some(timeout))?; + socket.set_write_timeout(Some(timeout))?; + } + } + self.stream.fill_buf().map_err(|e| { + // On unix-y platforms set_read_timeout and set_write_timeout + // causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut. + // Since the socket most definitely not set_nonblocking(true), + // we can safely normalize WouldBlock to TimedOut + if e.kind() == io::ErrorKind::WouldBlock { + return io_err_timeout("timed out reading response".to_string()); + } + e + }) + } + + fn consume(&mut self, amt: usize) { + self.stream.consume(amt) + } +} + impl Read for DeadlineStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { if let Some(deadline) = self.deadline { @@ -109,6 +135,19 @@ impl Stream { inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])), } } + + fn from_tcp_stream(t: TcpStream) -> Stream { + Stream { + inner: BufReader::with_capacity(1000, Inner::Http(t)), + } + } + + fn from_tls_stream(t: StreamOwned) -> Stream { + Stream { + inner: BufReader::with_capacity(1000, Inner::Https(t)), + } + } + // Check if the server has closed a stream by performing a one-byte // non-blocking read. If this returns EOF, the server has closed the // connection: return true. If this returns WouldBlock (aka EAGAIN), @@ -198,16 +237,6 @@ impl Read for Inner { } } -impl BufRead for DeadlineStream { - fn fill_buf(&mut self) -> io::Result<&[u8]> { - self.stream.fill_buf() - } - - fn consume(&mut self, amt: usize) { - self.stream.consume(amt) - } -} - impl BufRead for Stream { fn fill_buf(&mut self) -> io::Result<&[u8]> { self.inner.fill_buf() @@ -279,11 +308,7 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result // let port = unit.url.port().unwrap_or(80); - connect_host(unit, hostname, port) - .map(Inner::Http) - .map(|h| Stream { - inner: BufReader::new(h), - }) + connect_host(unit, hostname, port).map(Stream::from_tcp_stream) } #[cfg(all(feature = "tls", feature = "native-certs"))] @@ -327,9 +352,7 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result Result { From 6a22c54ba2c7c75ec45ddf01f094424f6c776e9f Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Sat, 28 Nov 2020 22:34:32 -0800 Subject: [PATCH 3/6] Small cleanups. --- src/response.rs | 13 +------------ src/stream.rs | 4 ++-- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/response.rs b/src/response.rs index 8e032d9..441f6be 100644 --- a/src/response.rs +++ b/src/response.rs @@ -526,18 +526,7 @@ pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option, s fn read_next_line(reader: &mut impl BufRead) -> io::Result { let mut s = String::new(); - let result = reader.read_line(&mut s).map_err(|e| { - // On unix-y platforms set_read_timeout and set_write_timeout - // causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut. - // Since the socket most definitely not set_nonblocking(true), - // we can safely normalize WouldBlock to TimedOut - if e.kind() == io::ErrorKind::WouldBlock { - io::Error::new(io::ErrorKind::TimedOut, "timed out reading headers") - } else { - e - } - }); - if result? == 0 { + if reader.read_line(&mut s)? == 0 { return Err(io::Error::new( io::ErrorKind::ConnectionAborted, "Unexpected EOF", diff --git a/src/stream.rs b/src/stream.rs index 717b590..98cea90 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -138,13 +138,13 @@ impl Stream { fn from_tcp_stream(t: TcpStream) -> Stream { Stream { - inner: BufReader::with_capacity(1000, Inner::Http(t)), + inner: BufReader::new(Inner::Http(t)), } } fn from_tls_stream(t: StreamOwned) -> Stream { Stream { - inner: BufReader::with_capacity(1000, Inner::Https(t)), + inner: BufReader::new(Inner::Https(t)), } } From a286a7a22d358e060c711a4c90c134ed8516e216 Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Sat, 28 Nov 2020 23:33:28 -0800 Subject: [PATCH 4/6] Make DeadlineStream Read use the BufRead. --- src/stream.rs | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/src/stream.rs b/src/stream.rs index 98cea90..5827e53 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -82,23 +82,16 @@ impl BufRead for DeadlineStream { impl Read for DeadlineStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - if let Some(deadline) = self.deadline { - let timeout = time_until_deadline(deadline)?; - if let Some(socket) = self.stream.socket() { - socket.set_read_timeout(Some(timeout))?; - socket.set_write_timeout(Some(timeout))?; - } - } - self.stream.read(buf).map_err(|e| { - // On unix-y platforms set_read_timeout and set_write_timeout - // causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut. - // Since the socket most definitely not set_nonblocking(true), - // we can safely normalize WouldBlock to TimedOut - if e.kind() == io::ErrorKind::WouldBlock { - return io_err_timeout("timed out reading response".to_string()); - } - e - }) + // All reads on a DeadlineStream use the BufRead impl. This ensures + // that we have a chance to set the correct timeout before each recv + // syscall. + // Copied from the BufReader implementation of `read()`. + let nread = { + let mut rem = self.fill_buf()?; + rem.read(buf)? + }; + self.consume(nread); + Ok(nread) } } From 6b6a59f215933caf137c8aeaaed6ec41914b3a4f Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Sat, 28 Nov 2020 23:56:56 -0800 Subject: [PATCH 5/6] Fix non-tls case. --- src/stream.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/stream.rs b/src/stream.rs index 5827e53..f47bbf1 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -135,6 +135,7 @@ impl Stream { } } + #[cfg(tls)] fn from_tls_stream(t: StreamOwned) -> Stream { Stream { inner: BufReader::new(Inner::Https(t)), From 36b307423cf23825c636c39e14d1f8841bfb300b Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Sun, 29 Nov 2020 00:02:33 -0800 Subject: [PATCH 6/6] Fix non-tls case correctly. --- src/stream.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stream.rs b/src/stream.rs index f47bbf1..bb3419a 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -135,7 +135,7 @@ impl Stream { } } - #[cfg(tls)] + #[cfg(feature = "tls")] fn from_tls_stream(t: StreamOwned) -> Stream { Stream { inner: BufReader::new(Inner::Https(t)),