diff --git a/src/agent.rs b/src/agent.rs index f104050..38dfb99 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -3,10 +3,11 @@ use cookie::Cookie; #[cfg(feature = "cookie")] use cookie_store::CookieStore; use std::sync::Arc; -use std::sync::Mutex; #[cfg(feature = "cookie")] use url::Url; +#[cfg(feature = "cookie")] +use crate::cookies::CookieTin; use crate::header::{self, Header}; use crate::pool::ConnectionPool; use crate::proxy::Proxy; @@ -60,12 +61,15 @@ impl Default for Agent { /// println!("Secret is: {}", secret.unwrap().into_string().unwrap()); /// } /// ``` +/// +/// Agent uses an inner Arc, so cloning an Agent results in an instance +/// that shares the same underlying connection pool and other state. #[derive(Debug, Clone)] pub struct Agent { /// Copied into each request of this agent. pub(crate) headers: Vec
, /// Reused agent state for repeated requests from this agent. - pub(crate) state: Arc>, + pub(crate) state: Arc, } /// Container of the state @@ -79,16 +83,10 @@ pub(crate) struct AgentState { /// Cookies saved between requests. /// Invariant: All cookies must have a nonempty domain and path. #[cfg(feature = "cookie")] - pub(crate) jar: CookieStore, + pub(crate) jar: CookieTin, pub(crate) resolver: ArcResolver, } -impl AgentState { - pub(crate) fn pool(&mut self) -> &mut ConnectionPool { - &mut self.pool - } -} - impl Agent { /// Request by providing the HTTP verb such as `GET`, `POST`... /// @@ -104,70 +102,21 @@ impl Agent { Request::new(&self, method.into(), path.into()) } - /// Gets a cookie in this agent by name. Cookies are available - /// either by setting it in the agent, or by making requests - /// that `Set-Cookie` in the agent. - /// - /// Note that this will return any cookie for the given name, - /// regardless of which host and path that cookie was set on. - /// - /// ``` - /// let agent = ureq::agent(); - /// - /// agent.get("http://www.google.com").call(); - /// - /// assert!(agent.cookie("NID").is_some()); - /// ``` - #[cfg(feature = "cookie")] - pub fn cookie(&self, name: &str) -> Option> { - let state = self.state.lock().unwrap(); - let first_found = state.jar.iter_any().find(|c| c.name() == name); - if let Some(first_found) = first_found { - let c: &Cookie = &*first_found; - Some(c.clone()) - } else { - None - } - } - - /// Set a cookie in this agent. - /// - /// Cookies without a domain, or with a malformed domain or path, - /// will be silently ignored. + /// Store a cookie in this agent. /// /// ``` /// let agent = ureq::agent(); /// /// let cookie = ureq::Cookie::build("name", "value") - /// .domain("example.com") - /// .path("/") /// .secure(true) /// .finish(); - /// agent.set_cookie(cookie); + /// agent.set_cookie(cookie, &"https://example.com/".parse().unwrap()); /// ``` #[cfg(feature = "cookie")] - pub fn set_cookie(&self, cookie: Cookie<'static>) { - let mut cookie = cookie.clone(); - if cookie.domain().is_none() { - return; - } - - if cookie.path().is_none() { - cookie.set_path("/"); - } - let path = cookie.path().unwrap(); - let domain = cookie.domain().unwrap(); - - let fake_url: Url = match format!("http://{}{}", domain, path).parse() { - Ok(u) => u, - Err(_) => return, - }; - let mut state = self.state.lock().unwrap(); - let cs_cookie = match cookie_store::Cookie::try_from_raw_cookie(&cookie, &fake_url) { - Ok(c) => c, - Err(_) => return, - }; - state.jar.insert(cs_cookie, &fake_url).ok(); + pub fn set_cookie(&self, cookie: Cookie<'static>, url: &Url) { + self.state + .jar + .store_response_cookies(Some(cookie).into_iter(), url); } /// Make a GET request from this agent. @@ -228,16 +177,16 @@ impl AgentBuilder { pub fn build(self) -> Agent { Agent { headers: self.headers.clone(), - state: Arc::new(Mutex::new(AgentState { + state: Arc::new(AgentState { pool: ConnectionPool::new( self.max_idle_connections, self.max_idle_connections_per_host, ), proxy: self.proxy.clone(), #[cfg(feature = "cookie")] - jar: self.jar, + jar: CookieTin::new(self.jar), resolver: self.resolver, - })), + }), } } @@ -298,6 +247,7 @@ impl AgentBuilder { let value = format!("{} {}", kind, pass); self.set("Authorization", &value) } + /// Sets the maximum number of connections allowed in the connection pool. /// By default, this is set to 100. Setting this to zero would disable /// connection pooling. @@ -396,8 +346,7 @@ mod tests { reader.read_to_end(&mut buf).unwrap(); fn poolsize(agent: &Agent) -> usize { - let mut state = agent.state.lock().unwrap(); - state.pool().len() + agent.state.pool.len() } assert_eq!(poolsize(&agent), 1); diff --git a/src/cookies.rs b/src/cookies.rs new file mode 100644 index 0000000..be6d10b --- /dev/null +++ b/src/cookies.rs @@ -0,0 +1,37 @@ +#[cfg(feature = "cookie")] +use std::sync::RwLock; + +#[cfg(feature = "cookie")] +use cookie_store::CookieStore; +#[cfg(feature = "cookie")] +use url::Url; + +#[cfg(feature = "cookie")] +#[derive(Default, Debug)] +pub(crate) struct CookieTin { + inner: RwLock, +} + +#[cfg(feature = "cookie")] +impl CookieTin { + pub(crate) fn new(store: CookieStore) -> Self { + CookieTin { + inner: RwLock::new(store), + } + } + pub(crate) fn get_request_cookies(&self, url: &Url) -> Vec { + let store = self.inner.read().unwrap(); + store + .get_request_cookies(url) + .map(|c| cookie::Cookie::new(c.name().to_owned(), c.value().to_owned())) + .collect() + } + + pub(crate) fn store_response_cookies(&self, cookies: I, url: &Url) + where + I: Iterator>, + { + let mut store = self.inner.write().unwrap(); + store.store_response_cookies(cookies, url) + } +} diff --git a/src/lib.rs b/src/lib.rs index 2c4eab2..7bcd231 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -103,6 +103,7 @@ mod agent; mod body; +mod cookies; mod error; mod header; mod pool; diff --git a/src/pool.rs b/src/pool.rs index 30ff980..3023e2b 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,6 +1,7 @@ use std::collections::hash_map::Entry; use std::collections::{HashMap, VecDeque}; use std::io::{self, Read}; +use std::sync::Mutex; use crate::stream::Stream; use crate::unit::Unit; @@ -34,8 +35,13 @@ const DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST: usize = 1; /// - The length of recycle[K] is less than or equal to max_idle_connections_per_host. /// /// *Internal API* -#[derive(Debug)] pub(crate) struct ConnectionPool { + inner: Mutex, + max_idle_connections: usize, + max_idle_connections_per_host: usize, +} + +struct Inner { // the actual pooled connection. however only one per hostname:port. recycle: HashMap>, // This is used to keep track of which streams to expire when the @@ -43,10 +49,17 @@ pub(crate) struct ConnectionPool { // recently used Streams are added to the back of the queue; // old streams are removed from the front. lru: VecDeque, - max_idle_connections: usize, - max_idle_connections_per_host: usize, } +impl fmt::Debug for ConnectionPool { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ConnectionPool") + .field("max_idle", &self.max_idle_connections) + .field("max_idle_per_host", &self.max_idle_connections_per_host) + .field("connections", &self.inner.lock().unwrap().lru.len()) + .finish() + } +} fn remove_first_match(list: &mut VecDeque, key: &PoolKey) -> Option { match list.iter().position(|x| x == key) { Some(i) => list.remove(i), @@ -66,8 +79,10 @@ impl Default for ConnectionPool { Self { max_idle_connections: DEFAULT_MAX_IDLE_CONNECTIONS, max_idle_connections_per_host: DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST, - recycle: HashMap::default(), - lru: VecDeque::default(), + inner: Mutex::new(Inner { + recycle: HashMap::default(), + lru: VecDeque::default(), + }), } } } @@ -75,8 +90,10 @@ impl Default for ConnectionPool { impl ConnectionPool { pub(crate) fn new(max_idle_connections: usize, max_idle_connections_per_host: usize) -> Self { ConnectionPool { - recycle: Default::default(), - lru: Default::default(), + inner: Mutex::new(Inner { + recycle: HashMap::default(), + lru: VecDeque::default(), + }), max_idle_connections, max_idle_connections_per_host, } @@ -88,13 +105,14 @@ impl ConnectionPool { } /// How the unit::connect tries to get a pooled connection. - pub fn try_get_connection(&mut self, url: &Url, proxy: &Option) -> Option { + pub fn try_get_connection(&self, url: &Url, proxy: &Option) -> Option { let key = PoolKey::new(url, proxy); self.remove(&key) } - fn remove(&mut self, key: &PoolKey) -> Option { - match self.recycle.entry(key.clone()) { + fn remove(&self, key: &PoolKey) -> Option { + let mut inner = self.inner.lock().unwrap(); + match inner.recycle.entry(key.clone()) { Entry::Occupied(mut occupied_entry) => { let streams = occupied_entry.get_mut(); // Take the newest stream. @@ -107,7 +125,7 @@ impl ConnectionPool { // Remove the newest matching PoolKey from self.lru. That // corresponds to the stream we just removed from `recycle`. - remove_last_match(&mut self.lru, &key) + remove_last_match(&mut inner.lru, &key) .expect("invariant failed: key in recycle but not in lru"); Some(stream) @@ -116,19 +134,20 @@ impl ConnectionPool { } } - fn add(&mut self, key: PoolKey, stream: Stream) { + fn add(&self, key: PoolKey, stream: Stream) { if self.noop() { return; } - match self.recycle.entry(key.clone()) { + let mut inner = self.inner.lock().unwrap(); + match inner.recycle.entry(key.clone()) { Entry::Occupied(mut occupied_entry) => { let streams = occupied_entry.get_mut(); streams.push_back(stream); if streams.len() > self.max_idle_connections_per_host { // Remove the oldest entry streams.pop_front(); - remove_first_match(&mut self.lru, &key) + remove_first_match(&mut inner.lru, &key) .expect("invariant failed: key in recycle but not in lru"); } } @@ -136,19 +155,21 @@ impl ConnectionPool { vacant_entry.insert(vec![stream].into()); } } - self.lru.push_back(key); - if self.lru.len() > self.max_idle_connections { + inner.lru.push_back(key); + if inner.lru.len() > self.max_idle_connections { + drop(inner); self.remove_oldest() } } /// Find the oldest stream in the pool. Remove its representation from lru, /// and the stream itself from `recycle`. Drops the stream, which closes it. - fn remove_oldest(&mut self) { + fn remove_oldest(&self) { assert!(!self.noop(), "remove_oldest called on Pool with max of 0"); - let key = self.lru.pop_front(); + let mut inner = self.inner.lock().unwrap(); + let key = inner.lru.pop_front(); let key = key.expect("tried to remove oldest but no entries found!"); - match self.recycle.entry(key) { + match inner.recycle.entry(key) { Entry::Occupied(mut occupied_entry) => { let streams = occupied_entry.get_mut(); streams @@ -164,7 +185,7 @@ impl ConnectionPool { #[cfg(test)] pub fn len(&self) -> usize { - self.lru.len() + self.inner.lock().unwrap().lru.len() } } @@ -212,7 +233,7 @@ fn pool_connections_limit() { // Test inserting connections with different keys into the pool, // filling and draining it. The pool should evict earlier connections // when the connection limit is reached. - let mut pool = ConnectionPool::default(); + let pool = ConnectionPool::default(); let hostnames = (0..DEFAULT_MAX_IDLE_CONNECTIONS * 2).map(|i| format!("{}.example", i)); let poolkeys = hostnames.map(|hostname| PoolKey { scheme: "https".to_string(), @@ -237,7 +258,7 @@ fn pool_per_host_connections_limit() { // Test inserting connections with the same key into the pool, // filling and draining it. The pool should evict earlier connections // when the per-host connection limit is reached. - let mut pool = ConnectionPool::default(); + let pool = ConnectionPool::default(); let poolkey = PoolKey { scheme: "https".to_string(), hostname: "example.com".to_string(), @@ -264,7 +285,7 @@ fn pool_per_host_connections_limit() { fn pool_checks_proxy() { // Test inserting different poolkeys with same address but different proxies. // Each insertion should result in an additional entry in the pool. - let mut pool = ConnectionPool::default(); + let pool = ConnectionPool::default(); let url = Url::parse("zzz:///example.com").unwrap(); pool.add( @@ -311,7 +332,6 @@ impl> PoolReturnRead { fn return_connection(&mut self) -> io::Result<()> { // guard we only do this once. if let (Some(unit), Some(reader)) = (self.unit.take(), self.reader.take()) { - let state = &mut unit.req.agent.lock().unwrap(); // bring back stream here to either go into pool or dealloc let mut stream = reader.into(); if !stream.is_poolable() { @@ -324,7 +344,7 @@ impl> PoolReturnRead { // insert back into pool let key = PoolKey::new(&unit.url, &unit.req.proxy); - state.pool().add(key, stream); + unit.req.agent.state.pool.add(key, stream); } Ok(()) diff --git a/src/request.rs b/src/request.rs index c431144..a534dea 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,12 +1,13 @@ use std::fmt; use std::io::Read; -use std::sync::{Arc, Mutex}; +#[cfg(any(feature = "native-tls", feature = "tls"))] +use std::sync::Arc; use std::time; use qstring::QString; use url::{form_urlencoded, Url}; -use crate::agent::{self, Agent, AgentState}; +use crate::agent::{self, Agent}; use crate::body::BodySize; use crate::body::{Payload, SizedReader}; use crate::error::Error; @@ -31,7 +32,7 @@ pub type Result = std::result::Result; /// ``` #[derive(Clone, Default)] pub struct Request { - pub(crate) agent: Arc>, + pub(crate) agent: Agent, // via agent pub(crate) method: String, @@ -73,7 +74,7 @@ impl fmt::Debug for Request { impl Request { pub(crate) fn new(agent: &Agent, method: String, url: String) -> Request { Request { - agent: Arc::clone(&agent.state), + agent: agent.clone(), method, url, headers: agent.headers.clone(), @@ -596,7 +597,7 @@ impl Request { pub(crate) fn proxy(&self) -> Option { if let Some(proxy) = &self.proxy { Some(proxy.clone()) - } else if let Some(proxy) = &self.agent.lock().unwrap().proxy { + } else if let Some(proxy) = &self.agent.state.proxy { Some(proxy.clone()) } else { None diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index 08700cb..3a510c5 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -53,10 +53,7 @@ fn connection_reuse() { assert_eq!(resp.status(), 200); resp.into_string().unwrap(); - { - let mut state = agent.state.lock().unwrap(); - assert!(state.pool().len() > 0); - } + assert!(agent.state.pool.len() > 0); // wait for the server to close the connection. std::thread::sleep(Duration::from_secs(3)); @@ -149,9 +146,14 @@ fn test_cookies_on_redirect() -> Result<(), Error> { let url = format!("http://localhost:{}/first", testserver.port); let agent = Agent::default(); agent.post(&url).call()?; - assert!(agent.cookie("first").is_some()); - assert!(agent.cookie("second").is_some()); - assert!(agent.cookie("third").is_some()); + let cookies = agent.state.jar.get_request_cookies( + &format!("https://localhost:{}/", testserver.port) + .parse() + .unwrap(), + ); + let mut cookie_names: Vec = cookies.iter().map(|c| c.name().to_string()).collect(); + cookie_names.sort(); + assert_eq!(cookie_names, vec!["first", "second", "third"]); Ok(()) } diff --git a/src/unit.rs b/src/unit.rs index 5ba1a8c..8e95a8a 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -8,12 +8,12 @@ use url::Url; #[cfg(feature = "cookie")] use cookie::Cookie; -#[cfg(feature = "cookie")] -use crate::agent::AgentState; use crate::body::{self, BodySize, Payload, SizedReader}; use crate::header; use crate::resolve::ArcResolver; use crate::stream::{self, connect_test, Stream}; +#[cfg(feature = "cookie")] +use crate::Agent; use crate::{Error, Header, Request, Response}; /// It's a "unit of work". Maybe a bad name for it? @@ -117,7 +117,7 @@ impl Unit { } pub fn resolver(&self) -> ArcResolver { - self.req.agent.lock().unwrap().resolver.clone() + self.req.agent.state.resolver.clone() } #[cfg(test)] @@ -256,12 +256,13 @@ pub(crate) fn connect( } #[cfg(feature = "cookie")] -fn extract_cookies(state: &std::sync::Mutex, url: &Url) -> Option
{ - let state = state.lock().unwrap(); - let header_value = state +fn extract_cookies(agent: &Agent, url: &Url) -> Option
{ + let header_value = agent + .state .jar .get_request_cookies(url) - .map(|c| Cookie::new(c.name(), c.value()).encoded().to_string()) + .iter() + .map(|c| c.encoded().to_string()) .collect::>() .join(";"); match header_value.as_str() { @@ -287,11 +288,15 @@ fn connect_socket(unit: &Unit, hostname: &str, use_pooled: bool) -> Result<(Stre _ => return Err(Error::UnknownScheme(unit.url.scheme().to_string())), }; if use_pooled { - let state = &mut unit.req.agent.lock().unwrap(); + let agent = &unit.req.agent; // The connection may have been closed by the server // due to idle timeout while it was sitting in the pool. // Loop until we find one that is still good or run out of connections. - while let Some(stream) = state.pool.try_get_connection(&unit.url, &unit.req.proxy) { + while let Some(stream) = agent + .state + .pool + .try_get_connection(&unit.url, &unit.req.proxy) + { let server_closed = stream.server_closed()?; if !server_closed { return Ok((stream, true)); @@ -389,8 +394,11 @@ fn save_cookies(unit: &Unit, resp: &Response) { Ok(c) => Some(c), } }); - let state = &mut unit.req.agent.lock().unwrap(); - state.jar.store_response_cookies(cookies, &unit.url.clone()); + unit.req + .agent + .state + .jar + .store_response_cookies(cookies, &unit.url.clone()); } #[cfg(test)] @@ -409,14 +417,12 @@ mod tests { let cookie2: Cookie = "cookie2=value2; Domain=crates.io; Path=/".parse().unwrap(); agent .state - .lock() - .unwrap() .jar .store_response_cookies(vec![cookie1, cookie2].into_iter(), &url); // There's no guarantee to the order in which cookies are defined. // Ensure that they're either in one order or the other. - let result = extract_cookies(&agent.state, &url); + let result = extract_cookies(&agent, &url); let order1 = "cookie1=value1;cookie2=value2"; let order2 = "cookie2=value2;cookie1=value1";