From 049b5a5acd6766919452b3ef034807ccea65a845 Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Sun, 1 May 2022 08:33:47 +0200 Subject: [PATCH] Fixes after feedback --- examples/custom-tls.rs | 3 ++- examples/mbedtls/mbedtls_connector.rs | 33 +++++++-------------------- src/lib.rs | 2 +- src/ntls.rs | 8 +++++-- src/rtls.rs | 11 +++++++-- src/stream.rs | 12 +++++++--- 6 files changed, 35 insertions(+), 34 deletions(-) diff --git a/examples/custom-tls.rs b/examples/custom-tls.rs index 772a2f4..0480eb5 100644 --- a/examples/custom-tls.rs +++ b/examples/custom-tls.rs @@ -34,7 +34,7 @@ impl TlsConnector for PassThrough { fn connect( &self, _dns_name: &str, - io: Box, + io: Box, ) -> Result, Error> { if self.handshake_fail { let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError); @@ -45,6 +45,7 @@ impl TlsConnector for PassThrough { } } +#[derive(Debug)] struct CustomTlsStream(Box); impl ReadWrite for CustomTlsStream { diff --git a/examples/mbedtls/mbedtls_connector.rs b/examples/mbedtls/mbedtls_connector.rs index 1fdd8c5..b1cd503 100644 --- a/examples/mbedtls/mbedtls_connector.rs +++ b/examples/mbedtls/mbedtls_connector.rs @@ -56,11 +56,10 @@ impl TlsConnector for MbedTlsConnector { fn connect( &self, _dns_name: &str, - io: Box, + io: Box, ) -> Result, Error> { let mut ctx = self.context.lock().unwrap(); - let sync = SyncIo(Mutex::new(io)); - match ctx.establish(sync, None) { + match ctx.establish(io, None) { Err(_) => { let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError); return Err(io_err.into()); @@ -70,32 +69,16 @@ impl TlsConnector for MbedTlsConnector { } } -/// Internal wrapper to make Box implement Sync -struct SyncIo(Mutex>); - -impl io::Read for SyncIo { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut lock = self.0.lock().unwrap(); - lock.read(buf) - } -} - -impl io::Write for SyncIo { - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut lock = self.0.lock().unwrap(); - lock.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - let mut lock = self.0.lock().unwrap(); - lock.flush() - } -} - struct MbedTlsStream { context: Arc>, //tcp_stream: TcpStream, } +impl fmt::Debug for MbedTlsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MbedTlsStream").finish() + } +} + impl MbedTlsStream { pub fn new(mtc: &MbedTlsConnector) -> Box { Box::new(MbedTlsStream { diff --git a/src/lib.rs b/src/lib.rs index 87c91bc..8856125 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -364,7 +364,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { fn connect( &self, _dns_name: &str, - _io: Box, + _io: Box, ) -> Result, crate::error::Error> { Err(ErrorKind::UnknownScheme .msg("cannot make HTTPS request because no TLS backend is configured")) diff --git a/src/ntls.rs b/src/ntls.rs index 65b2c00..bda511b 100644 --- a/src/ntls.rs +++ b/src/ntls.rs @@ -11,7 +11,11 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { } impl TlsConnector for native_tls::TlsConnector { - fn connect(&self, dns_name: &str, io: Box) -> Result, Error> { + fn connect( + &self, + dns_name: &str, + io: Box, + ) -> Result, Error> { let stream = native_tls::TlsConnector::connect(self, dns_name, io).map_err(|e| match e { native_tls::HandshakeError::Failure(e) => ErrorKind::ConnectionFailed @@ -27,7 +31,7 @@ impl TlsConnector for native_tls::TlsConnector { } #[cfg(feature = "native-tls")] -impl ReadWrite for native_tls::TlsStream> { +impl ReadWrite for native_tls::TlsStream> { fn socket(&self) -> Option<&TcpStream> { self.get_ref().socket() } diff --git a/src/rtls.rs b/src/rtls.rs index 39848d3..c9bc9be 100644 --- a/src/rtls.rs +++ b/src/rtls.rs @@ -1,4 +1,5 @@ use std::convert::TryFrom; +use std::fmt; use std::io::{self, Read, Write}; use std::net::TcpStream; use std::sync::Arc; @@ -26,7 +27,7 @@ fn is_close_notify(e: &std::io::Error) -> bool { false } -struct RustlsStream(rustls::StreamOwned>); +struct RustlsStream(rustls::StreamOwned>); impl ReadWrite for RustlsStream { fn socket(&self) -> Option<&TcpStream> { @@ -96,7 +97,7 @@ impl TlsConnector for Arc { fn connect( &self, dns_name: &str, - mut io: Box, + mut io: Box, ) -> Result, Error> { let sni = rustls::ServerName::try_from(dns_name) .map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?; @@ -125,3 +126,9 @@ pub fn default_tls_config() -> Arc { }); TLS_CONF.clone() } + +impl fmt::Debug for RustlsStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("RustlsStream").finish() + } +} diff --git a/src/stream.rs b/src/stream.rs index 396c273..7a735d4 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -18,7 +18,7 @@ use crate::error::ErrorKind; use crate::unit::Unit; /// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector]. -pub trait ReadWrite: Read + Write + Send + 'static { +pub trait ReadWrite: Read + Write + Send + fmt::Debug + 'static { fn socket(&self) -> Option<&TcpStream>; fn is_poolable(&self) -> bool; @@ -42,7 +42,7 @@ pub trait TlsConnector: Send + Sync { fn connect( &self, dns_name: &str, - io: Box, + io: Box, ) -> Result, crate::error::Error>; } @@ -95,6 +95,12 @@ impl Write for TestStream { } } +impl fmt::Debug for TestStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("TestStream").finish() + } +} + // DeadlineStream wraps a stream such that read() will return an error // after the provided deadline, and sets timeouts on the underlying // TcpStream to ensure read() doesn't block beyond the deadline. @@ -175,7 +181,7 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error { impl fmt::Debug for Stream { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.inner.get_ref().socket() { - Some(s) => write!(f, "{:?}", s), + Some(_) => write!(f, "Stream({:?})", self.inner.get_ref()), None => write!(f, "Stream(Test)"), } }