Implement Pluggable Name-resolution (#148)
This defines a new trait `Resolver`, which turns an address into a Vec<SocketAddr>. It also provides an implementation of Resolver for `Fn(&str)` so it's easy to define simple resolvers with a closure. Fixes #82 Co-authored-by: Ulrik <ulrikm@spotify.com>
This commit is contained in:
31
src/agent.rs
31
src/agent.rs
@@ -6,6 +6,7 @@ use std::sync::Mutex;
|
||||
use crate::header::{self, Header};
|
||||
use crate::pool::ConnectionPool;
|
||||
use crate::request::Request;
|
||||
use crate::resolve::ArcResolver;
|
||||
|
||||
/// Agents keep state between requests.
|
||||
///
|
||||
@@ -53,15 +54,12 @@ pub(crate) struct AgentState {
|
||||
/// Cookies saved between requests.
|
||||
#[cfg(feature = "cookie")]
|
||||
pub(crate) jar: CookieJar,
|
||||
pub(crate) resolver: ArcResolver,
|
||||
}
|
||||
|
||||
impl AgentState {
|
||||
fn new() -> Self {
|
||||
AgentState {
|
||||
pool: ConnectionPool::new(),
|
||||
#[cfg(feature = "cookie")]
|
||||
jar: CookieJar::new(),
|
||||
}
|
||||
Self::default()
|
||||
}
|
||||
pub fn pool(&mut self) -> &mut ConnectionPool {
|
||||
&mut self.pool
|
||||
@@ -194,6 +192,29 @@ impl Agent {
|
||||
.set_max_idle_connections_per_host(max_connections);
|
||||
}
|
||||
|
||||
/// Configures a custom resolver to be used by this agent. By default,
|
||||
/// address-resolution is done by std::net::ToSocketAddrs. This allows you
|
||||
/// to override that resolution with your own alternative. Useful for
|
||||
/// testing and special-cases like DNS-based load balancing.
|
||||
///
|
||||
/// A `Fn(&str) -> io::Result<Vec<SocketAddr>>` is a valid resolver,
|
||||
/// passing a closure is a simple way to override. Note that you might need
|
||||
/// explicit type `&str` on the closure argument for type inference to
|
||||
/// succeed.
|
||||
/// ```
|
||||
/// use std::net::ToSocketAddrs;
|
||||
///
|
||||
/// let mut agent = ureq::agent();
|
||||
/// agent.set_resolver(|addr: &str| match addr {
|
||||
/// "example.com" => Ok(vec![([127,0,0,1], 8096).into()]),
|
||||
/// addr => addr.to_socket_addrs().map(Iterator::collect),
|
||||
/// });
|
||||
/// ```
|
||||
pub fn set_resolver(&mut self, resolver: impl crate::Resolver + 'static) -> &mut Self {
|
||||
self.state.lock().unwrap().resolver = resolver.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets a cookie in this agent by name. Cookies are available
|
||||
/// either by setting it in the agent, or by making requests
|
||||
/// that `Set-Cookie` in the agent.
|
||||
|
||||
@@ -125,6 +125,7 @@ mod header;
|
||||
mod pool;
|
||||
mod proxy;
|
||||
mod request;
|
||||
mod resolve;
|
||||
mod response;
|
||||
mod stream;
|
||||
mod unit;
|
||||
@@ -140,6 +141,7 @@ pub use crate::error::Error;
|
||||
pub use crate::header::Header;
|
||||
pub use crate::proxy::Proxy;
|
||||
pub use crate::request::Request;
|
||||
pub use crate::resolve::Resolver;
|
||||
pub use crate::response::Response;
|
||||
|
||||
// re-export
|
||||
|
||||
14
src/pool.rs
14
src/pool.rs
@@ -74,10 +74,6 @@ impl Default for ConnectionPool {
|
||||
}
|
||||
|
||||
impl ConnectionPool {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn set_max_idle_connections(&mut self, max_connections: usize) {
|
||||
if self.max_idle_connections == max_connections {
|
||||
return;
|
||||
@@ -251,7 +247,7 @@ fn pool_connections_limit() {
|
||||
// Test inserting connections with different keys into the pool,
|
||||
// filling and draining it. The pool should evict earlier connections
|
||||
// when the connection limit is reached.
|
||||
let mut pool = ConnectionPool::new();
|
||||
let mut pool = ConnectionPool::default();
|
||||
let hostnames = (0..DEFAULT_MAX_IDLE_CONNECTIONS * 2).map(|i| format!("{}.example", i));
|
||||
let poolkeys = hostnames.map(|hostname| PoolKey {
|
||||
scheme: "https".to_string(),
|
||||
@@ -276,7 +272,7 @@ fn pool_per_host_connections_limit() {
|
||||
// Test inserting connections with the same key into the pool,
|
||||
// filling and draining it. The pool should evict earlier connections
|
||||
// when the per-host connection limit is reached.
|
||||
let mut pool = ConnectionPool::new();
|
||||
let mut pool = ConnectionPool::default();
|
||||
let poolkey = PoolKey {
|
||||
scheme: "https".to_string(),
|
||||
hostname: "example.com".to_string(),
|
||||
@@ -301,7 +297,7 @@ fn pool_per_host_connections_limit() {
|
||||
|
||||
#[test]
|
||||
fn pool_update_connection_limit() {
|
||||
let mut pool = ConnectionPool::new();
|
||||
let mut pool = ConnectionPool::default();
|
||||
pool.set_max_idle_connections(50);
|
||||
|
||||
let hostnames = (0..pool.max_idle_connections).map(|i| format!("{}.example", i));
|
||||
@@ -321,7 +317,7 @@ fn pool_update_connection_limit() {
|
||||
|
||||
#[test]
|
||||
fn pool_update_per_host_connection_limit() {
|
||||
let mut pool = ConnectionPool::new();
|
||||
let mut pool = ConnectionPool::default();
|
||||
pool.set_max_idle_connections(50);
|
||||
pool.set_max_idle_connections_per_host(50);
|
||||
|
||||
@@ -347,7 +343,7 @@ fn pool_update_per_host_connection_limit() {
|
||||
fn pool_checks_proxy() {
|
||||
// Test inserting different poolkeys with same address but different proxies.
|
||||
// Each insertion should result in an additional entry in the pool.
|
||||
let mut pool = ConnectionPool::new();
|
||||
let mut pool = ConnectionPool::default();
|
||||
let url = Url::parse("zzz:///example.com").unwrap();
|
||||
|
||||
pool.add(
|
||||
|
||||
59
src/resolve.rs
Normal file
59
src/resolve.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::fmt;
|
||||
use std::io::Result as IoResult;
|
||||
use std::net::{SocketAddr, ToSocketAddrs};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub trait Resolver: Send + Sync {
|
||||
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>>;
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct StdResolver;
|
||||
|
||||
impl Resolver for StdResolver {
|
||||
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>> {
|
||||
ToSocketAddrs::to_socket_addrs(netloc).map(|iter| iter.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Resolver for F
|
||||
where
|
||||
F: Fn(&str) -> IoResult<Vec<SocketAddr>>,
|
||||
F: Send + Sync,
|
||||
{
|
||||
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>> {
|
||||
self(netloc)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ArcResolver(Arc<dyn Resolver>);
|
||||
|
||||
impl<R> From<R> for ArcResolver
|
||||
where
|
||||
R: Resolver + 'static,
|
||||
{
|
||||
fn from(r: R) -> Self {
|
||||
Self(Arc::new(r))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for ArcResolver {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "ArcResolver(...)")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for ArcResolver {
|
||||
type Target = dyn Resolver;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ArcResolver {
|
||||
fn default() -> Self {
|
||||
StdResolver.into()
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ use std::io::{
|
||||
};
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpStream;
|
||||
use std::net::ToSocketAddrs;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
@@ -386,15 +385,17 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
||||
} else {
|
||||
unit.deadline
|
||||
};
|
||||
|
||||
// TODO: Find a way to apply deadline to DNS lookup.
|
||||
let sock_addrs: Vec<SocketAddr> = match unit.req.proxy {
|
||||
|
||||
let netloc = match unit.req.proxy {
|
||||
Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port),
|
||||
None => format!("{}:{}", hostname, port),
|
||||
}
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| Error::DnsFailed(format!("{}", e)))?
|
||||
.collect();
|
||||
};
|
||||
|
||||
// TODO: Find a way to apply deadline to DNS lookup.
|
||||
let sock_addrs = unit
|
||||
.resolver()
|
||||
.resolve(&netloc)
|
||||
.map_err(|e| Error::DnsFailed(format!("{}", e)))?;
|
||||
|
||||
if sock_addrs.is_empty() {
|
||||
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
|
||||
@@ -419,6 +420,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
||||
// connect with a configured timeout.
|
||||
let stream = if Some(Proto::SOCKS5) == proto {
|
||||
connect_socks5(
|
||||
&unit,
|
||||
unit.req.proxy.to_owned().unwrap(),
|
||||
deadline,
|
||||
sock_addr,
|
||||
@@ -496,11 +498,15 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
|
||||
}
|
||||
|
||||
#[cfg(feature = "socks-proxy")]
|
||||
fn socks5_local_nslookup(hostname: &str, port: u16) -> Result<TargetAddr, std::io::Error> {
|
||||
let addrs: Vec<SocketAddr> = format!("{}:{}", hostname, port)
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))?
|
||||
.collect();
|
||||
fn socks5_local_nslookup(
|
||||
unit: &Unit,
|
||||
hostname: &str,
|
||||
port: u16,
|
||||
) -> Result<TargetAddr, std::io::Error> {
|
||||
let addrs: Vec<SocketAddr> = unit
|
||||
.resolver()
|
||||
.resolve(&format!("{}:{}", hostname, port))
|
||||
.map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))?;
|
||||
|
||||
if addrs.is_empty() {
|
||||
return Err(std::io::Error::new(
|
||||
@@ -522,6 +528,7 @@ fn socks5_local_nslookup(hostname: &str, port: u16) -> Result<TargetAddr, std::i
|
||||
|
||||
#[cfg(feature = "socks-proxy")]
|
||||
fn connect_socks5(
|
||||
unit: &Unit,
|
||||
proxy: Proxy,
|
||||
deadline: Option<Instant>,
|
||||
proxy_addr: SocketAddr,
|
||||
@@ -533,7 +540,7 @@ fn connect_socks5(
|
||||
use std::str::FromStr;
|
||||
|
||||
let host_addr = if Ipv4Addr::from_str(host).is_ok() || Ipv6Addr::from_str(host).is_ok() {
|
||||
match socks5_local_nslookup(host, port) {
|
||||
match socks5_local_nslookup(unit, host, port) {
|
||||
Ok(addr) => addr,
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
@@ -625,6 +632,7 @@ fn get_socks5_stream(
|
||||
|
||||
#[cfg(not(feature = "socks-proxy"))]
|
||||
fn connect_socks5(
|
||||
_unit: &Unit,
|
||||
_proxy: Proxy,
|
||||
_deadline: Option<Instant>,
|
||||
_proxy_addr: SocketAddr,
|
||||
|
||||
@@ -101,6 +101,31 @@ fn connection_reuse() {
|
||||
assert_eq!(resp.status(), 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_resolver() {
|
||||
use std::io::Read;
|
||||
use std::net::TcpListener;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
|
||||
let local_addr = listener.local_addr().unwrap();
|
||||
|
||||
let server = std::thread::spawn(move || {
|
||||
let (mut client, _) = listener.accept().unwrap();
|
||||
let mut buf = vec![0u8; 16];
|
||||
let read = client.read(&mut buf).unwrap();
|
||||
buf.truncate(read);
|
||||
buf
|
||||
});
|
||||
|
||||
crate::agent()
|
||||
.set_resolver(move |_: &str| Ok(vec![local_addr]))
|
||||
.get("http://cool.server/")
|
||||
.call();
|
||||
|
||||
assert_eq!(&server.join().unwrap(), b"GET / HTTP/1.1\r\n");
|
||||
}
|
||||
|
||||
#[cfg(feature = "cookie")]
|
||||
#[cfg(test)]
|
||||
fn cookie_and_redirect(mut stream: TcpStream) -> io::Result<()> {
|
||||
|
||||
@@ -10,6 +10,7 @@ use cookie::{Cookie, CookieJar};
|
||||
use crate::agent::AgentState;
|
||||
use crate::body::{self, Payload, SizedReader};
|
||||
use crate::header;
|
||||
use crate::resolve::ArcResolver;
|
||||
use crate::stream::{self, connect_test, Stream};
|
||||
use crate::{Error, Header, Request, Response};
|
||||
|
||||
@@ -95,6 +96,10 @@ impl Unit {
|
||||
self.req.method.eq_ignore_ascii_case("head")
|
||||
}
|
||||
|
||||
pub fn resolver(&self) -> ArcResolver {
|
||||
self.req.agent.lock().unwrap().resolver.clone()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn header(&self, name: &str) -> Option<&str> {
|
||||
header::get_header(&self.headers, name)
|
||||
|
||||
Reference in New Issue
Block a user