Move BufReader up the stack in Stream.
Stream now has an `Inner` enum, and wraps an instance of that enum in a BufReader. This allows Stream itself to implement BufRead trivially, and simplify some of the match dispatching. Having Stream implement BufRead means we can make use of `read_line` instead of our own `read_next_line` (not done in this PR yet). Also, removes the `Cursor` variant of the Inner enum in favor of using the `Test` variant everywhere, since it's strictly more powerful.
This commit is contained in:
138
src/stream.rs
138
src/stream.rs
@@ -1,10 +1,10 @@
|
||||
use log::debug;
|
||||
use std::fmt;
|
||||
use std::io::{self, BufRead, BufReader, Cursor, Read, Write};
|
||||
use std::io::{self, BufRead, BufReader, Read, Write};
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpStream;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
use std::{fmt, io::Cursor};
|
||||
|
||||
use chunked_transfer::Decoder as ChunkDecoder;
|
||||
|
||||
@@ -21,14 +21,15 @@ use crate::{error::Error, proxy::Proto};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::unit::Unit;
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
pub enum Stream {
|
||||
Http(BufReader<TcpStream>),
|
||||
pub(crate) struct Stream {
|
||||
inner: BufReader<Inner>,
|
||||
}
|
||||
|
||||
enum Inner {
|
||||
Http(TcpStream),
|
||||
#[cfg(feature = "tls")]
|
||||
Https(BufReader<rustls::StreamOwned<rustls::ClientSession, TcpStream>>),
|
||||
Cursor(Cursor<Vec<u8>>),
|
||||
#[cfg(test)]
|
||||
Test(Box<dyn BufRead + Send + Sync>, Vec<u8>),
|
||||
Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
|
||||
Test(Box<dyn Read + Send + Sync>, Vec<u8>),
|
||||
}
|
||||
|
||||
// DeadlineStream wraps a stream such that read() will return an error
|
||||
@@ -36,7 +37,7 @@ pub enum Stream {
|
||||
// TcpStream to ensure read() doesn't block beyond the deadline.
|
||||
// When the From trait is used to turn a DeadlineStream back into a
|
||||
// Stream (by PoolReturningRead), the timeouts are removed.
|
||||
pub struct DeadlineStream {
|
||||
pub(crate) struct DeadlineStream {
|
||||
stream: Stream,
|
||||
deadline: Option<Instant>,
|
||||
}
|
||||
@@ -91,22 +92,23 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
|
||||
|
||||
impl fmt::Debug for Stream {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Stream[{}]",
|
||||
match self {
|
||||
Stream::Http(_) => "http",
|
||||
#[cfg(feature = "tls")]
|
||||
Stream::Https(_) => "https",
|
||||
Stream::Cursor(_) => "cursor",
|
||||
#[cfg(test)]
|
||||
Stream::Test(_, _) => "test",
|
||||
}
|
||||
)
|
||||
let mut result = f.debug_struct("Stream");
|
||||
match self.inner.get_ref() {
|
||||
Inner::Http(tcpstream) => result.field("tcp", tcpstream),
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Https(tlsstream) => result.field("tls", tlsstream.get_ref()),
|
||||
Inner::Test(_, _) => result.field("test", &String::new()),
|
||||
};
|
||||
result.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream {
|
||||
pub(crate) fn from_vec(v: Vec<u8>) -> Stream {
|
||||
Stream {
|
||||
inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])),
|
||||
}
|
||||
}
|
||||
// Check if the server has closed a stream by performing a one-byte
|
||||
// non-blocking read. If this returns EOF, the server has closed the
|
||||
// connection: return true. If this returns WouldBlock (aka EAGAIN),
|
||||
@@ -134,10 +136,10 @@ impl Stream {
|
||||
}
|
||||
}
|
||||
pub fn is_poolable(&self) -> bool {
|
||||
match self {
|
||||
Stream::Http(_) => true,
|
||||
match self.inner.get_ref() {
|
||||
Inner::Http(_) => true,
|
||||
#[cfg(feature = "tls")]
|
||||
Stream::Https(_) => true,
|
||||
Inner::Https(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
@@ -154,10 +156,10 @@ impl Stream {
|
||||
}
|
||||
|
||||
pub(crate) fn socket(&self) -> Option<&TcpStream> {
|
||||
match self {
|
||||
Stream::Http(b) => Some(b.get_ref()),
|
||||
match self.inner.get_ref() {
|
||||
Inner::Http(b) => Some(b),
|
||||
#[cfg(feature = "tls")]
|
||||
Stream::Https(b) => Some(&b.get_ref().sock),
|
||||
Inner::Https(b) => Some(&b.get_ref()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -171,48 +173,48 @@ impl Stream {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn to_write_vec(&self) -> Vec<u8> {
|
||||
match self {
|
||||
Stream::Test(_, writer) => writer.clone(),
|
||||
pub fn to_write_vec(self) -> Vec<u8> {
|
||||
match self.inner.into_inner() {
|
||||
Inner::Test(_, writer) => writer.clone(),
|
||||
_ => panic!("to_write_vec on non Test stream"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for Stream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.inner.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for Inner {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Stream::Http(sock) => sock.read(buf),
|
||||
Inner::Http(sock) => sock.read(buf),
|
||||
#[cfg(feature = "tls")]
|
||||
Stream::Https(stream) => read_https(stream, buf),
|
||||
Stream::Cursor(read) => read.read(buf),
|
||||
#[cfg(test)]
|
||||
Stream::Test(reader, _) => reader.read(buf),
|
||||
Inner::Https(stream) => read_https(stream, buf),
|
||||
Inner::Test(reader, _) => reader.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BufRead for DeadlineStream {
|
||||
fn fill_buf(&mut self) -> io::Result<&[u8]> {
|
||||
self.stream.fill_buf()
|
||||
}
|
||||
|
||||
fn consume(&mut self, amt: usize) {
|
||||
self.stream.consume(amt)
|
||||
}
|
||||
}
|
||||
|
||||
impl BufRead for Stream {
|
||||
fn fill_buf(&mut self) -> io::Result<&[u8]> {
|
||||
match self {
|
||||
Stream::Http(r) => r.fill_buf(),
|
||||
#[cfg(feature = "tls")]
|
||||
Stream::Https(r) => r.fill_buf(),
|
||||
Stream::Cursor(r) => r.fill_buf(),
|
||||
#[cfg(test)]
|
||||
Stream::Test(r, _) => r.fill_buf(),
|
||||
}
|
||||
self.inner.fill_buf()
|
||||
}
|
||||
|
||||
fn consume(&mut self, amt: usize) {
|
||||
match self {
|
||||
Stream::Http(r) => r.consume(amt),
|
||||
#[cfg(feature = "tls")]
|
||||
Stream::Https(r) => r.consume(amt),
|
||||
Stream::Cursor(r) => r.consume(amt),
|
||||
#[cfg(test)]
|
||||
Stream::Test(r, _) => r.consume(amt),
|
||||
}
|
||||
self.inner.consume(amt)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,7 +230,7 @@ where
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
fn read_https(
|
||||
stream: &mut BufReader<StreamOwned<ClientSession, TcpStream>>,
|
||||
stream: &mut StreamOwned<ClientSession, TcpStream>,
|
||||
buf: &mut [u8],
|
||||
) -> io::Result<usize> {
|
||||
match stream.read(buf) {
|
||||
@@ -256,23 +258,19 @@ fn is_close_notify(e: &std::io::Error) -> bool {
|
||||
|
||||
impl Write for Stream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Stream::Http(sock) => sock.get_mut().write(buf),
|
||||
match self.inner.get_mut() {
|
||||
Inner::Http(sock) => sock.write(buf),
|
||||
#[cfg(feature = "tls")]
|
||||
Stream::Https(stream) => stream.get_mut().write(buf),
|
||||
Stream::Cursor(_) => panic!("Write to read only stream"),
|
||||
#[cfg(test)]
|
||||
Stream::Test(_, writer) => writer.write(buf),
|
||||
Inner::Https(stream) => stream.write(buf),
|
||||
Inner::Test(_, writer) => writer.write(buf),
|
||||
}
|
||||
}
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
Stream::Http(sock) => sock.get_mut().flush(),
|
||||
match self.inner.get_mut() {
|
||||
Inner::Http(sock) => sock.flush(),
|
||||
#[cfg(feature = "tls")]
|
||||
Stream::Https(stream) => stream.get_mut().flush(),
|
||||
Stream::Cursor(_) => panic!("Flush read only stream"),
|
||||
#[cfg(test)]
|
||||
Stream::Test(_, writer) => writer.flush(),
|
||||
Inner::Https(stream) => stream.flush(),
|
||||
Inner::Test(_, writer) => writer.flush(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -282,8 +280,10 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error>
|
||||
let port = unit.url.port().unwrap_or(80);
|
||||
|
||||
connect_host(unit, hostname, port)
|
||||
.map(BufReader::new)
|
||||
.map(Stream::Http)
|
||||
.map(Inner::Http)
|
||||
.map(|h| Stream {
|
||||
inner: BufReader::new(h),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "tls", feature = "native-certs"))]
|
||||
@@ -327,7 +327,9 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error
|
||||
|
||||
let stream = rustls::StreamOwned::new(sess, sock);
|
||||
|
||||
Ok(Stream::Https(BufReader::new(stream)))
|
||||
Ok(Stream {
|
||||
inner: BufReader::new(Inner::Https(stream)),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
|
||||
|
||||
Reference in New Issue
Block a user