Move BufReader up the stack in Stream.

Stream now has an `Inner` enum, and wraps an instance of that enum in a
BufReader. This allows Stream itself to implement BufRead trivially, and
simplify some of the match dispatching. Having Stream implement BufRead
means we can make use of `read_line` instead of our own `read_next_line`
(not done in this PR yet).

Also, removes the `Cursor` variant of the Inner enum in favor of using
the `Test` variant everywhere, since it's strictly more powerful.
This commit is contained in:
Jacob Hoffman-Andrews
2020-11-28 12:04:28 -08:00
parent a0b88926fa
commit 131a0264d1
4 changed files with 86 additions and 93 deletions

View File

@@ -1,10 +1,10 @@
use log::debug;
use std::fmt;
use std::io::{self, BufRead, BufReader, Cursor, Read, Write};
use std::io::{self, BufRead, BufReader, Read, Write};
use std::net::SocketAddr;
use std::net::TcpStream;
use std::time::Duration;
use std::time::Instant;
use std::{fmt, io::Cursor};
use chunked_transfer::Decoder as ChunkDecoder;
@@ -21,14 +21,15 @@ use crate::{error::Error, proxy::Proto};
use crate::error::ErrorKind;
use crate::unit::Unit;
#[allow(clippy::large_enum_variant)]
pub enum Stream {
Http(BufReader<TcpStream>),
pub(crate) struct Stream {
inner: BufReader<Inner>,
}
enum Inner {
Http(TcpStream),
#[cfg(feature = "tls")]
Https(BufReader<rustls::StreamOwned<rustls::ClientSession, TcpStream>>),
Cursor(Cursor<Vec<u8>>),
#[cfg(test)]
Test(Box<dyn BufRead + Send + Sync>, Vec<u8>),
Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
Test(Box<dyn Read + Send + Sync>, Vec<u8>),
}
// DeadlineStream wraps a stream such that read() will return an error
@@ -36,7 +37,7 @@ pub enum Stream {
// TcpStream to ensure read() doesn't block beyond the deadline.
// When the From trait is used to turn a DeadlineStream back into a
// Stream (by PoolReturningRead), the timeouts are removed.
pub struct DeadlineStream {
pub(crate) struct DeadlineStream {
stream: Stream,
deadline: Option<Instant>,
}
@@ -91,22 +92,23 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Stream[{}]",
match self {
Stream::Http(_) => "http",
#[cfg(feature = "tls")]
Stream::Https(_) => "https",
Stream::Cursor(_) => "cursor",
#[cfg(test)]
Stream::Test(_, _) => "test",
}
)
let mut result = f.debug_struct("Stream");
match self.inner.get_ref() {
Inner::Http(tcpstream) => result.field("tcp", tcpstream),
#[cfg(feature = "tls")]
Inner::Https(tlsstream) => result.field("tls", tlsstream.get_ref()),
Inner::Test(_, _) => result.field("test", &String::new()),
};
result.finish()
}
}
impl Stream {
pub(crate) fn from_vec(v: Vec<u8>) -> Stream {
Stream {
inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])),
}
}
// Check if the server has closed a stream by performing a one-byte
// non-blocking read. If this returns EOF, the server has closed the
// connection: return true. If this returns WouldBlock (aka EAGAIN),
@@ -134,10 +136,10 @@ impl Stream {
}
}
pub fn is_poolable(&self) -> bool {
match self {
Stream::Http(_) => true,
match self.inner.get_ref() {
Inner::Http(_) => true,
#[cfg(feature = "tls")]
Stream::Https(_) => true,
Inner::Https(_) => true,
_ => false,
}
}
@@ -154,10 +156,10 @@ impl Stream {
}
pub(crate) fn socket(&self) -> Option<&TcpStream> {
match self {
Stream::Http(b) => Some(b.get_ref()),
match self.inner.get_ref() {
Inner::Http(b) => Some(b),
#[cfg(feature = "tls")]
Stream::Https(b) => Some(&b.get_ref().sock),
Inner::Https(b) => Some(&b.get_ref()),
_ => None,
}
}
@@ -171,48 +173,48 @@ impl Stream {
}
#[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> {
match self {
Stream::Test(_, writer) => writer.clone(),
pub fn to_write_vec(self) -> Vec<u8> {
match self.inner.into_inner() {
Inner::Test(_, writer) => writer.clone(),
_ => panic!("to_write_vec on non Test stream"),
}
}
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl Read for Inner {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Stream::Http(sock) => sock.read(buf),
Inner::Http(sock) => sock.read(buf),
#[cfg(feature = "tls")]
Stream::Https(stream) => read_https(stream, buf),
Stream::Cursor(read) => read.read(buf),
#[cfg(test)]
Stream::Test(reader, _) => reader.read(buf),
Inner::Https(stream) => read_https(stream, buf),
Inner::Test(reader, _) => reader.read(buf),
}
}
}
impl BufRead for DeadlineStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.stream.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
impl BufRead for Stream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
match self {
Stream::Http(r) => r.fill_buf(),
#[cfg(feature = "tls")]
Stream::Https(r) => r.fill_buf(),
Stream::Cursor(r) => r.fill_buf(),
#[cfg(test)]
Stream::Test(r, _) => r.fill_buf(),
}
self.inner.fill_buf()
}
fn consume(&mut self, amt: usize) {
match self {
Stream::Http(r) => r.consume(amt),
#[cfg(feature = "tls")]
Stream::Https(r) => r.consume(amt),
Stream::Cursor(r) => r.consume(amt),
#[cfg(test)]
Stream::Test(r, _) => r.consume(amt),
}
self.inner.consume(amt)
}
}
@@ -228,7 +230,7 @@ where
#[cfg(feature = "tls")]
fn read_https(
stream: &mut BufReader<StreamOwned<ClientSession, TcpStream>>,
stream: &mut StreamOwned<ClientSession, TcpStream>,
buf: &mut [u8],
) -> io::Result<usize> {
match stream.read(buf) {
@@ -256,23 +258,19 @@ fn is_close_notify(e: &std::io::Error) -> bool {
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Stream::Http(sock) => sock.get_mut().write(buf),
match self.inner.get_mut() {
Inner::Http(sock) => sock.write(buf),
#[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().write(buf),
Stream::Cursor(_) => panic!("Write to read only stream"),
#[cfg(test)]
Stream::Test(_, writer) => writer.write(buf),
Inner::Https(stream) => stream.write(buf),
Inner::Test(_, writer) => writer.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Stream::Http(sock) => sock.get_mut().flush(),
match self.inner.get_mut() {
Inner::Http(sock) => sock.flush(),
#[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().flush(),
Stream::Cursor(_) => panic!("Flush read only stream"),
#[cfg(test)]
Stream::Test(_, writer) => writer.flush(),
Inner::Https(stream) => stream.flush(),
Inner::Test(_, writer) => writer.flush(),
}
}
}
@@ -282,8 +280,10 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error>
let port = unit.url.port().unwrap_or(80);
connect_host(unit, hostname, port)
.map(BufReader::new)
.map(Stream::Http)
.map(Inner::Http)
.map(|h| Stream {
inner: BufReader::new(h),
})
}
#[cfg(all(feature = "tls", feature = "native-certs"))]
@@ -327,7 +327,9 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error
let stream = rustls::StreamOwned::new(sess, sock);
Ok(Stream::Https(BufReader::new(stream)))
Ok(Stream {
inner: BufReader::new(Inner::Https(stream)),
})
}
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {