Simplify ReadWrite interface (#530)

Previously, ReadWrite had methods `is_poolable` and `written_bytes`, which
were solely for the use of unittests.

This replaces `written_bytes` and `TestStream` with a `struct Recorder`
that implements `ReadWrite` and allows unittests to access its recorded
bytes via an `Arc<Mutex<Vec<u8>>>`. It eliminates `is_poolable`; it's fine
to pool a Stream of any kind.

The new `Recorder` also has some convenience methods that abstract away
boilerplate code from many of our unittests.

I got rid of `Stream::from_vec` and `Stream::from_vec_poolable` because
they depended on `TestStream`. They've been replaced by `NoopStream` for
the pool.rs tests, and `ReadOnlyStream` for constructing `Response`s from
`&str` and some test cases.
This commit is contained in:
Jacob Hoffman-Andrews
2022-07-09 10:13:44 -07:00
committed by GitHub
parent 0cf1f8dbb9
commit 9908c446d6
11 changed files with 211 additions and 226 deletions

View File

@@ -31,7 +31,4 @@ impl ReadWrite for native_tls::TlsStream<Box<dyn ReadWrite>> {
fn socket(&self) -> Option<&TcpStream> {
self.get_ref().socket()
}
fn is_poolable(&self) -> bool {
self.get_ref().is_poolable()
}
}

View File

@@ -248,10 +248,6 @@ impl<R: Read + Sized + Into<Stream>> PoolReturnRead<R> {
if let Some(reader) = self.reader.take() {
// bring back stream here to either go into pool or dealloc
let mut stream = reader.into();
if !stream.is_poolable() {
// just let it deallocate
return Ok(());
}
// ensure stream can be reused
stream.reset()?;
@@ -306,8 +302,41 @@ impl<R: Read + Sized + Done + Into<Stream>> Read for PoolReturnRead<R> {
#[cfg(test)]
mod tests {
use crate::ReadWrite;
use super::*;
#[derive(Debug)]
struct NoopStream;
impl NoopStream {
fn stream() -> Stream {
Stream::new(NoopStream)
}
}
impl Read for NoopStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
Ok(buf.len())
}
}
impl std::io::Write for NoopStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl ReadWrite for NoopStream {
fn socket(&self) -> Option<&std::net::TcpStream> {
None
}
}
#[test]
fn poolkey_new() {
// Test that PoolKey::new() does not panic on unrecognized schemes.
@@ -328,7 +357,7 @@ mod tests {
proxy: None,
});
for key in poolkeys.clone() {
pool.add(&key, Stream::from_vec(vec![]))
pool.add(&key, NoopStream::stream());
}
assert_eq!(pool.len(), pool.max_idle_connections);
@@ -353,7 +382,7 @@ mod tests {
};
for _ in 0..pool.max_idle_connections_per_host * 2 {
pool.add(&poolkey, Stream::from_vec(vec![]))
pool.add(&poolkey, NoopStream::stream())
}
assert_eq!(pool.len(), pool.max_idle_connections_per_host);
@@ -372,12 +401,12 @@ mod tests {
let url = Url::parse("zzz:///example.com").unwrap();
let pool_key = PoolKey::new(&url, None);
pool.add(&pool_key, Stream::from_vec(vec![]));
pool.add(&pool_key, NoopStream::stream());
assert_eq!(pool.len(), 1);
let pool_key = PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap()));
pool.add(&pool_key, Stream::from_vec(vec![]));
pool.add(&pool_key, NoopStream::stream());
assert_eq!(pool.len(), 2);
let pool_key = PoolKey::new(
@@ -385,7 +414,7 @@ mod tests {
Some(Proxy::new("user:password@localhost:9999").unwrap()),
);
pool.add(&pool_key, Stream::from_vec(vec![]));
pool.add(&pool_key, NoopStream::stream());
assert_eq!(pool.len(), 3);
}
@@ -396,10 +425,9 @@ mod tests {
let url = Url::parse("https:///example.com").unwrap();
let mut out_buf = [0u8; 500];
let long_vec = vec![0u8; 1000];
let agent = Agent::new();
let stream = Stream::from_vec_poolable(long_vec);
let stream = NoopStream::stream();
let limited_read = LimitedRead::new(stream, 500);
let mut pool_return_read = PoolReturnRead::new(&agent, &url, limited_read);

View File

@@ -9,7 +9,7 @@ use crate::body::SizedReader;
use crate::error::{Error, ErrorKind::BadStatus};
use crate::header::{get_all_headers, get_header, Header, HeaderLine};
use crate::pool::PoolReturnRead;
use crate::stream::{DeadlineStream, Stream};
use crate::stream::{DeadlineStream, ReadOnlyStream, Stream};
use crate::unit::Unit;
use crate::{stream, Agent, ErrorKind};
@@ -520,12 +520,6 @@ impl Response {
})
}
#[cfg(test)]
pub fn into_written_bytes(self) -> Vec<u8> {
// Deliberately consume `self` so that any access to `self.stream` must be non-shared.
self.stream.written_bytes()
}
#[cfg(test)]
pub fn set_url(&mut self, url: Url) {
self.url = url;
@@ -643,7 +637,7 @@ impl FromStr for Response {
/// # }
/// ```
fn from_str(s: &str) -> Result<Self, Self::Err> {
let stream = Stream::from_vec(s.as_bytes().to_owned());
let stream = Stream::new(ReadOnlyStream::new(s.into()));
let request_url = "https://example.com".parse().unwrap();
let request_reader = SizedReader {
size: crate::body::BodySize::Empty,
@@ -1004,7 +998,7 @@ mod tests {
OK",
);
let v = cow.to_vec();
let s = Stream::from_vec(v);
let s = Stream::new(ReadOnlyStream::new(v));
let request_url = "https://example.com".parse().unwrap();
let request_reader = SizedReader {
size: crate::body::BodySize::Empty,

View File

@@ -33,9 +33,6 @@ impl ReadWrite for RustlsStream {
fn socket(&self) -> Option<&TcpStream> {
self.0.get_ref().socket()
}
fn is_poolable(&self) -> bool {
self.0.get_ref().is_poolable()
}
}
// TODO: After upgrading to rustls 0.20 or higher, we can remove these Read

View File

@@ -20,22 +20,12 @@ use crate::unit::Unit;
/// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector].
pub trait ReadWrite: Read + Write + Send + Sync + fmt::Debug + 'static {
fn socket(&self) -> Option<&TcpStream>;
fn is_poolable(&self) -> bool;
/// The bytes written to the stream as a Vec<u8>. This is used for tests only.
#[cfg(test)]
fn written_bytes(&self) -> Vec<u8> {
panic!("written_bytes on non Test stream");
}
}
impl ReadWrite for TcpStream {
fn socket(&self) -> Option<&TcpStream> {
Some(self)
}
fn is_poolable(&self) -> bool {
true
}
}
pub trait TlsConnector: Send + Sync {
@@ -54,51 +44,6 @@ impl<T: ReadWrite + ?Sized> ReadWrite for Box<T> {
fn socket(&self) -> Option<&TcpStream> {
ReadWrite::socket(self.as_ref())
}
fn is_poolable(&self) -> bool {
ReadWrite::is_poolable(self.as_ref())
}
#[cfg(test)]
fn written_bytes(&self) -> Vec<u8> {
ReadWrite::written_bytes(self.as_ref())
}
}
struct TestStream(Box<dyn Read + Send + Sync>, Vec<u8>, bool);
impl ReadWrite for TestStream {
fn is_poolable(&self) -> bool {
self.2
}
fn socket(&self) -> Option<&TcpStream> {
None
}
#[cfg(test)]
fn written_bytes(&self) -> Vec<u8> {
self.1.clone()
}
}
impl Read for TestStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for TestStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.1.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl fmt::Debug for TestStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("TestStream").finish()
}
}
// DeadlineStream wraps a stream such that read() will return an error
@@ -187,6 +132,37 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
io::Error::new(io::ErrorKind::TimedOut, error)
}
#[derive(Debug)]
pub(crate) struct ReadOnlyStream(Cursor<Vec<u8>>);
impl ReadOnlyStream {
pub(crate) fn new(v: Vec<u8>) -> Self {
Self(Cursor::new(v))
}
}
impl Read for ReadOnlyStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl std::io::Write for ReadOnlyStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl ReadWrite for ReadOnlyStream {
fn socket(&self) -> Option<&std::net::TcpStream> {
None
}
}
impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner.get_ref().socket() {
@@ -197,7 +173,7 @@ impl fmt::Debug for Stream {
}
impl Stream {
fn new(t: impl ReadWrite) -> Stream {
pub(crate) fn new(t: impl ReadWrite) -> Stream {
Stream::logged_create(Stream {
inner: BufReader::new(Box::new(t)),
})
@@ -208,23 +184,6 @@ impl Stream {
stream
}
pub(crate) fn from_vec(v: Vec<u8>) -> Stream {
Stream::logged_create(Stream {
inner: BufReader::new(Box::new(TestStream(
Box::new(Cursor::new(v)),
vec![],
false,
))),
})
}
#[cfg(test)]
pub(crate) fn from_vec_poolable(v: Vec<u8>) -> Stream {
Stream::logged_create(Stream {
inner: BufReader::new(Box::new(TestStream(Box::new(Cursor::new(v)), vec![], true))),
})
}
fn from_tcp_stream(t: TcpStream) -> Stream {
Stream::logged_create(Stream {
inner: BufReader::new(Box::new(t)),
@@ -270,9 +229,6 @@ impl Stream {
None => Ok(false),
}
}
pub fn is_poolable(&self) -> bool {
self.inner.get_ref().is_poolable()
}
pub(crate) fn reset(&mut self) -> io::Result<()> {
// When we are turning this back into a regular, non-deadline Stream,
@@ -296,11 +252,6 @@ impl Stream {
Ok(())
}
}
#[cfg(test)]
pub fn written_bytes(&self) -> Vec<u8> {
self.inner.get_ref().written_bytes()
}
}
impl Read for Stream {
@@ -693,10 +644,6 @@ mod tests {
fn socket(&self) -> Option<&TcpStream> {
unimplemented!()
}
fn is_poolable(&self) -> bool {
unimplemented!()
}
}
// Test that when a DeadlineStream wraps a Stream, and the user performs a series of

View File

@@ -1,79 +1,61 @@
use crate::test;
use crate::test::Recorder;
use super::super::*;
#[test]
fn content_length_on_str() {
test::set_handler("/content_length_on_str", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = post("test://host/content_length_on_str")
let recorder = Recorder::register("/content_length_on_str");
post("test://host/content_length_on_str")
.send_string("Hello World!!!")
.unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("\r\nContent-Length: 14\r\n"));
assert!(recorder.contains("\r\nContent-Length: 14\r\n"));
}
#[test]
fn user_set_content_length_on_str() {
test::set_handler("/user_set_content_length_on_str", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = post("test://host/user_set_content_length_on_str")
let recorder = Recorder::register("/user_set_content_length_on_str");
post("test://host/user_set_content_length_on_str")
.set("Content-Length", "12345")
.send_string("Hello World!!!")
.unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("\r\nContent-Length: 12345\r\n"));
assert!(recorder.contains("\r\nContent-Length: 12345\r\n"));
}
#[test]
#[cfg(feature = "json")]
fn content_length_on_json() {
test::set_handler("/content_length_on_json", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let recorder = Recorder::register("/content_length_on_json");
let mut json = serde_json::Map::new();
json.insert(
"Hello".to_string(),
serde_json::Value::String("World!!!".to_string()),
);
let resp = post("test://host/content_length_on_json")
post("test://host/content_length_on_json")
.send_json(serde_json::Value::Object(json))
.unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("\r\nContent-Length: 20\r\n"));
assert!(recorder.contains("\r\nContent-Length: 20\r\n"));
}
#[test]
fn content_length_and_chunked() {
test::set_handler("/content_length_and_chunked", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = post("test://host/content_length_and_chunked")
let recorder = Recorder::register("/content_length_and_chunked");
post("test://host/content_length_and_chunked")
.set("Transfer-Encoding", "chunked")
.send_string("Hello World!!!")
.unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("Transfer-Encoding: chunked\r\n"));
assert!(!s.contains("\r\nContent-Length:\r\n"));
assert!(recorder.contains("Transfer-Encoding: chunked\r\n"));
assert!(!recorder.contains("\r\nContent-Length:\r\n"));
}
#[test]
#[cfg(feature = "charset")]
fn str_with_encoding() {
test::set_handler("/str_with_encoding", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = post("test://host/str_with_encoding")
let recorder = Recorder::register("/str_with_encoding");
post("test://host/str_with_encoding")
.set("Content-Type", "text/plain; charset=iso-8859-1")
.send_string("Hällo Wörld!!!")
.unwrap();
let vec = resp.into_written_bytes();
let vec = recorder.to_vec();
assert_eq!(
&vec[vec.len() - 14..],
//H ä l l o _ W ö r l d ! ! !
@@ -84,38 +66,30 @@ fn str_with_encoding() {
#[test]
#[cfg(feature = "json")]
fn content_type_on_json() {
test::set_handler("/content_type_on_json", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let recorder = Recorder::register("/content_type_on_json");
let mut json = serde_json::Map::new();
json.insert(
"Hello".to_string(),
serde_json::Value::String("World!!!".to_string()),
);
let resp = post("test://host/content_type_on_json")
post("test://host/content_type_on_json")
.send_json(serde_json::Value::Object(json))
.unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("\r\nContent-Type: application/json\r\n"));
assert!(recorder.contains("\r\nContent-Type: application/json\r\n"));
}
#[test]
#[cfg(feature = "json")]
fn content_type_not_overriden_on_json() {
test::set_handler("/content_type_not_overriden_on_json", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let recorder = Recorder::register("/content_type_not_overriden_on_json");
let mut json = serde_json::Map::new();
json.insert(
"Hello".to_string(),
serde_json::Value::String("World!!!".to_string()),
);
let resp = post("test://host/content_type_not_overriden_on_json")
post("test://host/content_type_not_overriden_on_json")
.set("content-type", "text/plain")
.send_json(serde_json::Value::Object(json))
.unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("\r\ncontent-type: text/plain\r\n"));
assert!(recorder.contains("\r\ncontent-type: text/plain\r\n"));
}

View File

@@ -1,9 +1,12 @@
use crate::error::Error;
use crate::stream::Stream;
use crate::stream::{ReadOnlyStream, Stream};
use crate::unit::Unit;
use crate::ReadWrite;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::io::Write;
use std::fmt;
use std::io::{self, Cursor, Read, Write};
use std::net::TcpStream;
use std::sync::{Arc, Mutex};
mod agent_test;
@@ -48,7 +51,7 @@ pub(crate) fn make_response(
}
write!(&mut buf, "\r\n").ok();
buf.append(&mut body);
Ok(Stream::from_vec(buf))
Ok(Stream::new(ReadOnlyStream::new(buf)))
}
pub(crate) fn resolve_handler(unit: &Unit) -> Result<Stream, Error> {
@@ -66,3 +69,87 @@ pub(crate) fn resolve_handler(unit: &Unit) -> Result<Stream, Error> {
drop(handlers);
handler(unit)
}
#[derive(Default, Clone)]
pub(crate) struct Recorder {
contents: Arc<Mutex<Vec<u8>>>,
}
impl Recorder {
fn register(path: &str) -> Self {
let recorder = Recorder::default();
let recorder2 = recorder.clone();
set_handler(path, move |_unit| Ok(recorder.stream()));
recorder2
}
#[cfg(feature = "charset")]
fn to_vec(self) -> Vec<u8> {
self.contents.lock().unwrap().clone()
}
fn contains(&self, s: &str) -> bool {
std::str::from_utf8(&self.contents.lock().unwrap())
.unwrap()
.contains(s)
}
fn stream(&self) -> Stream {
let cursor = Cursor::new(b"HTTP/1.1 200 OK\r\n\r\n");
Stream::new(TestStream::new(cursor, self.clone()))
}
}
impl Write for Recorder {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.contents.lock().unwrap().write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
pub(crate) struct TestStream(
Box<dyn Read + Send + Sync>,
Box<dyn Write + Send + Sync>,
bool,
);
impl TestStream {
#[cfg(test)]
pub(crate) fn new(
response: impl Read + Send + Sync + 'static,
recorder: impl Write + Send + Sync + 'static,
) -> Self {
Self(Box::new(response), Box::new(recorder), false)
}
}
impl ReadWrite for TestStream {
fn socket(&self) -> Option<&TcpStream> {
None
}
}
impl Read for TestStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for TestStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.1.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl fmt::Debug for TestStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("TestStream").finish()
}
}

View File

@@ -1,58 +1,37 @@
use crate::test;
use super::super::*;
use super::Recorder;
use crate::get;
#[test]
fn no_query_string() {
test::set_handler("/no_query_string", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = get("test://host/no_query_string").call().unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("GET /no_query_string HTTP/1.1"))
let recorder = Recorder::register("/no_query_string");
get("test://host/no_query_string").call().unwrap();
assert!(recorder.contains("GET /no_query_string HTTP/1.1"))
}
#[test]
fn escaped_query_string() {
test::set_handler("/escaped_query_string", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = get("test://host/escaped_query_string")
let recorder = Recorder::register("/escaped_query_string");
get("test://host/escaped_query_string")
.query("foo", "bar")
.query("baz", "yo lo")
.call()
.unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(
s.contains("GET /escaped_query_string?foo=bar&baz=yo+lo HTTP/1.1"),
"req: {}",
s
);
assert!(recorder.contains("GET /escaped_query_string?foo=bar&baz=yo+lo HTTP/1.1"));
}
#[test]
fn query_in_path() {
test::set_handler("/query_in_path", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = get("test://host/query_in_path?foo=bar").call().unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("GET /query_in_path?foo=bar HTTP/1.1"))
let recorder = Recorder::register("/query_in_path");
get("test://host/query_in_path?foo=bar").call().unwrap();
assert!(recorder.contains("GET /query_in_path?foo=bar HTTP/1.1"))
}
#[test]
fn query_in_path_and_req() {
test::set_handler("/query_in_path_and_req", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = get("test://host/query_in_path_and_req?foo=bar")
let recorder = Recorder::register("/query_in_path_and_req");
get("test://host/query_in_path_and_req?foo=bar")
.query("baz", "1 2 3")
.call()
.unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("GET /query_in_path_and_req?foo=bar&baz=1+2+3 HTTP/1.1"))
assert!(recorder.contains("GET /query_in_path_and_req?foo=bar&baz=1+2+3 HTTP/1.1"));
}

View File

@@ -1,7 +1,7 @@
use crate::test;
use std::io::Read;
use super::super::*;
use super::{super::*, Recorder};
#[test]
fn header_passing() {
@@ -116,13 +116,9 @@ fn body_as_reader() {
#[test]
fn escape_path() {
test::set_handler("/escape_path%20here", |_unit| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = get("test://host/escape_path here").call().unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("GET /escape_path%20here HTTP/1.1"))
let recorder = Recorder::register("/escape_path%20here");
get("test://host/escape_path here").call().unwrap();
assert!(recorder.contains("GET /escape_path%20here HTTP/1.1"))
}
#[test]
@@ -194,22 +190,14 @@ pub fn header_with_spaces_before_value() {
#[test]
pub fn host_no_port() {
test::set_handler("/host_no_port", |_| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = get("test://myhost/host_no_port").call().unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("\r\nHost: myhost\r\n"));
let recorder = Recorder::register("/host_no_port");
get("test://myhost/host_no_port").call().unwrap();
assert!(recorder.contains("\r\nHost: myhost\r\n"));
}
#[test]
pub fn host_with_port() {
test::set_handler("/host_with_port", |_| {
test::make_response(200, "OK", vec![], vec![])
});
let resp = get("test://myhost:234/host_with_port").call().unwrap();
let vec = resp.into_written_bytes();
let s = String::from_utf8_lossy(&vec);
assert!(s.contains("\r\nHost: myhost:234\r\n"));
let recorder = Recorder::register("/host_with_port");
get("test://myhost:234/host_with_port").call().unwrap();
assert!(recorder.contains("\r\nHost: myhost:234\r\n"));
}