diff --git a/src/pool.rs b/src/pool.rs index 812365d..ea88a52 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -3,11 +3,9 @@ use std::collections::{HashMap, VecDeque}; use std::io::{self, Read}; use std::sync::Mutex; -use crate::response::LimitedRead; use crate::stream::Stream; use crate::{Agent, Proxy}; -use chunked_transfer::Decoder; use log::debug; use url::Url; @@ -124,7 +122,7 @@ impl ConnectionPool { } } - fn add(&self, key: &PoolKey, stream: Stream) { + pub(crate) fn add(&self, key: &PoolKey, stream: Stream) { if self.noop() { return; } @@ -188,7 +186,7 @@ impl ConnectionPool { } #[derive(PartialEq, Clone, Eq, Hash)] -struct PoolKey { +pub(crate) struct PoolKey { scheme: String, hostname: String, port: Option, @@ -218,6 +216,40 @@ impl PoolKey { proxy, } } + + pub(crate) fn from_parts(scheme: &str, hostname: &str, port: u16) -> Self { + PoolKey { + scheme: scheme.to_string(), + hostname: hostname.to_string(), + port: Some(port), + proxy: None, + } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct PoolReturner { + inner: Option<(Agent, PoolKey)>, +} + +impl PoolReturner { + /// A PoolReturner that returns to the given Agent's Pool. + pub(crate) fn new(agent: Agent, pool_key: PoolKey) -> Self { + Self { + inner: Some((agent, pool_key)), + } + } + + /// A PoolReturner that does nothing + pub(crate) fn none() -> Self { + Self { inner: None } + } + + pub(crate) fn return_to_pool(&self, stream: Stream) { + if let Some((agent, pool_key)) = &self.inner { + agent.state.pool.add(pool_key, stream); + } + } } /// Read wrapper that returns a stream to the pool once the @@ -225,20 +257,14 @@ impl PoolKey { /// /// *Internal API* pub(crate) struct PoolReturnRead> { - // the agent where we want to return the stream. - agent: Agent, // wrapped reader around the same stream. It's an Option because we `take()` it // upon returning the stream to the Agent. reader: Option, - // Key under which to store the stream when we're done. - key: PoolKey, } impl> PoolReturnRead { - pub fn new(agent: &Agent, url: &Url, reader: R) -> Self { + pub fn new(reader: R) -> Self { PoolReturnRead { - agent: agent.clone(), - key: PoolKey::new(url, agent.config.proxy.clone()), reader: Some(reader), } } @@ -247,13 +273,8 @@ impl> PoolReturnRead { // guard we only do this once. if let Some(reader) = self.reader.take() { // bring back stream here to either go into pool or dealloc - let mut stream = reader.into(); - - // ensure stream can be reused - stream.reset()?; - - // insert back into pool - self.agent.state.pool.add(&self.key, stream); + let stream: Stream = reader.into(); + stream.return_to_pool()?; } Ok(()) @@ -267,33 +288,12 @@ impl> PoolReturnRead { } } -// Done allows a reader to indicate it is done (next read will return Ok(0)) -// without actually performing a read. This is useful so LimitedRead can -// inform PoolReturnRead to return a stream to the pool even if the user -// never read past the end of the response (For instance because their -// application is handling length information on its own). -pub(crate) trait Done { - fn done(&self) -> bool; -} - -impl Done for LimitedRead { - fn done(&self) -> bool { - self.remaining() == 0 - } -} - -impl Done for Decoder { - fn done(&self) -> bool { - false - } -} - -impl> Read for PoolReturnRead { +impl> Read for PoolReturnRead { fn read(&mut self, buf: &mut [u8]) -> io::Result { let amount = self.do_read(buf)?; // only if the underlying reader is exhausted can we send a new // request to the same socket. hence, we only return it now. - if amount == 0 || self.reader.as_ref().map(|r| r.done()).unwrap_or_default() { + if amount == 0 { self.return_connection()?; } Ok(amount) @@ -313,8 +313,8 @@ mod tests { struct NoopStream; impl NoopStream { - fn stream() -> Stream { - Stream::new(NoopStream, remote_addr_for_test()) + fn stream(pool_returner: PoolReturner) -> Stream { + Stream::new(NoopStream, remote_addr_for_test(), pool_returner) } } @@ -360,7 +360,7 @@ mod tests { proxy: None, }); for key in poolkeys.clone() { - pool.add(&key, NoopStream::stream()); + pool.add(&key, NoopStream::stream(PoolReturner::none())); } assert_eq!(pool.len(), pool.max_idle_connections); @@ -385,7 +385,7 @@ mod tests { }; for _ in 0..pool.max_idle_connections_per_host * 2 { - pool.add(&poolkey, NoopStream::stream()) + pool.add(&poolkey, NoopStream::stream(PoolReturner::none())) } assert_eq!(pool.len(), pool.max_idle_connections_per_host); @@ -404,12 +404,12 @@ mod tests { let url = Url::parse("zzz:///example.com").unwrap(); let pool_key = PoolKey::new(&url, None); - pool.add(&pool_key, NoopStream::stream()); + pool.add(&pool_key, NoopStream::stream(PoolReturner::none())); assert_eq!(pool.len(), 1); let pool_key = PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())); - pool.add(&pool_key, NoopStream::stream()); + pool.add(&pool_key, NoopStream::stream(PoolReturner::none())); assert_eq!(pool.len(), 2); let pool_key = PoolKey::new( @@ -417,7 +417,7 @@ mod tests { Some(Proxy::new("user:password@localhost:9999").unwrap()), ); - pool.add(&pool_key, NoopStream::stream()); + pool.add(&pool_key, NoopStream::stream(PoolReturner::none())); assert_eq!(pool.len(), 3); } @@ -425,17 +425,18 @@ mod tests { // user reads the exact right number of bytes (but never gets a read of 0 bytes). #[test] fn read_exact() { + use crate::response::LimitedRead; + let url = Url::parse("https:///example.com").unwrap(); let mut out_buf = [0u8; 500]; let agent = Agent::new(); - let stream = NoopStream::stream(); - let limited_read = LimitedRead::new(stream, 500); + let pool_key = PoolKey::new(&url, None); + let stream = NoopStream::stream(PoolReturner::new(agent.clone(), pool_key)); + let mut limited_read = LimitedRead::new(stream, std::num::NonZeroUsize::new(500).unwrap()); - let mut pool_return_read = PoolReturnRead::new(&agent, &url, limited_read); - - pool_return_read.read_exact(&mut out_buf).unwrap(); + limited_read.read_exact(&mut out_buf).unwrap(); assert_eq!(agent.state.pool.len(), 1); } @@ -448,6 +449,7 @@ mod tests { fn read_exact_chunked_gzip() { use crate::response::Compression; use chunked_transfer::Decoder as ChunkDecoder; + use std::io::Cursor; let gz_body = vec![ b'E', b'\r', b'\n', // 14 first chunk @@ -464,28 +466,19 @@ mod tests { b'\r', b'\n', // ]; - println!("{:?}", gz_body); - - impl ReadWrite for io::Cursor> { - fn socket(&self) -> Option<&std::net::TcpStream> { - None - } - } - - impl From>> for Stream { - fn from(c: io::Cursor>) -> Self { - Stream::new(c, "1.1.1.1:8080".parse().unwrap()) - } - } - let agent = Agent::new(); - let url = Url::parse("https://example.com").unwrap(); - assert_eq!(agent.state.pool.len(), 0); - let chunked = ChunkDecoder::new(io::Cursor::new(gz_body)); + let ro = crate::test::TestStream::new(Cursor::new(gz_body), std::io::sink()); + let stream = Stream::new( + ro, + "1.1.1.1:4343".parse().unwrap(), + PoolReturner::new(agent.clone(), PoolKey::from_parts("http", "1.1.1.1", 8080)), + ); + + let chunked = ChunkDecoder::new(stream); let pool_return_read: Box<(dyn Read + Send + Sync + 'static)> = - Box::new(PoolReturnRead::new(&agent, &url, chunked)); + Box::new(PoolReturnRead::new(chunked)); let compression = Compression::Gzip; let mut stream = compression.wrap_reader(pool_return_read); diff --git a/src/response.rs b/src/response.rs index 35fd1ab..8867a4a 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,5 +1,6 @@ use std::io::{self, Cursor, Read}; use std::net::SocketAddr; +use std::num::NonZeroUsize; use std::str::FromStr; use std::{fmt, io::BufRead}; @@ -10,7 +11,7 @@ use url::Url; use crate::body::SizedReader; use crate::error::{Error, ErrorKind::BadStatus}; use crate::header::{get_all_headers, get_header, Header, HeaderLine}; -use crate::pool::PoolReturnRead; +use crate::pool::{PoolReturnRead, PoolReturner}; use crate::stream::{DeadlineStream, ReadOnlyStream, Stream}; use crate::unit::Unit; use crate::{stream, Agent, ErrorKind}; @@ -327,7 +328,7 @@ impl Response { let inner = stream.inner_ref(); let result = inner.set_read_timeout(unit.agent.config.timeout_read); if let Err(e) = result { - return Box::new(ErrorReader(e)) as Box; + return Box::new(ErrorReader(e)); } let buffer_len = inner.buffer().len(); @@ -337,29 +338,40 @@ impl Response { // to the connection pool. BodyType::Chunked => { debug!("Chunked body in response"); - Box::new(PoolReturnRead::new( - &unit.agent, - &unit.url, - ChunkDecoder::new(stream), - )) + Box::new(PoolReturnRead::new(ChunkDecoder::new(stream))) } // Responses with a content-length header means we should limit the reading // of the body to the number of bytes in the header. Once done, we can // return the underlying stream to the connection pool. BodyType::LengthDelimited(len) => { - let mut pooler = - PoolReturnRead::new(&unit.agent, &unit.url, LimitedRead::new(stream, len)); + match NonZeroUsize::new(len) { + None => { + debug!("zero-length body returning stream directly to pool"); + let stream: Stream = stream.into(); + // TODO: This expect can actually panic if we get an error when + // returning the stream to the pool. We reset the read timeouts + // when we do that, and since that's a syscall it can fail. + stream.return_to_pool().expect("returning stream to pool"); + Box::new(std::io::empty()) + } + Some(len) => { + let mut limited_read = LimitedRead::new(stream, len); - if len <= buffer_len { - debug!("Body entirely buffered (length: {})", len); - let mut buf = vec![0; len]; - pooler - .read_exact(&mut buf) - .expect("failed to read exact buffer length from stream"); - Box::new(Cursor::new(buf)) - } else { - debug!("Streaming body until content-length: {}", len); - Box::new(pooler) + if len.get() <= buffer_len { + debug!("Body entirely buffered (length: {})", len); + let mut buf = vec![0; len.get()]; + // TODO: This expect can actually panic if we get an error when + // returning the stream to the pool. We reset the read timeouts + // when we do that, and since that's a syscall it can fail. + limited_read + .read_exact(&mut buf) + .expect("failed to read exact buffer length from stream"); + Box::new(Cursor::new(buf)) + } else { + debug!("Streaming body until content-length: {}", len); + Box::new(limited_read) + } + } } } BodyType::CloseDelimited => { @@ -698,7 +710,11 @@ impl FromStr for Response { /// ``` fn from_str(s: &str) -> Result { let remote_addr = "0.0.0.0:0".parse().unwrap(); - let stream = Stream::new(ReadOnlyStream::new(s.into()), remote_addr); + let stream = Stream::new( + ReadOnlyStream::new(s.into()), + remote_addr, + PoolReturner::none(), + ); let request_url = "https://example.com".parse().unwrap(); let request_reader = SizedReader { size: crate::body::BodySize::Empty, @@ -763,16 +779,16 @@ fn read_next_line(reader: &mut impl BufRead, context: &str) -> io::Result
{ - reader: R, + reader: Option, limit: usize, position: usize, } -impl LimitedRead { - pub(crate) fn new(reader: R, limit: usize) -> Self { +impl> LimitedRead { + pub(crate) fn new(reader: R, limit: NonZeroUsize) -> Self { LimitedRead { - reader, - limit, + reader: Some(reader), + limit: limit.get(), position: 0, } } @@ -780,9 +796,20 @@ impl LimitedRead { pub(crate) fn remaining(&self) -> usize { self.limit - self.position } + + fn return_stream_to_pool(&mut self) -> io::Result<()> { + if let Some(reader) = self.reader.take() { + // Convert back to a stream. If return_to_pool fails, the stream will + // drop and the connection will be closed. + let stream: Stream = reader.into(); + stream.return_to_pool()?; + } + + Ok(()) + } } -impl Read for LimitedRead { +impl> Read for LimitedRead { fn read(&mut self, buf: &mut [u8]) -> io::Result { if self.remaining() == 0 { return Ok(0); @@ -792,18 +819,27 @@ impl Read for LimitedRead { } else { buf }; - match self.reader.read(from) { + let reader = match self.reader.as_mut() { + // If the reader has already been taken, return Ok(0) to all reads. + None => return Ok(0), + Some(r) => r, + }; + match reader.read(from) { // https://tools.ietf.org/html/rfc7230#page-33 // If the sender closes the connection or // the recipient times out before the indicated number of octets are // received, the recipient MUST consider the message to be // incomplete and close the connection. + // TODO: actually close the connection by dropping the stream Ok(0) => Err(io::Error::new( io::ErrorKind::UnexpectedEof, "response body closed before all bytes were read", )), Ok(amount) => { self.position += amount; + if self.remaining() == 0 { + self.return_stream_to_pool()?; + } Ok(amount) } Err(e) => Err(e), @@ -811,15 +847,6 @@ impl Read for LimitedRead { } } -impl From> for Stream -where - Stream: From, -{ - fn from(limited_read: LimitedRead) -> Stream { - limited_read.reader.into() - } -} - /// Extract the charset from a "Content-Type" header. /// /// "Content-Type: text/plain; charset=iso8859-1" -> "iso8859-1" @@ -852,12 +879,20 @@ impl Read for ErrorReader { mod tests { use std::io::Cursor; + use crate::{body::Payload, pool::PoolKey}; + use super::*; #[test] fn short_read() { use std::io::Cursor; - let mut lr = LimitedRead::new(Cursor::new(vec![b'a'; 3]), 10); + let test_stream = crate::test::TestStream::new(Cursor::new(vec![b'a'; 3]), std::io::sink()); + let stream = Stream::new( + test_stream, + "1.1.1.1:4343".parse().unwrap(), + PoolReturner::none(), + ); + let mut lr = LimitedRead::new(stream, std::num::NonZeroUsize::new(10).unwrap()); let mut buf = vec![0; 1000]; let result = lr.read_to_end(&mut buf); assert!(result.err().unwrap().kind() == io::ErrorKind::UnexpectedEof); @@ -1062,6 +1097,7 @@ mod tests { let s = Stream::new( ReadOnlyStream::new(v), crate::stream::remote_addr_for_test(), + PoolReturner::none(), ); let request_url = "https://example.com".parse().unwrap(); let request_reader = SizedReader { @@ -1112,4 +1148,39 @@ mod tests { println!("Response size: {}", size); assert!(size < 400); // 200 on Macbook M1 } + + // Test that a stream gets returned to the pool immediately for a zero-length response, and + // that reads from the response's body consistently return Ok(0). + #[test] + fn zero_length_body_immediate_return() { + use std::io::Cursor; + let response_bytes = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n" + .as_bytes() + .to_vec(); + let test_stream = + crate::test::TestStream::new(Cursor::new(response_bytes), std::io::sink()); + let agent = Agent::new(); + let agent2 = agent.clone(); + let stream = Stream::new( + test_stream, + "1.1.1.1:4343".parse().unwrap(), + PoolReturner::new( + agent.clone(), + PoolKey::from_parts("https", "example.com", 443), + ), + ); + Response::do_from_stream( + stream, + Unit::new( + &agent, + "GET", + &"https://example.com/".parse().unwrap(), + vec![], + &Payload::Empty.into_read(), + None, + ), + ) + .unwrap(); + assert_eq!(agent2.state.pool.len(), 1); + } } diff --git a/src/stream.rs b/src/stream.rs index 5f978a5..f911afb 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -11,6 +11,7 @@ use chunked_transfer::Decoder as ChunkDecoder; #[cfg(feature = "socks-proxy")] use socks::{TargetAddr, ToTargetAddr}; +use crate::pool::{PoolKey, PoolReturner}; use crate::proxy::Proxy; use crate::{error::Error, proxy::Proto}; @@ -40,6 +41,7 @@ pub(crate) struct Stream { inner: BufReader>, /// The remote address the stream is connected to. pub(crate) remote_addr: SocketAddr, + pool_returner: PoolReturner, } impl ReadWrite for Box { @@ -179,10 +181,15 @@ impl fmt::Debug for Stream { } impl Stream { - pub(crate) fn new(t: impl ReadWrite, remote_addr: SocketAddr) -> Stream { + pub(crate) fn new( + t: impl ReadWrite, + remote_addr: SocketAddr, + pool_returner: PoolReturner, + ) -> Stream { Stream::logged_create(Stream { inner: BufReader::new(Box::new(t)), remote_addr, + pool_returner, }) } @@ -235,6 +242,13 @@ impl Stream { } } + pub(crate) fn return_to_pool(mut self) -> io::Result<()> { + // ensure stream can be reused + self.reset()?; + self.pool_returner.clone().return_to_pool(self); + Ok(()) + } + pub(crate) fn reset(&mut self) -> io::Result<()> { // When we are turning this back into a regular, non-deadline Stream, // remove any timeouts we set. @@ -303,8 +317,9 @@ impl Drop for Stream { pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result { // let port = unit.url.port().unwrap_or(80); - - connect_host(unit, hostname, port).map(|(t, r)| Stream::new(t, r)) + let pool_key = PoolKey::from_parts("http", hostname, port); + let pool_returner = PoolReturner::new(unit.agent.clone(), pool_key); + connect_host(unit, hostname, port).map(|(t, r)| Stream::new(t, r, pool_returner)) } pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result { @@ -314,7 +329,9 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result