Add support for alternate TLs implementations.

This commit is contained in:
Jacob Hoffman-Andrews
2021-10-04 22:47:00 -07:00
committed by Martin Algesten
parent 1c1dfaa691
commit 56276c3742
17 changed files with 527 additions and 233 deletions

View File

@@ -8,10 +8,6 @@ use std::{fmt, io::Cursor};
use chunked_transfer::Decoder as ChunkDecoder;
#[cfg(feature = "tls")]
use rustls::ClientConnection;
#[cfg(feature = "tls")]
use rustls::StreamOwned;
#[cfg(feature = "socks-proxy")]
use socks::{TargetAddr, ToTargetAddr};
@@ -21,16 +17,83 @@ use crate::{error::Error, proxy::Proto};
use crate::error::ErrorKind;
use crate::unit::Unit;
pub(crate) struct Stream {
inner: BufReader<Inner>,
pub trait HttpsStream: Read + Write + Send + Sync + 'static {
fn socket(&self) -> Option<&TcpStream>;
}
#[allow(clippy::large_enum_variant)]
enum Inner {
Http(TcpStream),
#[cfg(feature = "tls")]
Https(rustls::StreamOwned<rustls::ClientConnection, TcpStream>),
Test(Box<dyn Read + Send + Sync>, Vec<u8>),
pub trait TlsConnector: Send + Sync {
fn connect(
&self,
dns_name: &str,
tcp_stream: TcpStream,
) -> Result<Box<dyn HttpsStream>, crate::error::Error>;
}
pub(crate) struct Stream {
inner: BufReader<Box<dyn Inner + Send + Sync + 'static>>,
}
trait Inner: Read + Write {
fn is_poolable(&self) -> bool;
fn socket(&self) -> Option<&TcpStream>;
fn as_write_vec(&self) -> &[u8] {
panic!("as_write_vec on non Test stream");
}
}
impl<T: HttpsStream + ?Sized> HttpsStream for Box<T> {
fn socket(&self) -> Option<&TcpStream> {
HttpsStream::socket(self.as_ref())
}
}
impl<T: HttpsStream> Inner for T {
fn is_poolable(&self) -> bool {
true
}
fn socket(&self) -> Option<&TcpStream> {
HttpsStream::socket(self)
}
}
impl Inner for TcpStream {
fn is_poolable(&self) -> bool {
true
}
fn socket(&self) -> Option<&TcpStream> {
Some(self)
}
}
struct TestStream(Box<dyn Read + Send + Sync>, Vec<u8>);
impl Inner for TestStream {
fn is_poolable(&self) -> bool {
false
}
fn socket(&self) -> Option<&TcpStream> {
None
}
fn as_write_vec(&self) -> &[u8] {
&self.1
}
}
impl Read for TestStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for TestStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.1.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
// DeadlineStream wraps a stream such that read() will return an error
@@ -112,16 +175,20 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
impl fmt::Debug for Stream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.inner.get_ref() {
Inner::Http(tcpstream) => write!(f, "{:?}", tcpstream),
#[cfg(feature = "tls")]
Inner::Https(tlsstream) => write!(f, "{:?}", tlsstream.get_ref()),
Inner::Test(_, _) => write!(f, "Stream(Test)"),
match self.inner.get_ref().socket() {
Some(s) => write!(f, "{:?}", s),
None => write!(f, "Stream(Test)"),
}
}
}
impl Stream {
fn new(t: impl Inner + Send + Sync + 'static) -> Stream {
Stream::logged_create(Stream {
inner: BufReader::new(Box::new(t)),
})
}
fn logged_create(stream: Stream) -> Stream {
debug!("created stream: {:?}", stream);
stream
@@ -129,20 +196,13 @@ impl Stream {
pub(crate) fn from_vec(v: Vec<u8>) -> Stream {
Stream::logged_create(Stream {
inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])),
inner: BufReader::new(Box::new(TestStream(Box::new(Cursor::new(v)), vec![]))),
})
}
fn from_tcp_stream(t: TcpStream) -> Stream {
Stream::logged_create(Stream {
inner: BufReader::new(Inner::Http(t)),
})
}
#[cfg(feature = "tls")]
fn from_tls_stream(t: StreamOwned<ClientConnection, TcpStream>) -> Stream {
Stream::logged_create(Stream {
inner: BufReader::new(Inner::Https(t)),
inner: BufReader::new(Box::new(t)),
})
}
@@ -186,12 +246,7 @@ impl Stream {
}
}
pub fn is_poolable(&self) -> bool {
match self.inner.get_ref() {
Inner::Http(_) => true,
#[cfg(feature = "tls")]
Inner::Https(_) => true,
_ => false,
}
self.inner.get_ref().is_poolable()
}
pub(crate) fn reset(&mut self) -> io::Result<()> {
@@ -206,12 +261,7 @@ impl Stream {
}
pub(crate) fn socket(&self) -> Option<&TcpStream> {
match self.inner.get_ref() {
Inner::Http(b) => Some(b),
#[cfg(feature = "tls")]
Inner::Https(b) => Some(b.get_ref()),
_ => None,
}
self.inner.get_ref().socket()
}
pub(crate) fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
@@ -223,11 +273,8 @@ impl Stream {
}
#[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> {
match self.inner.get_ref() {
Inner::Test(_, writer) => writer.clone(),
_ => panic!("to_write_vec on non Test stream"),
}
pub fn as_write_vec(&self) -> &[u8] {
self.inner.get_ref().as_write_vec()
}
}
@@ -237,17 +284,6 @@ impl Read for Stream {
}
}
impl Read for Inner {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Inner::Http(sock) => sock.read(buf),
#[cfg(feature = "tls")]
Inner::Https(stream) => read_https(stream, buf),
Inner::Test(reader, _) => reader.read(buf),
}
}
}
impl BufRead for Stream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.inner.fill_buf()
@@ -268,50 +304,12 @@ where
}
}
#[cfg(feature = "tls")]
fn read_https(
stream: &mut StreamOwned<ClientConnection, TcpStream>,
buf: &mut [u8],
) -> io::Result<usize> {
match stream.read(buf) {
Ok(size) => Ok(size),
Err(ref e) if is_close_notify(e) => Ok(0),
Err(e) => Err(e),
}
}
#[allow(deprecated)]
#[cfg(feature = "tls")]
fn is_close_notify(e: &std::io::Error) -> bool {
if e.kind() != io::ErrorKind::ConnectionAborted {
return false;
}
if let Some(msg) = e.get_ref() {
// :(
return msg.description().contains("CloseNotify");
}
false
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self.inner.get_mut() {
Inner::Http(sock) => sock.write(buf),
#[cfg(feature = "tls")]
Inner::Https(stream) => stream.write(buf),
Inner::Test(_, writer) => writer.write(buf),
}
self.inner.get_mut().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
match self.inner.get_mut() {
Inner::Http(sock) => sock.flush(),
#[cfg(feature = "tls")]
Inner::Https(stream) => stream.flush(),
Inner::Test(_, writer) => writer.flush(),
}
self.inner.get_mut().flush()
}
}
@@ -328,55 +326,14 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error>
connect_host(unit, hostname, port).map(Stream::from_tcp_stream)
}
#[cfg(feature = "tls")]
pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error> {
use once_cell::sync::Lazy;
use std::{convert::TryFrom, sync::Arc};
static TLS_CONF: Lazy<Arc<rustls::ClientConfig>> = Lazy::new(|| {
let mut root_store = rustls::RootCertStore::empty();
#[cfg(not(feature = "native-certs"))]
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
ta.subject,
ta.spki,
ta.name_constraints,
)
}));
#[cfg(feature = "native-certs")]
for cert in rustls_native_certs::load_native_certs().expect("Could not load platform certs")
{
root_store.add(&rustls::Certificate(cert.0)).unwrap();
}
let config = rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_store)
.with_no_client_auth();
Arc::new(config)
});
let port = unit.url.port().unwrap_or(443);
let tls_conf: Arc<rustls::ClientConfig> = unit
.agent
.config
.tls_config
.as_ref()
.map(|c| c.0.clone())
.unwrap_or_else(|| TLS_CONF.clone());
let mut sock = connect_host(unit, hostname, port)?;
let mut sess = rustls::ClientConnection::new(
tls_conf,
rustls::ServerName::try_from(hostname).map_err(|e| ErrorKind::Dns.new().src(e))?,
)
.map_err(|e| ErrorKind::Io.new().src(e))?;
let sock = connect_host(unit, hostname, port)?;
sess.complete_io(&mut sock)
.map_err(|err| ErrorKind::ConnectionFailed.new().src(err))?;
let stream = rustls::StreamOwned::new(sess, sock);
Ok(Stream::from_tls_stream(stream))
let tls_conf = &unit.agent.config.tls_config;
let https_stream = tls_conf.connect(hostname, sock)?;
Ok(Stream::new(https_stream))
}
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
@@ -666,10 +623,3 @@ pub(crate) fn connect_test(unit: &Unit) -> Result<Stream, Error> {
pub(crate) fn connect_test(unit: &Unit) -> Result<Stream, Error> {
Err(ErrorKind::UnknownScheme.msg(&format!("unknown scheme '{}'", unit.url.scheme())))
}
#[cfg(not(feature = "tls"))]
pub(crate) fn connect_https(unit: &Unit, _hostname: &str) -> Result<Stream, Error> {
Err(ErrorKind::UnknownScheme
.msg("URL has 'https:' scheme but ureq was build without HTTP support")
.url(unit.url.clone()))
}