From 131a0264d1b67a67e4f6425542bac753fddb1b6a Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Sat, 28 Nov 2020 12:04:28 -0800 Subject: [PATCH] Move BufReader up the stack in Stream. Stream now has an `Inner` enum, and wraps an instance of that enum in a BufReader. This allows Stream itself to implement BufRead trivially, and simplify some of the match dispatching. Having Stream implement BufRead means we can make use of `read_line` instead of our own `read_next_line` (not done in this PR yet). Also, removes the `Cursor` variant of the Inner enum in favor of using the `Test` variant everywhere, since it's strictly more powerful. --- src/pool.rs | 16 ++---- src/response.rs | 17 +++--- src/stream.rs | 138 ++++++++++++++++++++++++------------------------ src/test/mod.rs | 8 ++- 4 files changed, 86 insertions(+), 93 deletions(-) diff --git a/src/pool.rs b/src/pool.rs index 6f732b3..f6a154e 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -229,7 +229,7 @@ fn pool_connections_limit() { proxy: None, }); for key in poolkeys.clone() { - pool.add(key, Stream::Cursor(std::io::Cursor::new(vec![]))); + pool.add(key, Stream::from_vec(vec![])) } assert_eq!(pool.len(), pool.max_idle_connections); @@ -254,10 +254,7 @@ fn pool_per_host_connections_limit() { }; for _ in 0..pool.max_idle_connections_per_host * 2 { - pool.add( - poolkey.clone(), - Stream::Cursor(std::io::Cursor::new(vec![])), - ); + pool.add(poolkey.clone(), Stream::from_vec(vec![])) } assert_eq!(pool.len(), pool.max_idle_connections_per_host); @@ -275,15 +272,12 @@ fn pool_checks_proxy() { let pool = ConnectionPool::new_with_limits(10, 1); let url = Url::parse("zzz:///example.com").unwrap(); - pool.add( - PoolKey::new(&url, None), - Stream::Cursor(std::io::Cursor::new(vec![])), - ); + pool.add(PoolKey::new(&url, None), Stream::from_vec(vec![])); assert_eq!(pool.len(), 1); pool.add( PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())), - Stream::Cursor(std::io::Cursor::new(vec![])), + Stream::from_vec(vec![]), ); assert_eq!(pool.len(), 2); @@ -292,7 +286,7 @@ fn pool_checks_proxy() { &url, Some(Proxy::new("user:password@localhost:9999").unwrap()), ), - Stream::Cursor(std::io::Cursor::new(vec![])), + Stream::from_vec(vec![]), ); assert_eq!(pool.len(), 3); } diff --git a/src/response.rs b/src/response.rs index 18320e1..803c82b 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,6 +1,6 @@ -use std::fmt; -use std::io::{self, Cursor, Read}; +use std::io::{self, Read}; use std::str::FromStr; +use std::{fmt, io::BufRead}; use chunked_transfer::Decoder as ChunkDecoder; @@ -425,7 +425,7 @@ impl Response { /// let resp = ureq::Response::do_from_read(read); /// /// assert_eq!(resp.status(), 401); - pub(crate) fn do_from_read(mut reader: impl Read) -> Result { + pub(crate) fn do_from_read(mut reader: impl BufRead) -> Result { // // HTTP/1.1 200 OK\r\n let status_line = read_next_line(&mut reader)?; @@ -455,8 +455,8 @@ impl Response { } #[cfg(test)] - pub fn to_write_vec(&self) -> Vec { - self.stream.as_ref().unwrap().to_write_vec() + pub fn to_write_vec(self) -> Vec { + self.stream.unwrap().to_write_vec() } } @@ -508,10 +508,9 @@ impl FromStr for Response { /// assert_eq!(body, "Hello World!!!"); /// ``` fn from_str(s: &str) -> Result { - let bytes = s.as_bytes().to_owned(); - let mut cursor = Cursor::new(bytes); - let mut resp = Self::do_from_read(&mut cursor)?; - set_stream(&mut resp, "".into(), None, Stream::Cursor(cursor)); + let mut stream = Stream::from_vec(s.as_bytes().to_owned()); + let mut resp = Self::do_from_read(&mut stream)?; + set_stream(&mut resp, "".into(), None, stream); Ok(resp) } } diff --git a/src/stream.rs b/src/stream.rs index b551094..a885b5a 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,10 +1,10 @@ use log::debug; -use std::fmt; -use std::io::{self, BufRead, BufReader, Cursor, Read, Write}; +use std::io::{self, BufRead, BufReader, Read, Write}; use std::net::SocketAddr; use std::net::TcpStream; use std::time::Duration; use std::time::Instant; +use std::{fmt, io::Cursor}; use chunked_transfer::Decoder as ChunkDecoder; @@ -21,14 +21,15 @@ use crate::{error::Error, proxy::Proto}; use crate::error::ErrorKind; use crate::unit::Unit; -#[allow(clippy::large_enum_variant)] -pub enum Stream { - Http(BufReader), +pub(crate) struct Stream { + inner: BufReader, +} + +enum Inner { + Http(TcpStream), #[cfg(feature = "tls")] - Https(BufReader>), - Cursor(Cursor>), - #[cfg(test)] - Test(Box, Vec), + Https(rustls::StreamOwned), + Test(Box, Vec), } // DeadlineStream wraps a stream such that read() will return an error @@ -36,7 +37,7 @@ pub enum Stream { // TcpStream to ensure read() doesn't block beyond the deadline. // When the From trait is used to turn a DeadlineStream back into a // Stream (by PoolReturningRead), the timeouts are removed. -pub struct DeadlineStream { +pub(crate) struct DeadlineStream { stream: Stream, deadline: Option, } @@ -91,22 +92,23 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error { impl fmt::Debug for Stream { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "Stream[{}]", - match self { - Stream::Http(_) => "http", - #[cfg(feature = "tls")] - Stream::Https(_) => "https", - Stream::Cursor(_) => "cursor", - #[cfg(test)] - Stream::Test(_, _) => "test", - } - ) + let mut result = f.debug_struct("Stream"); + match self.inner.get_ref() { + Inner::Http(tcpstream) => result.field("tcp", tcpstream), + #[cfg(feature = "tls")] + Inner::Https(tlsstream) => result.field("tls", tlsstream.get_ref()), + Inner::Test(_, _) => result.field("test", &String::new()), + }; + result.finish() } } impl Stream { + pub(crate) fn from_vec(v: Vec) -> Stream { + Stream { + inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])), + } + } // Check if the server has closed a stream by performing a one-byte // non-blocking read. If this returns EOF, the server has closed the // connection: return true. If this returns WouldBlock (aka EAGAIN), @@ -134,10 +136,10 @@ impl Stream { } } pub fn is_poolable(&self) -> bool { - match self { - Stream::Http(_) => true, + match self.inner.get_ref() { + Inner::Http(_) => true, #[cfg(feature = "tls")] - Stream::Https(_) => true, + Inner::Https(_) => true, _ => false, } } @@ -154,10 +156,10 @@ impl Stream { } pub(crate) fn socket(&self) -> Option<&TcpStream> { - match self { - Stream::Http(b) => Some(b.get_ref()), + match self.inner.get_ref() { + Inner::Http(b) => Some(b), #[cfg(feature = "tls")] - Stream::Https(b) => Some(&b.get_ref().sock), + Inner::Https(b) => Some(&b.get_ref()), _ => None, } } @@ -171,48 +173,48 @@ impl Stream { } #[cfg(test)] - pub fn to_write_vec(&self) -> Vec { - match self { - Stream::Test(_, writer) => writer.clone(), + pub fn to_write_vec(self) -> Vec { + match self.inner.into_inner() { + Inner::Test(_, writer) => writer.clone(), _ => panic!("to_write_vec on non Test stream"), } } } impl Read for Stream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.inner.read(buf) + } +} + +impl Read for Inner { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { - Stream::Http(sock) => sock.read(buf), + Inner::Http(sock) => sock.read(buf), #[cfg(feature = "tls")] - Stream::Https(stream) => read_https(stream, buf), - Stream::Cursor(read) => read.read(buf), - #[cfg(test)] - Stream::Test(reader, _) => reader.read(buf), + Inner::Https(stream) => read_https(stream, buf), + Inner::Test(reader, _) => reader.read(buf), } } } +impl BufRead for DeadlineStream { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + self.stream.fill_buf() + } + + fn consume(&mut self, amt: usize) { + self.stream.consume(amt) + } +} + impl BufRead for Stream { fn fill_buf(&mut self) -> io::Result<&[u8]> { - match self { - Stream::Http(r) => r.fill_buf(), - #[cfg(feature = "tls")] - Stream::Https(r) => r.fill_buf(), - Stream::Cursor(r) => r.fill_buf(), - #[cfg(test)] - Stream::Test(r, _) => r.fill_buf(), - } + self.inner.fill_buf() } fn consume(&mut self, amt: usize) { - match self { - Stream::Http(r) => r.consume(amt), - #[cfg(feature = "tls")] - Stream::Https(r) => r.consume(amt), - Stream::Cursor(r) => r.consume(amt), - #[cfg(test)] - Stream::Test(r, _) => r.consume(amt), - } + self.inner.consume(amt) } } @@ -228,7 +230,7 @@ where #[cfg(feature = "tls")] fn read_https( - stream: &mut BufReader>, + stream: &mut StreamOwned, buf: &mut [u8], ) -> io::Result { match stream.read(buf) { @@ -256,23 +258,19 @@ fn is_close_notify(e: &std::io::Error) -> bool { impl Write for Stream { fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Stream::Http(sock) => sock.get_mut().write(buf), + match self.inner.get_mut() { + Inner::Http(sock) => sock.write(buf), #[cfg(feature = "tls")] - Stream::Https(stream) => stream.get_mut().write(buf), - Stream::Cursor(_) => panic!("Write to read only stream"), - #[cfg(test)] - Stream::Test(_, writer) => writer.write(buf), + Inner::Https(stream) => stream.write(buf), + Inner::Test(_, writer) => writer.write(buf), } } fn flush(&mut self) -> io::Result<()> { - match self { - Stream::Http(sock) => sock.get_mut().flush(), + match self.inner.get_mut() { + Inner::Http(sock) => sock.flush(), #[cfg(feature = "tls")] - Stream::Https(stream) => stream.get_mut().flush(), - Stream::Cursor(_) => panic!("Flush read only stream"), - #[cfg(test)] - Stream::Test(_, writer) => writer.flush(), + Inner::Https(stream) => stream.flush(), + Inner::Test(_, writer) => writer.flush(), } } } @@ -282,8 +280,10 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result let port = unit.url.port().unwrap_or(80); connect_host(unit, hostname, port) - .map(BufReader::new) - .map(Stream::Http) + .map(Inner::Http) + .map(|h| Stream { + inner: BufReader::new(h), + }) } #[cfg(all(feature = "tls", feature = "native-certs"))] @@ -327,7 +327,9 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result Result { diff --git a/src/test/mod.rs b/src/test/mod.rs index 09747cf..15a918d 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -3,7 +3,7 @@ use crate::stream::Stream; use crate::unit::Unit; use once_cell::sync::Lazy; use std::collections::HashMap; -use std::io::{Cursor, Write}; +use std::io::Write; use std::sync::{Arc, Mutex}; mod agent_test; @@ -29,7 +29,7 @@ where } #[allow(clippy::write_with_newline)] -pub fn make_response( +pub(crate) fn make_response( status: u16, status_text: &str, headers: Vec<&str>, @@ -42,9 +42,7 @@ pub fn make_response( } write!(&mut buf, "\r\n").ok(); buf.append(&mut body); - let cursor = Cursor::new(buf); - let write: Vec = vec![]; - Ok(Stream::Test(Box::new(cursor), write)) + Ok(Stream::from_vec(buf)) } pub(crate) fn resolve_handler(unit: &Unit) -> Result {