diff --git a/src/request.rs b/src/request.rs index bf444d1..3e72d3c 100644 --- a/src/request.rs +++ b/src/request.rs @@ -8,6 +8,9 @@ use url::{form_urlencoded, Url}; #[cfg(feature = "tls")] use std::fmt; +#[cfg(all(feature = "native-tls", not(feature = "tls")))] +use std::fmt; + use crate::agent::{self, Agent, AgentState}; use crate::body::{Payload, SizedReader}; use crate::error::Error; @@ -47,6 +50,8 @@ pub struct Request { pub(crate) proxy: Option, #[cfg(feature = "tls")] pub(crate) tls_config: Option, + #[cfg(all(feature = "native-tls", not(feature = "tls")))] + pub(crate) tls_connector: Option, } impl ::std::fmt::Debug for Request { @@ -599,6 +604,20 @@ impl Request { self } + /// Sets the TLS connector that will be used for the connection. + /// + /// Example: + /// ``` + /// let tls_connector = std::sync::Arc::new(native_tls::TlsConnector::new()); + /// let req = ureq::post("https://cool.server") + /// .set_tls_connector(tls_connector.clone()); + /// ``` + #[cfg(all(feature = "native-tls", not(feature = "tls")))] + pub fn set_tls_connector(&mut self, tls_connector: Arc) -> &mut Request { + self.tls_connector = Some(TLSConnector(tls_connector)); + self + } + // Returns true if this request, with the provided body, is retryable. pub(crate) fn is_retryable(&self, body: &SizedReader) -> bool { // Per https://tools.ietf.org/html/rfc7231#section-8.1.3 @@ -626,3 +645,14 @@ impl fmt::Debug for TLSClientConfig { f.debug_struct("TLSClientConfig").finish() } } + +#[cfg(all(feature = "native-tls", not(feature = "tls")))] +#[derive(Clone)] +pub(crate) struct TLSConnector(pub(crate) Arc); + +#[cfg(all(feature = "native-tls", not(feature = "tls")))] +impl fmt::Debug for TLSConnector { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TLSConnector").finish() + } +} diff --git a/src/stream.rs b/src/stream.rs index fd3800c..042821e 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -15,7 +15,7 @@ use rustls::StreamOwned; use socks::{TargetAddr, ToTargetAddr}; #[cfg(feature = "native-tls")] -use native_tls::{HandshakeError, TlsConnector, TlsStream}; +use native_tls::{HandshakeError, TlsStream}; use crate::proxy::Proto; use crate::proxy::Proxy; @@ -311,15 +311,22 @@ pub(crate) fn connect_https(unit: &Unit) -> Result { #[cfg(all(feature = "native-tls", not(feature = "tls")))] pub(crate) fn connect_https(unit: &Unit) -> Result { + use std::sync::Arc; + let hostname = unit.url.host_str().unwrap(); let port = unit.url.port().unwrap_or(443); let sock = connect_host(unit, hostname, port)?; - let tls_connector = TlsConnector::new().map_err(|e| Error::TlsError(e))?; - let stream = tls_connector.connect(&hostname.trim_matches(|c| c == '[' || c == ']'), sock).map_err(|e| match e { - HandshakeError::Failure(err) => Error::TlsError(err), - _ => Error::BadStatusRead, - })?; + let tls_connector: Arc = match &unit.tls_connector { + Some(connector) => connector.0.clone(), + None => Arc::new(native_tls::TlsConnector::new().map_err(|e| Error::TlsError(e))?), + }; + let stream = tls_connector + .connect(&hostname.trim_matches(|c| c == '[' || c == ']'), sock) + .map_err(|e| match e { + HandshakeError::Failure(err) => Error::TlsError(err), + _ => Error::BadStatusRead, + })?; Ok(Stream::Https(stream)) } diff --git a/src/unit.rs b/src/unit.rs index 794353a..964214c 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -18,6 +18,9 @@ use crate::{Error, Header, Request, Response}; #[cfg(feature = "tls")] use crate::request::TLSClientConfig; +#[cfg(all(feature = "native-tls", not(feature = "tls")))] +use crate::request::TLSConnector; + #[cfg(feature = "cookie")] use crate::pool::DEFAULT_HOST; @@ -39,6 +42,8 @@ pub(crate) struct Unit { pub proxy: Option, #[cfg(feature = "tls")] pub tls_config: Option, + #[cfg(all(feature = "native-tls", not(feature = "tls")))] + pub tls_connector: Option, } impl Unit { @@ -108,6 +113,8 @@ impl Unit { proxy: req.proxy.clone(), #[cfg(feature = "tls")] tls_config: req.tls_config.clone(), + #[cfg(all(feature = "native-tls", not(feature = "tls")))] + tls_connector: req.tls_connector.clone(), } }