Fixes after feedback

This commit is contained in:
Martin Algesten
2022-05-01 08:33:47 +02:00
parent 65371c966c
commit 049b5a5acd
6 changed files with 35 additions and 34 deletions

View File

@@ -34,7 +34,7 @@ impl TlsConnector for PassThrough {
fn connect( fn connect(
&self, &self,
_dns_name: &str, _dns_name: &str,
io: Box<dyn ReadWrite>, io: Box<dyn ReadWrite + Sync>,
) -> Result<Box<dyn ReadWrite>, Error> { ) -> Result<Box<dyn ReadWrite>, Error> {
if self.handshake_fail { if self.handshake_fail {
let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError); let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError);
@@ -45,6 +45,7 @@ impl TlsConnector for PassThrough {
} }
} }
#[derive(Debug)]
struct CustomTlsStream(Box<dyn ReadWrite>); struct CustomTlsStream(Box<dyn ReadWrite>);
impl ReadWrite for CustomTlsStream { impl ReadWrite for CustomTlsStream {

View File

@@ -56,11 +56,10 @@ impl TlsConnector for MbedTlsConnector {
fn connect( fn connect(
&self, &self,
_dns_name: &str, _dns_name: &str,
io: Box<dyn ReadWrite>, io: Box<dyn ReadWrite + Sync>,
) -> Result<Box<dyn ReadWrite>, Error> { ) -> Result<Box<dyn ReadWrite>, Error> {
let mut ctx = self.context.lock().unwrap(); let mut ctx = self.context.lock().unwrap();
let sync = SyncIo(Mutex::new(io)); match ctx.establish(io, None) {
match ctx.establish(sync, None) {
Err(_) => { Err(_) => {
let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError); let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError);
return Err(io_err.into()); return Err(io_err.into());
@@ -70,32 +69,16 @@ impl TlsConnector for MbedTlsConnector {
} }
} }
/// Internal wrapper to make Box<dyn ReadWrite> implement Sync
struct SyncIo(Mutex<Box<dyn ReadWrite>>);
impl io::Read for SyncIo {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut lock = self.0.lock().unwrap();
lock.read(buf)
}
}
impl io::Write for SyncIo {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut lock = self.0.lock().unwrap();
lock.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
let mut lock = self.0.lock().unwrap();
lock.flush()
}
}
struct MbedTlsStream { struct MbedTlsStream {
context: Arc<Mutex<Context>>, //tcp_stream: TcpStream, context: Arc<Mutex<Context>>, //tcp_stream: TcpStream,
} }
impl fmt::Debug for MbedTlsStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MbedTlsStream").finish()
}
}
impl MbedTlsStream { impl MbedTlsStream {
pub fn new(mtc: &MbedTlsConnector) -> Box<MbedTlsStream> { pub fn new(mtc: &MbedTlsConnector) -> Box<MbedTlsStream> {
Box::new(MbedTlsStream { Box::new(MbedTlsStream {

View File

@@ -364,7 +364,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
fn connect( fn connect(
&self, &self,
_dns_name: &str, _dns_name: &str,
_io: Box<dyn ReadWrite>, _io: Box<dyn ReadWrite + Sync>,
) -> Result<Box<dyn ReadWrite>, crate::error::Error> { ) -> Result<Box<dyn ReadWrite>, crate::error::Error> {
Err(ErrorKind::UnknownScheme Err(ErrorKind::UnknownScheme
.msg("cannot make HTTPS request because no TLS backend is configured")) .msg("cannot make HTTPS request because no TLS backend is configured"))

View File

@@ -11,7 +11,11 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
} }
impl TlsConnector for native_tls::TlsConnector { impl TlsConnector for native_tls::TlsConnector {
fn connect(&self, dns_name: &str, io: Box<dyn ReadWrite>) -> Result<Box<dyn ReadWrite>, Error> { fn connect(
&self,
dns_name: &str,
io: Box<dyn ReadWrite + Sync>,
) -> Result<Box<dyn ReadWrite>, Error> {
let stream = let stream =
native_tls::TlsConnector::connect(self, dns_name, io).map_err(|e| match e { native_tls::TlsConnector::connect(self, dns_name, io).map_err(|e| match e {
native_tls::HandshakeError::Failure(e) => ErrorKind::ConnectionFailed native_tls::HandshakeError::Failure(e) => ErrorKind::ConnectionFailed
@@ -27,7 +31,7 @@ impl TlsConnector for native_tls::TlsConnector {
} }
#[cfg(feature = "native-tls")] #[cfg(feature = "native-tls")]
impl ReadWrite for native_tls::TlsStream<Box<dyn ReadWrite>> { impl ReadWrite for native_tls::TlsStream<Box<dyn ReadWrite + Sync>> {
fn socket(&self) -> Option<&TcpStream> { fn socket(&self) -> Option<&TcpStream> {
self.get_ref().socket() self.get_ref().socket()
} }

View File

@@ -1,4 +1,5 @@
use std::convert::TryFrom; use std::convert::TryFrom;
use std::fmt;
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
use std::net::TcpStream; use std::net::TcpStream;
use std::sync::Arc; use std::sync::Arc;
@@ -26,7 +27,7 @@ fn is_close_notify(e: &std::io::Error) -> bool {
false false
} }
struct RustlsStream(rustls::StreamOwned<rustls::ClientConnection, Box<dyn ReadWrite>>); struct RustlsStream(rustls::StreamOwned<rustls::ClientConnection, Box<dyn ReadWrite + Sync>>);
impl ReadWrite for RustlsStream { impl ReadWrite for RustlsStream {
fn socket(&self) -> Option<&TcpStream> { fn socket(&self) -> Option<&TcpStream> {
@@ -96,7 +97,7 @@ impl TlsConnector for Arc<rustls::ClientConfig> {
fn connect( fn connect(
&self, &self,
dns_name: &str, dns_name: &str,
mut io: Box<dyn ReadWrite>, mut io: Box<dyn ReadWrite + Sync>,
) -> Result<Box<dyn ReadWrite>, Error> { ) -> Result<Box<dyn ReadWrite>, Error> {
let sni = rustls::ServerName::try_from(dns_name) let sni = rustls::ServerName::try_from(dns_name)
.map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?; .map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?;
@@ -125,3 +126,9 @@ pub fn default_tls_config() -> Arc<dyn TlsConnector> {
}); });
TLS_CONF.clone() TLS_CONF.clone()
} }
impl fmt::Debug for RustlsStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("RustlsStream").finish()
}
}

View File

@@ -18,7 +18,7 @@ use crate::error::ErrorKind;
use crate::unit::Unit; use crate::unit::Unit;
/// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector]. /// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector].
pub trait ReadWrite: Read + Write + Send + 'static { pub trait ReadWrite: Read + Write + Send + fmt::Debug + 'static {
fn socket(&self) -> Option<&TcpStream>; fn socket(&self) -> Option<&TcpStream>;
fn is_poolable(&self) -> bool; fn is_poolable(&self) -> bool;
@@ -42,7 +42,7 @@ pub trait TlsConnector: Send + Sync {
fn connect( fn connect(
&self, &self,
dns_name: &str, dns_name: &str,
io: Box<dyn ReadWrite>, io: Box<dyn ReadWrite + Sync>,
) -> Result<Box<dyn ReadWrite>, crate::error::Error>; ) -> Result<Box<dyn ReadWrite>, crate::error::Error>;
} }
@@ -95,6 +95,12 @@ impl Write for TestStream {
} }
} }
impl fmt::Debug for TestStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("TestStream").finish()
}
}
// DeadlineStream wraps a stream such that read() will return an error // DeadlineStream wraps a stream such that read() will return an error
// after the provided deadline, and sets timeouts on the underlying // after the provided deadline, and sets timeouts on the underlying
// TcpStream to ensure read() doesn't block beyond the deadline. // TcpStream to ensure read() doesn't block beyond the deadline.
@@ -175,7 +181,7 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
impl fmt::Debug for Stream { impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner.get_ref().socket() { match self.inner.get_ref().socket() {
Some(s) => write!(f, "{:?}", s), Some(_) => write!(f, "Stream({:?})", self.inner.get_ref()),
None => write!(f, "Stream(Test)"), None => write!(f, "Stream(Test)"),
} }
} }