Implement Pluggable Name-resolution (#148)

This defines a new trait `Resolver`, which turns an address into a
Vec<SocketAddr>. It also provides an implementation of Resolver for
`Fn(&str)` so it's easy to define simple resolvers with a closure.


Fixes #82

Co-authored-by: Ulrik <ulrikm@spotify.com>
This commit is contained in:
Ulrik Mikaelsson
2020-09-27 01:35:13 +02:00
committed by GitHub
parent 8bba07a9af
commit 11413726cd
7 changed files with 144 additions and 28 deletions

View File

@@ -6,6 +6,7 @@ use std::sync::Mutex;
use crate::header::{self, Header};
use crate::pool::ConnectionPool;
use crate::request::Request;
use crate::resolve::ArcResolver;
/// Agents keep state between requests.
///
@@ -53,15 +54,12 @@ pub(crate) struct AgentState {
/// Cookies saved between requests.
#[cfg(feature = "cookie")]
pub(crate) jar: CookieJar,
pub(crate) resolver: ArcResolver,
}
impl AgentState {
fn new() -> Self {
AgentState {
pool: ConnectionPool::new(),
#[cfg(feature = "cookie")]
jar: CookieJar::new(),
}
Self::default()
}
pub fn pool(&mut self) -> &mut ConnectionPool {
&mut self.pool
@@ -194,6 +192,29 @@ impl Agent {
.set_max_idle_connections_per_host(max_connections);
}
/// Configures a custom resolver to be used by this agent. By default,
/// address-resolution is done by std::net::ToSocketAddrs. This allows you
/// to override that resolution with your own alternative. Useful for
/// testing and special-cases like DNS-based load balancing.
///
/// A `Fn(&str) -> io::Result<Vec<SocketAddr>>` is a valid resolver,
/// passing a closure is a simple way to override. Note that you might need
/// explicit type `&str` on the closure argument for type inference to
/// succeed.
/// ```
/// use std::net::ToSocketAddrs;
///
/// let mut agent = ureq::agent();
/// agent.set_resolver(|addr: &str| match addr {
/// "example.com" => Ok(vec![([127,0,0,1], 8096).into()]),
/// addr => addr.to_socket_addrs().map(Iterator::collect),
/// });
/// ```
pub fn set_resolver(&mut self, resolver: impl crate::Resolver + 'static) -> &mut Self {
self.state.lock().unwrap().resolver = resolver.into();
self
}
/// Gets a cookie in this agent by name. Cookies are available
/// either by setting it in the agent, or by making requests
/// that `Set-Cookie` in the agent.

View File

@@ -125,6 +125,7 @@ mod header;
mod pool;
mod proxy;
mod request;
mod resolve;
mod response;
mod stream;
mod unit;
@@ -140,6 +141,7 @@ pub use crate::error::Error;
pub use crate::header::Header;
pub use crate::proxy::Proxy;
pub use crate::request::Request;
pub use crate::resolve::Resolver;
pub use crate::response::Response;
// re-export

View File

@@ -74,10 +74,6 @@ impl Default for ConnectionPool {
}
impl ConnectionPool {
pub fn new() -> Self {
Self::default()
}
pub fn set_max_idle_connections(&mut self, max_connections: usize) {
if self.max_idle_connections == max_connections {
return;
@@ -251,7 +247,7 @@ fn pool_connections_limit() {
// Test inserting connections with different keys into the pool,
// filling and draining it. The pool should evict earlier connections
// when the connection limit is reached.
let mut pool = ConnectionPool::new();
let mut pool = ConnectionPool::default();
let hostnames = (0..DEFAULT_MAX_IDLE_CONNECTIONS * 2).map(|i| format!("{}.example", i));
let poolkeys = hostnames.map(|hostname| PoolKey {
scheme: "https".to_string(),
@@ -276,7 +272,7 @@ fn pool_per_host_connections_limit() {
// Test inserting connections with the same key into the pool,
// filling and draining it. The pool should evict earlier connections
// when the per-host connection limit is reached.
let mut pool = ConnectionPool::new();
let mut pool = ConnectionPool::default();
let poolkey = PoolKey {
scheme: "https".to_string(),
hostname: "example.com".to_string(),
@@ -301,7 +297,7 @@ fn pool_per_host_connections_limit() {
#[test]
fn pool_update_connection_limit() {
let mut pool = ConnectionPool::new();
let mut pool = ConnectionPool::default();
pool.set_max_idle_connections(50);
let hostnames = (0..pool.max_idle_connections).map(|i| format!("{}.example", i));
@@ -321,7 +317,7 @@ fn pool_update_connection_limit() {
#[test]
fn pool_update_per_host_connection_limit() {
let mut pool = ConnectionPool::new();
let mut pool = ConnectionPool::default();
pool.set_max_idle_connections(50);
pool.set_max_idle_connections_per_host(50);
@@ -347,7 +343,7 @@ fn pool_update_per_host_connection_limit() {
fn pool_checks_proxy() {
// Test inserting different poolkeys with same address but different proxies.
// Each insertion should result in an additional entry in the pool.
let mut pool = ConnectionPool::new();
let mut pool = ConnectionPool::default();
let url = Url::parse("zzz:///example.com").unwrap();
pool.add(

59
src/resolve.rs Normal file
View File

@@ -0,0 +1,59 @@
use std::fmt;
use std::io::Result as IoResult;
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::Arc;
pub trait Resolver: Send + Sync {
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>>;
}
#[derive(Debug)]
pub(crate) struct StdResolver;
impl Resolver for StdResolver {
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>> {
ToSocketAddrs::to_socket_addrs(netloc).map(|iter| iter.collect())
}
}
impl<F> Resolver for F
where
F: Fn(&str) -> IoResult<Vec<SocketAddr>>,
F: Send + Sync,
{
fn resolve(&self, netloc: &str) -> IoResult<Vec<SocketAddr>> {
self(netloc)
}
}
#[derive(Clone)]
pub(crate) struct ArcResolver(Arc<dyn Resolver>);
impl<R> From<R> for ArcResolver
where
R: Resolver + 'static,
{
fn from(r: R) -> Self {
Self(Arc::new(r))
}
}
impl fmt::Debug for ArcResolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ArcResolver(...)")
}
}
impl std::ops::Deref for ArcResolver {
type Target = dyn Resolver;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl Default for ArcResolver {
fn default() -> Self {
StdResolver.into()
}
}

View File

@@ -4,7 +4,6 @@ use std::io::{
};
use std::net::SocketAddr;
use std::net::TcpStream;
use std::net::ToSocketAddrs;
use std::time::Duration;
use std::time::Instant;
@@ -387,14 +386,16 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
unit.deadline
};
// TODO: Find a way to apply deadline to DNS lookup.
let sock_addrs: Vec<SocketAddr> = match unit.req.proxy {
let netloc = match unit.req.proxy {
Some(ref proxy) => format!("{}:{}", proxy.server, proxy.port),
None => format!("{}:{}", hostname, port),
}
.to_socket_addrs()
.map_err(|e| Error::DnsFailed(format!("{}", e)))?
.collect();
};
// TODO: Find a way to apply deadline to DNS lookup.
let sock_addrs = unit
.resolver()
.resolve(&netloc)
.map_err(|e| Error::DnsFailed(format!("{}", e)))?;
if sock_addrs.is_empty() {
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
@@ -419,6 +420,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
// connect with a configured timeout.
let stream = if Some(Proto::SOCKS5) == proto {
connect_socks5(
&unit,
unit.req.proxy.to_owned().unwrap(),
deadline,
sock_addr,
@@ -496,11 +498,15 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
}
#[cfg(feature = "socks-proxy")]
fn socks5_local_nslookup(hostname: &str, port: u16) -> Result<TargetAddr, std::io::Error> {
let addrs: Vec<SocketAddr> = format!("{}:{}", hostname, port)
.to_socket_addrs()
.map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))?
.collect();
fn socks5_local_nslookup(
unit: &Unit,
hostname: &str,
port: u16,
) -> Result<TargetAddr, std::io::Error> {
let addrs: Vec<SocketAddr> = unit
.resolver()
.resolve(&format!("{}:{}", hostname, port))
.map_err(|e| std::io::Error::new(ErrorKind::NotFound, format!("DNS failure: {}.", e)))?;
if addrs.is_empty() {
return Err(std::io::Error::new(
@@ -522,6 +528,7 @@ fn socks5_local_nslookup(hostname: &str, port: u16) -> Result<TargetAddr, std::i
#[cfg(feature = "socks-proxy")]
fn connect_socks5(
unit: &Unit,
proxy: Proxy,
deadline: Option<Instant>,
proxy_addr: SocketAddr,
@@ -533,7 +540,7 @@ fn connect_socks5(
use std::str::FromStr;
let host_addr = if Ipv4Addr::from_str(host).is_ok() || Ipv6Addr::from_str(host).is_ok() {
match socks5_local_nslookup(host, port) {
match socks5_local_nslookup(unit, host, port) {
Ok(addr) => addr,
Err(err) => return Err(err),
}
@@ -625,6 +632,7 @@ fn get_socks5_stream(
#[cfg(not(feature = "socks-proxy"))]
fn connect_socks5(
_unit: &Unit,
_proxy: Proxy,
_deadline: Option<Instant>,
_proxy_addr: SocketAddr,

View File

@@ -101,6 +101,31 @@ fn connection_reuse() {
assert_eq!(resp.status(), 200);
}
#[test]
fn custom_resolver() {
use std::io::Read;
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
let local_addr = listener.local_addr().unwrap();
let server = std::thread::spawn(move || {
let (mut client, _) = listener.accept().unwrap();
let mut buf = vec![0u8; 16];
let read = client.read(&mut buf).unwrap();
buf.truncate(read);
buf
});
crate::agent()
.set_resolver(move |_: &str| Ok(vec![local_addr]))
.get("http://cool.server/")
.call();
assert_eq!(&server.join().unwrap(), b"GET / HTTP/1.1\r\n");
}
#[cfg(feature = "cookie")]
#[cfg(test)]
fn cookie_and_redirect(mut stream: TcpStream) -> io::Result<()> {

View File

@@ -10,6 +10,7 @@ use cookie::{Cookie, CookieJar};
use crate::agent::AgentState;
use crate::body::{self, Payload, SizedReader};
use crate::header;
use crate::resolve::ArcResolver;
use crate::stream::{self, connect_test, Stream};
use crate::{Error, Header, Request, Response};
@@ -95,6 +96,10 @@ impl Unit {
self.req.method.eq_ignore_ascii_case("head")
}
pub fn resolver(&self) -> ArcResolver {
self.req.agent.lock().unwrap().resolver.clone()
}
#[cfg(test)]
pub fn header(&self, name: &str) -> Option<&str> {
header::get_header(&self.headers, name)