Add support for alternate TLs implementations.
This commit is contained in:
committed by
Martin Algesten
parent
1c1dfaa691
commit
56276c3742
242
src/stream.rs
242
src/stream.rs
@@ -8,10 +8,6 @@ use std::{fmt, io::Cursor};
|
||||
|
||||
use chunked_transfer::Decoder as ChunkDecoder;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use rustls::ClientConnection;
|
||||
#[cfg(feature = "tls")]
|
||||
use rustls::StreamOwned;
|
||||
#[cfg(feature = "socks-proxy")]
|
||||
use socks::{TargetAddr, ToTargetAddr};
|
||||
|
||||
@@ -21,16 +17,83 @@ use crate::{error::Error, proxy::Proto};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::unit::Unit;
|
||||
|
||||
pub(crate) struct Stream {
|
||||
inner: BufReader<Inner>,
|
||||
pub trait HttpsStream: Read + Write + Send + Sync + 'static {
|
||||
fn socket(&self) -> Option<&TcpStream>;
|
||||
}
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
enum Inner {
|
||||
Http(TcpStream),
|
||||
#[cfg(feature = "tls")]
|
||||
Https(rustls::StreamOwned<rustls::ClientConnection, TcpStream>),
|
||||
Test(Box<dyn Read + Send + Sync>, Vec<u8>),
|
||||
pub trait TlsConnector: Send + Sync {
|
||||
fn connect(
|
||||
&self,
|
||||
dns_name: &str,
|
||||
tcp_stream: TcpStream,
|
||||
) -> Result<Box<dyn HttpsStream>, crate::error::Error>;
|
||||
}
|
||||
|
||||
pub(crate) struct Stream {
|
||||
inner: BufReader<Box<dyn Inner + Send + Sync + 'static>>,
|
||||
}
|
||||
|
||||
trait Inner: Read + Write {
|
||||
fn is_poolable(&self) -> bool;
|
||||
fn socket(&self) -> Option<&TcpStream>;
|
||||
fn as_write_vec(&self) -> &[u8] {
|
||||
panic!("as_write_vec on non Test stream");
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: HttpsStream + ?Sized> HttpsStream for Box<T> {
|
||||
fn socket(&self) -> Option<&TcpStream> {
|
||||
HttpsStream::socket(self.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: HttpsStream> Inner for T {
|
||||
fn is_poolable(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn socket(&self) -> Option<&TcpStream> {
|
||||
HttpsStream::socket(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Inner for TcpStream {
|
||||
fn is_poolable(&self) -> bool {
|
||||
true
|
||||
}
|
||||
fn socket(&self) -> Option<&TcpStream> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
struct TestStream(Box<dyn Read + Send + Sync>, Vec<u8>);
|
||||
|
||||
impl Inner for TestStream {
|
||||
fn is_poolable(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn socket(&self) -> Option<&TcpStream> {
|
||||
None
|
||||
}
|
||||
fn as_write_vec(&self) -> &[u8] {
|
||||
&self.1
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for TestStream {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
self.0.read(buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Write for TestStream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
self.1.write(buf)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// DeadlineStream wraps a stream such that read() will return an error
|
||||
@@ -112,16 +175,20 @@ pub(crate) fn io_err_timeout(error: String) -> io::Error {
|
||||
|
||||
impl fmt::Debug for Stream {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self.inner.get_ref() {
|
||||
Inner::Http(tcpstream) => write!(f, "{:?}", tcpstream),
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Https(tlsstream) => write!(f, "{:?}", tlsstream.get_ref()),
|
||||
Inner::Test(_, _) => write!(f, "Stream(Test)"),
|
||||
match self.inner.get_ref().socket() {
|
||||
Some(s) => write!(f, "{:?}", s),
|
||||
None => write!(f, "Stream(Test)"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream {
|
||||
fn new(t: impl Inner + Send + Sync + 'static) -> Stream {
|
||||
Stream::logged_create(Stream {
|
||||
inner: BufReader::new(Box::new(t)),
|
||||
})
|
||||
}
|
||||
|
||||
fn logged_create(stream: Stream) -> Stream {
|
||||
debug!("created stream: {:?}", stream);
|
||||
stream
|
||||
@@ -129,20 +196,13 @@ impl Stream {
|
||||
|
||||
pub(crate) fn from_vec(v: Vec<u8>) -> Stream {
|
||||
Stream::logged_create(Stream {
|
||||
inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])),
|
||||
inner: BufReader::new(Box::new(TestStream(Box::new(Cursor::new(v)), vec![]))),
|
||||
})
|
||||
}
|
||||
|
||||
fn from_tcp_stream(t: TcpStream) -> Stream {
|
||||
Stream::logged_create(Stream {
|
||||
inner: BufReader::new(Inner::Http(t)),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
fn from_tls_stream(t: StreamOwned<ClientConnection, TcpStream>) -> Stream {
|
||||
Stream::logged_create(Stream {
|
||||
inner: BufReader::new(Inner::Https(t)),
|
||||
inner: BufReader::new(Box::new(t)),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -186,12 +246,7 @@ impl Stream {
|
||||
}
|
||||
}
|
||||
pub fn is_poolable(&self) -> bool {
|
||||
match self.inner.get_ref() {
|
||||
Inner::Http(_) => true,
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Https(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
self.inner.get_ref().is_poolable()
|
||||
}
|
||||
|
||||
pub(crate) fn reset(&mut self) -> io::Result<()> {
|
||||
@@ -206,12 +261,7 @@ impl Stream {
|
||||
}
|
||||
|
||||
pub(crate) fn socket(&self) -> Option<&TcpStream> {
|
||||
match self.inner.get_ref() {
|
||||
Inner::Http(b) => Some(b),
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Https(b) => Some(b.get_ref()),
|
||||
_ => None,
|
||||
}
|
||||
self.inner.get_ref().socket()
|
||||
}
|
||||
|
||||
pub(crate) fn set_read_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
|
||||
@@ -223,11 +273,8 @@ impl Stream {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn to_write_vec(&self) -> Vec<u8> {
|
||||
match self.inner.get_ref() {
|
||||
Inner::Test(_, writer) => writer.clone(),
|
||||
_ => panic!("to_write_vec on non Test stream"),
|
||||
}
|
||||
pub fn as_write_vec(&self) -> &[u8] {
|
||||
self.inner.get_ref().as_write_vec()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,17 +284,6 @@ impl Read for Stream {
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for Inner {
|
||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
||||
match self {
|
||||
Inner::Http(sock) => sock.read(buf),
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Https(stream) => read_https(stream, buf),
|
||||
Inner::Test(reader, _) => reader.read(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BufRead for Stream {
|
||||
fn fill_buf(&mut self) -> io::Result<&[u8]> {
|
||||
self.inner.fill_buf()
|
||||
@@ -268,50 +304,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
fn read_https(
|
||||
stream: &mut StreamOwned<ClientConnection, TcpStream>,
|
||||
buf: &mut [u8],
|
||||
) -> io::Result<usize> {
|
||||
match stream.read(buf) {
|
||||
Ok(size) => Ok(size),
|
||||
Err(ref e) if is_close_notify(e) => Ok(0),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(deprecated)]
|
||||
#[cfg(feature = "tls")]
|
||||
fn is_close_notify(e: &std::io::Error) -> bool {
|
||||
if e.kind() != io::ErrorKind::ConnectionAborted {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(msg) = e.get_ref() {
|
||||
// :(
|
||||
|
||||
return msg.description().contains("CloseNotify");
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
impl Write for Stream {
|
||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||||
match self.inner.get_mut() {
|
||||
Inner::Http(sock) => sock.write(buf),
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Https(stream) => stream.write(buf),
|
||||
Inner::Test(_, writer) => writer.write(buf),
|
||||
}
|
||||
self.inner.get_mut().write(buf)
|
||||
}
|
||||
fn flush(&mut self) -> io::Result<()> {
|
||||
match self.inner.get_mut() {
|
||||
Inner::Http(sock) => sock.flush(),
|
||||
#[cfg(feature = "tls")]
|
||||
Inner::Https(stream) => stream.flush(),
|
||||
Inner::Test(_, writer) => writer.flush(),
|
||||
}
|
||||
self.inner.get_mut().flush()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,55 +326,14 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error>
|
||||
connect_host(unit, hostname, port).map(Stream::from_tcp_stream)
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error> {
|
||||
use once_cell::sync::Lazy;
|
||||
use std::{convert::TryFrom, sync::Arc};
|
||||
|
||||
static TLS_CONF: Lazy<Arc<rustls::ClientConfig>> = Lazy::new(|| {
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
#[cfg(not(feature = "native-certs"))]
|
||||
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}));
|
||||
#[cfg(feature = "native-certs")]
|
||||
for cert in rustls_native_certs::load_native_certs().expect("Could not load platform certs")
|
||||
{
|
||||
root_store.add(&rustls::Certificate(cert.0)).unwrap();
|
||||
}
|
||||
|
||||
let config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
Arc::new(config)
|
||||
});
|
||||
|
||||
let port = unit.url.port().unwrap_or(443);
|
||||
|
||||
let tls_conf: Arc<rustls::ClientConfig> = unit
|
||||
.agent
|
||||
.config
|
||||
.tls_config
|
||||
.as_ref()
|
||||
.map(|c| c.0.clone())
|
||||
.unwrap_or_else(|| TLS_CONF.clone());
|
||||
let mut sock = connect_host(unit, hostname, port)?;
|
||||
let mut sess = rustls::ClientConnection::new(
|
||||
tls_conf,
|
||||
rustls::ServerName::try_from(hostname).map_err(|e| ErrorKind::Dns.new().src(e))?,
|
||||
)
|
||||
.map_err(|e| ErrorKind::Io.new().src(e))?;
|
||||
let sock = connect_host(unit, hostname, port)?;
|
||||
|
||||
sess.complete_io(&mut sock)
|
||||
.map_err(|err| ErrorKind::ConnectionFailed.new().src(err))?;
|
||||
let stream = rustls::StreamOwned::new(sess, sock);
|
||||
|
||||
Ok(Stream::from_tls_stream(stream))
|
||||
let tls_conf = &unit.agent.config.tls_config;
|
||||
let https_stream = tls_conf.connect(hostname, sock)?;
|
||||
Ok(Stream::new(https_stream))
|
||||
}
|
||||
|
||||
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
|
||||
@@ -666,10 +623,3 @@ pub(crate) fn connect_test(unit: &Unit) -> Result<Stream, Error> {
|
||||
pub(crate) fn connect_test(unit: &Unit) -> Result<Stream, Error> {
|
||||
Err(ErrorKind::UnknownScheme.msg(&format!("unknown scheme '{}'", unit.url.scheme())))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
pub(crate) fn connect_https(unit: &Unit, _hostname: &str) -> Result<Stream, Error> {
|
||||
Err(ErrorKind::UnknownScheme
|
||||
.msg("URL has 'https:' scheme but ureq was build without HTTP support")
|
||||
.url(unit.url.clone()))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user