diff --git a/src/stream.rs b/src/stream.rs index 8ac4d9b..7057a93 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,8 +1,9 @@ -use std::io::{Cursor, ErrorKind, Read, Result as IoResult, Write}; +use std::io::{Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write}; use std::net::SocketAddr; use std::net::TcpStream; use std::net::ToSocketAddrs; use std::time::Duration; +use std::time::Instant; #[cfg(feature = "tls")] use rustls::ClientSession; @@ -186,6 +187,10 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result Result 0; // Find the first sock_addr that accepts a connection for sock_addr in sock_addrs { + // ensure connect timeout isn't hit overall. + if has_timeout { + let lapsed = (Instant::now() - start_time).as_millis() as u64; + if lapsed >= unit.timeout_connect { + any_err = Some(IoError::new(ErrorKind::TimedOut, "Didn't connect in time")); + break; + } else { + timeout_connect = unit.timeout_connect - lapsed; + } + } + // connect with a configured timeout. let stream = if Some(Proto::SOCKS5) == proto { connect_socks5( unit.proxy.to_owned().unwrap(), - unit.timeout_connect, + timeout_connect, sock_addr, hostname, port, ) } else { - match unit.timeout_connect { - 0 => TcpStream::connect(&sock_addr), - _ => TcpStream::connect_timeout( - &sock_addr, - Duration::from_millis(unit.timeout_connect as u64), - ), + if has_timeout { + let timeout = Duration::from_millis(timeout_connect); + TcpStream::connect_timeout(&sock_addr, timeout) + } else { + TcpStream::connect(&sock_addr) } }; @@ -227,11 +245,7 @@ pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result