diff --git a/src/agent.rs b/src/agent.rs index b5b79da..1aa5929 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -6,6 +6,7 @@ use std::sync::Mutex; use crate::header::{self, Header}; use crate::pool::ConnectionPool; use crate::request::Request; +use crate::resolve::ArcResolver; /// Agents keep state between requests. /// @@ -53,15 +54,12 @@ pub(crate) struct AgentState { /// Cookies saved between requests. #[cfg(feature = "cookie")] pub(crate) jar: CookieJar, + pub(crate) resolver: ArcResolver, } impl AgentState { fn new() -> Self { - AgentState { - pool: ConnectionPool::new(), - #[cfg(feature = "cookie")] - jar: CookieJar::new(), - } + Self::default() } pub fn pool(&mut self) -> &mut ConnectionPool { &mut self.pool @@ -194,6 +192,29 @@ impl Agent { .set_max_idle_connections_per_host(max_connections); } + /// Configures a custom resolver to be used by this agent. By default, + /// address-resolution is done by std::net::ToSocketAddrs. This allows you + /// to override that resolution with your own alternative. Useful for + /// testing and special-cases like DNS-based load balancing. + /// + /// A `Fn(&str) -> io::Result>` is a valid resolver, + /// passing a closure is a simple way to override. Note that you might need + /// explicit type `&str` on the closure argument for type inference to + /// succeed. + /// ``` + /// use std::net::ToSocketAddrs; + /// + /// let mut agent = ureq::agent(); + /// agent.set_resolver(|addr: &str| match addr { + /// "example.com" => Ok(vec![([127,0,0,1], 8096).into()]), + /// addr => addr.to_socket_addrs().map(Iterator::collect), + /// }); + /// ``` + pub fn set_resolver(&mut self, resolver: impl crate::Resolver + 'static) -> &mut Self { + self.state.lock().unwrap().resolver = resolver.into(); + self + } + /// Gets a cookie in this agent by name. Cookies are available /// either by setting it in the agent, or by making requests /// that `Set-Cookie` in the agent. diff --git a/src/lib.rs b/src/lib.rs index df3e1f2..de2129f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,6 +125,7 @@ mod header; mod pool; mod proxy; mod request; +mod resolve; mod response; mod stream; mod unit; @@ -140,6 +141,7 @@ pub use crate::error::Error; pub use crate::header::Header; pub use crate::proxy::Proxy; pub use crate::request::Request; +pub use crate::resolve::Resolver; pub use crate::response::Response; // re-export diff --git a/src/pool.rs b/src/pool.rs index 571867f..5c3482f 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -74,10 +74,6 @@ impl Default for ConnectionPool { } impl ConnectionPool { - pub fn new() -> Self { - Self::default() - } - pub fn set_max_idle_connections(&mut self, max_connections: usize) { if self.max_idle_connections == max_connections { return; @@ -251,7 +247,7 @@ fn pool_connections_limit() { // Test inserting connections with different keys into the pool, // filling and draining it. The pool should evict earlier connections // when the connection limit is reached. - let mut pool = ConnectionPool::new(); + let mut pool = ConnectionPool::default(); let hostnames = (0..DEFAULT_MAX_IDLE_CONNECTIONS * 2).map(|i| format!("{}.example", i)); let poolkeys = hostnames.map(|hostname| PoolKey { scheme: "https".to_string(), @@ -276,7 +272,7 @@ fn pool_per_host_connections_limit() { // Test inserting connections with the same key into the pool, // filling and draining it. The pool should evict earlier connections // when the per-host connection limit is reached. - let mut pool = ConnectionPool::new(); + let mut pool = ConnectionPool::default(); let poolkey = PoolKey { scheme: "https".to_string(), hostname: "example.com".to_string(), @@ -301,7 +297,7 @@ fn pool_per_host_connections_limit() { #[test] fn pool_update_connection_limit() { - let mut pool = ConnectionPool::new(); + let mut pool = ConnectionPool::default(); pool.set_max_idle_connections(50); let hostnames = (0..pool.max_idle_connections).map(|i| format!("{}.example", i)); @@ -321,7 +317,7 @@ fn pool_update_connection_limit() { #[test] fn pool_update_per_host_connection_limit() { - let mut pool = ConnectionPool::new(); + let mut pool = ConnectionPool::default(); pool.set_max_idle_connections(50); pool.set_max_idle_connections_per_host(50); @@ -347,7 +343,7 @@ fn pool_update_per_host_connection_limit() { fn pool_checks_proxy() { // Test inserting different poolkeys with same address but different proxies. // Each insertion should result in an additional entry in the pool. - let mut pool = ConnectionPool::new(); + let mut pool = ConnectionPool::default(); let url = Url::parse("zzz:///example.com").unwrap(); pool.add( diff --git a/src/resolve.rs b/src/resolve.rs new file mode 100644 index 0000000..ca6497d --- /dev/null +++ b/src/resolve.rs @@ -0,0 +1,59 @@ +use std::fmt; +use std::io::Result as IoResult; +use std::net::{SocketAddr, ToSocketAddrs}; +use std::sync::Arc; + +pub trait Resolver: Send + Sync { + fn resolve(&self, netloc: &str) -> IoResult>; +} + +#[derive(Debug)] +pub(crate) struct StdResolver; + +impl Resolver for StdResolver { + fn resolve(&self, netloc: &str) -> IoResult> { + ToSocketAddrs::to_socket_addrs(netloc).map(|iter| iter.collect()) + } +} + +impl Resolver for F +where + F: Fn(&str) -> IoResult>, + F: Send + Sync, +{ + fn resolve(&self, netloc: &str) -> IoResult> { + self(netloc) + } +} + +#[derive(Clone)] +pub(crate) struct ArcResolver(Arc); + +impl From for ArcResolver +where + R: Resolver + 'static, +{ + fn from(r: R) -> Self { + Self(Arc::new(r)) + } +} + +impl fmt::Debug for ArcResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "ArcResolver(...)") + } +} + +impl std::ops::Deref for ArcResolver { + type Target = dyn Resolver; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl Default for ArcResolver { + fn default() -> Self { + StdResolver.into() + } +} diff --git a/src/stream.rs b/src/stream.rs index 7b42a6c..c51520f 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,7 +4,6 @@ use std::io::{ }; use std::net::SocketAddr; use std::net::TcpStream; -use std::net::ToSocketAddrs; use std::time::Duration; use std::time::Instant; @@ -386,15 +385,17 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result = match unit.req.proxy { + + let netloc = match unit.req.proxy { Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port), None => format!("{}:{}", hostname, port), - } - .to_socket_addrs() - .map_err(|e| Error::DnsFailed(format!("{}", e)))? - .collect(); + }; + + // TODO: Find a way to apply deadline to DNS lookup. + let sock_addrs = unit + .resolver() + .resolve(&netloc) + .map_err(|e| Error::DnsFailed(format!("{}", e)))?; if sock_addrs.is_empty() { return Err(Error::DnsFailed(format!("No ip address for {}", hostname))); @@ -419,6 +420,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result Result Result { - let addrs: Vec = format!("{}:{}", hostname, port) - .to_socket_addrs() - .map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))? - .collect(); +fn socks5_local_nslookup( + unit: &Unit, + hostname: &str, + port: u16, +) -> Result { + let addrs: Vec = unit + .resolver() + .resolve(&format!("{}:{}", hostname, port)) + .map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))?; if addrs.is_empty() { return Err(std::io::Error::new( @@ -522,6 +528,7 @@ fn socks5_local_nslookup(hostname: &str, port: u16) -> Result, proxy_addr: SocketAddr, @@ -533,7 +540,7 @@ fn connect_socks5( use std::str::FromStr; let host_addr = if Ipv4Addr::from_str(host).is_ok() || Ipv6Addr::from_str(host).is_ok() { - match socks5_local_nslookup(host, port) { + match socks5_local_nslookup(unit, host, port) { Ok(addr) => addr, Err(err) => return Err(err), } @@ -625,6 +632,7 @@ fn get_socks5_stream( #[cfg(not(feature = "socks-proxy"))] fn connect_socks5( + _unit: &Unit, _proxy: Proxy, _deadline: Option, _proxy_addr: SocketAddr, diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index f080352..ea107ec 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -101,6 +101,31 @@ fn connection_reuse() { assert_eq!(resp.status(), 200); } +#[test] +fn custom_resolver() { + use std::io::Read; + use std::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + + let local_addr = listener.local_addr().unwrap(); + + let server = std::thread::spawn(move || { + let (mut client, _) = listener.accept().unwrap(); + let mut buf = vec![0u8; 16]; + let read = client.read(&mut buf).unwrap(); + buf.truncate(read); + buf + }); + + crate::agent() + .set_resolver(move |_: &str| Ok(vec![local_addr])) + .get("http://cool.server/") + .call(); + + assert_eq!(&server.join().unwrap(), b"GET / HTTP/1.1\r\n"); +} + #[cfg(feature = "cookie")] #[cfg(test)] fn cookie_and_redirect(mut stream: TcpStream) -> io::Result<()> { diff --git a/src/unit.rs b/src/unit.rs index 13d03a8..f3822f2 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -10,6 +10,7 @@ use cookie::{Cookie, CookieJar}; use crate::agent::AgentState; use crate::body::{self, Payload, SizedReader}; use crate::header; +use crate::resolve::ArcResolver; use crate::stream::{self, connect_test, Stream}; use crate::{Error, Header, Request, Response}; @@ -95,6 +96,10 @@ impl Unit { self.req.method.eq_ignore_ascii_case("head") } + pub fn resolver(&self) -> ArcResolver { + self.req.agent.lock().unwrap().resolver.clone() + } + #[cfg(test)] pub fn header(&self, name: &str) -> Option<&str> { header::get_header(&self.headers, name)