Add req field to Unit and remove cloned parts from request (#158)

Instead of cloning most of `Request`'s fields individually when
creating a `Unit`, this PR switches to just cloning `Request` and
stuffing it in `Unit`, and changes references to `unit.[field]` to
`unit.req.[field]` where appropriate.

Fixes #155
This commit is contained in:
Daniel Rivas
2020-09-26 18:22:10 +01:00
committed by GitHub
parent be9e3ca936
commit 8bba07a9af
4 changed files with 31 additions and 53 deletions

View File

@@ -394,7 +394,7 @@ impl<R: Read + Sized + Into<Stream>> PoolReturnRead<R> {
fn return_connection(&mut self) {
// guard we only do this once.
if let (Some(unit), Some(reader)) = (self.unit.take(), self.reader.take()) {
let state = &mut unit.agent.lock().unwrap();
let state = &mut unit.req.agent.lock().unwrap();
// bring back stream here to either go into pool or dealloc
let stream = reader.into();
if !stream.is_poolable() {
@@ -402,7 +402,7 @@ impl<R: Read + Sized + Into<Stream>> PoolReturnRead<R> {
return;
}
// insert back into pool
let key = PoolKey::new(&unit.url, &unit.proxy);
let key = PoolKey::new(&unit.url, &unit.req.proxy);
state.pool().add(key, stream);
}
}

View File

@@ -343,8 +343,12 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
let sni = webpki::DNSNameRef::try_from_ascii_str(hostname)
.map_err(|err| Error::DnsFailed(err.to_string()))?;
let tls_conf: &Arc<rustls::ClientConfig> =
unit.tls_config.as_ref().map(|c| &c.0).unwrap_or(&*TLS_CONF);
let tls_conf: &Arc<rustls::ClientConfig> = unit
.req
.tls_config
.as_ref()
.map(|c| &c.0)
.unwrap_or(&*TLS_CONF);
let sess = rustls::ClientSession::new(&tls_conf, sni);
let sock = connect_host(unit, hostname, port)?;
@@ -362,7 +366,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
let port = unit.url.port().unwrap_or(443);
let sock = connect_host(unit, hostname, port)?;
let tls_connector: Arc<native_tls::TlsConnector> = match &unit.tls_connector {
let tls_connector: Arc<native_tls::TlsConnector> = match &unit.req.tls_connector {
Some(connector) => connector.0.clone(),
None => Arc::new(native_tls::TlsConnector::new().map_err(|e| Error::TlsError(e))?),
};
@@ -377,14 +381,14 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
}
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
let deadline: Option<Instant> = if unit.timeout_connect > 0 {
Instant::now().checked_add(Duration::from_millis(unit.timeout_connect))
let deadline: Option<Instant> = if unit.req.timeout_connect > 0 {
Instant::now().checked_add(Duration::from_millis(unit.req.timeout_connect))
} else {
unit.deadline
};
// TODO: Find a way to apply deadline to DNS lookup.
let sock_addrs: Vec<SocketAddr> = match unit.proxy {
let sock_addrs: Vec<SocketAddr> = match unit.req.proxy {
Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port),
None => format!("{}:{}", hostname, port),
}
@@ -396,7 +400,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
}
let proto = if let Some(ref proxy) = unit.proxy {
let proto = if let Some(ref proxy) = unit.req.proxy {
Some(proxy.proto)
} else {
None
@@ -415,7 +419,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
// connect with a configured timeout.
let stream = if Some(Proto::SOCKS5) == proto {
connect_socks5(
unit.proxy.to_owned().unwrap(),
unit.req.proxy.to_owned().unwrap(),
deadline,
sock_addr,
hostname,
@@ -448,9 +452,9 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
stream
.set_read_timeout(Some(time_until_deadline(deadline)?))
.ok();
} else if unit.timeout_read > 0 {
} else if unit.req.timeout_read > 0 {
stream
.set_read_timeout(Some(Duration::from_millis(unit.timeout_read as u64)))
.set_read_timeout(Some(Duration::from_millis(unit.req.timeout_read as u64)))
.ok();
} else {
stream.set_read_timeout(None).ok();
@@ -460,16 +464,16 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
stream
.set_write_timeout(Some(time_until_deadline(deadline)?))
.ok();
} else if unit.timeout_write > 0 {
} else if unit.req.timeout_write > 0 {
stream
.set_write_timeout(Some(Duration::from_millis(unit.timeout_write as u64)))
.set_write_timeout(Some(Duration::from_millis(unit.req.timeout_write as u64)))
.ok();
} else {
stream.set_write_timeout(None).ok();
}
if proto == Some(Proto::HTTPConnect) {
if let Some(ref proxy) = unit.proxy {
if let Some(ref proxy) = unit.req.proxy {
write!(stream, "{}", proxy.connect(hostname, port)).unwrap();
stream.flush()?;

View File

@@ -46,7 +46,7 @@ fn redirect_head() {
test::make_response(302, "Go here", vec!["Location: /redirect_head2"], vec![])
});
test::set_handler("/redirect_head2", |unit| {
assert_eq!(unit.method, "HEAD");
assert_eq!(unit.req.method, "HEAD");
test::make_response(200, "OK", vec!["x-foo: bar"], vec![])
});
let resp = head("test://host/redirect_head1").call();
@@ -62,7 +62,7 @@ fn redirect_get() {
test::make_response(302, "Go here", vec!["Location: /redirect_get2"], vec![])
});
test::set_handler("/redirect_get2", |unit| {
assert_eq!(unit.method, "GET");
assert_eq!(unit.req.method, "GET");
assert!(unit.has("Range"));
assert_eq!(unit.header("Range").unwrap(), "bytes=10-50");
test::make_response(200, "OK", vec!["x-foo: bar"], vec![])
@@ -82,7 +82,7 @@ fn redirect_post() {
test::make_response(302, "Go here", vec!["Location: /redirect_post2"], vec![])
});
test::set_handler("/redirect_post2", |unit| {
assert_eq!(unit.method, "GET");
assert_eq!(unit.req.method, "GET");
test::make_response(200, "OK", vec!["x-foo: bar"], vec![])
});
let resp = post("test://host/redirect_post1").call();

View File

@@ -1,5 +1,4 @@
use std::io::{Result as IoResult, Write};
use std::sync::{Arc, Mutex};
use std::time;
use qstring::QString;
@@ -12,15 +11,8 @@ use crate::agent::AgentState;
use crate::body::{self, Payload, SizedReader};
use crate::header;
use crate::stream::{self, connect_test, Stream};
use crate::Proxy;
use crate::{Error, Header, Request, Response};
#[cfg(feature = "tls")]
use crate::request::TLSClientConfig;
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
use crate::request::TLSConnector;
#[cfg(feature = "cookie")]
use crate::pool::DEFAULT_HOST;
@@ -28,21 +20,12 @@ use crate::pool::DEFAULT_HOST;
///
/// *Internal API*
pub(crate) struct Unit {
pub agent: Arc<Mutex<AgentState>>,
pub req: Request,
pub url: Url,
pub is_chunked: bool,
pub query_string: String,
pub headers: Vec<Header>,
pub timeout_connect: u64,
pub timeout_read: u64,
pub timeout_write: u64,
pub deadline: Option<time::Instant>,
pub method: String,
pub proxy: Option<Proxy>,
#[cfg(feature = "tls")]
pub tls_config: Option<TLSClientConfig>,
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
pub tls_connector: Option<TLSConnector>,
}
impl Unit {
@@ -99,26 +82,17 @@ impl Unit {
};
Unit {
agent: Arc::clone(&req.agent),
req: req.clone(),
url: url.clone(),
is_chunked,
query_string,
headers,
timeout_connect: req.timeout_connect,
timeout_read: req.timeout_read,
timeout_write: req.timeout_write,
deadline,
method: req.method.clone(),
proxy: req.proxy.clone(),
#[cfg(feature = "tls")]
tls_config: req.tls_config.clone(),
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
tls_connector: req.tls_connector.clone(),
}
}
pub fn is_head(&self) -> bool {
self.method.eq_ignore_ascii_case("head")
self.req.method.eq_ignore_ascii_case("head")
}
#[cfg(test)]
@@ -217,8 +191,8 @@ pub(crate) fn connect(
let mut new_unit = Unit::new(req, &new_url, false, &empty);
// this is to follow how curl does it. POST, PUT etc change
// to GET on a redirect.
new_unit.method = match &unit.method[..] {
"GET" | "HEAD" => unit.method,
new_unit.req.method = match &unit.req.method[..] {
"GET" | "HEAD" => unit.req.method,
_ => "GET".into(),
};
return connect(req, new_unit, use_pooled, redirect_count + 1, empty, true);
@@ -305,11 +279,11 @@ fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error
_ => return Err(Error::UnknownScheme(unit.url.scheme().to_string())),
};
if use_pooled {
let state = &mut unit.agent.lock().unwrap();
let state = &mut unit.req.agent.lock().unwrap();
// The connection may have been closed by the server
// due to idle timeout while it was sitting in the pool.
// Loop until we find one that is still good or run out of connections.
while let Some(stream) = state.pool.try_get_connection(&unit.url, &unit.proxy) {
while let Some(stream) = state.pool.try_get_connection(&unit.url, &unit.req.proxy) {
let server_closed = stream.server_closed()?;
if !server_closed {
return Ok((stream, true));
@@ -337,7 +311,7 @@ fn send_prelude(unit: &Unit, stream: &mut Stream, redir: bool) -> IoResult<()> {
write!(
prelude,
"{} {}{} HTTP/1.1\r\n",
unit.method,
unit.req.method,
unit.url.path(),
&unit.query_string
)?;
@@ -404,7 +378,7 @@ fn save_cookies(unit: &Unit, resp: &Response) {
}
// only lock if we know there is something to process
let state = &mut unit.agent.lock().unwrap();
let state = &mut unit.req.agent.lock().unwrap();
for raw_cookie in cookies.iter() {
let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
(*raw_cookie).to_string()