Box<dyn ReadWrite> for TlsConnector::connect
This commit is contained in:
@@ -31,21 +31,28 @@ struct PassThrough {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TlsConnector for PassThrough {
|
impl TlsConnector for PassThrough {
|
||||||
fn connect(&self, _dns_name: &str, tcp_stream: TcpStream) -> Result<Box<dyn ReadWrite>, Error> {
|
fn connect(
|
||||||
|
&self,
|
||||||
|
_dns_name: &str,
|
||||||
|
io: Box<dyn ReadWrite>,
|
||||||
|
) -> Result<Box<dyn ReadWrite>, Error> {
|
||||||
if self.handshake_fail {
|
if self.handshake_fail {
|
||||||
let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError);
|
let io_err = io::Error::new(io::ErrorKind::InvalidData, PassThroughError);
|
||||||
return Err(io_err.into());
|
return Err(io_err.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Box::new(CustomTlsStream(tcp_stream)))
|
Ok(Box::new(CustomTlsStream(io)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CustomTlsStream(TcpStream);
|
struct CustomTlsStream(Box<dyn ReadWrite>);
|
||||||
|
|
||||||
impl ReadWrite for CustomTlsStream {
|
impl ReadWrite for CustomTlsStream {
|
||||||
fn socket(&self) -> Option<&TcpStream> {
|
fn socket(&self) -> Option<&TcpStream> {
|
||||||
Some(&self.0)
|
self.0.socket()
|
||||||
|
}
|
||||||
|
fn is_poolable(&self) -> bool {
|
||||||
|
self.0.is_poolable()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -53,9 +53,14 @@ impl MbedTlsConnector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TlsConnector for MbedTlsConnector {
|
impl TlsConnector for MbedTlsConnector {
|
||||||
fn connect(&self, _dns_name: &str, tcp_stream: TcpStream) -> Result<Box<dyn ReadWrite>, Error> {
|
fn connect(
|
||||||
|
&self,
|
||||||
|
_dns_name: &str,
|
||||||
|
io: Box<dyn ReadWrite>,
|
||||||
|
) -> Result<Box<dyn ReadWrite>, Error> {
|
||||||
let mut ctx = self.context.lock().unwrap();
|
let mut ctx = self.context.lock().unwrap();
|
||||||
match ctx.establish(tcp_stream, None) {
|
let sync = SyncIo(Mutex::new(io));
|
||||||
|
match ctx.establish(sync, None) {
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError);
|
let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError);
|
||||||
return Err(io_err.into());
|
return Err(io_err.into());
|
||||||
@@ -65,6 +70,28 @@ impl TlsConnector for MbedTlsConnector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Internal wrapper to make Box<dyn ReadWrite> implement Sync
|
||||||
|
struct SyncIo(Mutex<Box<dyn ReadWrite>>);
|
||||||
|
|
||||||
|
impl io::Read for SyncIo {
|
||||||
|
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||||
|
let mut lock = self.0.lock().unwrap();
|
||||||
|
lock.read(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl io::Write for SyncIo {
|
||||||
|
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||||
|
let mut lock = self.0.lock().unwrap();
|
||||||
|
lock.write(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> io::Result<()> {
|
||||||
|
let mut lock = self.0.lock().unwrap();
|
||||||
|
lock.flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct MbedTlsStream {
|
struct MbedTlsStream {
|
||||||
context: Arc<Mutex<Context>>, //tcp_stream: TcpStream,
|
context: Arc<Mutex<Context>>, //tcp_stream: TcpStream,
|
||||||
}
|
}
|
||||||
@@ -84,6 +111,9 @@ impl ReadWrite for MbedTlsStream {
|
|||||||
fn socket(&self) -> Option<&TcpStream> {
|
fn socket(&self) -> Option<&TcpStream> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
fn is_poolable(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl io::Read for MbedTlsStream {
|
impl io::Read for MbedTlsStream {
|
||||||
|
|||||||
@@ -356,7 +356,6 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
|
|||||||
// calls at the top of the crate (`ureq::get` etc).
|
// calls at the top of the crate (`ureq::get` etc).
|
||||||
#[cfg(not(feature = "tls"))]
|
#[cfg(not(feature = "tls"))]
|
||||||
pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
|
pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
|
||||||
use std::net::TcpStream;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
struct NoTlsConfig;
|
struct NoTlsConfig;
|
||||||
@@ -365,7 +364,7 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
|
|||||||
fn connect(
|
fn connect(
|
||||||
&self,
|
&self,
|
||||||
_dns_name: &str,
|
_dns_name: &str,
|
||||||
_tcp_stream: TcpStream,
|
_io: Box<dyn ReadWrite>,
|
||||||
) -> Result<Box<dyn ReadWrite>, crate::error::Error> {
|
) -> Result<Box<dyn ReadWrite>, crate::error::Error> {
|
||||||
Err(ErrorKind::UnknownScheme
|
Err(ErrorKind::UnknownScheme
|
||||||
.msg("cannot make HTTPS request because no TLS backend is configured"))
|
.msg("cannot make HTTPS request because no TLS backend is configured"))
|
||||||
|
|||||||
18
src/ntls.rs
18
src/ntls.rs
@@ -11,12 +11,15 @@ pub(crate) fn default_tls_config() -> std::sync::Arc<dyn TlsConnector> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TlsConnector for native_tls::TlsConnector {
|
impl TlsConnector for native_tls::TlsConnector {
|
||||||
fn connect(&self, dns_name: &str, tcp_stream: TcpStream) -> Result<Box<dyn ReadWrite>, Error> {
|
fn connect(&self, dns_name: &str, io: Box<dyn ReadWrite>) -> Result<Box<dyn ReadWrite>, Error> {
|
||||||
let stream =
|
let stream =
|
||||||
native_tls::TlsConnector::connect(self, dns_name, tcp_stream).map_err(|e| {
|
native_tls::TlsConnector::connect(self, dns_name, io).map_err(|e| match e {
|
||||||
ErrorKind::ConnectionFailed
|
native_tls::HandshakeError::Failure(e) => ErrorKind::ConnectionFailed
|
||||||
.msg("native_tls connect failed")
|
.msg("native_tls connect failed")
|
||||||
.src(e)
|
.src(e),
|
||||||
|
native_tls::HandshakeError::WouldBlock(_) => {
|
||||||
|
ErrorKind::Io.msg("Unexpected native_tls::HandshakeError::WouldBlock")
|
||||||
|
}
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
Ok(Box::new(stream))
|
Ok(Box::new(stream))
|
||||||
@@ -24,8 +27,11 @@ impl TlsConnector for native_tls::TlsConnector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "native-tls")]
|
#[cfg(feature = "native-tls")]
|
||||||
impl ReadWrite for native_tls::TlsStream<TcpStream> {
|
impl ReadWrite for native_tls::TlsStream<Box<dyn ReadWrite>> {
|
||||||
fn socket(&self) -> Option<&TcpStream> {
|
fn socket(&self) -> Option<&TcpStream> {
|
||||||
Some(self.get_ref())
|
self.get_ref().socket()
|
||||||
|
}
|
||||||
|
fn is_poolable(&self) -> bool {
|
||||||
|
self.get_ref().is_poolable()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
13
src/rtls.rs
13
src/rtls.rs
@@ -26,11 +26,14 @@ fn is_close_notify(e: &std::io::Error) -> bool {
|
|||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RustlsStream(rustls::StreamOwned<rustls::ClientConnection, TcpStream>);
|
struct RustlsStream(rustls::StreamOwned<rustls::ClientConnection, Box<dyn ReadWrite>>);
|
||||||
|
|
||||||
impl ReadWrite for RustlsStream {
|
impl ReadWrite for RustlsStream {
|
||||||
fn socket(&self) -> Option<&TcpStream> {
|
fn socket(&self) -> Option<&TcpStream> {
|
||||||
Some(self.0.get_ref())
|
self.0.get_ref().socket()
|
||||||
|
}
|
||||||
|
fn is_poolable(&self) -> bool {
|
||||||
|
self.0.get_ref().is_poolable()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,7 +96,7 @@ impl TlsConnector for Arc<rustls::ClientConfig> {
|
|||||||
fn connect(
|
fn connect(
|
||||||
&self,
|
&self,
|
||||||
dns_name: &str,
|
dns_name: &str,
|
||||||
mut tcp_stream: TcpStream,
|
mut io: Box<dyn ReadWrite>,
|
||||||
) -> Result<Box<dyn ReadWrite>, Error> {
|
) -> Result<Box<dyn ReadWrite>, Error> {
|
||||||
let sni = rustls::ServerName::try_from(dns_name)
|
let sni = rustls::ServerName::try_from(dns_name)
|
||||||
.map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?;
|
.map_err(|e| ErrorKind::Dns.msg(format!("parsing '{}'", dns_name)).src(e))?;
|
||||||
@@ -101,12 +104,12 @@ impl TlsConnector for Arc<rustls::ClientConfig> {
|
|||||||
let mut sess = rustls::ClientConnection::new(self.clone(), sni)
|
let mut sess = rustls::ClientConnection::new(self.clone(), sni)
|
||||||
.map_err(|e| ErrorKind::Io.msg("tls connection creation failed").src(e))?;
|
.map_err(|e| ErrorKind::Io.msg("tls connection creation failed").src(e))?;
|
||||||
|
|
||||||
sess.complete_io(&mut tcp_stream).map_err(|e| {
|
sess.complete_io(&mut io).map_err(|e| {
|
||||||
ErrorKind::ConnectionFailed
|
ErrorKind::ConnectionFailed
|
||||||
.msg("tls connection init failed")
|
.msg("tls connection init failed")
|
||||||
.src(e)
|
.src(e)
|
||||||
})?;
|
})?;
|
||||||
let stream = rustls::StreamOwned::new(sess, tcp_stream);
|
let stream = rustls::StreamOwned::new(sess, io);
|
||||||
|
|
||||||
Ok(Box::new(RustlsStream(stream)))
|
Ok(Box::new(RustlsStream(stream)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,58 +20,52 @@ use crate::unit::Unit;
|
|||||||
/// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector].
|
/// Trait for things implementing [std::io::Read] + [std::io::Write]. Used in [TlsConnector].
|
||||||
pub trait ReadWrite: Read + Write + Send + 'static {
|
pub trait ReadWrite: Read + Write + Send + 'static {
|
||||||
fn socket(&self) -> Option<&TcpStream>;
|
fn socket(&self) -> Option<&TcpStream>;
|
||||||
|
fn is_poolable(&self) -> bool;
|
||||||
|
|
||||||
|
/// The bytes written to the stream as a Vec<u8>. This is used for tests only.
|
||||||
|
#[cfg(test)]
|
||||||
|
fn written_bytes(&self) -> Vec<u8> {
|
||||||
|
panic!("written_bytes on non Test stream");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReadWrite for TcpStream {
|
||||||
|
fn socket(&self) -> Option<&TcpStream> {
|
||||||
|
Some(self)
|
||||||
|
}
|
||||||
|
fn is_poolable(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait TlsConnector: Send + Sync {
|
pub trait TlsConnector: Send + Sync {
|
||||||
fn connect(
|
fn connect(
|
||||||
&self,
|
&self,
|
||||||
dns_name: &str,
|
dns_name: &str,
|
||||||
tcp_stream: TcpStream,
|
io: Box<dyn ReadWrite>,
|
||||||
) -> Result<Box<dyn ReadWrite>, crate::error::Error>;
|
) -> Result<Box<dyn ReadWrite>, crate::error::Error>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct Stream {
|
pub(crate) struct Stream {
|
||||||
inner: BufReader<Box<dyn Inner + Send + 'static>>,
|
inner: BufReader<Box<dyn ReadWrite>>,
|
||||||
}
|
|
||||||
|
|
||||||
trait Inner: Read + Write {
|
|
||||||
fn is_poolable(&self) -> bool;
|
|
||||||
fn socket(&self) -> Option<&TcpStream>;
|
|
||||||
|
|
||||||
/// The bytes written to the stream as a Vec<u8>. This is used for tests only.
|
|
||||||
fn written_bytes(&self) -> Vec<u8> {
|
|
||||||
panic!("written_bytes on non Test stream");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: ReadWrite + ?Sized> ReadWrite for Box<T> {
|
impl<T: ReadWrite + ?Sized> ReadWrite for Box<T> {
|
||||||
fn socket(&self) -> Option<&TcpStream> {
|
fn socket(&self) -> Option<&TcpStream> {
|
||||||
ReadWrite::socket(self.as_ref())
|
ReadWrite::socket(self.as_ref())
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: ReadWrite> Inner for T {
|
|
||||||
fn is_poolable(&self) -> bool {
|
fn is_poolable(&self) -> bool {
|
||||||
true
|
ReadWrite::is_poolable(self.as_ref())
|
||||||
}
|
}
|
||||||
|
#[cfg(test)]
|
||||||
fn socket(&self) -> Option<&TcpStream> {
|
fn written_bytes(&self) -> Vec<u8> {
|
||||||
ReadWrite::socket(self)
|
ReadWrite::written_bytes(self.as_ref())
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Inner for TcpStream {
|
|
||||||
fn is_poolable(&self) -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
fn socket(&self) -> Option<&TcpStream> {
|
|
||||||
Some(self)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TestStream(Box<dyn Read + Send + Sync>, Vec<u8>, bool);
|
struct TestStream(Box<dyn Read + Send + Sync>, Vec<u8>, bool);
|
||||||
|
|
||||||
impl Inner for TestStream {
|
impl ReadWrite for TestStream {
|
||||||
fn is_poolable(&self) -> bool {
|
fn is_poolable(&self) -> bool {
|
||||||
self.2
|
self.2
|
||||||
}
|
}
|
||||||
@@ -79,7 +73,7 @@ impl Inner for TestStream {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
/// For tests only
|
#[cfg(test)]
|
||||||
fn written_bytes(&self) -> Vec<u8> {
|
fn written_bytes(&self) -> Vec<u8> {
|
||||||
self.1.clone()
|
self.1.clone()
|
||||||
}
|
}
|
||||||
@@ -188,7 +182,7 @@ impl fmt::Debug for Stream {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Stream {
|
impl Stream {
|
||||||
fn new(t: impl Inner + Send + 'static) -> Stream {
|
fn new(t: impl ReadWrite) -> Stream {
|
||||||
Stream::logged_create(Stream {
|
Stream::logged_create(Stream {
|
||||||
inner: BufReader::new(Box::new(t)),
|
inner: BufReader::new(Box::new(t)),
|
||||||
})
|
})
|
||||||
@@ -348,7 +342,7 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error
|
|||||||
let sock = connect_host(unit, hostname, port)?;
|
let sock = connect_host(unit, hostname, port)?;
|
||||||
|
|
||||||
let tls_conf = &unit.agent.config.tls_config;
|
let tls_conf = &unit.agent.config.tls_config;
|
||||||
let https_stream = tls_conf.connect(hostname, sock)?;
|
let https_stream = tls_conf.connect(hostname, Box::new(sock))?;
|
||||||
Ok(Stream::new(https_stream))
|
Ok(Stream::new(https_stream))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user