Move BufReader up the stack in Stream.

Stream now has an `Inner` enum, and wraps an instance of that enum in a
BufReader. This allows Stream itself to implement BufRead trivially, and
simplify some of the match dispatching. Having Stream implement BufRead
means we can make use of `read_line` instead of our own `read_next_line`
(not done in this PR yet).

Also, removes the `Cursor` variant of the Inner enum in favor of using
the `Test` variant everywhere, since it's strictly more powerful.
This commit is contained in:
Jacob Hoffman-Andrews
2020-11-28 12:04:28 -08:00
parent a0b88926fa
commit 131a0264d1
4 changed files with 86 additions and 93 deletions

View File

@@ -229,7 +229,7 @@ fn pool_connections_limit() {
proxy: None, proxy: None,
}); });
for key in poolkeys.clone() { for key in poolkeys.clone() {
pool.add(key, Stream::Cursor(std::io::Cursor::new(vec![]))); pool.add(key, Stream::from_vec(vec![]))
} }
assert_eq!(pool.len(), pool.max_idle_connections); assert_eq!(pool.len(), pool.max_idle_connections);
@@ -254,10 +254,7 @@ fn pool_per_host_connections_limit() {
}; };
for _ in 0..pool.max_idle_connections_per_host * 2 { for _ in 0..pool.max_idle_connections_per_host * 2 {
pool.add( pool.add(poolkey.clone(), Stream::from_vec(vec![]))
poolkey.clone(),
Stream::Cursor(std::io::Cursor::new(vec![])),
);
} }
assert_eq!(pool.len(), pool.max_idle_connections_per_host); assert_eq!(pool.len(), pool.max_idle_connections_per_host);
@@ -275,15 +272,12 @@ fn pool_checks_proxy() {
let pool = ConnectionPool::new_with_limits(10, 1); let pool = ConnectionPool::new_with_limits(10, 1);
let url = Url::parse("zzz:///example.com").unwrap(); let url = Url::parse("zzz:///example.com").unwrap();
pool.add( pool.add(PoolKey::new(&url, None), Stream::from_vec(vec![]));
PoolKey::new(&url, None),
Stream::Cursor(std::io::Cursor::new(vec![])),
);
assert_eq!(pool.len(), 1); assert_eq!(pool.len(), 1);
pool.add( pool.add(
PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())), PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())),
Stream::Cursor(std::io::Cursor::new(vec![])), Stream::from_vec(vec![]),
); );
assert_eq!(pool.len(), 2); assert_eq!(pool.len(), 2);
@@ -292,7 +286,7 @@ fn pool_checks_proxy() {
&url, &url,
Some(Proxy::new("user:password@localhost:9999").unwrap()), Some(Proxy::new("user:password@localhost:9999").unwrap()),
), ),
Stream::Cursor(std::io::Cursor::new(vec![])), Stream::from_vec(vec![]),
); );
assert_eq!(pool.len(), 3); assert_eq!(pool.len(), 3);
} }

View File

@@ -1,6 +1,6 @@
use std::fmt; use std::io::{self, Read};
use std::io::{self, Cursor, Read};
use std::str::FromStr; use std::str::FromStr;
use std::{fmt, io::BufRead};
use chunked_transfer::Decoder as ChunkDecoder; use chunked_transfer::Decoder as ChunkDecoder;
@@ -425,7 +425,7 @@ impl Response {
/// let resp = ureq::Response::do_from_read(read); /// let resp = ureq::Response::do_from_read(read);
/// ///
/// assert_eq!(resp.status(), 401); /// assert_eq!(resp.status(), 401);
pub(crate) fn do_from_read(mut reader: impl Read) -> Result<Response, Error> { pub(crate) fn do_from_read(mut reader: impl BufRead) -> Result<Response, Error> {
// //
// HTTP/1.1 200 OK\r\n // HTTP/1.1 200 OK\r\n
let status_line = read_next_line(&mut reader)?; let status_line = read_next_line(&mut reader)?;
@@ -455,8 +455,8 @@ impl Response {
} }
#[cfg(test)] #[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> { pub fn to_write_vec(self) -> Vec<u8> {
self.stream.as_ref().unwrap().to_write_vec() self.stream.unwrap().to_write_vec()
} }
} }
@@ -508,10 +508,9 @@ impl FromStr for Response {
/// assert_eq!(body, "Hello World!!!"); /// assert_eq!(body, "Hello World!!!");
/// ``` /// ```
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
let bytes = s.as_bytes().to_owned(); let mut stream = Stream::from_vec(s.as_bytes().to_owned());
let mut cursor = Cursor::new(bytes); let mut resp = Self::do_from_read(&mut stream)?;
let mut resp = Self::do_from_read(&mut cursor)?; set_stream(&mut resp, "".into(), None, stream);
set_stream(&mut resp, "".into(), None, Stream::Cursor(cursor));
Ok(resp) Ok(resp)
} }
} }

View File

@@ -1,10 +1,10 @@
use log::debug; use log::debug;
use std::fmt; use std::io::{self, BufRead, BufReader, Read, Write};
use std::io::{self, BufRead, BufReader, Cursor, Read, Write};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::net::TcpStream; use std::net::TcpStream;
use std::time::Duration; use std::time::Duration;
use std::time::Instant; use std::time::Instant;
use std::{fmt, io::Cursor};
use chunked_transfer::Decoder as ChunkDecoder; use chunked_transfer::Decoder as ChunkDecoder;
@@ -21,14 +21,15 @@ use crate::{error::Error, proxy::Proto};
use crate::error::ErrorKind; use crate::error::ErrorKind;
use crate::unit::Unit; use crate::unit::Unit;
#[allow(clippy::large_enum_variant)] pub(crate) struct Stream {
pub enum Stream { inner: BufReader<Inner>,
Http(BufReader<TcpStream>), }
enum Inner {
Http(TcpStream),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Https(BufReader<rustls::StreamOwned<rustls::ClientSession, TcpStream>>), Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
Cursor(Cursor<Vec<u8>>), Test(Box<dyn Read + Send + Sync>, Vec<u8>),
#[cfg(test)]
Test(Box<dyn BufRead + Send + Sync>, Vec<u8>),
} }
// DeadlineStream wraps a stream such that read() will return an error // DeadlineStream wraps a stream such that read() will return an error
@@ -36,7 +37,7 @@ pub enum Stream {
// TcpStream to ensure read() doesn't block beyond the deadline. // TcpStream to ensure read() doesn't block beyond the deadline.
// When the From trait is used to turn a DeadlineStream back into a // When the From trait is used to turn a DeadlineStream back into a
// Stream (by PoolReturningRead), the timeouts are removed. // Stream (by PoolReturningRead), the timeouts are removed.
pub struct DeadlineStream { pub(crate) struct DeadlineStream {
stream: Stream, stream: Stream,
deadline: Option<Instant>, deadline: Option<Instant>,
} }
@@ -91,22 +92,23 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
impl fmt::Debug for Stream { impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( let mut result = f.debug_struct("Stream");
f, match self.inner.get_ref() {
"Stream[{}]", Inner::Http(tcpstream) => result.field("tcp", tcpstream),
match self { #[cfg(feature = "tls")]
Stream::Http(_) => "http", Inner::Https(tlsstream) => result.field("tls", tlsstream.get_ref()),
#[cfg(feature = "tls")] Inner::Test(_, _) => result.field("test", &String::new()),
Stream::Https(_) => "https", };
Stream::Cursor(_) => "cursor", result.finish()
#[cfg(test)]
Stream::Test(_, _) => "test",
}
)
} }
} }
impl Stream { impl Stream {
pub(crate) fn from_vec(v: Vec<u8>) -> Stream {
Stream {
inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])),
}
}
// Check if the server has closed a stream by performing a one-byte // Check if the server has closed a stream by performing a one-byte
// non-blocking read. If this returns EOF, the server has closed the // non-blocking read. If this returns EOF, the server has closed the
// connection: return true. If this returns WouldBlock (aka EAGAIN), // connection: return true. If this returns WouldBlock (aka EAGAIN),
@@ -134,10 +136,10 @@ impl Stream {
} }
} }
pub fn is_poolable(&self) -> bool { pub fn is_poolable(&self) -> bool {
match self { match self.inner.get_ref() {
Stream::Http(_) => true, Inner::Http(_) => true,
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(_) => true, Inner::Https(_) => true,
_ => false, _ => false,
} }
} }
@@ -154,10 +156,10 @@ impl Stream {
} }
pub(crate) fn socket(&self) -> Option<&TcpStream> { pub(crate) fn socket(&self) -> Option<&TcpStream> {
match self { match self.inner.get_ref() {
Stream::Http(b) => Some(b.get_ref()), Inner::Http(b) => Some(b),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(b) => Some(&b.get_ref().sock), Inner::Https(b) => Some(&b.get_ref()),
_ => None, _ => None,
} }
} }
@@ -171,48 +173,48 @@ impl Stream {
} }
#[cfg(test)] #[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> { pub fn to_write_vec(self) -> Vec<u8> {
match self { match self.inner.into_inner() {
Stream::Test(_, writer) => writer.clone(), Inner::Test(_, writer) => writer.clone(),
_ => panic!("to_write_vec on non Test stream"), _ => panic!("to_write_vec on non Test stream"),
} }
} }
} }
impl Read for Stream { impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl Read for Inner {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self { match self {
Stream::Http(sock) => sock.read(buf), Inner::Http(sock) => sock.read(buf),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(stream) => read_https(stream, buf), Inner::Https(stream) => read_https(stream, buf),
Stream::Cursor(read) => read.read(buf), Inner::Test(reader, _) => reader.read(buf),
#[cfg(test)]
Stream::Test(reader, _) => reader.read(buf),
} }
} }
} }
impl BufRead for DeadlineStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.stream.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
impl BufRead for Stream { impl BufRead for Stream {
fn fill_buf(&mut self) -> io::Result<&[u8]> { fn fill_buf(&mut self) -> io::Result<&[u8]> {
match self { self.inner.fill_buf()
Stream::Http(r) => r.fill_buf(),
#[cfg(feature = "tls")]
Stream::Https(r) => r.fill_buf(),
Stream::Cursor(r) => r.fill_buf(),
#[cfg(test)]
Stream::Test(r, _) => r.fill_buf(),
}
} }
fn consume(&mut self, amt: usize) { fn consume(&mut self, amt: usize) {
match self { self.inner.consume(amt)
Stream::Http(r) => r.consume(amt),
#[cfg(feature = "tls")]
Stream::Https(r) => r.consume(amt),
Stream::Cursor(r) => r.consume(amt),
#[cfg(test)]
Stream::Test(r, _) => r.consume(amt),
}
} }
} }
@@ -228,7 +230,7 @@ where
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
fn read_https( fn read_https(
stream: &mut BufReader<StreamOwned<ClientSession, TcpStream>>, stream: &mut StreamOwned<ClientSession, TcpStream>,
buf: &mut [u8], buf: &mut [u8],
) -> io::Result<usize> { ) -> io::Result<usize> {
match stream.read(buf) { match stream.read(buf) {
@@ -256,23 +258,19 @@ fn is_close_notify(e: &std::io::Error) -> bool {
impl Write for Stream { impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self { match self.inner.get_mut() {
Stream::Http(sock) => sock.get_mut().write(buf), Inner::Http(sock) => sock.write(buf),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().write(buf), Inner::Https(stream) => stream.write(buf),
Stream::Cursor(_) => panic!("Write to read only stream"), Inner::Test(_, writer) => writer.write(buf),
#[cfg(test)]
Stream::Test(_, writer) => writer.write(buf),
} }
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
match self { match self.inner.get_mut() {
Stream::Http(sock) => sock.get_mut().flush(), Inner::Http(sock) => sock.flush(),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().flush(), Inner::Https(stream) => stream.flush(),
Stream::Cursor(_) => panic!("Flush read only stream"), Inner::Test(_, writer) => writer.flush(),
#[cfg(test)]
Stream::Test(_, writer) => writer.flush(),
} }
} }
} }
@@ -282,8 +280,10 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error>
let port = unit.url.port().unwrap_or(80); let port = unit.url.port().unwrap_or(80);
connect_host(unit, hostname, port) connect_host(unit, hostname, port)
.map(BufReader::new) .map(Inner::Http)
.map(Stream::Http) .map(|h| Stream {
inner: BufReader::new(h),
})
} }
#[cfg(all(feature = "tls", feature = "native-certs"))] #[cfg(all(feature = "tls", feature = "native-certs"))]
@@ -327,7 +327,9 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error
let stream = rustls::StreamOwned::new(sess, sock); let stream = rustls::StreamOwned::new(sess, sock);
Ok(Stream::Https(BufReader::new(stream))) Ok(Stream {
inner: BufReader::new(Inner::Https(stream)),
})
} }
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> { pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {

View File

@@ -3,7 +3,7 @@ use crate::stream::Stream;
use crate::unit::Unit; use crate::unit::Unit;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::{Cursor, Write}; use std::io::Write;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
mod agent_test; mod agent_test;
@@ -29,7 +29,7 @@ where
} }
#[allow(clippy::write_with_newline)] #[allow(clippy::write_with_newline)]
pub fn make_response( pub(crate) fn make_response(
status: u16, status: u16,
status_text: &str, status_text: &str,
headers: Vec<&str>, headers: Vec<&str>,
@@ -42,9 +42,7 @@ pub fn make_response(
} }
write!(&mut buf, "\r\n").ok(); write!(&mut buf, "\r\n").ok();
buf.append(&mut body); buf.append(&mut body);
let cursor = Cursor::new(buf); Ok(Stream::from_vec(buf))
let write: Vec<u8> = vec![];
Ok(Stream::Test(Box::new(cursor), write))
} }
pub(crate) fn resolve_handler(unit: &Unit) -> Result<Stream, Error> { pub(crate) fn resolve_handler(unit: &Unit) -> Result<Stream, Error> {