From 65371c966c8d915c989d476e99d25e9b4853eaa9 Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Sat, 30 Apr 2022 12:50:08 +0200 Subject: [PATCH] Box for TlsConnector::connect --- examples/custom-tls.rs | 15 +++++-- examples/mbedtls/mbedtls_connector.rs | 34 +++++++++++++++- src/lib.rs | 3 +- src/ntls.rs | 18 ++++++--- src/rtls.rs | 13 +++--- src/stream.rs | 58 ++++++++++++--------------- 6 files changed, 90 insertions(+), 51 deletions(-) diff --git a/examples/custom-tls.rs b/examples/custom-tls.rs index 8d97f8d..772a2f4 100644 --- a/examples/custom-tls.rs +++ b/examples/custom-tls.rs @@ -31,21 +31,28 @@ struct PassThrough { } impl TlsConnector for PassThrough { - fn connect(&self, _dns_name: &str, tcp_stream: TcpStream) -> Result, Error> { + fn connect( + &self, + _dns_name: &str, + io: Box, + ) -> Result, Error> { if self.handshake_fail { let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError); return Err(io_err.into()); } - Ok(Box::new(CustomTlsStream(tcp_stream))) + Ok(Box::new(CustomTlsStream(io))) } } -struct CustomTlsStream(TcpStream); +struct CustomTlsStream(Box); impl ReadWrite for CustomTlsStream { fn socket(&self) -> Option<&TcpStream> { - Some(&self.0) + self.0.socket() + } + fn is_poolable(&self) -> bool { + self.0.is_poolable() } } diff --git a/examples/mbedtls/mbedtls_connector.rs b/examples/mbedtls/mbedtls_connector.rs index 13855d1..1fdd8c5 100644 --- a/examples/mbedtls/mbedtls_connector.rs +++ b/examples/mbedtls/mbedtls_connector.rs @@ -53,9 +53,14 @@ impl MbedTlsConnector { } impl TlsConnector for MbedTlsConnector { - fn connect(&self, _dns_name: &str, tcp_stream: TcpStream) -> Result, Error> { + fn connect( + &self, + _dns_name: &str, + io: Box, + ) -> Result, Error> { let mut ctx = self.context.lock().unwrap(); - match ctx.establish(tcp_stream, None) { + let sync = SyncIo(Mutex::new(io)); + match ctx.establish(sync, None) { Err(_) => { let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError); return Err(io_err.into()); @@ -65,6 +70,28 @@ impl TlsConnector for MbedTlsConnector { } } +/// Internal wrapper to make Box implement Sync +struct SyncIo(Mutex>); + +impl io::Read for SyncIo { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut lock = self.0.lock().unwrap(); + lock.read(buf) + } +} + +impl io::Write for SyncIo { + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut lock = self.0.lock().unwrap(); + lock.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + let mut lock = self.0.lock().unwrap(); + lock.flush() + } +} + struct MbedTlsStream { context: Arc>, //tcp_stream: TcpStream, } @@ -84,6 +111,9 @@ impl ReadWrite for MbedTlsStream { fn socket(&self) -> Option<&TcpStream> { None } + fn is_poolable(&self) -> bool { + true + } } impl io::Read for MbedTlsStream { diff --git a/src/lib.rs b/src/lib.rs index 69ee72f..87c91bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -356,7 +356,6 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { // calls at the top of the crate (`ureq::get` etc). #[cfg(not(feature = "tls"))] pub(crate) fn default_tls_config() -> std::sync::Arc { - use std::net::TcpStream; use std::sync::Arc; struct NoTlsConfig; @@ -365,7 +364,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { fn connect( &self, _dns_name: &str, - _tcp_stream: TcpStream, + _io: Box, ) -> Result, crate::error::Error> { Err(ErrorKind::UnknownScheme .msg("cannot make HTTPS request because no TLS backend is configured")) diff --git a/src/ntls.rs b/src/ntls.rs index a669e3a..65b2c00 100644 --- a/src/ntls.rs +++ b/src/ntls.rs @@ -11,12 +11,15 @@ pub(crate) fn default_tls_config() -> std::sync::Arc { } impl TlsConnector for native_tls::TlsConnector { - fn connect(&self, dns_name: &str, tcp_stream: TcpStream) -> Result, Error> { + fn connect(&self, dns_name: &str, io: Box) -> Result, Error> { let stream = - native_tls::TlsConnector::connect(self, dns_name, tcp_stream).map_err(|e| { - ErrorKind::ConnectionFailed + native_tls::TlsConnector::connect(self, dns_name, io).map_err(|e| match e { + native_tls::HandshakeError::Failure(e) => ErrorKind::ConnectionFailed .msg("native_tls connect failed") - .src(e) + .src(e), + native_tls::HandshakeError::WouldBlock(_) => { + ErrorKind::Io.msg("Unexpected native_tls::HandshakeError::WouldBlock") + } })?; Ok(Box::new(stream)) @@ -24,8 +27,11 @@ impl TlsConnector for native_tls::TlsConnector { } #[cfg(feature = "native-tls")] -impl ReadWrite for native_tls::TlsStream { +impl ReadWrite for native_tls::TlsStream> { fn socket(&self) -> Option<&TcpStream> { - Some(self.get_ref()) + self.get_ref().socket() + } + fn is_poolable(&self) -> bool { + self.get_ref().is_poolable() } } diff --git a/src/rtls.rs b/src/rtls.rs index 277e536..39848d3 100644 --- a/src/rtls.rs +++ b/src/rtls.rs @@ -26,11 +26,14 @@ fn is_close_notify(e: &std::io::Error) -> bool { false } -struct RustlsStream(rustls::StreamOwned); +struct RustlsStream(rustls::StreamOwned>); impl ReadWrite for RustlsStream { fn socket(&self) -> Option<&TcpStream> { - Some(self.0.get_ref()) + self.0.get_ref().socket() + } + fn is_poolable(&self) -> bool { + self.0.get_ref().is_poolable() } } @@ -93,7 +96,7 @@ impl TlsConnector for Arc { fn connect( &self, dns_name: &str, - mut tcp_stream: TcpStream, + mut io: Box, ) -> Result, Error> { let sni = rustls::ServerName::try_from(dns_name) .map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?; @@ -101,12 +104,12 @@ impl TlsConnector for Arc { let mut sess = rustls::ClientConnection::new(self.clone(), sni) .map_err(|e| ErrorKind::Io.msg("tls connection creation failed").src(e))?; - sess.complete_io(&mut tcp_stream).map_err(|e| { + sess.complete_io(&mut io).map_err(|e| { ErrorKind::ConnectionFailed .msg("tls connection init failed") .src(e) })?; - let stream = rustls::StreamOwned::new(sess, tcp_stream); + let stream = rustls::StreamOwned::new(sess, io); Ok(Box::new(RustlsStream(stream))) } diff --git a/src/stream.rs b/src/stream.rs index 643ae55..396c273 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -20,58 +20,52 @@ use crate::unit::Unit; /// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector]. pub trait ReadWrite: Read + Write + Send + 'static { fn socket(&self) -> Option<&TcpStream>; + fn is_poolable(&self) -> bool; + + /// The bytes written to the stream as a Vec. This is used for tests only. + #[cfg(test)] + fn written_bytes(&self) -> Vec { + panic!("written_bytes on non Test stream"); + } +} + +impl ReadWrite for TcpStream { + fn socket(&self) -> Option<&TcpStream> { + Some(self) + } + fn is_poolable(&self) -> bool { + true + } } pub trait TlsConnector: Send + Sync { fn connect( &self, dns_name: &str, - tcp_stream: TcpStream, + io: Box, ) -> Result, crate::error::Error>; } pub(crate) struct Stream { - inner: BufReader>, -} - -trait Inner: Read + Write { - fn is_poolable(&self) -> bool; - fn socket(&self) -> Option<&TcpStream>; - - /// The bytes written to the stream as a Vec. This is used for tests only. - fn written_bytes(&self) -> Vec { - panic!("written_bytes on non Test stream"); - } + inner: BufReader>, } impl ReadWrite for Box { fn socket(&self) -> Option<&TcpStream> { ReadWrite::socket(self.as_ref()) } -} - -impl Inner for T { fn is_poolable(&self) -> bool { - true + ReadWrite::is_poolable(self.as_ref()) } - - fn socket(&self) -> Option<&TcpStream> { - ReadWrite::socket(self) - } -} - -impl Inner for TcpStream { - fn is_poolable(&self) -> bool { - true - } - fn socket(&self) -> Option<&TcpStream> { - Some(self) + #[cfg(test)] + fn written_bytes(&self) -> Vec { + ReadWrite::written_bytes(self.as_ref()) } } struct TestStream(Box, Vec, bool); -impl Inner for TestStream { +impl ReadWrite for TestStream { fn is_poolable(&self) -> bool { self.2 } @@ -79,7 +73,7 @@ impl Inner for TestStream { None } - /// For tests only + #[cfg(test)] fn written_bytes(&self) -> Vec { self.1.clone() } @@ -188,7 +182,7 @@ impl fmt::Debug for Stream { } impl Stream { - fn new(t: impl Inner + Send + 'static) -> Stream { + fn new(t: impl ReadWrite) -> Stream { Stream::logged_create(Stream { inner: BufReader::new(Box::new(t)), }) @@ -348,7 +342,7 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result