Use BufRead in more places (#261)

This moves Stream's enum into an `Inner` enum, and wraps a single
BufReader around the whole thing. This makes it easier to consistently
treat the contents of Stream as being wrapped in a BufReader.

Also, implement BufRead for DeadlineStream. This means when a timeout
is set, we don't have to set that timeout on the socket with every
small read, just when we fill up the buffer. This reduces the number
of syscalls.

Remove the `Cursor` variant from Stream. It was strictly less powerful
than that `Test` variant, so I've replaced the handful of Cursor uses
with `Test`. Because some of those cases weren't test, I removed the
`#[cfg(test)]` param on the `Test` variant.

Now that all inputs to `do_from_read` are `impl BufRead`, add that
as a type constraint. Change `read_next_line` to take advantage of
`BufRead::read_line`, which may be somewhat faster (though I haven't
benchmarked).
This commit is contained in:
Jacob Hoffman-Andrews
2020-11-30 09:23:19 -08:00
committed by GitHub
4 changed files with 124 additions and 124 deletions

View File

@@ -229,7 +229,7 @@ fn pool_connections_limit() {
proxy: None,
});
for key in poolkeys.clone() {
pool.add(key, Stream::Cursor(std::io::Cursor::new(vec![])));
pool.add(key, Stream::from_vec(vec![]))
}
assert_eq!(pool.len(), pool.max_idle_connections);
@@ -254,10 +254,7 @@ fn pool_per_host_connections_limit() {
};
for _ in 0..pool.max_idle_connections_per_host * 2 {
pool.add(
poolkey.clone(),
Stream::Cursor(std::io::Cursor::new(vec![])),
);
pool.add(poolkey.clone(), Stream::from_vec(vec![]))
}
assert_eq!(pool.len(), pool.max_idle_connections_per_host);
@@ -275,15 +272,12 @@ fn pool_checks_proxy() {
let pool = ConnectionPool::new_with_limits(10, 1);
let url = Url::parse("zzz:///example.com").unwrap();
pool.add(
PoolKey::new(&url, None),
Stream::Cursor(std::io::Cursor::new(vec![])),
);
pool.add(PoolKey::new(&url, None), Stream::from_vec(vec![]));
assert_eq!(pool.len(), 1);
pool.add(
PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())),
Stream::Cursor(std::io::Cursor::new(vec![])),
Stream::from_vec(vec![]),
);
assert_eq!(pool.len(), 2);
@@ -292,7 +286,7 @@ fn pool_checks_proxy() {
&url,
Some(Proxy::new("user:password@localhost:9999").unwrap()),
),
Stream::Cursor(std::io::Cursor::new(vec![])),
Stream::from_vec(vec![]),
);
assert_eq!(pool.len(), 3);
}

View File

@@ -1,6 +1,6 @@
use std::fmt;
use std::io::{self, Cursor, Read};
use std::io::{self, Read};
use std::str::FromStr;
use std::{fmt, io::BufRead};
use chunked_transfer::Decoder as ChunkDecoder;
@@ -424,7 +424,7 @@ impl Response {
/// let resp = ureq::Response::do_from_read(read);
///
/// assert_eq!(resp.status(), 401);
pub(crate) fn do_from_read(mut reader: impl Read) -> Result<Response, Error> {
pub(crate) fn do_from_read(mut reader: impl BufRead) -> Result<Response, Error> {
//
// HTTP/1.1 200 OK\r\n
let status_line = read_next_line(&mut reader)?;
@@ -454,8 +454,8 @@ impl Response {
}
#[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> {
self.stream.as_ref().unwrap().to_write_vec()
pub fn to_write_vec(self) -> Vec<u8> {
self.stream.unwrap().to_write_vec()
}
}
@@ -507,10 +507,9 @@ impl FromStr for Response {
/// assert_eq!(body, "Hello World!!!");
/// ```
fn from_str(s: &str) -> Result<Self, Self::Err> {
let bytes = s.as_bytes().to_owned();
let mut cursor = Cursor::new(bytes);
let mut resp = Self::do_from_read(&mut cursor)?;
set_stream(&mut resp, "".into(), None, Stream::Cursor(cursor));
let mut stream = Stream::from_vec(s.as_bytes().to_owned());
let mut resp = Self::do_from_read(&mut stream)?;
set_stream(&mut resp, "".into(), None, stream);
Ok(resp)
}
}
@@ -524,34 +523,24 @@ pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option<Unit>, s
resp.stream = Some(stream);
}
fn read_next_line<R: Read>(reader: &mut R) -> io::Result<String> {
let mut buf = Vec::new();
let mut prev_byte_was_cr = false;
let mut one = [0_u8];
loop {
let amt = reader.read(&mut one[..])?;
if amt == 0 {
return Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"Unexpected EOF",
));
}
let byte = one[0];
if byte == b'\n' && prev_byte_was_cr {
buf.pop(); // removing the '\r'
return String::from_utf8(buf).map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, "Header is not in ASCII")
});
}
prev_byte_was_cr = byte == b'\r';
buf.push(byte);
fn read_next_line(reader: &mut impl BufRead) -> io::Result<String> {
let mut s = String::new();
if reader.read_line(&mut s)? == 0 {
return Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"Unexpected EOF",
));
}
if !s.ends_with("\r\n") {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Header field didn't end with \\r: {}", s),
));
}
s.pop();
s.pop();
Ok(s)
}
/// Limits a `Read` to a content size (as set by a "Content-Length" header).

View File

@@ -1,10 +1,10 @@
use log::debug;
use std::fmt;
use std::io::{self, BufRead, BufReader, Cursor, Read, Write};
use std::io::{self, BufRead, BufReader, Read, Write};
use std::net::SocketAddr;
use std::net::TcpStream;
use std::time::Duration;
use std::time::Instant;
use std::{fmt, io::Cursor};
use chunked_transfer::Decoder as ChunkDecoder;
@@ -21,14 +21,15 @@ use crate::{error::Error, proxy::Proto};
use crate::error::ErrorKind;
use crate::unit::Unit;
#[allow(clippy::large_enum_variant)]
pub enum Stream {
Http(BufReader<TcpStream>),
pub(crate) struct Stream {
inner: BufReader<Inner>,
}
enum Inner {
Http(TcpStream),
#[cfg(feature = "tls")]
Https(BufReader<rustls::StreamOwned<rustls::ClientSession, TcpStream>>),
Cursor(Cursor<Vec<u8>>),
#[cfg(test)]
Test(Box<dyn BufRead + Send + Sync>, Vec<u8>),
Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
Test(Box<dyn Read + Send + Sync>, Vec<u8>),
}
// DeadlineStream wraps a stream such that read() will return an error
@@ -36,7 +37,7 @@ pub enum Stream {
// TcpStream to ensure read() doesn't block beyond the deadline.
// When the From trait is used to turn a DeadlineStream back into a
// Stream (by PoolReturningRead), the timeouts are removed.
pub struct DeadlineStream {
pub(crate) struct DeadlineStream {
stream: Stream,
deadline: Option<Instant>,
}
@@ -53,8 +54,8 @@ impl From<DeadlineStream> for Stream {
}
}
impl Read for DeadlineStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
impl BufRead for DeadlineStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if let Some(deadline) = self.deadline {
let timeout = time_until_deadline(deadline)?;
if let Some(socket) = self.stream.socket() {
@@ -62,7 +63,7 @@ impl Read for DeadlineStream {
socket.set_write_timeout(Some(timeout))?;
}
}
self.stream.read(buf).map_err(|e| {
self.stream.fill_buf().map_err(|e| {
// On unix-y platforms set_read_timeout and set_write_timeout
// causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut.
// Since the socket most definitely not set_nonblocking(true),
@@ -73,6 +74,25 @@ impl Read for DeadlineStream {
e
})
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
impl Read for DeadlineStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
// All reads on a DeadlineStream use the BufRead impl. This ensures
// that we have a chance to set the correct timeout before each recv
// syscall.
// Copied from the BufReader implementation of `read()`.
let nread = {
let mut rem = self.fill_buf()?;
rem.read(buf)?
};
self.consume(nread);
Ok(nread)
}
}
// If the deadline is in the future, return the remaining time until
@@ -91,22 +111,37 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Stream[{}]",
match self {
Stream::Http(_) => "http",
#[cfg(feature = "tls")]
Stream::Https(_) => "https",
Stream::Cursor(_) => "cursor",
#[cfg(test)]
Stream::Test(_, _) => "test",
}
)
let mut result = f.debug_struct("Stream");
match self.inner.get_ref() {
Inner::Http(tcpstream) => result.field("tcp", tcpstream),
#[cfg(feature = "tls")]
Inner::Https(tlsstream) => result.field("tls", tlsstream.get_ref()),
Inner::Test(_, _) => result.field("test", &String::new()),
};
result.finish()
}
}
impl Stream {
pub(crate) fn from_vec(v: Vec<u8>) -> Stream {
Stream {
inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])),
}
}
fn from_tcp_stream(t: TcpStream) -> Stream {
Stream {
inner: BufReader::new(Inner::Http(t)),
}
}
#[cfg(feature = "tls")]
fn from_tls_stream(t: StreamOwned<ClientSession, TcpStream>) -> Stream {
Stream {
inner: BufReader::new(Inner::Https(t)),
}
}
// Check if the server has closed a stream by performing a one-byte
// non-blocking read. If this returns EOF, the server has closed the
// connection: return true. If this returns WouldBlock (aka EAGAIN),
@@ -134,10 +169,10 @@ impl Stream {
}
}
pub fn is_poolable(&self) -> bool {
match self {
Stream::Http(_) => true,
match self.inner.get_ref() {
Inner::Http(_) => true,
#[cfg(feature = "tls")]
Stream::Https(_) => true,
Inner::Https(_) => true,
_ => false,
}
}
@@ -154,10 +189,10 @@ impl Stream {
}
pub(crate) fn socket(&self) -> Option<&TcpStream> {
match self {
Stream::Http(b) => Some(b.get_ref()),
match self.inner.get_ref() {
Inner::Http(b) => Some(b),
#[cfg(feature = "tls")]
Stream::Https(b) => Some(&b.get_ref().sock),
Inner::Https(b) => Some(&b.get_ref()),
_ => None,
}
}
@@ -171,48 +206,38 @@ impl Stream {
}
#[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> {
match self {
Stream::Test(_, writer) => writer.clone(),
pub fn to_write_vec(self) -> Vec<u8> {
match self.inner.into_inner() {
Inner::Test(_, writer) => writer.clone(),
_ => panic!("to_write_vec on non Test stream"),
}
}
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl Read for Inner {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Stream::Http(sock) => sock.read(buf),
Inner::Http(sock) => sock.read(buf),
#[cfg(feature = "tls")]
Stream::Https(stream) => read_https(stream, buf),
Stream::Cursor(read) => read.read(buf),
#[cfg(test)]
Stream::Test(reader, _) => reader.read(buf),
Inner::Https(stream) => read_https(stream, buf),
Inner::Test(reader, _) => reader.read(buf),
}
}
}
impl BufRead for Stream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
match self {
Stream::Http(r) => r.fill_buf(),
#[cfg(feature = "tls")]
Stream::Https(r) => r.fill_buf(),
Stream::Cursor(r) => r.fill_buf(),
#[cfg(test)]
Stream::Test(r, _) => r.fill_buf(),
}
self.inner.fill_buf()
}
fn consume(&mut self, amt: usize) {
match self {
Stream::Http(r) => r.consume(amt),
#[cfg(feature = "tls")]
Stream::Https(r) => r.consume(amt),
Stream::Cursor(r) => r.consume(amt),
#[cfg(test)]
Stream::Test(r, _) => r.consume(amt),
}
self.inner.consume(amt)
}
}
@@ -228,7 +253,7 @@ where
#[cfg(feature = "tls")]
fn read_https(
stream: &mut BufReader<StreamOwned<ClientSession, TcpStream>>,
stream: &mut StreamOwned<ClientSession, TcpStream>,
buf: &mut [u8],
) -> io::Result<usize> {
match stream.read(buf) {
@@ -256,23 +281,19 @@ fn is_close_notify(e: &std::io::Error) -> bool {
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
Stream::Http(sock) => sock.get_mut().write(buf),
match self.inner.get_mut() {
Inner::Http(sock) => sock.write(buf),
#[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().write(buf),
Stream::Cursor(_) => panic!("Write to read only stream"),
#[cfg(test)]
Stream::Test(_, writer) => writer.write(buf),
Inner::Https(stream) => stream.write(buf),
Inner::Test(_, writer) => writer.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
Stream::Http(sock) => sock.get_mut().flush(),
match self.inner.get_mut() {
Inner::Http(sock) => sock.flush(),
#[cfg(feature = "tls")]
Stream::Https(stream) => stream.get_mut().flush(),
Stream::Cursor(_) => panic!("Flush read only stream"),
#[cfg(test)]
Stream::Test(_, writer) => writer.flush(),
Inner::Https(stream) => stream.flush(),
Inner::Test(_, writer) => writer.flush(),
}
}
}
@@ -281,9 +302,7 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error>
//
let port = unit.url.port().unwrap_or(80);
connect_host(unit, hostname, port)
.map(BufReader::new)
.map(Stream::Http)
connect_host(unit, hostname, port).map(Stream::from_tcp_stream)
}
#[cfg(all(feature = "tls", feature = "native-certs"))]
@@ -327,7 +346,7 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error
let stream = rustls::StreamOwned::new(sess, sock);
Ok(Stream::Https(BufReader::new(stream)))
Ok(Stream::from_tls_stream(stream))
}
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {

View File

@@ -3,7 +3,7 @@ use crate::stream::Stream;
use crate::unit::Unit;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::io::{Cursor, Write};
use std::io::Write;
use std::sync::{Arc, Mutex};
mod agent_test;
@@ -29,7 +29,7 @@ where
}
#[allow(clippy::write_with_newline)]
pub fn make_response(
pub(crate) fn make_response(
status: u16,
status_text: &str,
headers: Vec<&str>,
@@ -42,9 +42,7 @@ pub fn make_response(
}
write!(&mut buf, "\r\n").ok();
buf.append(&mut body);
let cursor = Cursor::new(buf);
let write: Vec<u8> = vec![];
Ok(Stream::Test(Box::new(cursor), write))
Ok(Stream::from_vec(buf))
}
pub(crate) fn resolve_handler(unit: &Unit) -> Result<Stream, Error> {