Add req field to Unit and remove cloned parts from request (#158)

Instead of cloning most of `Request`'s fields individually when
creating a `Unit`, this PR switches to just cloning `Request` and
stuffing it in `Unit`, and changes references to `unit.[field]` to
`unit.req.[field]` where appropriate.

Fixes #155
This commit is contained in:
Daniel Rivas
2020-09-26 18:22:10 +01:00
committed by GitHub
parent be9e3ca936
commit 8bba07a9af
4 changed files with 31 additions and 53 deletions

View File

@@ -394,7 +394,7 @@ impl<R: Read + Sized + Into<Stream>> PoolReturnRead<R> {
fn return_connection(&mut self) { fn return_connection(&mut self) {
// guard we only do this once. // guard we only do this once.
if let (Some(unit), Some(reader)) = (self.unit.take(), self.reader.take()) { if let (Some(unit), Some(reader)) = (self.unit.take(), self.reader.take()) {
let state = &mut unit.agent.lock().unwrap(); let state = &mut unit.req.agent.lock().unwrap();
// bring back stream here to either go into pool or dealloc // bring back stream here to either go into pool or dealloc
let stream = reader.into(); let stream = reader.into();
if !stream.is_poolable() { if !stream.is_poolable() {
@@ -402,7 +402,7 @@ impl<R: Read + Sized + Into<Stream>> PoolReturnRead<R> {
return; return;
} }
// insert back into pool // insert back into pool
let key = PoolKey::new(&unit.url, &unit.proxy); let key = PoolKey::new(&unit.url, &unit.req.proxy);
state.pool().add(key, stream); state.pool().add(key, stream);
} }
} }

View File

@@ -343,8 +343,12 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
let sni = webpki::DNSNameRef::try_from_ascii_str(hostname) let sni = webpki::DNSNameRef::try_from_ascii_str(hostname)
.map_err(|err| Error::DnsFailed(err.to_string()))?; .map_err(|err| Error::DnsFailed(err.to_string()))?;
let tls_conf: &Arc<rustls::ClientConfig> = let tls_conf: &Arc<rustls::ClientConfig> = unit
unit.tls_config.as_ref().map(|c| &c.0).unwrap_or(&*TLS_CONF); .req
.tls_config
.as_ref()
.map(|c| &c.0)
.unwrap_or(&*TLS_CONF);
let sess = rustls::ClientSession::new(&tls_conf, sni); let sess = rustls::ClientSession::new(&tls_conf, sni);
let sock = connect_host(unit, hostname, port)?; let sock = connect_host(unit, hostname, port)?;
@@ -362,7 +366,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
let port = unit.url.port().unwrap_or(443); let port = unit.url.port().unwrap_or(443);
let sock = connect_host(unit, hostname, port)?; let sock = connect_host(unit, hostname, port)?;
let tls_connector: Arc<native_tls::TlsConnector> = match &unit.tls_connector { let tls_connector: Arc<native_tls::TlsConnector> = match &unit.req.tls_connector {
Some(connector) => connector.0.clone(), Some(connector) => connector.0.clone(),
None => Arc::new(native_tls::TlsConnector::new().map_err(|e| Error::TlsError(e))?), None => Arc::new(native_tls::TlsConnector::new().map_err(|e| Error::TlsError(e))?),
}; };
@@ -377,14 +381,14 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
} }
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> { pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
let deadline: Option<Instant> = if unit.timeout_connect > 0 { let deadline: Option<Instant> = if unit.req.timeout_connect > 0 {
Instant::now().checked_add(Duration::from_millis(unit.timeout_connect)) Instant::now().checked_add(Duration::from_millis(unit.req.timeout_connect))
} else { } else {
unit.deadline unit.deadline
}; };
// TODO: Find a way to apply deadline to DNS lookup. // TODO: Find a way to apply deadline to DNS lookup.
let sock_addrs: Vec<SocketAddr> = match unit.proxy { let sock_addrs: Vec<SocketAddr> = match unit.req.proxy {
Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port), Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port),
None => format!("{}:{}", hostname, port), None => format!("{}:{}", hostname, port),
} }
@@ -396,7 +400,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
return Err(Error::DnsFailed(format!("No ip address for {}", hostname))); return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
} }
let proto = if let Some(ref proxy) = unit.proxy { let proto = if let Some(ref proxy) = unit.req.proxy {
Some(proxy.proto) Some(proxy.proto)
} else { } else {
None None
@@ -415,7 +419,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
// connect with a configured timeout. // connect with a configured timeout.
let stream = if Some(Proto::SOCKS5) == proto { let stream = if Some(Proto::SOCKS5) == proto {
connect_socks5( connect_socks5(
unit.proxy.to_owned().unwrap(), unit.req.proxy.to_owned().unwrap(),
deadline, deadline,
sock_addr, sock_addr,
hostname, hostname,
@@ -448,9 +452,9 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
stream stream
.set_read_timeout(Some(time_until_deadline(deadline)?)) .set_read_timeout(Some(time_until_deadline(deadline)?))
.ok(); .ok();
} else if unit.timeout_read > 0 { } else if unit.req.timeout_read > 0 {
stream stream
.set_read_timeout(Some(Duration::from_millis(unit.timeout_read as u64))) .set_read_timeout(Some(Duration::from_millis(unit.req.timeout_read as u64)))
.ok(); .ok();
} else { } else {
stream.set_read_timeout(None).ok(); stream.set_read_timeout(None).ok();
@@ -460,16 +464,16 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
stream stream
.set_write_timeout(Some(time_until_deadline(deadline)?)) .set_write_timeout(Some(time_until_deadline(deadline)?))
.ok(); .ok();
} else if unit.timeout_write > 0 { } else if unit.req.timeout_write > 0 {
stream stream
.set_write_timeout(Some(Duration::from_millis(unit.timeout_write as u64))) .set_write_timeout(Some(Duration::from_millis(unit.req.timeout_write as u64)))
.ok(); .ok();
} else { } else {
stream.set_write_timeout(None).ok(); stream.set_write_timeout(None).ok();
} }
if proto == Some(Proto::HTTPConnect) { if proto == Some(Proto::HTTPConnect) {
if let Some(ref proxy) = unit.proxy { if let Some(ref proxy) = unit.req.proxy {
write!(stream, "{}", proxy.connect(hostname, port)).unwrap(); write!(stream, "{}", proxy.connect(hostname, port)).unwrap();
stream.flush()?; stream.flush()?;

View File

@@ -46,7 +46,7 @@ fn redirect_head() {
test::make_response(302, "Go here", vec!["Location: /redirect_head2"], vec![]) test::make_response(302, "Go here", vec!["Location: /redirect_head2"], vec![])
}); });
test::set_handler("/redirect_head2", |unit| { test::set_handler("/redirect_head2", |unit| {
assert_eq!(unit.method, "HEAD"); assert_eq!(unit.req.method, "HEAD");
test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) test::make_response(200, "OK", vec!["x-foo: bar"], vec![])
}); });
let resp = head("test://host/redirect_head1").call(); let resp = head("test://host/redirect_head1").call();
@@ -62,7 +62,7 @@ fn redirect_get() {
test::make_response(302, "Go here", vec!["Location: /redirect_get2"], vec![]) test::make_response(302, "Go here", vec!["Location: /redirect_get2"], vec![])
}); });
test::set_handler("/redirect_get2", |unit| { test::set_handler("/redirect_get2", |unit| {
assert_eq!(unit.method, "GET"); assert_eq!(unit.req.method, "GET");
assert!(unit.has("Range")); assert!(unit.has("Range"));
assert_eq!(unit.header("Range").unwrap(), "bytes=10-50"); assert_eq!(unit.header("Range").unwrap(), "bytes=10-50");
test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) test::make_response(200, "OK", vec!["x-foo: bar"], vec![])
@@ -82,7 +82,7 @@ fn redirect_post() {
test::make_response(302, "Go here", vec!["Location: /redirect_post2"], vec![]) test::make_response(302, "Go here", vec!["Location: /redirect_post2"], vec![])
}); });
test::set_handler("/redirect_post2", |unit| { test::set_handler("/redirect_post2", |unit| {
assert_eq!(unit.method, "GET"); assert_eq!(unit.req.method, "GET");
test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) test::make_response(200, "OK", vec!["x-foo: bar"], vec![])
}); });
let resp = post("test://host/redirect_post1").call(); let resp = post("test://host/redirect_post1").call();

View File

@@ -1,5 +1,4 @@
use std::io::{Result as IoResult, Write}; use std::io::{Result as IoResult, Write};
use std::sync::{Arc, Mutex};
use std::time; use std::time;
use qstring::QString; use qstring::QString;
@@ -12,15 +11,8 @@ use crate::agent::AgentState;
use crate::body::{self, Payload, SizedReader}; use crate::body::{self, Payload, SizedReader};
use crate::header; use crate::header;
use crate::stream::{self, connect_test, Stream}; use crate::stream::{self, connect_test, Stream};
use crate::Proxy;
use crate::{Error, Header, Request, Response}; use crate::{Error, Header, Request, Response};
#[cfg(feature = "tls")]
use crate::request::TLSClientConfig;
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
use crate::request::TLSConnector;
#[cfg(feature = "cookie")] #[cfg(feature = "cookie")]
use crate::pool::DEFAULT_HOST; use crate::pool::DEFAULT_HOST;
@@ -28,21 +20,12 @@ use crate::pool::DEFAULT_HOST;
/// ///
/// *Internal API* /// *Internal API*
pub(crate) struct Unit { pub(crate) struct Unit {
pub agent: Arc<Mutex<AgentState>>, pub req: Request,
pub url: Url, pub url: Url,
pub is_chunked: bool, pub is_chunked: bool,
pub query_string: String, pub query_string: String,
pub headers: Vec<Header>, pub headers: Vec<Header>,
pub timeout_connect: u64,
pub timeout_read: u64,
pub timeout_write: u64,
pub deadline: Option<time::Instant>, pub deadline: Option<time::Instant>,
pub method: String,
pub proxy: Option<Proxy>,
#[cfg(feature = "tls")]
pub tls_config: Option<TLSClientConfig>,
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
pub tls_connector: Option<TLSConnector>,
} }
impl Unit { impl Unit {
@@ -99,26 +82,17 @@ impl Unit {
}; };
Unit { Unit {
agent: Arc::clone(&req.agent), req: req.clone(),
url: url.clone(), url: url.clone(),
is_chunked, is_chunked,
query_string, query_string,
headers, headers,
timeout_connect: req.timeout_connect,
timeout_read: req.timeout_read,
timeout_write: req.timeout_write,
deadline, deadline,
method: req.method.clone(),
proxy: req.proxy.clone(),
#[cfg(feature = "tls")]
tls_config: req.tls_config.clone(),
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
tls_connector: req.tls_connector.clone(),
} }
} }
pub fn is_head(&self) -> bool { pub fn is_head(&self) -> bool {
self.method.eq_ignore_ascii_case("head") self.req.method.eq_ignore_ascii_case("head")
} }
#[cfg(test)] #[cfg(test)]
@@ -217,8 +191,8 @@ pub(crate) fn connect(
let mut new_unit = Unit::new(req, &new_url, false, &empty); let mut new_unit = Unit::new(req, &new_url, false, &empty);
// this is to follow how curl does it. POST, PUT etc change // this is to follow how curl does it. POST, PUT etc change
// to GET on a redirect. // to GET on a redirect.
new_unit.method = match &unit.method[..] { new_unit.req.method = match &unit.req.method[..] {
"GET" | "HEAD" => unit.method, "GET" | "HEAD" => unit.req.method,
_ => "GET".into(), _ => "GET".into(),
}; };
return connect(req, new_unit, use_pooled, redirect_count + 1, empty, true); return connect(req, new_unit, use_pooled, redirect_count + 1, empty, true);
@@ -305,11 +279,11 @@ fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error
_ => return Err(Error::UnknownScheme(unit.url.scheme().to_string())), _ => return Err(Error::UnknownScheme(unit.url.scheme().to_string())),
}; };
if use_pooled { if use_pooled {
let state = &mut unit.agent.lock().unwrap(); let state = &mut unit.req.agent.lock().unwrap();
// The connection may have been closed by the server // The connection may have been closed by the server
// due to idle timeout while it was sitting in the pool. // due to idle timeout while it was sitting in the pool.
// Loop until we find one that is still good or run out of connections. // Loop until we find one that is still good or run out of connections.
while let Some(stream) = state.pool.try_get_connection(&unit.url, &unit.proxy) { while let Some(stream) = state.pool.try_get_connection(&unit.url, &unit.req.proxy) {
let server_closed = stream.server_closed()?; let server_closed = stream.server_closed()?;
if !server_closed { if !server_closed {
return Ok((stream, true)); return Ok((stream, true));
@@ -337,7 +311,7 @@ fn send_prelude(unit: &Unit, stream: &mut Stream, redir: bool) -> IoResult<()> {
write!( write!(
prelude, prelude,
"{} {}{} HTTP/1.1\r\n", "{} {}{} HTTP/1.1\r\n",
unit.method, unit.req.method,
unit.url.path(), unit.url.path(),
&unit.query_string &unit.query_string
)?; )?;
@@ -404,7 +378,7 @@ fn save_cookies(unit: &Unit, resp: &Response) {
} }
// only lock if we know there is something to process // only lock if we know there is something to process
let state = &mut unit.agent.lock().unwrap(); let state = &mut unit.req.agent.lock().unwrap();
for raw_cookie in cookies.iter() { for raw_cookie in cookies.iter() {
let to_parse = if raw_cookie.to_lowercase().contains("domain=") { let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
(*raw_cookie).to_string() (*raw_cookie).to_string()