Rename trait HttpsStream -> ReadWrite and make it public

Also provide an example of how to use it.
This commit is contained in:
Martin Algesten
2021-12-23 16:43:50 +01:00
parent 140aa5901f
commit 6e5041044b
5 changed files with 96 additions and 19 deletions

81
examples/custom-tls.rs Normal file
View File

@@ -0,0 +1,81 @@
use std::fmt;
use std::io;
use std::net::TcpStream;
use std::sync::Arc;
use ureq::{Error, ReadWrite, TlsConnector};
pub fn main() -> Result<(), Error> {
let pass = PassThrough {
handshake_fail: false,
};
let agent = ureq::builder().tls_connector(Arc::new(pass)).build();
let _response = agent.get("https://httpbin.org/").call();
// Uncomment this if handshake_fail is set to true above.
// assert_eq!(
// _response.unwrap_err().to_string(),
// "https://httpbin.org/: Network Error: Tls handshake failed"
// );
Ok(())
}
/// A pass-through tls connector that just uses the plain socket without any encryption.
/// This is not a good idea for production code. The `handshake_fail` can be set to true
/// to simulate a TLS handshake failure.
struct PassThrough {
handshake_fail: bool,
}
impl TlsConnector for PassThrough {
fn connect(&self, _dns_name: &str, tcp_stream: TcpStream) -> Result<Box<dyn ReadWrite>, Error> {
if self.handshake_fail {
let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError);
return Err(io_err.into());
}
Ok(Box::new(CustomTlsStream(tcp_stream)))
}
}
struct CustomTlsStream(TcpStream);
impl ReadWrite for CustomTlsStream {
fn socket(&self) -> Option<&TcpStream> {
Some(&self.0)
}
}
impl io::Read for CustomTlsStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl io::Write for CustomTlsStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
#[derive(Debug)]
struct PassThroughError;
impl fmt::Display for PassThroughError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Tls handshake failed")
}
}
impl std::error::Error for PassThroughError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}

View File

@@ -356,7 +356,6 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
// calls at the top of the crate (`ureq::get` etc).
#[cfg(not(feature = "tls"))]
pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
use crate::stream::HttpsStream;
use std::net::TcpStream;
use std::sync::Arc;
@@ -367,7 +366,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
&self,
_dns_name: &str,
_tcp_stream: TcpStream,
) -> Result<Box<dyn HttpsStream>, crate::error::Error> {
) -> Result<Box<dyn ReadWrite>, crate::error::Error> {
Err(ErrorKind::UnknownScheme
.msg("cannot make HTTPS request because no TLS backend is configured"))
}
@@ -398,7 +397,7 @@ pub use crate::proxy::Proxy;
pub use crate::request::{Request, RequestUrl};
pub use crate::resolve::Resolver;
pub use crate::response::Response;
pub use crate::stream::TlsConnector;
pub use crate::stream::{ReadWrite, TlsConnector};
// re-export
#[cfg(feature = "cookies")]

View File

@@ -1,6 +1,6 @@
use crate::error::Error;
use crate::error::ErrorKind;
use crate::stream::{HttpsStream, TlsConnector};
use crate::stream::{ReadWrite, TlsConnector};
use std::net::TcpStream;
use std::sync::Arc;
@@ -11,11 +11,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
}
impl TlsConnector for native_tls::TlsConnector {
fn connect(
&self,
dns_name: &str,
tcp_stream: TcpStream,
) -> Result<Box<dyn HttpsStream>, Error> {
fn connect(&self, dns_name: &str, tcp_stream: TcpStream) -> Result<Box<dyn ReadWrite>, Error> {
let stream =
native_tls::TlsConnector::connect(self, dns_name, tcp_stream).map_err(|e| {
ErrorKind::ConnectionFailed
@@ -28,7 +24,7 @@ impl TlsConnector for native_tls::TlsConnector {
}
#[cfg(feature = "native-tls")]
impl HttpsStream for native_tls::TlsStream<TcpStream> {
impl ReadWrite for native_tls::TlsStream<TcpStream> {
fn socket(&self) -> Option<&TcpStream> {
Some(self.get_ref())
}

View File

@@ -7,7 +7,7 @@ use once_cell::sync::Lazy;
use crate::ErrorKind;
use crate::{
stream::{HttpsStream, TlsConnector},
stream::{ReadWrite, TlsConnector},
Error,
};
@@ -28,7 +28,7 @@ fn is_close_notify(e: &std::io::Error) -> bool {
struct RustlsStream(rustls::StreamOwned<rustls::ClientConnection, TcpStream>);
impl HttpsStream for RustlsStream {
impl ReadWrite for RustlsStream {
fn socket(&self) -> Option<&TcpStream> {
Some(self.0.get_ref())
}
@@ -94,7 +94,7 @@ impl TlsConnector for Arc<rustls::ClientConfig> {
&self,
dns_name: &str,
mut tcp_stream: TcpStream,
) -> Result<Box<dyn HttpsStream>, Error> {
) -> Result<Box<dyn ReadWrite>, Error> {
let sni = rustls::ServerName::try_from(dns_name)
.map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?;

View File

@@ -17,7 +17,8 @@ use crate::{error::Error, proxy::Proto};
use crate::error::ErrorKind;
use crate::unit::Unit;
pub trait HttpsStream: Read + Write + Send + Sync + 'static {
/// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector].
pub trait ReadWrite: Read + Write + Send + Sync + 'static {
fn socket(&self) -> Option<&TcpStream>;
}
@@ -26,7 +27,7 @@ pub trait TlsConnector: Send + Sync {
&self,
dns_name: &str,
tcp_stream: TcpStream,
) -> Result<Box<dyn HttpsStream>, crate::error::Error>;
) -> Result<Box<dyn ReadWrite>, crate::error::Error>;
}
pub(crate) struct Stream {
@@ -41,19 +42,19 @@ trait Inner: Read + Write {
}
}
impl<T: HttpsStream + ?Sized> HttpsStream for Box<T> {
impl<T: ReadWrite + ?Sized> ReadWrite for Box<T> {
fn socket(&self) -> Option<&TcpStream> {
HttpsStream::socket(self.as_ref())
ReadWrite::socket(self.as_ref())
}
}
impl<T: HttpsStream> Inner for T {
impl<T: ReadWrite> Inner for T {
fn is_poolable(&self) -> bool {
true
}
fn socket(&self) -> Option<&TcpStream> {
HttpsStream::socket(self)
ReadWrite::socket(self)
}
}