Box<dyn ReadWrite> for TlsConnector::connect

This commit is contained in:
Martin Algesten
2022-04-30 12:50:08 +02:00
parent 101467f13f
commit 65371c966c
6 changed files with 90 additions and 51 deletions

View File

@@ -31,21 +31,28 @@ struct PassThrough {
} }
impl TlsConnector for PassThrough { impl TlsConnector for PassThrough {
fn connect(&self, _dns_name: &str, tcp_stream: TcpStream) -> Result<Box<dyn ReadWrite>, Error> { fn connect(
&self,
_dns_name: &str,
io: Box<dyn ReadWrite>,
) -> Result<Box<dyn ReadWrite>, Error> {
if self.handshake_fail { if self.handshake_fail {
let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError); let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError);
return Err(io_err.into()); return Err(io_err.into());
} }
Ok(Box::new(CustomTlsStream(tcp_stream))) Ok(Box::new(CustomTlsStream(io)))
} }
} }
struct CustomTlsStream(TcpStream); struct CustomTlsStream(Box<dyn ReadWrite>);
impl ReadWrite for CustomTlsStream { impl ReadWrite for CustomTlsStream {
fn socket(&self) -> Option<&TcpStream> { fn socket(&self) -> Option<&TcpStream> {
Some(&self.0) self.0.socket()
}
fn is_poolable(&self) -> bool {
self.0.is_poolable()
} }
} }

View File

@@ -53,9 +53,14 @@ impl MbedTlsConnector {
} }
impl TlsConnector for MbedTlsConnector { impl TlsConnector for MbedTlsConnector {
fn connect(&self, _dns_name: &str, tcp_stream: TcpStream) -> Result<Box<dyn ReadWrite>, Error> { fn connect(
&self,
_dns_name: &str,
io: Box<dyn ReadWrite>,
) -> Result<Box<dyn ReadWrite>, Error> {
let mut ctx = self.context.lock().unwrap(); let mut ctx = self.context.lock().unwrap();
match ctx.establish(tcp_stream, None) { let sync = SyncIo(Mutex::new(io));
match ctx.establish(sync, None) {
Err(_) => { Err(_) => {
let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError); let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError);
return Err(io_err.into()); return Err(io_err.into());
@@ -65,6 +70,28 @@ impl TlsConnector for MbedTlsConnector {
} }
} }
/// Internal wrapper to make Box<dyn ReadWrite> implement Sync
struct SyncIo(Mutex<Box<dyn ReadWrite>>);
impl io::Read for SyncIo {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut lock = self.0.lock().unwrap();
lock.read(buf)
}
}
impl io::Write for SyncIo {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let mut lock = self.0.lock().unwrap();
lock.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
let mut lock = self.0.lock().unwrap();
lock.flush()
}
}
struct MbedTlsStream { struct MbedTlsStream {
context: Arc<Mutex<Context>>, //tcp_stream: TcpStream, context: Arc<Mutex<Context>>, //tcp_stream: TcpStream,
} }
@@ -84,6 +111,9 @@ impl ReadWrite for MbedTlsStream {
fn socket(&self) -> Option<&TcpStream> { fn socket(&self) -> Option<&TcpStream> {
None None
} }
fn is_poolable(&self) -> bool {
true
}
} }
impl io::Read for MbedTlsStream { impl io::Read for MbedTlsStream {

View File

@@ -356,7 +356,6 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
// calls at the top of the crate (`ureq::get` etc). // calls at the top of the crate (`ureq::get` etc).
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> { pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
use std::net::TcpStream;
use std::sync::Arc; use std::sync::Arc;
struct NoTlsConfig; struct NoTlsConfig;
@@ -365,7 +364,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
fn connect( fn connect(
&self, &self,
_dns_name: &str, _dns_name: &str,
_tcp_stream: TcpStream, _io: Box<dyn ReadWrite>,
) -> Result<Box<dyn ReadWrite>, crate::error::Error> { ) -> Result<Box<dyn ReadWrite>, crate::error::Error> {
Err(ErrorKind::UnknownScheme Err(ErrorKind::UnknownScheme
.msg("cannot make HTTPS request because no TLS backend is configured")) .msg("cannot make HTTPS request because no TLS backend is configured"))

View File

@@ -11,12 +11,15 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
} }
impl TlsConnector for native_tls::TlsConnector { impl TlsConnector for native_tls::TlsConnector {
fn connect(&self, dns_name: &str, tcp_stream: TcpStream) -> Result<Box<dyn ReadWrite>, Error> { fn connect(&self, dns_name: &str, io: Box<dyn ReadWrite>) -> Result<Box<dyn ReadWrite>, Error> {
let stream = let stream =
native_tls::TlsConnector::connect(self, dns_name, tcp_stream).map_err(|e| { native_tls::TlsConnector::connect(self, dns_name, io).map_err(|e| match e {
ErrorKind::ConnectionFailed native_tls::HandshakeError::Failure(e) => ErrorKind::ConnectionFailed
.msg("native_tls connect failed") .msg("native_tls connect failed")
.src(e) .src(e),
native_tls::HandshakeError::WouldBlock(_) => {
ErrorKind::Io.msg("Unexpected native_tls::HandshakeError::WouldBlock")
}
})?; })?;
Ok(Box::new(stream)) Ok(Box::new(stream))
@@ -24,8 +27,11 @@ impl TlsConnector for native_tls::TlsConnector {
} }
#[cfg(feature = "native-tls")] #[cfg(feature = "native-tls")]
impl ReadWrite for native_tls::TlsStream<TcpStream> { impl ReadWrite for native_tls::TlsStream<Box<dyn ReadWrite>> {
fn socket(&self) -> Option<&TcpStream> { fn socket(&self) -> Option<&TcpStream> {
Some(self.get_ref()) self.get_ref().socket()
}
fn is_poolable(&self) -> bool {
self.get_ref().is_poolable()
} }
} }

View File

@@ -26,11 +26,14 @@ fn is_close_notify(e: &std::io::Error) -> bool {
false false
} }
struct RustlsStream(rustls::StreamOwned<rustls::ClientConnection, TcpStream>); struct RustlsStream(rustls::StreamOwned<rustls::ClientConnection, Box<dyn ReadWrite>>);
impl ReadWrite for RustlsStream { impl ReadWrite for RustlsStream {
fn socket(&self) -> Option<&TcpStream> { fn socket(&self) -> Option<&TcpStream> {
Some(self.0.get_ref()) self.0.get_ref().socket()
}
fn is_poolable(&self) -> bool {
self.0.get_ref().is_poolable()
} }
} }
@@ -93,7 +96,7 @@ impl TlsConnector for Arc<rustls::ClientConfig> {
fn connect( fn connect(
&self, &self,
dns_name: &str, dns_name: &str,
mut tcp_stream: TcpStream, mut io: Box<dyn ReadWrite>,
) -> Result<Box<dyn ReadWrite>, Error> { ) -> Result<Box<dyn ReadWrite>, Error> {
let sni = rustls::ServerName::try_from(dns_name) let sni = rustls::ServerName::try_from(dns_name)
.map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?; .map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?;
@@ -101,12 +104,12 @@ impl TlsConnector for Arc<rustls::ClientConfig> {
let mut sess = rustls::ClientConnection::new(self.clone(), sni) let mut sess = rustls::ClientConnection::new(self.clone(), sni)
.map_err(|e| ErrorKind::Io.msg("tls connection creation failed").src(e))?; .map_err(|e| ErrorKind::Io.msg("tls connection creation failed").src(e))?;
sess.complete_io(&mut tcp_stream).map_err(|e| { sess.complete_io(&mut io).map_err(|e| {
ErrorKind::ConnectionFailed ErrorKind::ConnectionFailed
.msg("tls connection init failed") .msg("tls connection init failed")
.src(e) .src(e)
})?; })?;
let stream = rustls::StreamOwned::new(sess, tcp_stream); let stream = rustls::StreamOwned::new(sess, io);
Ok(Box::new(RustlsStream(stream))) Ok(Box::new(RustlsStream(stream)))
} }

View File

@@ -20,58 +20,52 @@ use crate::unit::Unit;
/// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector]. /// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector].
pub trait ReadWrite: Read + Write + Send + 'static { pub trait ReadWrite: Read + Write + Send + 'static {
fn socket(&self) -> Option<&TcpStream>; fn socket(&self) -> Option<&TcpStream>;
fn is_poolable(&self) -> bool;
/// The bytes written to the stream as a Vec<u8>. This is used for tests only.
#[cfg(test)]
fn written_bytes(&self) -> Vec<u8> {
panic!("written_bytes on non Test stream");
}
}
impl ReadWrite for TcpStream {
fn socket(&self) -> Option<&TcpStream> {
Some(self)
}
fn is_poolable(&self) -> bool {
true
}
} }
pub trait TlsConnector: Send + Sync { pub trait TlsConnector: Send + Sync {
fn connect( fn connect(
&self, &self,
dns_name: &str, dns_name: &str,
tcp_stream: TcpStream, io: Box<dyn ReadWrite>,
) -> Result<Box<dyn ReadWrite>, crate::error::Error>; ) -> Result<Box<dyn ReadWrite>, crate::error::Error>;
} }
pub(crate) struct Stream { pub(crate) struct Stream {
inner: BufReader<Box<dyn Inner + Send + 'static>>, inner: BufReader<Box<dyn ReadWrite>>,
}
trait Inner: Read + Write {
fn is_poolable(&self) -> bool;
fn socket(&self) -> Option<&TcpStream>;
/// The bytes written to the stream as a Vec<u8>. This is used for tests only.
fn written_bytes(&self) -> Vec<u8> {
panic!("written_bytes on non Test stream");
}
} }
impl<T: ReadWrite + ?Sized> ReadWrite for Box<T> { impl<T: ReadWrite + ?Sized> ReadWrite for Box<T> {
fn socket(&self) -> Option<&TcpStream> { fn socket(&self) -> Option<&TcpStream> {
ReadWrite::socket(self.as_ref()) ReadWrite::socket(self.as_ref())
} }
}
impl<T: ReadWrite> Inner for T {
fn is_poolable(&self) -> bool { fn is_poolable(&self) -> bool {
true ReadWrite::is_poolable(self.as_ref())
} }
#[cfg(test)]
fn socket(&self) -> Option<&TcpStream> { fn written_bytes(&self) -> Vec<u8> {
ReadWrite::socket(self) ReadWrite::written_bytes(self.as_ref())
}
}
impl Inner for TcpStream {
fn is_poolable(&self) -> bool {
true
}
fn socket(&self) -> Option<&TcpStream> {
Some(self)
} }
} }
struct TestStream(Box<dyn Read + Send + Sync>, Vec<u8>, bool); struct TestStream(Box<dyn Read + Send + Sync>, Vec<u8>, bool);
impl Inner for TestStream { impl ReadWrite for TestStream {
fn is_poolable(&self) -> bool { fn is_poolable(&self) -> bool {
self.2 self.2
} }
@@ -79,7 +73,7 @@ impl Inner for TestStream {
None None
} }
/// For tests only #[cfg(test)]
fn written_bytes(&self) -> Vec<u8> { fn written_bytes(&self) -> Vec<u8> {
self.1.clone() self.1.clone()
} }
@@ -188,7 +182,7 @@ impl fmt::Debug for Stream {
} }
impl Stream { impl Stream {
fn new(t: impl Inner + Send + 'static) -> Stream { fn new(t: impl ReadWrite) -> Stream {
Stream::logged_create(Stream { Stream::logged_create(Stream {
inner: BufReader::new(Box::new(t)), inner: BufReader::new(Box::new(t)),
}) })
@@ -348,7 +342,7 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error
let sock = connect_host(unit, hostname, port)?; let sock = connect_host(unit, hostname, port)?;
let tls_conf = &unit.agent.config.tls_config; let tls_conf = &unit.agent.config.tls_config;
let https_stream = tls_conf.connect(hostname, sock)?; let https_stream = tls_conf.connect(hostname, Box::new(sock))?;
Ok(Stream::new(https_stream)) Ok(Stream::new(https_stream))
} }