Move BufReader up the stack in Stream.

Stream now has an `Inner` enum, and wraps an instance of that enum in a
BufReader. This allows Stream itself to implement BufRead trivially, and
simplify some of the match dispatching. Having Stream implement BufRead
means we can make use of `read_line` instead of our own `read_next_line`
(not done in this PR yet).

Also, removes the `Cursor` variant of the Inner enum in favor of using
the `Test` variant everywhere, since it's strictly more powerful.
This commit is contained in:
Jacob Hoffman-Andrews
2020-11-28 12:04:28 -08:00
parent a0b88926fa
commit 131a0264d1
4 changed files with 86 additions and 93 deletions

View File

@@ -229,7 +229,7 @@ fn pool_connections_limit() {
proxy: None,
});
for key in poolkeys.clone() {
pool.add(key, Stream::Cursor(std::io::Cursor::new(vec![])));
pool.add(key, Stream::from_vec(vec![]))
}
assert_eq!(pool.len(), pool.max_idle_connections);
@@ -254,10 +254,7 @@ fn pool_per_host_connections_limit() {
};
for _ in 0..pool.max_idle_connections_per_host * 2 {
pool.add(
poolkey.clone(),
Stream::Cursor(std::io::Cursor::new(vec![])),
);
pool.add(poolkey.clone(), Stream::from_vec(vec![]))
}
assert_eq!(pool.len(), pool.max_idle_connections_per_host);
@@ -275,15 +272,12 @@ fn pool_checks_proxy() {
let pool = ConnectionPool::new_with_limits(10, 1);
let url = Url::parse("zzz:///example.com").unwrap();
pool.add(
PoolKey::new(&url, None),
Stream::Cursor(std::io::Cursor::new(vec![])),
);
pool.add(PoolKey::new(&url, None), Stream::from_vec(vec![]));
assert_eq!(pool.len(), 1);
pool.add(
PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())),
Stream::Cursor(std::io::Cursor::new(vec![])),
Stream::from_vec(vec![]),
);
assert_eq!(pool.len(), 2);
@@ -292,7 +286,7 @@ fn pool_checks_proxy() {
&url,
Some(Proxy::new("user:password@localhost:9999").unwrap()),
),
Stream::Cursor(std::io::Cursor::new(vec![])),
Stream::from_vec(vec![]),
);
assert_eq!(pool.len(), 3);
}

View File

@@ -1,6 +1,6 @@
use std::fmt;
use std::io::{self, Cursor, Read};
use std::io::{self, Read};
use std::str::FromStr;
use std::{fmt, io::BufRead};
use chunked_transfer::Decoder as ChunkDecoder;
@@ -425,7 +425,7 @@ impl Response {
/// let resp = ureq::Response::do_from_read(read);
///
/// assert_eq!(resp.status(), 401);
pub(crate) fn do_from_read(mut reader: impl Read) -> Result<Response, Error> {
pub(crate) fn do_from_read(mut reader: impl BufRead) -> Result<Response, Error> {
//
// HTTP/1.1 200 OK\r\n
let status_line = read_next_line(&mut reader)?;
@@ -455,8 +455,8 @@ impl Response {
}
#[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> {
self.stream.as_ref().unwrap().to_write_vec()
pub fn to_write_vec(self) -> Vec<u8> {
self.stream.unwrap().to_write_vec()
}
}
@@ -508,10 +508,9 @@ impl FromStr for Response {
/// assert_eq!(body, "Hello World!!!");
/// ```
fn from_str(s: &str) -> Result<Self, Self::Err> {
let bytes = s.as_bytes().to_owned();
let mut cursor = Cursor::new(bytes);
let mut resp = Self::do_from_read(&mut cursor)?;
set_stream(&mut resp, "".into(), None, Stream::Cursor(cursor));
let mut stream = Stream::from_vec(s.as_bytes().to_owned());
let mut resp = Self::do_from_read(&mut stream)?;
set_stream(&mut resp, "".into(), None, stream);
Ok(resp)
}
}

View File

@@ -1,10 +1,10 @@
use log::debug;
use std::fmt;
use std::io::{self, BufRead, BufReader, Cursor, Read, Write};
use std::io::{self, BufRead, BufReader, Read, Write};
use std::net::SocketAddr;
use std::net::TcpStream;
use std::time::Duration;
use std::time::Instant;
use std::{fmt, io::Cursor};
use chunked_transfer::Decoder as ChunkDecoder;
@@ -21,14 +21,15 @@ use crate::{error::Error, proxy::Proto};
use crate::error::ErrorKind;
use crate::unit::Unit;
#[allow(clippy::large_enum_variant)]
pub enum Stream {
Http(BufReader<TcpStream>),
pub(crate) struct Stream {
inner: BufReader<Inner>,
}
enum Inner {
Http(TcpStream),
#[cfg(feature = "tls")]
Https(BufReader<rustls::StreamOwned<rustls::ClientSession, TcpStream>>),
Cursor(Cursor<Vec<u8>>),
#[cfg(test)]
Test(Box<dyn BufRead + Send + Sync>, Vec<u8>),
Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
Test(Box<dyn Read + Send + Sync>, Vec<u8>),
}
// DeadlineStream wraps a stream such that read() will return an error
@@ -36,7 +37,7 @@ pub enum Stream {
// TcpStream to ensure read() doesn't block beyond the deadline.
// When the From trait is used to turn a DeadlineStream back into a
// Stream (by PoolReturningRead), the timeouts are removed.
pub struct DeadlineStream {
pub(crate) struct DeadlineStream {
stream: Stream,
deadline: Option<Instant>,
}
@@ -91,22 +92,23 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Stream[{}]",
match self {
Stream::Http(_) => "http",
let mut result = f.debug_struct("Stream");
match self.inner.get_ref() {
Inner::Http(tcpstream) => result.field("tcp", tcpstream),
#[cfg(feature = "tls")]
Stream::Https(_) => "https",
Stream::Cursor(_) => "cursor",
#[cfg(test)]
Stream::Test(_, _) => "test",
}
)
Inner::Https(tlsstream) => result.field("tls", tlsstream.get_ref()),
Inner::Test(_, _) => result.field("test", &String::new()),
};
result.finish()
}
}
impl Stream {
pub(crate) fn from_vec(v: Vec<u8>) -> Stream {
Stream {
inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])),
}
}
// Check if the server has closed a stream by performing a one-byte
// non-blocking read. If this returns EOF, the server has closed the
// connection: return true. If this returns WouldBlock (aka EAGAIN),
@@ -134,10 +136,10 @@ impl Stream {
}
}
pub fn is_poolable(&self) -> bool {
match self {
Stream::Http(_) => true,
match self.inner.get_ref() {
Inner::Http(_) => true,
#[cfg(feature = "tls")]
Stream::Https(_) => true,
Inner::Https(_) => true,
_ => false,
}
}
@@ -154,10 +156,10 @@ impl Stream {
}
pub(crate) fn socket(&self) -> Option<&TcpStream> {
match self {
Stream::Http(b) => Some(b.get_ref()),
match self.inner.get_ref() {
Inner::Http(b) => Some(b),
#[cfg(feature = "tls")]
Stream::Https(b) => Some(&b.get_ref().sock),
Inner::Https(b) => Some(&b.get_ref()),
_ => None,
}
}
@@ -171,9 +173,9 @@ impl Stream {
}
#[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> {
match self {
Stream::Test(_, writer) => writer.clone(),
pub fn to_write_vec(self) -> Vec<u8> {
match self.inner.into_inner() {
Inner::Test(_, writer) => writer.clone(),
_ => panic!("to_write_vec on non Test stream"),
}
}
@@ -181,38 +183,38 @@ impl Stream {
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Stream::Http(sock) => sock.read(buf),
#[cfg(feature = "tls")]
Stream::Https(stream) => read_https(stream, buf),
Stream::Cursor(read) => read.read(buf),
#[cfg(test)]
Stream::Test(reader, _) => reader.read(buf),
self.inner.read(buf)
}
}
impl Read for Inner {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Inner::Http(sock) => sock.read(buf),
#[cfg(feature = "tls")]
Inner::Https(stream) => read_https(stream, buf),
Inner::Test(reader, _) => reader.read(buf),
}
}
}
impl BufRead for DeadlineStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.stream.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
impl BufRead for Stream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
match self {
Stream::Http(r) => r.fill_buf(),
#[cfg(feature = "tls")]
Stream::Https(r) => r.fill_buf(),
Stream::Cursor(r) => r.fill_buf(),
#[cfg(test)]
Stream::Test(r, _) => r.fill_buf(),
}
self.inner.fill_buf()
}
fn consume(&mut self, amt: usize) {
match self {
Stream::Http(r) => r.consume(amt),
#[cfg(feature = "tls")]
Stream::Https(r) => r.consume(amt),
Stream::Cursor(r) => r.consume(amt),
#[cfg(test)]
Stream::Test(r, _) => r.consume(amt),
}
self.inner.consume(amt)
}
}
@@ -228,7 +230,7 @@ where
#[cfg(feature = "tls")]
fn read_https(
stream: &mut BufReader<StreamOwned<ClientSession, TcpStream>>,
stream: &mut StreamOwned<ClientSession, TcpStream>,
buf: &mut [u8],
) -> io::Result<usize> {
match stream.read(buf) {
@@ -256,23 +258,19 @@ fn is_close_notify(e: &std::io::Error) -> bool {
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Stream::Http(sock) => sock.get_mut().write(buf),
match self.inner.get_mut() {
Inner::Http(sock) => sock.write(buf),
#[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().write(buf),
Stream::Cursor(_) => panic!("Write to read only stream"),
#[cfg(test)]
Stream::Test(_, writer) => writer.write(buf),
Inner::Https(stream) => stream.write(buf),
Inner::Test(_, writer) => writer.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Stream::Http(sock) => sock.get_mut().flush(),
match self.inner.get_mut() {
Inner::Http(sock) => sock.flush(),
#[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().flush(),
Stream::Cursor(_) => panic!("Flush read only stream"),
#[cfg(test)]
Stream::Test(_, writer) => writer.flush(),
Inner::Https(stream) => stream.flush(),
Inner::Test(_, writer) => writer.flush(),
}
}
}
@@ -282,8 +280,10 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error>
let port = unit.url.port().unwrap_or(80);
connect_host(unit, hostname, port)
.map(BufReader::new)
.map(Stream::Http)
.map(Inner::Http)
.map(|h| Stream {
inner: BufReader::new(h),
})
}
#[cfg(all(feature = "tls", feature = "native-certs"))]
@@ -327,7 +327,9 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error
let stream = rustls::StreamOwned::new(sess, sock);
Ok(Stream::Https(BufReader::new(stream)))
Ok(Stream {
inner: BufReader::new(Inner::Https(stream)),
})
}
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {

View File

@@ -3,7 +3,7 @@ use crate::stream::Stream;
use crate::unit::Unit;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::io::{Cursor, Write};
use std::io::Write;
use std::sync::{Arc, Mutex};
mod agent_test;
@@ -29,7 +29,7 @@ where
}
#[allow(clippy::write_with_newline)]
pub fn make_response(
pub(crate) fn make_response(
status: u16,
status_text: &str,
headers: Vec<&str>,
@@ -42,9 +42,7 @@ pub fn make_response(
}
write!(&mut buf, "\r\n").ok();
buf.append(&mut body);
let cursor = Cursor::new(buf);
let write: Vec<u8> = vec![];
Ok(Stream::Test(Box::new(cursor), write))
Ok(Stream::from_vec(buf))
}
pub(crate) fn resolve_handler(unit: &Unit) -> Result<Stream, Error> {