diff --git a/src/pool.rs b/src/pool.rs index 019692f..10e0c4d 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -3,10 +3,11 @@ use std::collections::{HashMap, VecDeque}; use std::io::{self, Read}; use std::sync::Mutex; +use crate::response::LimitedRead; use crate::stream::Stream; -use crate::unit::Unit; -use crate::Proxy; +use crate::{Agent, Proxy}; +use chunked_transfer::Decoder; use log::debug; use url::Url; @@ -123,7 +124,7 @@ impl ConnectionPool { } } - fn add(&self, key: PoolKey, stream: Stream) { + fn add(&self, key: &PoolKey, stream: Stream) { if self.noop() { return; } @@ -143,7 +144,7 @@ impl ConnectionPool { streams.len(), stream ); - remove_first_match(&mut inner.lru, &key) + remove_first_match(&mut inner.lru, key) .expect("invariant failed: key in recycle but not in lru"); } } @@ -151,7 +152,7 @@ impl ConnectionPool { vacant_entry.insert(vec![stream].into()); } } - inner.lru.push_back(key); + inner.lru.push_back(key.clone()); if inner.lru.len() > self.max_idle_connections { drop(inner); self.remove_oldest() @@ -219,28 +220,32 @@ impl PoolKey { } } -/// Read wrapper that returns the stream to the pool once the +/// Read wrapper that returns a stream to the pool once the /// read is exhausted (reached a 0). /// /// *Internal API* pub(crate) struct PoolReturnRead> { - // unit that contains the agent where we want to return the reader. - unit: Option>, - // wrapped reader around the same stream + // the agent where we want to return the stream. + agent: Agent, + // wrapped reader around the same stream. It's an Option because we `take()` it + // upon returning the stream to the Agent. reader: Option, + // Key under which to store the stream when we're done. + key: PoolKey, } impl> PoolReturnRead { - pub fn new(unit: Option>, reader: R) -> Self { + pub fn new(agent: &Agent, url: &Url, reader: R) -> Self { PoolReturnRead { - unit, + agent: agent.clone(), + key: PoolKey::new(url, agent.config.proxy.clone()), reader: Some(reader), } } fn return_connection(&mut self) -> io::Result<()> { // guard we only do this once. - if let (Some(unit), Some(reader)) = (self.unit.take(), self.reader.take()) { + if let Some(reader) = self.reader.take() { // bring back stream here to either go into pool or dealloc let mut stream = reader.into(); if !stream.is_poolable() { @@ -252,8 +257,7 @@ impl> PoolReturnRead { stream.reset()?; // insert back into pool - let key = PoolKey::new(&unit.url, unit.agent.config.proxy.clone()); - unit.agent.state.pool.add(key, stream); + self.agent.state.pool.add(&self.key, stream); } Ok(()) @@ -267,12 +271,33 @@ impl> PoolReturnRead { } } -impl> Read for PoolReturnRead { +// Done allows a reader to indicate it is done (next read will return Ok(0)) +// without actually performing a read. This is useful so LimitedRead can +// inform PoolReturnRead to return a stream to the pool even if the user +// never read past the end of the response (For instance because their +// application is handling length information on its own). +pub(crate) trait Done { + fn done(&self) -> bool; +} + +impl Done for LimitedRead { + fn done(&self) -> bool { + self.remaining() == 0 + } +} + +impl Done for Decoder { + fn done(&self) -> bool { + false + } +} + +impl> Read for PoolReturnRead { fn read(&mut self, buf: &mut [u8]) -> io::Result { let amount = self.do_read(buf)?; // only if the underlying reader is exhausted can we send a new // request to the same socket. hence, we only return it now. - if amount == 0 { + if amount == 0 || self.reader.as_ref().map(|r| r.done()).unwrap_or_default() { self.return_connection()?; } Ok(amount) @@ -303,7 +328,7 @@ mod tests { proxy: None, }); for key in poolkeys.clone() { - pool.add(key, Stream::from_vec(vec![])) + pool.add(&key, Stream::from_vec(vec![])) } assert_eq!(pool.len(), pool.max_idle_connections); @@ -328,7 +353,7 @@ mod tests { }; for _ in 0..pool.max_idle_connections_per_host * 2 { - pool.add(poolkey.clone(), Stream::from_vec(vec![])) + pool.add(&poolkey, Stream::from_vec(vec![])) } assert_eq!(pool.len(), pool.max_idle_connections_per_host); @@ -345,23 +370,42 @@ mod tests { // Each insertion should result in an additional entry in the pool. let pool = ConnectionPool::new_with_limits(10, 1); let url = Url::parse("zzz:///example.com").unwrap(); + let pool_key = PoolKey::new(&url, None); - pool.add(PoolKey::new(&url, None), Stream::from_vec(vec![])); + pool.add(&pool_key, Stream::from_vec(vec![])); assert_eq!(pool.len(), 1); - pool.add( - PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())), - Stream::from_vec(vec![]), - ); + let pool_key = PoolKey::new(&url, Some(Proxy::new("localhost:9999").unwrap())); + + pool.add(&pool_key, Stream::from_vec(vec![])); assert_eq!(pool.len(), 2); - pool.add( - PoolKey::new( - &url, - Some(Proxy::new("user:password@localhost:9999").unwrap()), - ), - Stream::from_vec(vec![]), + let pool_key = PoolKey::new( + &url, + Some(Proxy::new("user:password@localhost:9999").unwrap()), ); + + pool.add(&pool_key, Stream::from_vec(vec![])); assert_eq!(pool.len(), 3); } + + // Test that a stream gets returned to the pool if it was wrapped in a LimitedRead, and + // user reads the exact right number of bytes (but never gets a read of 0 bytes). + #[test] + fn read_exact() { + let url = Url::parse("https:///example.com").unwrap(); + + let mut out_buf = [0u8; 500]; + let long_vec = vec![0u8; 1000]; + + let agent = Agent::new(); + let stream = Stream::from_vec_poolable(long_vec); + let limited_read = LimitedRead::new(stream, 500); + + let mut pool_return_read = PoolReturnRead::new(&agent, &url, limited_read); + + pool_return_read.read_exact(&mut out_buf).unwrap(); + + assert_eq!(agent.state.pool.len(), 1); + } } diff --git a/src/response.rs b/src/response.rs index b57d201..2988b8f 100644 --- a/src/response.rs +++ b/src/response.rs @@ -6,12 +6,13 @@ use chunked_transfer::Decoder as ChunkDecoder; use sync_wrapper::SyncWrapper; use url::Url; +use crate::body::SizedReader; use crate::error::{Error, ErrorKind::BadStatus}; use crate::header::{get_all_headers, get_header, Header, HeaderLine}; use crate::pool::PoolReturnRead; use crate::stream::{DeadlineStream, Stream}; use crate::unit::Unit; -use crate::{stream, ErrorKind}; +use crate::{stream, Agent, ErrorKind}; #[cfg(feature = "json")] use serde::de::DeserializeOwned; @@ -60,13 +61,13 @@ const MAX_HEADER_COUNT: usize = 100; /// # } /// ``` pub struct Response { - pub(crate) url: Option, + pub(crate) url: Url, status_line: String, index: ResponseStatusIndex, status: u16, headers: Vec
, // Boxed to avoid taking up too much size. - unit: Option>, + unit: Box, // Boxed to avoid taking up too much size. stream: SyncWrapper>, /// The redirect history of this response, if any. The history starts with @@ -93,14 +94,11 @@ impl fmt::Debug for Response { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, - "Response[status: {}, status_text: {}", + "Response[status: {}, status_text: {}, url: {}]", self.status(), self.status_text(), - )?; - if let Some(url) = &self.url { - write!(f, ", url: {}", url)?; - } - write!(f, "]") + self.url, + ) } } @@ -128,7 +126,7 @@ impl Response { /// The URL we ended up at. This can differ from the request url when /// we have followed redirects. pub fn get_url(&self) -> &str { - self.url.as_ref().map(|s| &s[..]).unwrap_or("") + &self.url[..] } /// The http version: `HTTP/1.1` @@ -270,7 +268,7 @@ impl Response { .map(|c| c.eq_ignore_ascii_case("close")) .unwrap_or(false); - let is_head = self.unit.as_ref().map(|u| u.is_head()).unwrap_or(false); + let is_head = self.unit.is_head(); let has_no_body = is_head || match self.status { 204 | 304 => true, @@ -295,20 +293,24 @@ impl Response { let stream = self.stream.into_inner(); let unit = self.unit; - if let Some(unit) = &unit { - let result = stream.set_read_timeout(unit.agent.config.timeout_read); - if let Err(e) = result { - return Box::new(ErrorReader(e)) as Box; - } + let result = stream.set_read_timeout(unit.agent.config.timeout_read); + if let Err(e) = result { + return Box::new(ErrorReader(e)) as Box; } - let deadline = unit.as_ref().and_then(|u| u.deadline); + let deadline = unit.deadline; let stream = DeadlineStream::new(*stream, deadline); let body_reader: Box = match (use_chunked, limit_bytes) { - (true, _) => Box::new(PoolReturnRead::new(unit, ChunkDecoder::new(stream))), - (false, Some(len)) => { - Box::new(PoolReturnRead::new(unit, LimitedRead::new(stream, len))) - } + (true, _) => Box::new(PoolReturnRead::new( + &unit.agent, + &unit.url, + ChunkDecoder::new(stream), + )), + (false, Some(len)) => Box::new(PoolReturnRead::new( + &unit.agent, + &unit.url, + LimitedRead::new(stream, len), + )), (false, None) => Box::new(stream), }; @@ -467,11 +469,10 @@ impl Response { /// let resp = ureq::Response::do_from_read(read); /// /// assert_eq!(resp.status(), 401); - pub(crate) fn do_from_stream(stream: Stream, unit: Option) -> Result { + pub(crate) fn do_from_stream(stream: Stream, unit: Unit) -> Result { // // HTTP/1.1 200 OK\r\n - let mut stream = - stream::DeadlineStream::new(stream, unit.as_ref().and_then(|u| u.deadline)); + let mut stream = stream::DeadlineStream::new(stream, unit.deadline); // The status line we can ignore non-utf8 chars and parse as_str_lossy(). let status_line = read_next_line(&mut stream, "the status line")?.into_string_lossy(); @@ -504,7 +505,7 @@ impl Response { headers.retain(|h| !h.is_name("content-encoding") && !h.is_name("content-length")); } - let url = unit.as_ref().map(|u| u.url.clone()); + let url = unit.url.clone(); Ok(Response { url, @@ -512,7 +513,7 @@ impl Response { index, status, headers, - unit: unit.map(Box::new), + unit: Box::new(unit), stream: SyncWrapper::new(Box::new(stream.into())), history: vec![], length, @@ -528,14 +529,13 @@ impl Response { #[cfg(test)] pub fn set_url(&mut self, url: Url) { - self.url = Some(url); + self.url = url; } #[cfg(test)] pub fn history_from_previous(&mut self, previous: Response) { - let previous_url = previous.url.expect("previous url"); self.history = previous.history; - self.history.push(previous_url); + self.history.push(previous.url); } } @@ -645,7 +645,20 @@ impl FromStr for Response { /// ``` fn from_str(s: &str) -> Result { let stream = Stream::from_vec(s.as_bytes().to_owned()); - Self::do_from_stream(stream, None) + let request_url = "https://example.com".parse().unwrap(); + let request_reader = SizedReader { + size: crate::body::BodySize::Empty, + reader: Box::new(std::io::empty()), + }; + let unit = Unit::new( + &Agent::new(), + "GET", + &request_url, + vec![], + &request_reader, + None, + ); + Self::do_from_stream(stream, unit) } } @@ -695,30 +708,33 @@ fn read_next_line(reader: &mut impl BufRead, context: &str) -> io::Result
{ +pub(crate) struct LimitedRead { reader: R, limit: usize, position: usize, } impl LimitedRead { - fn new(reader: R, limit: usize) -> Self { + pub(crate) fn new(reader: R, limit: usize) -> Self { LimitedRead { reader, limit, position: 0, } } + + pub(crate) fn remaining(&self) -> usize { + self.limit - self.position + } } impl Read for LimitedRead { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let left = self.limit - self.position; - if left == 0 { + if self.remaining() == 0 { return Ok(0); } - let from = if left < buf.len() { - &mut buf[0..left] + let from = if self.remaining() < buf.len() { + &mut buf[0..self.remaining()] } else { buf }; @@ -990,7 +1006,20 @@ mod tests { ); let v = cow.to_vec(); let s = Stream::from_vec(v); - let resp = Response::do_from_stream(s.into(), None).unwrap(); + let request_url = "https://example.com".parse().unwrap(); + let request_reader = SizedReader { + size: crate::body::BodySize::Empty, + reader: Box::new(std::io::empty()), + }; + let unit = Unit::new( + &Agent::new(), + "GET", + &request_url, + vec![], + &request_reader, + None, + ); + let resp = Response::do_from_stream(s.into(), unit).unwrap(); assert_eq!(resp.status(), 200); assert_eq!(resp.header("x-geo-header"), None); } diff --git a/src/stream.rs b/src/stream.rs index c8cee56..643ae55 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -69,11 +69,11 @@ impl Inner for TcpStream { } } -struct TestStream(Box, Vec); +struct TestStream(Box, Vec, bool); impl Inner for TestStream { fn is_poolable(&self) -> bool { - false + self.2 } fn socket(&self) -> Option<&TcpStream> { None @@ -201,7 +201,18 @@ impl Stream { pub(crate) fn from_vec(v: Vec) -> Stream { Stream::logged_create(Stream { - inner: BufReader::new(Box::new(TestStream(Box::new(Cursor::new(v)), vec![]))), + inner: BufReader::new(Box::new(TestStream( + Box::new(Cursor::new(v)), + vec![], + false, + ))), + }) + } + + #[cfg(test)] + pub(crate) fn from_vec_poolable(v: Vec) -> Stream { + Stream::logged_create(Stream { + inner: BufReader::new(Box::new(TestStream(Box::new(Cursor::new(v)), vec![], true))), }) } diff --git a/src/unit.rs b/src/unit.rs index 4b327bc..6fce8bb 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -284,7 +284,7 @@ fn connect_inner( // TODO: this unit.clone() bothers me. At this stage, we're not // going to use the unit (much) anymore, and it should be possible // to have ownership of it and pass it into the Response. - let result = Response::do_from_stream(stream, Some(unit.clone())); + let result = Response::do_from_stream(stream, unit.clone()); // https://tools.ietf.org/html/rfc7230#section-6.3.1 // When an inbound connection is closed prematurely, a client MAY