Implement Pluggable Name-resolution (#148)
This defines a new trait `Resolver`, which turns an address into a Vec<SocketAddr>. It also provides an implementation of Resolver for `Fn(&str)` so it's easy to define simple resolvers with a closure. Fixes #82 Co-authored-by: Ulrik <ulrikm@spotify.com>
This commit is contained in:
31
src/agent.rs
31
src/agent.rs
@@ -6,6 +6,7 @@ use std::sync::Mutex;
|
|||||||
use crate::header::{self, Header};
|
use crate::header::{self, Header};
|
||||||
use crate::pool::ConnectionPool;
|
use crate::pool::ConnectionPool;
|
||||||
use crate::request::Request;
|
use crate::request::Request;
|
||||||
|
use crate::resolve::ArcResolver;
|
||||||
|
|
||||||
/// Agents keep state between requests.
|
/// Agents keep state between requests.
|
||||||
///
|
///
|
||||||
@@ -53,15 +54,12 @@ pub(crate) struct AgentState {
|
|||||||
/// Cookies saved between requests.
|
/// Cookies saved between requests.
|
||||||
#[cfg(feature = "cookie")]
|
#[cfg(feature = "cookie")]
|
||||||
pub(crate) jar: CookieJar,
|
pub(crate) jar: CookieJar,
|
||||||
|
pub(crate) resolver: ArcResolver,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentState {
|
impl AgentState {
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
AgentState {
|
Self::default()
|
||||||
pool: ConnectionPool::new(),
|
|
||||||
#[cfg(feature = "cookie")]
|
|
||||||
jar: CookieJar::new(),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
pub fn pool(&mut self) -> &mut ConnectionPool {
|
pub fn pool(&mut self) -> &mut ConnectionPool {
|
||||||
&mut self.pool
|
&mut self.pool
|
||||||
@@ -194,6 +192,29 @@ impl Agent {
|
|||||||
.set_max_idle_connections_per_host(max_connections);
|
.set_max_idle_connections_per_host(max_connections);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Configures a custom resolver to be used by this agent. By default,
|
||||||
|
/// address-resolution is done by std::net::ToSocketAddrs. This allows you
|
||||||
|
/// to override that resolution with your own alternative. Useful for
|
||||||
|
/// testing and special-cases like DNS-based load balancing.
|
||||||
|
///
|
||||||
|
/// A `Fn(&str) -> io::Result<Vec<SocketAddr>>` is a valid resolver,
|
||||||
|
/// passing a closure is a simple way to override. Note that you might need
|
||||||
|
/// explicit type `&str` on the closure argument for type inference to
|
||||||
|
/// succeed.
|
||||||
|
/// ```
|
||||||
|
/// use std::net::ToSocketAddrs;
|
||||||
|
///
|
||||||
|
/// let mut agent = ureq::agent();
|
||||||
|
/// agent.set_resolver(|addr: &str| match addr {
|
||||||
|
/// "example.com" => Ok(vec![([127,0,0,1], 8096).into()]),
|
||||||
|
/// addr => addr.to_socket_addrs().map(Iterator::collect),
|
||||||
|
/// });
|
||||||
|
/// ```
|
||||||
|
pub fn set_resolver(&mut self, resolver: impl crate::Resolver + 'static) -> &mut Self {
|
||||||
|
self.state.lock().unwrap().resolver = resolver.into();
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Gets a cookie in this agent by name. Cookies are available
|
/// Gets a cookie in this agent by name. Cookies are available
|
||||||
/// either by setting it in the agent, or by making requests
|
/// either by setting it in the agent, or by making requests
|
||||||
/// that `Set-Cookie` in the agent.
|
/// that `Set-Cookie` in the agent.
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ mod header;
|
|||||||
mod pool;
|
mod pool;
|
||||||
mod proxy;
|
mod proxy;
|
||||||
mod request;
|
mod request;
|
||||||
|
mod resolve;
|
||||||
mod response;
|
mod response;
|
||||||
mod stream;
|
mod stream;
|
||||||
mod unit;
|
mod unit;
|
||||||
@@ -140,6 +141,7 @@ pub use crate::error::Error;
|
|||||||
pub use crate::header::Header;
|
pub use crate::header::Header;
|
||||||
pub use crate::proxy::Proxy;
|
pub use crate::proxy::Proxy;
|
||||||
pub use crate::request::Request;
|
pub use crate::request::Request;
|
||||||
|
pub use crate::resolve::Resolver;
|
||||||
pub use crate::response::Response;
|
pub use crate::response::Response;
|
||||||
|
|
||||||
// re-export
|
// re-export
|
||||||
|
|||||||
14
src/pool.rs
14
src/pool.rs
@@ -74,10 +74,6 @@ impl Default for ConnectionPool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl ConnectionPool {
|
impl ConnectionPool {
|
||||||
pub fn new() -> Self {
|
|
||||||
Self::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn set_max_idle_connections(&mut self, max_connections: usize) {
|
pub fn set_max_idle_connections(&mut self, max_connections: usize) {
|
||||||
if self.max_idle_connections == max_connections {
|
if self.max_idle_connections == max_connections {
|
||||||
return;
|
return;
|
||||||
@@ -251,7 +247,7 @@ fn pool_connections_limit() {
|
|||||||
// Test inserting connections with different keys into the pool,
|
// Test inserting connections with different keys into the pool,
|
||||||
// filling and draining it. The pool should evict earlier connections
|
// filling and draining it. The pool should evict earlier connections
|
||||||
// when the connection limit is reached.
|
// when the connection limit is reached.
|
||||||
let mut pool = ConnectionPool::new();
|
let mut pool = ConnectionPool::default();
|
||||||
let hostnames = (0..DEFAULT_MAX_IDLE_CONNECTIONS * 2).map(|i| format!("{}.example", i));
|
let hostnames = (0..DEFAULT_MAX_IDLE_CONNECTIONS * 2).map(|i| format!("{}.example", i));
|
||||||
let poolkeys = hostnames.map(|hostname| PoolKey {
|
let poolkeys = hostnames.map(|hostname| PoolKey {
|
||||||
scheme: "https".to_string(),
|
scheme: "https".to_string(),
|
||||||
@@ -276,7 +272,7 @@ fn pool_per_host_connections_limit() {
|
|||||||
// Test inserting connections with the same key into the pool,
|
// Test inserting connections with the same key into the pool,
|
||||||
// filling and draining it. The pool should evict earlier connections
|
// filling and draining it. The pool should evict earlier connections
|
||||||
// when the per-host connection limit is reached.
|
// when the per-host connection limit is reached.
|
||||||
let mut pool = ConnectionPool::new();
|
let mut pool = ConnectionPool::default();
|
||||||
let poolkey = PoolKey {
|
let poolkey = PoolKey {
|
||||||
scheme: "https".to_string(),
|
scheme: "https".to_string(),
|
||||||
hostname: "example.com".to_string(),
|
hostname: "example.com".to_string(),
|
||||||
@@ -301,7 +297,7 @@ fn pool_per_host_connections_limit() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn pool_update_connection_limit() {
|
fn pool_update_connection_limit() {
|
||||||
let mut pool = ConnectionPool::new();
|
let mut pool = ConnectionPool::default();
|
||||||
pool.set_max_idle_connections(50);
|
pool.set_max_idle_connections(50);
|
||||||
|
|
||||||
let hostnames = (0..pool.max_idle_connections).map(|i| format!("{}.example", i));
|
let hostnames = (0..pool.max_idle_connections).map(|i| format!("{}.example", i));
|
||||||
@@ -321,7 +317,7 @@ fn pool_update_connection_limit() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn pool_update_per_host_connection_limit() {
|
fn pool_update_per_host_connection_limit() {
|
||||||
let mut pool = ConnectionPool::new();
|
let mut pool = ConnectionPool::default();
|
||||||
pool.set_max_idle_connections(50);
|
pool.set_max_idle_connections(50);
|
||||||
pool.set_max_idle_connections_per_host(50);
|
pool.set_max_idle_connections_per_host(50);
|
||||||
|
|
||||||
@@ -347,7 +343,7 @@ fn pool_update_per_host_connection_limit() {
|
|||||||
fn pool_checks_proxy() {
|
fn pool_checks_proxy() {
|
||||||
// Test inserting different poolkeys with same address but different proxies.
|
// Test inserting different poolkeys with same address but different proxies.
|
||||||
// Each insertion should result in an additional entry in the pool.
|
// Each insertion should result in an additional entry in the pool.
|
||||||
let mut pool = ConnectionPool::new();
|
let mut pool = ConnectionPool::default();
|
||||||
let url = Url::parse("zzz:///example.com").unwrap();
|
let url = Url::parse("zzz:///example.com").unwrap();
|
||||||
|
|
||||||
pool.add(
|
pool.add(
|
||||||
|
|||||||
59
src/resolve.rs
Normal file
59
src/resolve.rs
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
use std::fmt;
|
||||||
|
use std::io::Result as IoResult;
|
||||||
|
use std::net::{SocketAddr, ToSocketAddrs};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub trait Resolver: Send + Sync {
|
||||||
|
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct StdResolver;
|
||||||
|
|
||||||
|
impl Resolver for StdResolver {
|
||||||
|
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>> {
|
||||||
|
ToSocketAddrs::to_socket_addrs(netloc).map(|iter| iter.collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Resolver for F
|
||||||
|
where
|
||||||
|
F: Fn(&str) -> IoResult<Vec<SocketAddr>>,
|
||||||
|
F: Send + Sync,
|
||||||
|
{
|
||||||
|
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>> {
|
||||||
|
self(netloc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub(crate) struct ArcResolver(Arc<dyn Resolver>);
|
||||||
|
|
||||||
|
impl<R> From<R> for ArcResolver
|
||||||
|
where
|
||||||
|
R: Resolver + 'static,
|
||||||
|
{
|
||||||
|
fn from(r: R) -> Self {
|
||||||
|
Self(Arc::new(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for ArcResolver {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(f, "ArcResolver(...)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for ArcResolver {
|
||||||
|
type Target = dyn Resolver;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
self.0.as_ref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ArcResolver {
|
||||||
|
fn default() -> Self {
|
||||||
|
StdResolver.into()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,7 +4,6 @@ use std::io::{
|
|||||||
};
|
};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::net::TcpStream;
|
use std::net::TcpStream;
|
||||||
use std::net::ToSocketAddrs;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
@@ -386,15 +385,17 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
|||||||
} else {
|
} else {
|
||||||
unit.deadline
|
unit.deadline
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: Find a way to apply deadline to DNS lookup.
|
let netloc = match unit.req.proxy {
|
||||||
let sock_addrs: Vec<SocketAddr> = match unit.req.proxy {
|
|
||||||
Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port),
|
Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port),
|
||||||
None => format!("{}:{}", hostname, port),
|
None => format!("{}:{}", hostname, port),
|
||||||
}
|
};
|
||||||
.to_socket_addrs()
|
|
||||||
.map_err(|e| Error::DnsFailed(format!("{}", e)))?
|
// TODO: Find a way to apply deadline to DNS lookup.
|
||||||
.collect();
|
let sock_addrs = unit
|
||||||
|
.resolver()
|
||||||
|
.resolve(&netloc)
|
||||||
|
.map_err(|e| Error::DnsFailed(format!("{}", e)))?;
|
||||||
|
|
||||||
if sock_addrs.is_empty() {
|
if sock_addrs.is_empty() {
|
||||||
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
|
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
|
||||||
@@ -419,6 +420,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
|||||||
// connect with a configured timeout.
|
// connect with a configured timeout.
|
||||||
let stream = if Some(Proto::SOCKS5) == proto {
|
let stream = if Some(Proto::SOCKS5) == proto {
|
||||||
connect_socks5(
|
connect_socks5(
|
||||||
|
&unit,
|
||||||
unit.req.proxy.to_owned().unwrap(),
|
unit.req.proxy.to_owned().unwrap(),
|
||||||
deadline,
|
deadline,
|
||||||
sock_addr,
|
sock_addr,
|
||||||
@@ -496,11 +498,15 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "socks-proxy")]
|
#[cfg(feature = "socks-proxy")]
|
||||||
fn socks5_local_nslookup(hostname: &str, port: u16) -> Result<TargetAddr, std::io::Error> {
|
fn socks5_local_nslookup(
|
||||||
let addrs: Vec<SocketAddr> = format!("{}:{}", hostname, port)
|
unit: &Unit,
|
||||||
.to_socket_addrs()
|
hostname: &str,
|
||||||
.map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))?
|
port: u16,
|
||||||
.collect();
|
) -> Result<TargetAddr, std::io::Error> {
|
||||||
|
let addrs: Vec<SocketAddr> = unit
|
||||||
|
.resolver()
|
||||||
|
.resolve(&format!("{}:{}", hostname, port))
|
||||||
|
.map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))?;
|
||||||
|
|
||||||
if addrs.is_empty() {
|
if addrs.is_empty() {
|
||||||
return Err(std::io::Error::new(
|
return Err(std::io::Error::new(
|
||||||
@@ -522,6 +528,7 @@ fn socks5_local_nslookup(hostname: &str, port: u16) -> Result<TargetAddr, std::i
|
|||||||
|
|
||||||
#[cfg(feature = "socks-proxy")]
|
#[cfg(feature = "socks-proxy")]
|
||||||
fn connect_socks5(
|
fn connect_socks5(
|
||||||
|
unit: &Unit,
|
||||||
proxy: Proxy,
|
proxy: Proxy,
|
||||||
deadline: Option<Instant>,
|
deadline: Option<Instant>,
|
||||||
proxy_addr: SocketAddr,
|
proxy_addr: SocketAddr,
|
||||||
@@ -533,7 +540,7 @@ fn connect_socks5(
|
|||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
let host_addr = if Ipv4Addr::from_str(host).is_ok() || Ipv6Addr::from_str(host).is_ok() {
|
let host_addr = if Ipv4Addr::from_str(host).is_ok() || Ipv6Addr::from_str(host).is_ok() {
|
||||||
match socks5_local_nslookup(host, port) {
|
match socks5_local_nslookup(unit, host, port) {
|
||||||
Ok(addr) => addr,
|
Ok(addr) => addr,
|
||||||
Err(err) => return Err(err),
|
Err(err) => return Err(err),
|
||||||
}
|
}
|
||||||
@@ -625,6 +632,7 @@ fn get_socks5_stream(
|
|||||||
|
|
||||||
#[cfg(not(feature = "socks-proxy"))]
|
#[cfg(not(feature = "socks-proxy"))]
|
||||||
fn connect_socks5(
|
fn connect_socks5(
|
||||||
|
_unit: &Unit,
|
||||||
_proxy: Proxy,
|
_proxy: Proxy,
|
||||||
_deadline: Option<Instant>,
|
_deadline: Option<Instant>,
|
||||||
_proxy_addr: SocketAddr,
|
_proxy_addr: SocketAddr,
|
||||||
|
|||||||
@@ -101,6 +101,31 @@ fn connection_reuse() {
|
|||||||
assert_eq!(resp.status(), 200);
|
assert_eq!(resp.status(), 200);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn custom_resolver() {
|
||||||
|
use std::io::Read;
|
||||||
|
use std::net::TcpListener;
|
||||||
|
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||||
|
|
||||||
|
let local_addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let server = std::thread::spawn(move || {
|
||||||
|
let (mut client, _) = listener.accept().unwrap();
|
||||||
|
let mut buf = vec![0u8; 16];
|
||||||
|
let read = client.read(&mut buf).unwrap();
|
||||||
|
buf.truncate(read);
|
||||||
|
buf
|
||||||
|
});
|
||||||
|
|
||||||
|
crate::agent()
|
||||||
|
.set_resolver(move |_: &str| Ok(vec![local_addr]))
|
||||||
|
.get("http://cool.server/")
|
||||||
|
.call();
|
||||||
|
|
||||||
|
assert_eq!(&server.join().unwrap(), b"GET / HTTP/1.1\r\n");
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "cookie")]
|
#[cfg(feature = "cookie")]
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn cookie_and_redirect(mut stream: TcpStream) -> io::Result<()> {
|
fn cookie_and_redirect(mut stream: TcpStream) -> io::Result<()> {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use cookie::{Cookie, CookieJar};
|
|||||||
use crate::agent::AgentState;
|
use crate::agent::AgentState;
|
||||||
use crate::body::{self, Payload, SizedReader};
|
use crate::body::{self, Payload, SizedReader};
|
||||||
use crate::header;
|
use crate::header;
|
||||||
|
use crate::resolve::ArcResolver;
|
||||||
use crate::stream::{self, connect_test, Stream};
|
use crate::stream::{self, connect_test, Stream};
|
||||||
use crate::{Error, Header, Request, Response};
|
use crate::{Error, Header, Request, Response};
|
||||||
|
|
||||||
@@ -95,6 +96,10 @@ impl Unit {
|
|||||||
self.req.method.eq_ignore_ascii_case("head")
|
self.req.method.eq_ignore_ascii_case("head")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn resolver(&self) -> ArcResolver {
|
||||||
|
self.req.agent.lock().unwrap().resolver.clone()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub fn header(&self, name: &str) -> Option<&str> {
|
pub fn header(&self, name: &str) -> Option<&str> {
|
||||||
header::get_header(&self.headers, name)
|
header::get_header(&self.headers, name)
|
||||||
|
|||||||
Reference in New Issue
Block a user