Use BufRead in more places (#261)

This moves Stream's enum into an `Inner` enum, and wraps a single
BufReader around the whole thing. This makes it easier to consistently
treat the contents of Stream as being wrapped in a BufReader.

Also, implement BufRead for DeadlineStream. This means when a timeout
is set, we don't have to set that timeout on the socket with every
small read, just when we fill up the buffer. This reduces the number
of syscalls.

Remove the `Cursor` variant from Stream. It was strictly less powerful
than that `Test` variant, so I've replaced the handful of Cursor uses
with `Test`. Because some of those cases weren't test, I removed the
`#[cfg(test)]` param on the `Test` variant.

Now that all inputs to `do_from_read` are `impl BufRead`, add that
as a type constraint. Change `read_next_line` to take advantage of
`BufRead::read_line`, which may be somewhat faster (though I haven't
benchmarked).
This commit is contained in:
Jacob Hoffman-Andrews
2020-11-30 09:23:19 -08:00
committed by GitHub
4 changed files with 124 additions and 124 deletions

View File

@@ -229,7 +229,7 @@ fn pool_connections_limit() {
proxy: None, proxy: None,
}); });
for key in poolkeys.clone() { for key in poolkeys.clone() {
pool.add(key, Stream::Cursor(std::io::Cursor::new(vec![]))); pool.add(key, Stream::from_vec(vec![]))
} }
assert_eq!(pool.len(), pool.max_idle_connections); assert_eq!(pool.len(), pool.max_idle_connections);
@@ -254,10 +254,7 @@ fn pool_per_host_connections_limit() {
}; };
for _ in 0..pool.max_idle_connections_per_host * 2 { for _ in 0..pool.max_idle_connections_per_host * 2 {
pool.add( pool.add(poolkey.clone(), Stream::from_vec(vec![]))
poolkey.clone(),
Stream::Cursor(std::io::Cursor::new(vec![])),
);
} }
assert_eq!(pool.len(), pool.max_idle_connections_per_host); assert_eq!(pool.len(), pool.max_idle_connections_per_host);
@@ -275,15 +272,12 @@ fn pool_checks_proxy() {
let pool = ConnectionPool::new_with_limits(10, 1); let pool = ConnectionPool::new_with_limits(10, 1);
let url = Url::parse("zzz:///example.com").unwrap(); let url = Url::parse("zzz:///example.com").unwrap();
pool.add( pool.add(PoolKey::new(&url, None), Stream::from_vec(vec![]));
PoolKey::new(&url, None),
Stream::Cursor(std::io::Cursor::new(vec![])),
);
assert_eq!(pool.len(), 1); assert_eq!(pool.len(), 1);
pool.add( pool.add(
PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())), PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())),
Stream::Cursor(std::io::Cursor::new(vec![])), Stream::from_vec(vec![]),
); );
assert_eq!(pool.len(), 2); assert_eq!(pool.len(), 2);
@@ -292,7 +286,7 @@ fn pool_checks_proxy() {
&url, &url,
Some(Proxy::new("user:password@localhost:9999").unwrap()), Some(Proxy::new("user:password@localhost:9999").unwrap()),
), ),
Stream::Cursor(std::io::Cursor::new(vec![])), Stream::from_vec(vec![]),
); );
assert_eq!(pool.len(), 3); assert_eq!(pool.len(), 3);
} }

View File

@@ -1,6 +1,6 @@
use std::fmt; use std::io::{self, Read};
use std::io::{self, Cursor, Read};
use std::str::FromStr; use std::str::FromStr;
use std::{fmt, io::BufRead};
use chunked_transfer::Decoder as ChunkDecoder; use chunked_transfer::Decoder as ChunkDecoder;
@@ -424,7 +424,7 @@ impl Response {
/// let resp = ureq::Response::do_from_read(read); /// let resp = ureq::Response::do_from_read(read);
/// ///
/// assert_eq!(resp.status(), 401); /// assert_eq!(resp.status(), 401);
pub(crate) fn do_from_read(mut reader: impl Read) -> Result<Response, Error> { pub(crate) fn do_from_read(mut reader: impl BufRead) -> Result<Response, Error> {
// //
// HTTP/1.1 200 OK\r\n // HTTP/1.1 200 OK\r\n
let status_line = read_next_line(&mut reader)?; let status_line = read_next_line(&mut reader)?;
@@ -454,8 +454,8 @@ impl Response {
} }
#[cfg(test)] #[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> { pub fn to_write_vec(self) -> Vec<u8> {
self.stream.as_ref().unwrap().to_write_vec() self.stream.unwrap().to_write_vec()
} }
} }
@@ -507,10 +507,9 @@ impl FromStr for Response {
/// assert_eq!(body, "Hello World!!!"); /// assert_eq!(body, "Hello World!!!");
/// ``` /// ```
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
let bytes = s.as_bytes().to_owned(); let mut stream = Stream::from_vec(s.as_bytes().to_owned());
let mut cursor = Cursor::new(bytes); let mut resp = Self::do_from_read(&mut stream)?;
let mut resp = Self::do_from_read(&mut cursor)?; set_stream(&mut resp, "".into(), None, stream);
set_stream(&mut resp, "".into(), None, Stream::Cursor(cursor));
Ok(resp) Ok(resp)
} }
} }
@@ -524,34 +523,24 @@ pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option<Unit>, s
resp.stream = Some(stream); resp.stream = Some(stream);
} }
fn read_next_line<R: Read>(reader: &mut R) -> io::Result<String> { fn read_next_line(reader: &mut impl BufRead) -> io::Result<String> {
let mut buf = Vec::new(); let mut s = String::new();
let mut prev_byte_was_cr = false; if reader.read_line(&mut s)? == 0 {
let mut one = [0_u8];
loop {
let amt = reader.read(&mut one[..])?;
if amt == 0 {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::ConnectionAborted, io::ErrorKind::ConnectionAborted,
"Unexpected EOF", "Unexpected EOF",
)); ));
} }
let byte = one[0]; if !s.ends_with("\r\n") {
return Err(io::Error::new(
if byte == b'\n' && prev_byte_was_cr { io::ErrorKind::InvalidInput,
buf.pop(); // removing the '\r' format!("Header field didn't end with \\r: {}", s),
return String::from_utf8(buf).map_err(|_| { ));
io::Error::new(io::ErrorKind::InvalidInput, "Header is not in ASCII")
});
}
prev_byte_was_cr = byte == b'\r';
buf.push(byte);
} }
s.pop();
s.pop();
Ok(s)
} }
/// Limits a `Read` to a content size (as set by a "Content-Length" header). /// Limits a `Read` to a content size (as set by a "Content-Length" header).

View File

@@ -1,10 +1,10 @@
use log::debug; use log::debug;
use std::fmt; use std::io::{self, BufRead, BufReader, Read, Write};
use std::io::{self, BufRead, BufReader, Cursor, Read, Write};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::net::TcpStream; use std::net::TcpStream;
use std::time::Duration; use std::time::Duration;
use std::time::Instant; use std::time::Instant;
use std::{fmt, io::Cursor};
use chunked_transfer::Decoder as ChunkDecoder; use chunked_transfer::Decoder as ChunkDecoder;
@@ -21,14 +21,15 @@ use crate::{error::Error, proxy::Proto};
use crate::error::ErrorKind; use crate::error::ErrorKind;
use crate::unit::Unit; use crate::unit::Unit;
#[allow(clippy::large_enum_variant)] pub(crate) struct Stream {
pub enum Stream { inner: BufReader<Inner>,
Http(BufReader<TcpStream>), }
enum Inner {
Http(TcpStream),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Https(BufReader<rustls::StreamOwned<rustls::ClientSession, TcpStream>>), Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
Cursor(Cursor<Vec<u8>>), Test(Box<dyn Read + Send + Sync>, Vec<u8>),
#[cfg(test)]
Test(Box<dyn BufRead + Send + Sync>, Vec<u8>),
} }
// DeadlineStream wraps a stream such that read() will return an error // DeadlineStream wraps a stream such that read() will return an error
@@ -36,7 +37,7 @@ pub enum Stream {
// TcpStream to ensure read() doesn't block beyond the deadline. // TcpStream to ensure read() doesn't block beyond the deadline.
// When the From trait is used to turn a DeadlineStream back into a // When the From trait is used to turn a DeadlineStream back into a
// Stream (by PoolReturningRead), the timeouts are removed. // Stream (by PoolReturningRead), the timeouts are removed.
pub struct DeadlineStream { pub(crate) struct DeadlineStream {
stream: Stream, stream: Stream,
deadline: Option<Instant>, deadline: Option<Instant>,
} }
@@ -53,8 +54,8 @@ impl From<DeadlineStream> for Stream {
} }
} }
impl Read for DeadlineStream { impl BufRead for DeadlineStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn fill_buf(&mut self) -> io::Result<&[u8]> {
if let Some(deadline) = self.deadline { if let Some(deadline) = self.deadline {
let timeout = time_until_deadline(deadline)?; let timeout = time_until_deadline(deadline)?;
if let Some(socket) = self.stream.socket() { if let Some(socket) = self.stream.socket() {
@@ -62,7 +63,7 @@ impl Read for DeadlineStream {
socket.set_write_timeout(Some(timeout))?; socket.set_write_timeout(Some(timeout))?;
} }
} }
self.stream.read(buf).map_err(|e| { self.stream.fill_buf().map_err(|e| {
// On unix-y platforms set_read_timeout and set_write_timeout // On unix-y platforms set_read_timeout and set_write_timeout
// causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut. // causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut.
// Since the socket most definitely not set_nonblocking(true), // Since the socket most definitely not set_nonblocking(true),
@@ -73,6 +74,25 @@ impl Read for DeadlineStream {
e e
}) })
} }
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
impl Read for DeadlineStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
// All reads on a DeadlineStream use the BufRead impl. This ensures
// that we have a chance to set the correct timeout before each recv
// syscall.
// Copied from the BufReader implementation of `read()`.
let nread = {
let mut rem = self.fill_buf()?;
rem.read(buf)?
};
self.consume(nread);
Ok(nread)
}
} }
// If the deadline is in the future, return the remaining time until // If the deadline is in the future, return the remaining time until
@@ -91,22 +111,37 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
impl fmt::Debug for Stream { impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!( let mut result = f.debug_struct("Stream");
f, match self.inner.get_ref() {
"Stream[{}]", Inner::Http(tcpstream) => result.field("tcp", tcpstream),
match self {
Stream::Http(_) => "http",
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(_) => "https", Inner::Https(tlsstream) => result.field("tls", tlsstream.get_ref()),
Stream::Cursor(_) => "cursor", Inner::Test(_, _) => result.field("test", &String::new()),
#[cfg(test)] };
Stream::Test(_, _) => "test", result.finish()
}
)
} }
} }
impl Stream { impl Stream {
pub(crate) fn from_vec(v: Vec<u8>) -> Stream {
Stream {
inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])),
}
}
fn from_tcp_stream(t: TcpStream) -> Stream {
Stream {
inner: BufReader::new(Inner::Http(t)),
}
}
#[cfg(feature = "tls")]
fn from_tls_stream(t: StreamOwned<ClientSession, TcpStream>) -> Stream {
Stream {
inner: BufReader::new(Inner::Https(t)),
}
}
// Check if the server has closed a stream by performing a one-byte // Check if the server has closed a stream by performing a one-byte
// non-blocking read. If this returns EOF, the server has closed the // non-blocking read. If this returns EOF, the server has closed the
// connection: return true. If this returns WouldBlock (aka EAGAIN), // connection: return true. If this returns WouldBlock (aka EAGAIN),
@@ -134,10 +169,10 @@ impl Stream {
} }
} }
pub fn is_poolable(&self) -> bool { pub fn is_poolable(&self) -> bool {
match self { match self.inner.get_ref() {
Stream::Http(_) => true, Inner::Http(_) => true,
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(_) => true, Inner::Https(_) => true,
_ => false, _ => false,
} }
} }
@@ -154,10 +189,10 @@ impl Stream {
} }
pub(crate) fn socket(&self) -> Option<&TcpStream> { pub(crate) fn socket(&self) -> Option<&TcpStream> {
match self { match self.inner.get_ref() {
Stream::Http(b) => Some(b.get_ref()), Inner::Http(b) => Some(b),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(b) => Some(&b.get_ref().sock), Inner::Https(b) => Some(&b.get_ref()),
_ => None, _ => None,
} }
} }
@@ -171,48 +206,38 @@ impl Stream {
} }
#[cfg(test)] #[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> { pub fn to_write_vec(self) -> Vec<u8> {
match self { match self.inner.into_inner() {
Stream::Test(_, writer) => writer.clone(), Inner::Test(_, writer) => writer.clone(),
_ => panic!("to_write_vec on non Test stream"), _ => panic!("to_write_vec on non Test stream"),
} }
} }
} }
impl Read for Stream { impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl Read for Inner {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self { match self {
Stream::Http(sock) => sock.read(buf), Inner::Http(sock) => sock.read(buf),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(stream) => read_https(stream, buf), Inner::Https(stream) => read_https(stream, buf),
Stream::Cursor(read) => read.read(buf), Inner::Test(reader, _) => reader.read(buf),
#[cfg(test)]
Stream::Test(reader, _) => reader.read(buf),
} }
} }
} }
impl BufRead for Stream { impl BufRead for Stream {
fn fill_buf(&mut self) -> io::Result<&[u8]> { fn fill_buf(&mut self) -> io::Result<&[u8]> {
match self { self.inner.fill_buf()
Stream::Http(r) => r.fill_buf(),
#[cfg(feature = "tls")]
Stream::Https(r) => r.fill_buf(),
Stream::Cursor(r) => r.fill_buf(),
#[cfg(test)]
Stream::Test(r, _) => r.fill_buf(),
}
} }
fn consume(&mut self, amt: usize) { fn consume(&mut self, amt: usize) {
match self { self.inner.consume(amt)
Stream::Http(r) => r.consume(amt),
#[cfg(feature = "tls")]
Stream::Https(r) => r.consume(amt),
Stream::Cursor(r) => r.consume(amt),
#[cfg(test)]
Stream::Test(r, _) => r.consume(amt),
}
} }
} }
@@ -228,7 +253,7 @@ where
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
fn read_https( fn read_https(
stream: &mut BufReader<StreamOwned<ClientSession, TcpStream>>, stream: &mut StreamOwned<ClientSession, TcpStream>,
buf: &mut [u8], buf: &mut [u8],
) -> io::Result<usize> { ) -> io::Result<usize> {
match stream.read(buf) { match stream.read(buf) {
@@ -256,23 +281,19 @@ fn is_close_notify(e: &std::io::Error) -> bool {
impl Write for Stream { impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self { match self.inner.get_mut() {
Stream::Http(sock) => sock.get_mut().write(buf), Inner::Http(sock) => sock.write(buf),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().write(buf), Inner::Https(stream) => stream.write(buf),
Stream::Cursor(_) => panic!("Write to read only stream"), Inner::Test(_, writer) => writer.write(buf),
#[cfg(test)]
Stream::Test(_, writer) => writer.write(buf),
} }
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
match self { match self.inner.get_mut() {
Stream::Http(sock) => sock.get_mut().flush(), Inner::Http(sock) => sock.flush(),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().flush(), Inner::Https(stream) => stream.flush(),
Stream::Cursor(_) => panic!("Flush read only stream"), Inner::Test(_, writer) => writer.flush(),
#[cfg(test)]
Stream::Test(_, writer) => writer.flush(),
} }
} }
} }
@@ -281,9 +302,7 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error>
// //
let port = unit.url.port().unwrap_or(80); let port = unit.url.port().unwrap_or(80);
connect_host(unit, hostname, port) connect_host(unit, hostname, port).map(Stream::from_tcp_stream)
.map(BufReader::new)
.map(Stream::Http)
} }
#[cfg(all(feature = "tls", feature = "native-certs"))] #[cfg(all(feature = "tls", feature = "native-certs"))]
@@ -327,7 +346,7 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error
let stream = rustls::StreamOwned::new(sess, sock); let stream = rustls::StreamOwned::new(sess, sock);
Ok(Stream::Https(BufReader::new(stream))) Ok(Stream::from_tls_stream(stream))
} }
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> { pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {

View File

@@ -3,7 +3,7 @@ use crate::stream::Stream;
use crate::unit::Unit; use crate::unit::Unit;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::{Cursor, Write}; use std::io::Write;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
mod agent_test; mod agent_test;
@@ -29,7 +29,7 @@ where
} }
#[allow(clippy::write_with_newline)] #[allow(clippy::write_with_newline)]
pub fn make_response( pub(crate) fn make_response(
status: u16, status: u16,
status_text: &str, status_text: &str,
headers: Vec<&str>, headers: Vec<&str>,
@@ -42,9 +42,7 @@ pub fn make_response(
} }
write!(&mut buf, "\r\n").ok(); write!(&mut buf, "\r\n").ok();
buf.append(&mut body); buf.append(&mut body);
let cursor = Cursor::new(buf); Ok(Stream::from_vec(buf))
let write: Vec<u8> = vec![];
Ok(Stream::Test(Box::new(cursor), write))
} }
pub(crate) fn resolve_handler(unit: &Unit) -> Result<Stream, Error> { pub(crate) fn resolve_handler(unit: &Unit) -> Result<Stream, Error> {