diff --git a/src/agent.rs b/src/agent.rs index 5d2ef15..b7a43ee 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -166,6 +166,38 @@ impl Agent { Request::new(&self, method.into(), path.into()) } + /// Sets the maximum number of connections allowed in the connection pool. + /// By default, this is set to 100. Setting this to zero would disable + /// connection pooling. + /// + /// ``` + /// let agent = ureq::agent(); + /// agent.set_max_pool_connections(200); + /// ``` + pub fn set_max_pool_connections(&self, max_connections: usize) { + let mut optional_state = self.state.lock().unwrap(); + if let Some(state) = optional_state.as_mut() { + state.pool.set_max_idle_connections(max_connections); + } + } + + /// Sets the maximum number of connections per host to keep in the + /// connection pool. By default, this is set to 1. Setting this to zero + /// would disable connection pooling. + /// + /// ``` + /// let agent = ureq::agent(); + /// agent.set_max_pool_connections_per_host(10); + /// ``` + pub fn set_max_pool_connections_per_host(&self, max_connections: usize) { + let mut optional_state = self.state.lock().unwrap(); + if let Some(state) = optional_state.as_mut() { + state + .pool + .set_max_idle_connections_per_host(max_connections); + } + } + /// Gets a cookie in this agent by name. Cookies are available /// either by setting it in the agent, or by making requests /// that `Set-Cookie` in the agent. diff --git a/src/pool.rs b/src/pool.rs index d9ca195..ae2dd3f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,3 +1,4 @@ +use std::collections::hash_map::Entry; use std::collections::{HashMap, VecDeque}; use std::io::{Read, Result as IoResult}; @@ -8,7 +9,8 @@ use crate::Proxy; use url::Url; pub const DEFAULT_HOST: &str = "localhost"; -const MAX_IDLE_CONNECTIONS: usize = 100; +const DEFAULT_MAX_IDLE_CONNECTIONS: usize = 100; +const DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST: usize = 1; /// Holder of recycled connections. /// @@ -20,21 +22,71 @@ const MAX_IDLE_CONNECTIONS: usize = 100; #[derive(Default, Debug)] pub(crate) struct ConnectionPool { // the actual pooled connection. however only one per hostname:port. - recycle: HashMap, + recycle: HashMap>, // This is used to keep track of which streams to expire when the // pool reaches MAX_IDLE_CONNECTIONS. The corresponding PoolKeys for // recently used Streams are added to the back of the queue; // old streams are removed from the front. lru: VecDeque, + max_idle_connections: usize, + max_idle_connections_per_host: usize, } impl ConnectionPool { pub fn new() -> Self { ConnectionPool { + max_idle_connections: DEFAULT_MAX_IDLE_CONNECTIONS, + max_idle_connections_per_host: DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST, ..Default::default() } } + pub fn set_max_idle_connections(&mut self, max_connections: usize) { + if self.max_idle_connections == max_connections { + return; + } + self.max_idle_connections = max_connections; + + if max_connections == 0 { + // Clear the connection pool, caching is disabled. + self.lru.clear(); + self.recycle.clear(); + return; + } + + // Remove any extra connections if the number was decreased. + while self.lru.len() > max_connections { + self.remove_oldest(); + } + } + + pub fn set_max_idle_connections_per_host(&mut self, max_connections: usize) { + if self.max_idle_connections_per_host == max_connections { + return; + } + self.max_idle_connections_per_host = max_connections; + + if max_connections == 0 { + // Clear the connection pool, caching is disabled. + self.lru.clear(); + self.recycle.clear(); + return; + } + + // Remove any extra streams if the number was decreased. + for (key, val) in self.recycle.iter_mut() { + while val.len() > max_connections { + val.pop_front(); + let index = self + .lru + .iter() + .position(|x| x == key) + .expect("PoolKey not found in lru"); + self.lru.remove(index); + } + } + } + /// How the unit::connect tries to get a pooled connection. pub fn try_get_connection(&mut self, url: &Url, proxy: &Option) -> Option { let key = PoolKey::new(url, proxy); @@ -42,36 +94,85 @@ impl ConnectionPool { } fn remove(&mut self, key: &PoolKey) -> Option { - if !self.recycle.contains_key(&key) { - return None; + match self.recycle.entry(key.clone()) { + Entry::Occupied(mut occupied_entry) => { + let streams = occupied_entry.get_mut(); + // Take the newest stream. + let stream = streams.pop_back(); + assert!( + stream.is_some(), + "key existed in recycle but no streams available" + ); + + if streams.len() == 0 { + occupied_entry.remove(); + } + + // Remove the oldest matching PoolKey from self.lru. + // since this PoolKey was most recently used, removing the oldest + // PoolKey would delay other streams with this address from + // being removed. + self.remove_from_lru(key); + + stream + } + Entry::Vacant(_) => None, } - let index = self.lru.iter().position(|k| k == key); - assert!( - index.is_some(), - "invariant failed: key existed in recycle but not lru" - ); - self.lru.remove(index.unwrap()); - self.recycle.remove(&key) + } + + fn remove_from_lru(&mut self, key: &PoolKey) { + let index = self + .lru + .iter() + .position(|x| x == key) + .expect("PoolKey not found in lru"); + self.lru.remove(index); } fn add(&mut self, key: PoolKey, stream: Stream) { - // If an entry with the same key already exists, remove it. - // The more recently used stream is likely to live longer. - self.remove(&key); - if self.recycle.len() + 1 > MAX_IDLE_CONNECTIONS { - self.remove_oldest(); + if self.max_idle_connections == 0 || self.max_idle_connections_per_host == 0 { + return; + } + + match self.recycle.entry(key.clone()) { + Entry::Occupied(mut occupied_entry) => { + let streams = occupied_entry.get_mut(); + streams.push_back(stream); + if streams.len() > self.max_idle_connections_per_host { + streams.pop_front(); + self.remove_from_lru(&key); + } + } + Entry::Vacant(vacant_entry) => { + let mut new_deque = VecDeque::new(); + new_deque.push_back(stream); + vacant_entry.insert(new_deque); + } + } + self.lru.push_back(key); + if self.lru.len() > self.max_idle_connections { + self.remove_oldest() } - self.lru.push_back(key.clone()); - self.recycle.insert(key, stream); } fn remove_oldest(&mut self) { if let Some(key) = self.lru.pop_front() { - let removed = self.recycle.remove(&key); - assert!( - removed.is_some(), - "invariant failed: key existed in lru but not in recycle" - ); + match self.recycle.entry(key) { + Entry::Occupied(mut occupied_entry) => { + let streams = occupied_entry.get_mut(); + let removed_stream = streams.pop_front(); + assert!( + removed_stream.is_some(), + "key existed in recycle but no streams available" + ); + if streams.len() == 0 { + occupied_entry.remove(); + } + } + Entry::Vacant(_) => { + panic!("invariant failed: key existed in lru but not in recycle") + } + } } else { panic!("tried to remove oldest but no entries found!"); } @@ -79,7 +180,7 @@ impl ConnectionPool { #[cfg(test)] pub fn len(&self) -> usize { - self.recycle.len() + self.lru.len() } } @@ -123,10 +224,12 @@ fn poolkey_new() { } #[test] -fn pool_size_limit() { - assert_eq!(MAX_IDLE_CONNECTIONS, 100); +fn pool_connections_limit() { + // Test inserting connections with different keys into the pool, + // filling and draining it. The pool should evict earlier connections + // when the connection limit is reached. let mut pool = ConnectionPool::new(); - let hostnames = (0..200).map(|i| format!("{}.example", i)); + let hostnames = (0..DEFAULT_MAX_IDLE_CONNECTIONS * 2).map(|i| format!("{}.example", i)); let poolkeys = hostnames.map(|hostname| PoolKey { scheme: "https".to_string(), hostname, @@ -136,22 +239,49 @@ fn pool_size_limit() { for key in poolkeys.clone() { pool.add(key, Stream::Cursor(std::io::Cursor::new(vec![]))); } - assert_eq!(pool.len(), 100); + assert_eq!(pool.len(), DEFAULT_MAX_IDLE_CONNECTIONS); - for key in poolkeys.skip(100) { + for key in poolkeys.skip(DEFAULT_MAX_IDLE_CONNECTIONS) { let result = pool.remove(&key); assert!(result.is_some(), "expected key was not in pool"); } + assert_eq!(pool.len(), 0) } #[test] -fn pool_duplicates_limit() { - // Test inserting duplicates into the pool, and subsequently - // filling and draining it. The duplicates should evict earlier - // entries with the same key. - assert_eq!(MAX_IDLE_CONNECTIONS, 100); +fn pool_per_host_connections_limit() { + // Test inserting connections with the same key into the pool, + // filling and draining it. The pool should evict earlier connections + // when the per-host connection limit is reached. let mut pool = ConnectionPool::new(); - let hostnames = (0..100).map(|i| format!("{}.example", i)); + let poolkey = PoolKey { + scheme: "https".to_string(), + hostname: "example.com".to_string(), + port: Some(999), + proxy: None, + }; + + for _ in 0..pool.max_idle_connections_per_host * 2 { + pool.add( + poolkey.clone(), + Stream::Cursor(std::io::Cursor::new(vec![])), + ); + } + assert_eq!(pool.len(), DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST); + + for _ in 0..DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST { + let result = pool.remove(&poolkey); + assert!(result.is_some(), "expected key was not in pool"); + } + assert_eq!(pool.len(), 0); +} + +#[test] +fn pool_update_connection_limit() { + let mut pool = ConnectionPool::new(); + pool.set_max_idle_connections(50); + + let hostnames = (0..pool.max_idle_connections).map(|i| format!("{}.example", i)); let poolkeys = hostnames.map(|hostname| PoolKey { scheme: "https".to_string(), hostname, @@ -159,15 +289,35 @@ fn pool_duplicates_limit() { proxy: None, }); for key in poolkeys.clone() { - pool.add(key.clone(), Stream::Cursor(std::io::Cursor::new(vec![]))); pool.add(key, Stream::Cursor(std::io::Cursor::new(vec![]))); } - assert_eq!(pool.len(), 100); + assert_eq!(pool.len(), 50); + pool.set_max_idle_connections(25); + assert_eq!(pool.len(), 25); +} - for key in poolkeys { - let result = pool.remove(&key); - assert!(result.is_some(), "expected key was not in pool"); +#[test] +fn pool_update_per_host_connection_limit() { + let mut pool = ConnectionPool::new(); + pool.set_max_idle_connections(50); + pool.set_max_idle_connections_per_host(50); + + let poolkey = PoolKey { + scheme: "https".to_string(), + hostname: "example.com".to_string(), + port: Some(999), + proxy: None, + }; + + for _ in 0..50 { + pool.add( + poolkey.clone(), + Stream::Cursor(std::io::Cursor::new(vec![])), + ); } + assert_eq!(pool.len(), 50); + pool.set_max_idle_connections_per_host(25); + assert_eq!(pool.len(), 25); } #[test]