Add optional native-tls support, clear up warnings for flag configurations

This commit is contained in:
k3d3
2020-04-13 22:03:53 -04:00
committed by Martin Algesten
parent 8a05241eac
commit 9f7f712dde
8 changed files with 65 additions and 15 deletions

View File

@@ -37,6 +37,7 @@ rustls-native-certs = { version = "0.3", optional = true }
serde = { version = "1", optional = true } serde = { version = "1", optional = true }
serde_json = { version = "1", optional = true } serde_json = { version = "1", optional = true }
encoding = { version = "0.2", optional = true } encoding = { version = "0.2", optional = true }
native-tls = { version = "0.2", optional = true }
[dev-dependencies] [dev-dependencies]
serde = { version = "1", features = ["derive"] } serde = { version = "1", features = ["derive"] }

View File

@@ -250,7 +250,7 @@ impl Agent {
self.request("PATCH", path) self.request("PATCH", path)
} }
#[cfg(test)] #[cfg(all(test, any(feature = "tls", feature = "native-tls")))]
pub(crate) fn state(&self) -> &Arc<Mutex<Option<AgentState>>> { pub(crate) fn state(&self) -> &Arc<Mutex<Option<AgentState>>> {
&self.state &self.state
} }

View File

@@ -31,6 +31,9 @@ pub enum Error {
ProxyConnect, ProxyConnect,
/// Incorrect credentials for proxy /// Incorrect credentials for proxy
InvalidProxyCreds, InvalidProxyCreds,
/// TLS Error
#[cfg(feature = "native-tls")]
TlsError(native_tls::Error),
} }
impl Error { impl Error {
@@ -59,6 +62,8 @@ impl Error {
Error::BadProxyCreds => 500, Error::BadProxyCreds => 500,
Error::ProxyConnect => 500, Error::ProxyConnect => 500,
Error::InvalidProxyCreds => 500, Error::InvalidProxyCreds => 500,
#[cfg(feature = "native-tls")]
Error::TlsError(_) => 599,
} }
} }
@@ -78,6 +83,8 @@ impl Error {
Error::BadProxyCreds => "Failed to parse proxy credentials", Error::BadProxyCreds => "Failed to parse proxy credentials",
Error::ProxyConnect => "Proxy failed to connect", Error::ProxyConnect => "Proxy failed to connect",
Error::InvalidProxyCreds => "Provided proxy credentials are incorrect", Error::InvalidProxyCreds => "Provided proxy credentials are incorrect",
#[cfg(feature = "native-tls")]
Error::TlsError(_) => "TLS Error",
} }
} }
@@ -97,6 +104,8 @@ impl Error {
Error::BadProxyCreds => "Failed to parse proxy credentials".to_string(), Error::BadProxyCreds => "Failed to parse proxy credentials".to_string(),
Error::ProxyConnect => "Proxy failed to connect".to_string(), Error::ProxyConnect => "Proxy failed to connect".to_string(),
Error::InvalidProxyCreds => "Provided proxy credentials are incorrect".to_string(), Error::InvalidProxyCreds => "Provided proxy credentials are incorrect".to_string(),
#[cfg(feature = "native-tls")]
Error::TlsError(err) => format!("TLS Error: {}", err),
} }
} }
} }

View File

@@ -7,6 +7,7 @@
//! * Obvious API //! * Obvious API
//! //!
//! ``` //! ```
//! # #[cfg(feature = "json")] {
//! // requires feature: `ureq = { version = "*", features = ["json"] }` //! // requires feature: `ureq = { version = "*", features = ["json"] }`
//! # #[cfg(feature = "json")] { //! # #[cfg(feature = "json")] {
//! use ureq::json; //! use ureq::json;
@@ -195,7 +196,7 @@ mod tests {
} }
#[test] #[test]
#[cfg(feature = "tls")] #[cfg(any(feature = "tls", feature = "native-tls"))]
fn connect_https_google() { fn connect_https_google() {
let resp = get("https://www.google.com/").call(); let resp = get("https://www.google.com/").call();
assert_eq!( assert_eq!(
@@ -206,7 +207,7 @@ mod tests {
} }
#[test] #[test]
#[cfg(feature = "tls")] #[cfg(any(feature = "tls", feature = "native-tls"))]
fn connect_https_invalid_name() { fn connect_https_invalid_name() {
let resp = get("https://example.com{REQUEST_URI}/").call(); let resp = get("https://example.com{REQUEST_URI}/").call();
assert_eq!(400, resp.status()); assert_eq!(400, resp.status());

View File

@@ -29,12 +29,12 @@ impl ConnectionPool {
self.recycle.remove(&PoolKey::new(url)) self.recycle.remove(&PoolKey::new(url))
} }
#[cfg(test)] #[cfg(all(test, any(feature = "tls", feature = "native-tls")))]
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.recycle.len() self.recycle.len()
} }
#[cfg(test)] #[cfg(all(test, any(feature = "tls", feature = "native-tls")))]
pub fn get(&self, hostname: &str, port: u16) -> Option<&Stream> { pub fn get(&self, hostname: &str, port: u16) -> Option<&Stream> {
let key = PoolKey { let key = PoolKey {
hostname: hostname.into(), hostname: hostname.into(),

View File

@@ -12,6 +12,9 @@ use rustls::StreamOwned;
#[cfg(feature = "socks-proxy")] #[cfg(feature = "socks-proxy")]
use socks::{TargetAddr, ToTargetAddr}; use socks::{TargetAddr, ToTargetAddr};
#[cfg(feature = "native-tls")]
use native_tls::{TlsConnector, TlsStream, HandshakeError};
use crate::proxy::Proto; use crate::proxy::Proto;
use crate::proxy::Proxy; use crate::proxy::Proxy;
@@ -23,6 +26,8 @@ pub enum Stream {
Http(TcpStream), Http(TcpStream),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>), Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
#[cfg(feature = "native-tls")]
Https(TlsStream<TcpStream>),
Cursor(Cursor<Vec<u8>>), Cursor(Cursor<Vec<u8>>),
#[cfg(test)] #[cfg(test)]
Test(Box<dyn Read + Send>, Vec<u8>), Test(Box<dyn Read + Send>, Vec<u8>),
@@ -35,7 +40,7 @@ impl ::std::fmt::Debug for Stream {
"Stream[{}]", "Stream[{}]",
match self { match self {
Stream::Http(_) => "http", Stream::Http(_) => "http",
#[cfg(feature = "tls")] #[cfg(any(feature = "tls", feature = "native-tls"))]
Stream::Https(_) => "https", Stream::Https(_) => "https",
Stream::Cursor(_) => "cursor", Stream::Cursor(_) => "cursor",
#[cfg(test)] #[cfg(test)]
@@ -76,7 +81,7 @@ impl Stream {
pub fn is_poolable(&self) -> bool { pub fn is_poolable(&self) -> bool {
match self { match self {
Stream::Http(_) => true, Stream::Http(_) => true,
#[cfg(feature = "tls")] #[cfg(any(feature = "tls", feature = "native-tls"))]
Stream::Https(_) => true, Stream::Https(_) => true,
_ => false, _ => false,
} }
@@ -95,7 +100,7 @@ impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> { fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match self { match self {
Stream::Http(sock) => sock.read(buf), Stream::Http(sock) => sock.read(buf),
#[cfg(feature = "tls")] #[cfg(any(feature = "tls", feature = "native-tls"))]
Stream::Https(stream) => read_https(stream, buf), Stream::Https(stream) => read_https(stream, buf),
Stream::Cursor(read) => read.read(buf), Stream::Cursor(read) => read.read(buf),
#[cfg(test)] #[cfg(test)]
@@ -116,7 +121,20 @@ fn read_https(
} }
} }
#[cfg(feature = "native-tls")]
fn read_https(
stream: &mut TlsStream<TcpStream>,
buf: &mut [u8],
) -> IoResult<usize> {
match stream.read(buf) {
Ok(size) => Ok(size),
Err(ref e) if is_close_notify(e) => Ok(0),
Err(e) => Err(e),
}
}
#[allow(deprecated)] #[allow(deprecated)]
#[cfg(any(feature = "tls", feature = "native-tls"))]
fn is_close_notify(e: &std::io::Error) -> bool { fn is_close_notify(e: &std::io::Error) -> bool {
if e.kind() != ErrorKind::ConnectionAborted { if e.kind() != ErrorKind::ConnectionAborted {
return false; return false;
@@ -135,7 +153,7 @@ impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> { fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match self { match self {
Stream::Http(sock) => sock.write(buf), Stream::Http(sock) => sock.write(buf),
#[cfg(feature = "tls")] #[cfg(any(feature = "tls", feature = "native-tls"))]
Stream::Https(stream) => stream.write(buf), Stream::Https(stream) => stream.write(buf),
Stream::Cursor(_) => panic!("Write to read only stream"), Stream::Cursor(_) => panic!("Write to read only stream"),
#[cfg(test)] #[cfg(test)]
@@ -145,7 +163,7 @@ impl Write for Stream {
fn flush(&mut self) -> IoResult<()> { fn flush(&mut self) -> IoResult<()> {
match self { match self {
Stream::Http(sock) => sock.flush(), Stream::Http(sock) => sock.flush(),
#[cfg(feature = "tls")] #[cfg(any(feature = "tls", feature = "native-tls"))]
Stream::Https(stream) => stream.flush(), Stream::Https(stream) => stream.flush(),
Stream::Cursor(_) => panic!("Flush read only stream"), Stream::Cursor(_) => panic!("Flush read only stream"),
#[cfg(test)] #[cfg(test)]
@@ -162,6 +180,7 @@ pub(crate) fn connect_http(unit: &Unit) -> Result<Stream, Error> {
connect_host(unit, hostname, port).map(Stream::Http) connect_host(unit, hostname, port).map(Stream::Http)
} }
#[cfg(all(feature = "tls", feature = "native-certs"))] #[cfg(all(feature = "tls", feature = "native-certs"))]
fn configure_certs(config: &mut rustls::ClientConfig) { fn configure_certs(config: &mut rustls::ClientConfig) {
config.root_store = config.root_store =
@@ -204,6 +223,23 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
Ok(Stream::Https(stream)) Ok(Stream::Https(stream))
} }
#[cfg(feature = "native-tls")]
pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
let hostname = unit.url.host_str().unwrap();
let port = unit.url.port().unwrap_or(443);
let sock = connect_host(unit, hostname, port)?;
let tls_connector = TlsConnector::new().map_err(|e| Error::TlsError(e))?;
let stream = tls_connector.connect(hostname, sock).map_err(|e| {
match e {
HandshakeError::Failure(err) => Error::TlsError(err),
_ => Error::BadStatusRead,
}
})?;
Ok(Stream::Https(stream))
}
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> { pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
// //
let sock_addrs: Vec<SocketAddr> = match unit.proxy { let sock_addrs: Vec<SocketAddr> = match unit.proxy {
@@ -470,7 +506,8 @@ pub(crate) fn connect_test(unit: &Unit) -> Result<Stream, Error> {
Err(Error::UnknownScheme(unit.url.scheme().to_string())) Err(Error::UnknownScheme(unit.url.scheme().to_string()))
} }
#[cfg(not(feature = "tls"))] #[cfg(not(any(feature = "tls", feature = "native-tls")))]
pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> { pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
Err(Error::UnknownScheme(unit.url.scheme().to_string())) Err(Error::UnknownScheme(unit.url.scheme().to_string()))
} }

View File

@@ -1,9 +1,11 @@
#[cfg(any(feature = "tls", feature = "native-tls"))]
use std::io::Read; use std::io::Read;
#[cfg(any(feature = "tls", feature = "native-tls"))]
use super::super::*; use super::super::*;
#[test] #[test]
#[cfg(feature = "tls")] #[cfg(any(feature = "tls", feature = "native-tls"))]
fn read_range() { fn read_range() {
let resp = get("https://ureq.s3.eu-central-1.amazonaws.com/sherlock.txt") let resp = get("https://ureq.s3.eu-central-1.amazonaws.com/sherlock.txt")
.set("Range", "bytes=1000-1999") .set("Range", "bytes=1000-1999")
@@ -20,7 +22,7 @@ fn read_range() {
} }
#[test] #[test]
#[cfg(feature = "tls")] #[cfg(any(feature = "tls", feature = "native-tls"))]
fn agent_pool() { fn agent_pool() {
let agent = agent(); let agent = agent();

View File

@@ -1,7 +1,7 @@
#[cfg(all(test, any(feature = "tls", feature = "native-tls")))]
use std::io::Read; use std::io::Read;
#[cfg(feature = "tls")] #[cfg(all(test, any(feature = "tls", feature = "native-tls")))]
#[test]
fn tls_connection_close() { fn tls_connection_close() {
let agent = ureq::Agent::default().build(); let agent = ureq::Agent::default().build();
let resp = agent let resp = agent