From 4675d748e9b0fc7310d768448135960dddc15d5d Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Sun, 1 May 2022 09:33:43 +0200 Subject: [PATCH] Remove Sync bound from TlsConnector io arg --- examples/custom-tls.rs | 2 +- examples/mbedtls/mbedtls_connector.rs | 26 ++++++++++++++++++++++++-- src/lib.rs | 2 +- src/ntls.rs | 8 ++------ src/rtls.rs | 4 ++-- src/stream.rs | 2 +- 6 files changed, 31 insertions(+), 13 deletions(-) diff --git a/examples/custom-tls.rs b/examples/custom-tls.rs index 0480eb5..43eba25 100644 --- a/examples/custom-tls.rs +++ b/examples/custom-tls.rs @@ -34,7 +34,7 @@ impl TlsConnector for PassThrough { fn connect( &self, _dns_name: &str, - io: Box, + io: Box, ) -> Result, Error> { if self.handshake_fail { let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError); diff --git a/examples/mbedtls/mbedtls_connector.rs b/examples/mbedtls/mbedtls_connector.rs index b1cd503..941c6b4 100644 --- a/examples/mbedtls/mbedtls_connector.rs +++ b/examples/mbedtls/mbedtls_connector.rs @@ -56,10 +56,11 @@ impl TlsConnector for MbedTlsConnector { fn connect( &self, _dns_name: &str, - io: Box, + io: Box, ) -> Result, Error> { let mut ctx = self.context.lock().unwrap(); - match ctx.establish(io, None) { + let sync = SyncIo(Mutex::new(io)); + match ctx.establish(sync, None) { Err(_) => { let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError); return Err(io_err.into()); @@ -69,6 +70,27 @@ impl TlsConnector for MbedTlsConnector { } } +struct SyncIo(Mutex>); + +impl io::Read for SyncIo { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut lock = self.0.lock().unwrap(); + lock.read(buf) + } +} + +impl io::Write for SyncIo { + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut lock = self.0.lock().unwrap(); + lock.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + let mut lock = self.0.lock().unwrap(); + lock.flush() + } +} + struct MbedTlsStream { context: Arc>, //tcp_stream: TcpStream, } diff --git a/src/lib.rs b/src/lib.rs index 8856125..87c91bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -364,7 +364,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { fn connect( &self, _dns_name: &str, - _io: Box, + _io: Box, ) -> Result, crate::error::Error> { Err(ErrorKind::UnknownScheme .msg("cannot make HTTPS request because no TLS backend is configured")) diff --git a/src/ntls.rs b/src/ntls.rs index bda511b..65b2c00 100644 --- a/src/ntls.rs +++ b/src/ntls.rs @@ -11,11 +11,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { } impl TlsConnector for native_tls::TlsConnector { - fn connect( - &self, - dns_name: &str, - io: Box, - ) -> Result, Error> { + fn connect(&self, dns_name: &str, io: Box) -> Result, Error> { let stream = native_tls::TlsConnector::connect(self, dns_name, io).map_err(|e| match e { native_tls::HandshakeError::Failure(e) => ErrorKind::ConnectionFailed @@ -31,7 +27,7 @@ impl TlsConnector for native_tls::TlsConnector { } #[cfg(feature = "native-tls")] -impl ReadWrite for native_tls::TlsStream> { +impl ReadWrite for native_tls::TlsStream> { fn socket(&self) -> Option<&TcpStream> { self.get_ref().socket() } diff --git a/src/rtls.rs b/src/rtls.rs index c9bc9be..a3888c6 100644 --- a/src/rtls.rs +++ b/src/rtls.rs @@ -27,7 +27,7 @@ fn is_close_notify(e: &std::io::Error) -> bool { false } -struct RustlsStream(rustls::StreamOwned>); +struct RustlsStream(rustls::StreamOwned>); impl ReadWrite for RustlsStream { fn socket(&self) -> Option<&TcpStream> { @@ -97,7 +97,7 @@ impl TlsConnector for Arc { fn connect( &self, dns_name: &str, - mut io: Box, + mut io: Box, ) -> Result, Error> { let sni = rustls::ServerName::try_from(dns_name) .map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?; diff --git a/src/stream.rs b/src/stream.rs index 7a735d4..a88a925 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -42,7 +42,7 @@ pub trait TlsConnector: Send + Sync { fn connect( &self, dns_name: &str, - io: Box, + io: Box, ) -> Result, crate::error::Error>; }