ensure overall connect timeout

This commit is contained in:
Martin Algesten
2020-05-23 09:28:16 +02:00
parent f03995a72f
commit aa3e9b1ecf

View File

@@ -1,8 +1,9 @@
use std::io::{Cursor, ErrorKind, Read, Result as IoResult, Write}; use std::io::{Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::net::TcpStream; use std::net::TcpStream;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::time::Duration; use std::time::Duration;
use std::time::Instant;
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
use rustls::ClientSession; use rustls::ClientSession;
@@ -186,6 +187,10 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
.map_err(|e| Error::DnsFailed(format!("{}", e)))? .map_err(|e| Error::DnsFailed(format!("{}", e)))?
.collect(); .collect();
if sock_addrs.is_empty() {
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
}
let proto = if let Some(ref proxy) = unit.proxy { let proto = if let Some(ref proxy) = unit.proxy {
Some(proxy.proto) Some(proxy.proto)
} else { } else {
@@ -194,25 +199,38 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
let mut any_err = None; let mut any_err = None;
let mut any_stream = None; let mut any_stream = None;
let mut timeout_connect = unit.timeout_connect;
let start_time = Instant::now();
let has_timeout = unit.timeout_connect > 0;
// Find the first sock_addr that accepts a connection // Find the first sock_addr that accepts a connection
for sock_addr in sock_addrs { for sock_addr in sock_addrs {
// ensure connect timeout isn't hit overall.
if has_timeout {
let lapsed = (Instant::now() - start_time).as_millis() as u64;
if lapsed >= unit.timeout_connect {
any_err = Some(IoError::new(ErrorKind::TimedOut, "Didn't connect in time"));
break;
} else {
timeout_connect = unit.timeout_connect - lapsed;
}
}
// connect with a configured timeout. // connect with a configured timeout.
let stream = if Some(Proto::SOCKS5) == proto { let stream = if Some(Proto::SOCKS5) == proto {
connect_socks5( connect_socks5(
unit.proxy.to_owned().unwrap(), unit.proxy.to_owned().unwrap(),
unit.timeout_connect, timeout_connect,
sock_addr, sock_addr,
hostname, hostname,
port, port,
) )
} else { } else {
match unit.timeout_connect { if has_timeout {
0 => TcpStream::connect(&sock_addr), let timeout = Duration::from_millis(timeout_connect);
_ => TcpStream::connect_timeout( TcpStream::connect_timeout(&sock_addr, timeout)
&sock_addr, } else {
Duration::from_millis(unit.timeout_connect as u64), TcpStream::connect(&sock_addr)
),
} }
}; };
@@ -227,11 +245,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<Tcp
let mut stream = if let Some(stream) = any_stream { let mut stream = if let Some(stream) = any_stream {
stream stream
} else { } else {
let err = if let Some(err) = any_err { let err = Error::ConnectionFailed(format!("{}", any_err.expect("Connect error")));
Error::ConnectionFailed(format!("{}", err))
} else {
Error::DnsFailed(format!("No ip address for {}", hostname))
};
return Err(err); return Err(err);
}; };