Add req field to Unit and remove cloned parts from request (#158)
Instead of cloning most of `Request`'s fields individually when creating a `Unit`, this PR switches to just cloning `Request` and stuffing it in `Unit`, and changes references to `unit.[field]` to `unit.req.[field]` where appropriate. Fixes #155
This commit is contained in:
@@ -394,7 +394,7 @@ impl<R: Read + Sized + Into<Stream>> PoolReturnRead<R> {
|
||||
fn return_connection(&mut self) {
|
||||
// guard we only do this once.
|
||||
if let (Some(unit), Some(reader)) = (self.unit.take(), self.reader.take()) {
|
||||
let state = &mut unit.agent.lock().unwrap();
|
||||
let state = &mut unit.req.agent.lock().unwrap();
|
||||
// bring back stream here to either go into pool or dealloc
|
||||
let stream = reader.into();
|
||||
if !stream.is_poolable() {
|
||||
@@ -402,7 +402,7 @@ impl<R: Read + Sized + Into<Stream>> PoolReturnRead<R> {
|
||||
return;
|
||||
}
|
||||
// insert back into pool
|
||||
let key = PoolKey::new(&unit.url, &unit.proxy);
|
||||
let key = PoolKey::new(&unit.url, &unit.req.proxy);
|
||||
state.pool().add(key, stream);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,8 +343,12 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
|
||||
|
||||
let sni = webpki::DNSNameRef::try_from_ascii_str(hostname)
|
||||
.map_err(|err| Error::DnsFailed(err.to_string()))?;
|
||||
let tls_conf: &Arc<rustls::ClientConfig> =
|
||||
unit.tls_config.as_ref().map(|c| &c.0).unwrap_or(&*TLS_CONF);
|
||||
let tls_conf: &Arc<rustls::ClientConfig> = unit
|
||||
.req
|
||||
.tls_config
|
||||
.as_ref()
|
||||
.map(|c| &c.0)
|
||||
.unwrap_or(&*TLS_CONF);
|
||||
let sess = rustls::ClientSession::new(&tls_conf, sni);
|
||||
|
||||
let sock = connect_host(unit, hostname, port)?;
|
||||
@@ -362,7 +366,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
|
||||
let port = unit.url.port().unwrap_or(443);
|
||||
let sock = connect_host(unit, hostname, port)?;
|
||||
|
||||
let tls_connector: Arc<native_tls::TlsConnector> = match &unit.tls_connector {
|
||||
let tls_connector: Arc<native_tls::TlsConnector> = match &unit.req.tls_connector {
|
||||
Some(connector) => connector.0.clone(),
|
||||
None => Arc::new(native_tls::TlsConnector::new().map_err(|e| Error::TlsError(e))?),
|
||||
};
|
||||
@@ -377,14 +381,14 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
|
||||
}
|
||||
|
||||
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
|
||||
let deadline: Option<Instant> = if unit.timeout_connect > 0 {
|
||||
Instant::now().checked_add(Duration::from_millis(unit.timeout_connect))
|
||||
let deadline: Option<Instant> = if unit.req.timeout_connect > 0 {
|
||||
Instant::now().checked_add(Duration::from_millis(unit.req.timeout_connect))
|
||||
} else {
|
||||
unit.deadline
|
||||
};
|
||||
|
||||
// TODO: Find a way to apply deadline to DNS lookup.
|
||||
let sock_addrs: Vec<SocketAddr> = match unit.proxy {
|
||||
let sock_addrs: Vec<SocketAddr> = match unit.req.proxy {
|
||||
Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port),
|
||||
None => format!("{}:{}", hostname, port),
|
||||
}
|
||||
@@ -396,7 +400,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
||||
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
|
||||
}
|
||||
|
||||
let proto = if let Some(ref proxy) = unit.proxy {
|
||||
let proto = if let Some(ref proxy) = unit.req.proxy {
|
||||
Some(proxy.proto)
|
||||
} else {
|
||||
None
|
||||
@@ -415,7 +419,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
||||
// connect with a configured timeout.
|
||||
let stream = if Some(Proto::SOCKS5) == proto {
|
||||
connect_socks5(
|
||||
unit.proxy.to_owned().unwrap(),
|
||||
unit.req.proxy.to_owned().unwrap(),
|
||||
deadline,
|
||||
sock_addr,
|
||||
hostname,
|
||||
@@ -448,9 +452,9 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
||||
stream
|
||||
.set_read_timeout(Some(time_until_deadline(deadline)?))
|
||||
.ok();
|
||||
} else if unit.timeout_read > 0 {
|
||||
} else if unit.req.timeout_read > 0 {
|
||||
stream
|
||||
.set_read_timeout(Some(Duration::from_millis(unit.timeout_read as u64)))
|
||||
.set_read_timeout(Some(Duration::from_millis(unit.req.timeout_read as u64)))
|
||||
.ok();
|
||||
} else {
|
||||
stream.set_read_timeout(None).ok();
|
||||
@@ -460,16 +464,16 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
||||
stream
|
||||
.set_write_timeout(Some(time_until_deadline(deadline)?))
|
||||
.ok();
|
||||
} else if unit.timeout_write > 0 {
|
||||
} else if unit.req.timeout_write > 0 {
|
||||
stream
|
||||
.set_write_timeout(Some(Duration::from_millis(unit.timeout_write as u64)))
|
||||
.set_write_timeout(Some(Duration::from_millis(unit.req.timeout_write as u64)))
|
||||
.ok();
|
||||
} else {
|
||||
stream.set_write_timeout(None).ok();
|
||||
}
|
||||
|
||||
if proto == Some(Proto::HTTPConnect) {
|
||||
if let Some(ref proxy) = unit.proxy {
|
||||
if let Some(ref proxy) = unit.req.proxy {
|
||||
write!(stream, "{}", proxy.connect(hostname, port)).unwrap();
|
||||
stream.flush()?;
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ fn redirect_head() {
|
||||
test::make_response(302, "Go here", vec!["Location: /redirect_head2"], vec![])
|
||||
});
|
||||
test::set_handler("/redirect_head2", |unit| {
|
||||
assert_eq!(unit.method, "HEAD");
|
||||
assert_eq!(unit.req.method, "HEAD");
|
||||
test::make_response(200, "OK", vec!["x-foo: bar"], vec![])
|
||||
});
|
||||
let resp = head("test://host/redirect_head1").call();
|
||||
@@ -62,7 +62,7 @@ fn redirect_get() {
|
||||
test::make_response(302, "Go here", vec!["Location: /redirect_get2"], vec![])
|
||||
});
|
||||
test::set_handler("/redirect_get2", |unit| {
|
||||
assert_eq!(unit.method, "GET");
|
||||
assert_eq!(unit.req.method, "GET");
|
||||
assert!(unit.has("Range"));
|
||||
assert_eq!(unit.header("Range").unwrap(), "bytes=10-50");
|
||||
test::make_response(200, "OK", vec!["x-foo: bar"], vec![])
|
||||
@@ -82,7 +82,7 @@ fn redirect_post() {
|
||||
test::make_response(302, "Go here", vec!["Location: /redirect_post2"], vec![])
|
||||
});
|
||||
test::set_handler("/redirect_post2", |unit| {
|
||||
assert_eq!(unit.method, "GET");
|
||||
assert_eq!(unit.req.method, "GET");
|
||||
test::make_response(200, "OK", vec!["x-foo: bar"], vec![])
|
||||
});
|
||||
let resp = post("test://host/redirect_post1").call();
|
||||
|
||||
44
src/unit.rs
44
src/unit.rs
@@ -1,5 +1,4 @@
|
||||
use std::io::{Result as IoResult, Write};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time;
|
||||
|
||||
use qstring::QString;
|
||||
@@ -12,15 +11,8 @@ use crate::agent::AgentState;
|
||||
use crate::body::{self, Payload, SizedReader};
|
||||
use crate::header;
|
||||
use crate::stream::{self, connect_test, Stream};
|
||||
use crate::Proxy;
|
||||
use crate::{Error, Header, Request, Response};
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use crate::request::TLSClientConfig;
|
||||
|
||||
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
|
||||
use crate::request::TLSConnector;
|
||||
|
||||
#[cfg(feature = "cookie")]
|
||||
use crate::pool::DEFAULT_HOST;
|
||||
|
||||
@@ -28,21 +20,12 @@ use crate::pool::DEFAULT_HOST;
|
||||
///
|
||||
/// *Internal API*
|
||||
pub(crate) struct Unit {
|
||||
pub agent: Arc<Mutex<AgentState>>,
|
||||
pub req: Request,
|
||||
pub url: Url,
|
||||
pub is_chunked: bool,
|
||||
pub query_string: String,
|
||||
pub headers: Vec<Header>,
|
||||
pub timeout_connect: u64,
|
||||
pub timeout_read: u64,
|
||||
pub timeout_write: u64,
|
||||
pub deadline: Option<time::Instant>,
|
||||
pub method: String,
|
||||
pub proxy: Option<Proxy>,
|
||||
#[cfg(feature = "tls")]
|
||||
pub tls_config: Option<TLSClientConfig>,
|
||||
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
|
||||
pub tls_connector: Option<TLSConnector>,
|
||||
}
|
||||
|
||||
impl Unit {
|
||||
@@ -99,26 +82,17 @@ impl Unit {
|
||||
};
|
||||
|
||||
Unit {
|
||||
agent: Arc::clone(&req.agent),
|
||||
req: req.clone(),
|
||||
url: url.clone(),
|
||||
is_chunked,
|
||||
query_string,
|
||||
headers,
|
||||
timeout_connect: req.timeout_connect,
|
||||
timeout_read: req.timeout_read,
|
||||
timeout_write: req.timeout_write,
|
||||
deadline,
|
||||
method: req.method.clone(),
|
||||
proxy: req.proxy.clone(),
|
||||
#[cfg(feature = "tls")]
|
||||
tls_config: req.tls_config.clone(),
|
||||
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
|
||||
tls_connector: req.tls_connector.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_head(&self) -> bool {
|
||||
self.method.eq_ignore_ascii_case("head")
|
||||
self.req.method.eq_ignore_ascii_case("head")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -217,8 +191,8 @@ pub(crate) fn connect(
|
||||
let mut new_unit = Unit::new(req, &new_url, false, &empty);
|
||||
// this is to follow how curl does it. POST, PUT etc change
|
||||
// to GET on a redirect.
|
||||
new_unit.method = match &unit.method[..] {
|
||||
"GET" | "HEAD" => unit.method,
|
||||
new_unit.req.method = match &unit.req.method[..] {
|
||||
"GET" | "HEAD" => unit.req.method,
|
||||
_ => "GET".into(),
|
||||
};
|
||||
return connect(req, new_unit, use_pooled, redirect_count + 1, empty, true);
|
||||
@@ -305,11 +279,11 @@ fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error
|
||||
_ => return Err(Error::UnknownScheme(unit.url.scheme().to_string())),
|
||||
};
|
||||
if use_pooled {
|
||||
let state = &mut unit.agent.lock().unwrap();
|
||||
let state = &mut unit.req.agent.lock().unwrap();
|
||||
// The connection may have been closed by the server
|
||||
// due to idle timeout while it was sitting in the pool.
|
||||
// Loop until we find one that is still good or run out of connections.
|
||||
while let Some(stream) = state.pool.try_get_connection(&unit.url, &unit.proxy) {
|
||||
while let Some(stream) = state.pool.try_get_connection(&unit.url, &unit.req.proxy) {
|
||||
let server_closed = stream.server_closed()?;
|
||||
if !server_closed {
|
||||
return Ok((stream, true));
|
||||
@@ -337,7 +311,7 @@ fn send_prelude(unit: &Unit, stream: &mut Stream, redir: bool) -> IoResult<()> {
|
||||
write!(
|
||||
prelude,
|
||||
"{} {}{} HTTP/1.1\r\n",
|
||||
unit.method,
|
||||
unit.req.method,
|
||||
unit.url.path(),
|
||||
&unit.query_string
|
||||
)?;
|
||||
@@ -404,7 +378,7 @@ fn save_cookies(unit: &Unit, resp: &Response) {
|
||||
}
|
||||
|
||||
// only lock if we know there is something to process
|
||||
let state = &mut unit.agent.lock().unwrap();
|
||||
let state = &mut unit.req.agent.lock().unwrap();
|
||||
for raw_cookie in cookies.iter() {
|
||||
let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
|
||||
(*raw_cookie).to_string()
|
||||
|
||||
Reference in New Issue
Block a user