diff --git a/examples/custom-tls.rs b/examples/custom-tls.rs new file mode 100644 index 0000000..8d97f8d --- /dev/null +++ b/examples/custom-tls.rs @@ -0,0 +1,81 @@ +use std::fmt; +use std::io; +use std::net::TcpStream; +use std::sync::Arc; + +use ureq::{Error, ReadWrite, TlsConnector}; + +pub fn main() -> Result<(), Error> { + let pass = PassThrough { + handshake_fail: false, + }; + + let agent = ureq::builder().tls_connector(Arc::new(pass)).build(); + + let _response = agent.get("https://httpbin.org/").call(); + + // Uncomment this if handshake_fail is set to true above. + // assert_eq!( + // _response.unwrap_err().to_string(), + // "https://httpbin.org/: Network Error: Tls handshake failed" + // ); + + Ok(()) +} + +/// A pass-through tls connector that just uses the plain socket without any encryption. +/// This is not a good idea for production code. The `handshake_fail` can be set to true +/// to simulate a TLS handshake failure. +struct PassThrough { + handshake_fail: bool, +} + +impl TlsConnector for PassThrough { + fn connect(&self, _dns_name: &str, tcp_stream: TcpStream) -> Result, Error> { + if self.handshake_fail { + let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError); + return Err(io_err.into()); + } + + Ok(Box::new(CustomTlsStream(tcp_stream))) + } +} + +struct CustomTlsStream(TcpStream); + +impl ReadWrite for CustomTlsStream { + fn socket(&self) -> Option<&TcpStream> { + Some(&self.0) + } +} + +impl io::Read for CustomTlsStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl io::Write for CustomTlsStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.0.flush() + } +} + +#[derive(Debug)] +struct PassThroughError; + +impl fmt::Display for PassThroughError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Tls handshake failed") + } +} + +impl std::error::Error for PassThroughError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} diff --git a/src/lib.rs b/src/lib.rs index 731845c..da55ea4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -356,7 +356,6 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { // calls at the top of the crate (`ureq::get` etc). #[cfg(not(feature = "tls"))] pub(crate) fn default_tls_config() -> std::sync::Arc { - use crate::stream::HttpsStream; use std::net::TcpStream; use std::sync::Arc; @@ -367,7 +366,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { &self, _dns_name: &str, _tcp_stream: TcpStream, - ) -> Result, crate::error::Error> { + ) -> Result, crate::error::Error> { Err(ErrorKind::UnknownScheme .msg("cannot make HTTPS request because no TLS backend is configured")) } @@ -398,7 +397,7 @@ pub use crate::proxy::Proxy; pub use crate::request::{Request, RequestUrl}; pub use crate::resolve::Resolver; pub use crate::response::Response; -pub use crate::stream::TlsConnector; +pub use crate::stream::{ReadWrite, TlsConnector}; // re-export #[cfg(feature = "cookies")] diff --git a/src/ntls.rs b/src/ntls.rs index 1dceffe..a669e3a 100644 --- a/src/ntls.rs +++ b/src/ntls.rs @@ -1,6 +1,6 @@ use crate::error::Error; use crate::error::ErrorKind; -use crate::stream::{HttpsStream, TlsConnector}; +use crate::stream::{ReadWrite, TlsConnector}; use std::net::TcpStream; use std::sync::Arc; @@ -11,11 +11,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { } impl TlsConnector for native_tls::TlsConnector { - fn connect( - &self, - dns_name: &str, - tcp_stream: TcpStream, - ) -> Result, Error> { + fn connect(&self, dns_name: &str, tcp_stream: TcpStream) -> Result, Error> { let stream = native_tls::TlsConnector::connect(self, dns_name, tcp_stream).map_err(|e| { ErrorKind::ConnectionFailed @@ -28,7 +24,7 @@ impl TlsConnector for native_tls::TlsConnector { } #[cfg(feature = "native-tls")] -impl HttpsStream for native_tls::TlsStream { +impl ReadWrite for native_tls::TlsStream { fn socket(&self) -> Option<&TcpStream> { Some(self.get_ref()) } diff --git a/src/rtls.rs b/src/rtls.rs index 3d1e8dd..277e536 100644 --- a/src/rtls.rs +++ b/src/rtls.rs @@ -7,7 +7,7 @@ use once_cell::sync::Lazy; use crate::ErrorKind; use crate::{ - stream::{HttpsStream, TlsConnector}, + stream::{ReadWrite, TlsConnector}, Error, }; @@ -28,7 +28,7 @@ fn is_close_notify(e: &std::io::Error) -> bool { struct RustlsStream(rustls::StreamOwned); -impl HttpsStream for RustlsStream { +impl ReadWrite for RustlsStream { fn socket(&self) -> Option<&TcpStream> { Some(self.0.get_ref()) } @@ -94,7 +94,7 @@ impl TlsConnector for Arc { &self, dns_name: &str, mut tcp_stream: TcpStream, - ) -> Result, Error> { + ) -> Result, Error> { let sni = rustls::ServerName::try_from(dns_name) .map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?; diff --git a/src/stream.rs b/src/stream.rs index 91156b4..1b64f0f 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -17,7 +17,8 @@ use crate::{error::Error, proxy::Proto}; use crate::error::ErrorKind; use crate::unit::Unit; -pub trait HttpsStream: Read + Write + Send + Sync + 'static { +/// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector]. +pub trait ReadWrite: Read + Write + Send + Sync + 'static { fn socket(&self) -> Option<&TcpStream>; } @@ -26,7 +27,7 @@ pub trait TlsConnector: Send + Sync { &self, dns_name: &str, tcp_stream: TcpStream, - ) -> Result, crate::error::Error>; + ) -> Result, crate::error::Error>; } pub(crate) struct Stream { @@ -41,19 +42,19 @@ trait Inner: Read + Write { } } -impl HttpsStream for Box { +impl ReadWrite for Box { fn socket(&self) -> Option<&TcpStream> { - HttpsStream::socket(self.as_ref()) + ReadWrite::socket(self.as_ref()) } } -impl Inner for T { +impl Inner for T { fn is_poolable(&self) -> bool { true } fn socket(&self) -> Option<&TcpStream> { - HttpsStream::socket(self) + ReadWrite::socket(self) } }