From 8bba07a9afec72950951c53b114da7c99fba9f31 Mon Sep 17 00:00:00 2001 From: Daniel Rivas Date: Sat, 26 Sep 2020 18:22:10 +0100 Subject: [PATCH] Add `req` field to Unit and remove cloned parts from request (#158) Instead of cloning most of `Request`'s fields individually when creating a `Unit`, this PR switches to just cloning `Request` and stuffing it in `Unit`, and changes references to `unit.[field]` to `unit.req.[field]` where appropriate. Fixes #155 --- src/pool.rs | 4 ++-- src/stream.rs | 30 +++++++++++++++++------------- src/test/redirect.rs | 6 +++--- src/unit.rs | 44 +++++++++----------------------------------- 4 files changed, 31 insertions(+), 53 deletions(-) diff --git a/src/pool.rs b/src/pool.rs index 4663647..571867f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -394,7 +394,7 @@ impl> PoolReturnRead { fn return_connection(&mut self) { // guard we only do this once. if let (Some(unit), Some(reader)) = (self.unit.take(), self.reader.take()) { - let state = &mut unit.agent.lock().unwrap(); + let state = &mut unit.req.agent.lock().unwrap(); // bring back stream here to either go into pool or dealloc let stream = reader.into(); if !stream.is_poolable() { @@ -402,7 +402,7 @@ impl> PoolReturnRead { return; } // insert back into pool - let key = PoolKey::new(&unit.url, &unit.proxy); + let key = PoolKey::new(&unit.url, &unit.req.proxy); state.pool().add(key, stream); } } diff --git a/src/stream.rs b/src/stream.rs index b0f8309..7b42a6c 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -343,8 +343,12 @@ pub(crate) fn connect_https(unit: &Unit) -> Result { let sni = webpki::DNSNameRef::try_from_ascii_str(hostname) .map_err(|err| Error::DnsFailed(err.to_string()))?; - let tls_conf: &Arc = - unit.tls_config.as_ref().map(|c| &c.0).unwrap_or(&*TLS_CONF); + let tls_conf: &Arc = unit + .req + .tls_config + .as_ref() + .map(|c| &c.0) + .unwrap_or(&*TLS_CONF); let sess = rustls::ClientSession::new(&tls_conf, sni); let sock = connect_host(unit, hostname, port)?; @@ -362,7 +366,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result { let port = unit.url.port().unwrap_or(443); let sock = connect_host(unit, hostname, port)?; - let tls_connector: Arc = match &unit.tls_connector { + let tls_connector: Arc = match &unit.req.tls_connector { Some(connector) => connector.0.clone(), None => Arc::new(native_tls::TlsConnector::new().map_err(|e| Error::TlsError(e))?), }; @@ -377,14 +381,14 @@ pub(crate) fn connect_https(unit: &Unit) -> Result { } pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result { - let deadline: Option = if unit.timeout_connect > 0 { - Instant::now().checked_add(Duration::from_millis(unit.timeout_connect)) + let deadline: Option = if unit.req.timeout_connect > 0 { + Instant::now().checked_add(Duration::from_millis(unit.req.timeout_connect)) } else { unit.deadline }; // TODO: Find a way to apply deadline to DNS lookup. - let sock_addrs: Vec = match unit.proxy { + let sock_addrs: Vec = match unit.req.proxy { Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port), None => format!("{}:{}", hostname, port), } @@ -396,7 +400,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result Result Result 0 { + } else if unit.req.timeout_read > 0 { stream - .set_read_timeout(Some(Duration::from_millis(unit.timeout_read as u64))) + .set_read_timeout(Some(Duration::from_millis(unit.req.timeout_read as u64))) .ok(); } else { stream.set_read_timeout(None).ok(); @@ -460,16 +464,16 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result 0 { + } else if unit.req.timeout_write > 0 { stream - .set_write_timeout(Some(Duration::from_millis(unit.timeout_write as u64))) + .set_write_timeout(Some(Duration::from_millis(unit.req.timeout_write as u64))) .ok(); } else { stream.set_write_timeout(None).ok(); } if proto == Some(Proto::HTTPConnect) { - if let Some(ref proxy) = unit.proxy { + if let Some(ref proxy) = unit.req.proxy { write!(stream, "{}", proxy.connect(hostname, port)).unwrap(); stream.flush()?; diff --git a/src/test/redirect.rs b/src/test/redirect.rs index 705c672..03639e2 100644 --- a/src/test/redirect.rs +++ b/src/test/redirect.rs @@ -46,7 +46,7 @@ fn redirect_head() { test::make_response(302, "Go here", vec!["Location: /redirect_head2"], vec![]) }); test::set_handler("/redirect_head2", |unit| { - assert_eq!(unit.method, "HEAD"); + assert_eq!(unit.req.method, "HEAD"); test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) }); let resp = head("test://host/redirect_head1").call(); @@ -62,7 +62,7 @@ fn redirect_get() { test::make_response(302, "Go here", vec!["Location: /redirect_get2"], vec![]) }); test::set_handler("/redirect_get2", |unit| { - assert_eq!(unit.method, "GET"); + assert_eq!(unit.req.method, "GET"); assert!(unit.has("Range")); assert_eq!(unit.header("Range").unwrap(), "bytes=10-50"); test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) @@ -82,7 +82,7 @@ fn redirect_post() { test::make_response(302, "Go here", vec!["Location: /redirect_post2"], vec![]) }); test::set_handler("/redirect_post2", |unit| { - assert_eq!(unit.method, "GET"); + assert_eq!(unit.req.method, "GET"); test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) }); let resp = post("test://host/redirect_post1").call(); diff --git a/src/unit.rs b/src/unit.rs index 1edc634..13d03a8 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -1,5 +1,4 @@ use std::io::{Result as IoResult, Write}; -use std::sync::{Arc, Mutex}; use std::time; use qstring::QString; @@ -12,15 +11,8 @@ use crate::agent::AgentState; use crate::body::{self, Payload, SizedReader}; use crate::header; use crate::stream::{self, connect_test, Stream}; -use crate::Proxy; use crate::{Error, Header, Request, Response}; -#[cfg(feature = "tls")] -use crate::request::TLSClientConfig; - -#[cfg(all(feature = "native-tls", not(feature = "tls")))] -use crate::request::TLSConnector; - #[cfg(feature = "cookie")] use crate::pool::DEFAULT_HOST; @@ -28,21 +20,12 @@ use crate::pool::DEFAULT_HOST; /// /// *Internal API* pub(crate) struct Unit { - pub agent: Arc>, + pub req: Request, pub url: Url, pub is_chunked: bool, pub query_string: String, pub headers: Vec
, - pub timeout_connect: u64, - pub timeout_read: u64, - pub timeout_write: u64, pub deadline: Option, - pub method: String, - pub proxy: Option, - #[cfg(feature = "tls")] - pub tls_config: Option, - #[cfg(all(feature = "native-tls", not(feature = "tls")))] - pub tls_connector: Option, } impl Unit { @@ -99,26 +82,17 @@ impl Unit { }; Unit { - agent: Arc::clone(&req.agent), + req: req.clone(), url: url.clone(), is_chunked, query_string, headers, - timeout_connect: req.timeout_connect, - timeout_read: req.timeout_read, - timeout_write: req.timeout_write, deadline, - method: req.method.clone(), - proxy: req.proxy.clone(), - #[cfg(feature = "tls")] - tls_config: req.tls_config.clone(), - #[cfg(all(feature = "native-tls", not(feature = "tls")))] - tls_connector: req.tls_connector.clone(), } } pub fn is_head(&self) -> bool { - self.method.eq_ignore_ascii_case("head") + self.req.method.eq_ignore_ascii_case("head") } #[cfg(test)] @@ -217,8 +191,8 @@ pub(crate) fn connect( let mut new_unit = Unit::new(req, &new_url, false, &empty); // this is to follow how curl does it. POST, PUT etc change // to GET on a redirect. - new_unit.method = match &unit.method[..] { - "GET" | "HEAD" => unit.method, + new_unit.req.method = match &unit.req.method[..] { + "GET" | "HEAD" => unit.req.method, _ => "GET".into(), }; return connect(req, new_unit, use_pooled, redirect_count + 1, empty, true); @@ -305,11 +279,11 @@ fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error _ => return Err(Error::UnknownScheme(unit.url.scheme().to_string())), }; if use_pooled { - let state = &mut unit.agent.lock().unwrap(); + let state = &mut unit.req.agent.lock().unwrap(); // The connection may have been closed by the server // due to idle timeout while it was sitting in the pool. // Loop until we find one that is still good or run out of connections. - while let Some(stream) = state.pool.try_get_connection(&unit.url, &unit.proxy) { + while let Some(stream) = state.pool.try_get_connection(&unit.url, &unit.req.proxy) { let server_closed = stream.server_closed()?; if !server_closed { return Ok((stream, true)); @@ -337,7 +311,7 @@ fn send_prelude(unit: &Unit, stream: &mut Stream, redir: bool) -> IoResult<()> { write!( prelude, "{} {}{} HTTP/1.1\r\n", - unit.method, + unit.req.method, unit.url.path(), &unit.query_string )?; @@ -404,7 +378,7 @@ fn save_cookies(unit: &Unit, resp: &Response) { } // only lock if we know there is something to process - let state = &mut unit.agent.lock().unwrap(); + let state = &mut unit.req.agent.lock().unwrap(); for raw_cookie in cookies.iter() { let to_parse = if raw_cookie.to_lowercase().contains("domain=") { (*raw_cookie).to_string()