|
|
|
@@ -1,4 +1,6 @@
|
|
|
|
use std::io::{Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write};
|
|
|
|
use std::io::{
|
|
|
|
|
|
|
|
BufRead, BufReader, Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write,
|
|
|
|
|
|
|
|
};
|
|
|
|
use std::net::SocketAddr;
|
|
|
|
use std::net::SocketAddr;
|
|
|
|
use std::net::TcpStream;
|
|
|
|
use std::net::TcpStream;
|
|
|
|
use std::net::ToSocketAddrs;
|
|
|
|
use std::net::ToSocketAddrs;
|
|
|
|
@@ -25,14 +27,14 @@ use crate::unit::Unit;
|
|
|
|
|
|
|
|
|
|
|
|
#[allow(clippy::large_enum_variant)]
|
|
|
|
#[allow(clippy::large_enum_variant)]
|
|
|
|
pub enum Stream {
|
|
|
|
pub enum Stream {
|
|
|
|
Http(TcpStream),
|
|
|
|
Http(BufReader<TcpStream>),
|
|
|
|
#[cfg(all(feature = "tls", not(feature = "native-tls")))]
|
|
|
|
#[cfg(all(feature = "tls", not(feature = "native-tls")))]
|
|
|
|
Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
|
|
|
|
Https(BufReader<rustls::StreamOwned<rustls::ClientSession, TcpStream>>),
|
|
|
|
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
|
|
|
|
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
|
|
|
|
Https(TlsStream<TcpStream>),
|
|
|
|
Https(BufReader<TlsStream<TcpStream>>),
|
|
|
|
Cursor(Cursor<Vec<u8>>),
|
|
|
|
Cursor(Cursor<Vec<u8>>),
|
|
|
|
#[cfg(test)]
|
|
|
|
#[cfg(test)]
|
|
|
|
Test(Box<dyn Read + Send>, Vec<u8>),
|
|
|
|
Test(Box<dyn BufRead + Send>, Vec<u8>),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// DeadlineStream wraps a stream such that read() will return an error
|
|
|
|
// DeadlineStream wraps a stream such that read() will return an error
|
|
|
|
@@ -161,9 +163,9 @@ impl Stream {
|
|
|
|
|
|
|
|
|
|
|
|
pub(crate) fn socket(&self) -> Option<&TcpStream> {
|
|
|
|
pub(crate) fn socket(&self) -> Option<&TcpStream> {
|
|
|
|
match self {
|
|
|
|
match self {
|
|
|
|
Stream::Http(tcpstream) => Some(tcpstream),
|
|
|
|
Stream::Http(b) => Some(b.get_ref()),
|
|
|
|
#[cfg(feature = "tls")]
|
|
|
|
#[cfg(feature = "tls")]
|
|
|
|
Stream::Https(rustls_stream) => Some(&rustls_stream.sock),
|
|
|
|
Stream::Https(b) => Some(&b.get_ref().sock),
|
|
|
|
_ => None,
|
|
|
|
_ => None,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@@ -193,6 +195,36 @@ impl Read for Stream {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl BufRead for Stream {
|
|
|
|
|
|
|
|
fn fill_buf(&mut self) -> IoResult<&[u8]> {
|
|
|
|
|
|
|
|
match self {
|
|
|
|
|
|
|
|
Stream::Http(r) => r.fill_buf(),
|
|
|
|
|
|
|
|
#[cfg(any(
|
|
|
|
|
|
|
|
all(feature = "tls", not(feature = "native-tls")),
|
|
|
|
|
|
|
|
all(feature = "native-tls", not(feature = "tls")),
|
|
|
|
|
|
|
|
))]
|
|
|
|
|
|
|
|
Stream::Https(r) => r.fill_buf(),
|
|
|
|
|
|
|
|
Stream::Cursor(r) => r.fill_buf(),
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
|
|
|
Stream::Test(r, _) => r.fill_buf(),
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn consume(&mut self, amt: usize) {
|
|
|
|
|
|
|
|
match self {
|
|
|
|
|
|
|
|
Stream::Http(r) => r.consume(amt),
|
|
|
|
|
|
|
|
#[cfg(any(
|
|
|
|
|
|
|
|
all(feature = "tls", not(feature = "native-tls")),
|
|
|
|
|
|
|
|
all(feature = "native-tls", not(feature = "tls")),
|
|
|
|
|
|
|
|
))]
|
|
|
|
|
|
|
|
Stream::Https(r) => r.consume(amt),
|
|
|
|
|
|
|
|
Stream::Cursor(r) => r.consume(amt),
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
|
|
|
Stream::Test(r, _) => r.consume(amt),
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl<R: Read> From<ChunkDecoder<R>> for Stream
|
|
|
|
impl<R: Read> From<ChunkDecoder<R>> for Stream
|
|
|
|
where
|
|
|
|
where
|
|
|
|
R: Read,
|
|
|
|
R: Read,
|
|
|
|
@@ -205,7 +237,7 @@ where
|
|
|
|
|
|
|
|
|
|
|
|
#[cfg(all(feature = "tls", not(feature = "native-tls")))]
|
|
|
|
#[cfg(all(feature = "tls", not(feature = "native-tls")))]
|
|
|
|
fn read_https(
|
|
|
|
fn read_https(
|
|
|
|
stream: &mut StreamOwned<ClientSession, TcpStream>,
|
|
|
|
stream: &mut BufReader<StreamOwned<ClientSession, TcpStream>>,
|
|
|
|
buf: &mut [u8],
|
|
|
|
buf: &mut [u8],
|
|
|
|
) -> IoResult<usize> {
|
|
|
|
) -> IoResult<usize> {
|
|
|
|
match stream.read(buf) {
|
|
|
|
match stream.read(buf) {
|
|
|
|
@@ -216,7 +248,7 @@ fn read_https(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
|
|
|
|
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
|
|
|
|
fn read_https(stream: &mut TlsStream<TcpStream>, buf: &mut [u8]) -> IoResult<usize> {
|
|
|
|
fn read_https(stream: &mut BufReader<TlsStream<TcpStream>>, buf: &mut [u8]) -> IoResult<usize> {
|
|
|
|
match stream.read(buf) {
|
|
|
|
match stream.read(buf) {
|
|
|
|
Ok(size) => Ok(size),
|
|
|
|
Ok(size) => Ok(size),
|
|
|
|
Err(ref e) if is_close_notify(e) => Ok(0),
|
|
|
|
Err(ref e) if is_close_notify(e) => Ok(0),
|
|
|
|
@@ -243,12 +275,12 @@ fn is_close_notify(e: &std::io::Error) -> bool {
|
|
|
|
impl Write for Stream {
|
|
|
|
impl Write for Stream {
|
|
|
|
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
|
|
|
|
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
|
|
|
|
match self {
|
|
|
|
match self {
|
|
|
|
Stream::Http(sock) => sock.write(buf),
|
|
|
|
Stream::Http(sock) => sock.get_mut().write(buf),
|
|
|
|
#[cfg(any(
|
|
|
|
#[cfg(any(
|
|
|
|
all(feature = "tls", not(feature = "native-tls")),
|
|
|
|
all(feature = "tls", not(feature = "native-tls")),
|
|
|
|
all(feature = "native-tls", not(feature = "tls")),
|
|
|
|
all(feature = "native-tls", not(feature = "tls")),
|
|
|
|
))]
|
|
|
|
))]
|
|
|
|
Stream::Https(stream) => stream.write(buf),
|
|
|
|
Stream::Https(stream) => stream.get_mut().write(buf),
|
|
|
|
Stream::Cursor(_) => panic!("Write to read only stream"),
|
|
|
|
Stream::Cursor(_) => panic!("Write to read only stream"),
|
|
|
|
#[cfg(test)]
|
|
|
|
#[cfg(test)]
|
|
|
|
Stream::Test(_, writer) => writer.write(buf),
|
|
|
|
Stream::Test(_, writer) => writer.write(buf),
|
|
|
|
@@ -256,12 +288,12 @@ impl Write for Stream {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
fn flush(&mut self) -> IoResult<()> {
|
|
|
|
fn flush(&mut self) -> IoResult<()> {
|
|
|
|
match self {
|
|
|
|
match self {
|
|
|
|
Stream::Http(sock) => sock.flush(),
|
|
|
|
Stream::Http(sock) => sock.get_mut().flush(),
|
|
|
|
#[cfg(any(
|
|
|
|
#[cfg(any(
|
|
|
|
all(feature = "tls", not(feature = "native-tls")),
|
|
|
|
all(feature = "tls", not(feature = "native-tls")),
|
|
|
|
all(feature = "native-tls", not(feature = "tls")),
|
|
|
|
all(feature = "native-tls", not(feature = "tls")),
|
|
|
|
))]
|
|
|
|
))]
|
|
|
|
Stream::Https(stream) => stream.flush(),
|
|
|
|
Stream::Https(stream) => stream.get_mut().flush(),
|
|
|
|
Stream::Cursor(_) => panic!("Flush read only stream"),
|
|
|
|
Stream::Cursor(_) => panic!("Flush read only stream"),
|
|
|
|
#[cfg(test)]
|
|
|
|
#[cfg(test)]
|
|
|
|
Stream::Test(_, writer) => writer.flush(),
|
|
|
|
Stream::Test(_, writer) => writer.flush(),
|
|
|
|
@@ -274,7 +306,9 @@ pub(crate) fn connect_http(unit: &Unit) -> Result<Stream, Error> {
|
|
|
|
let hostname = unit.url.host_str().unwrap();
|
|
|
|
let hostname = unit.url.host_str().unwrap();
|
|
|
|
let port = unit.url.port().unwrap_or(80);
|
|
|
|
let port = unit.url.port().unwrap_or(80);
|
|
|
|
|
|
|
|
|
|
|
|
connect_host(unit, hostname, port).map(Stream::Http)
|
|
|
|
connect_host(unit, hostname, port)
|
|
|
|
|
|
|
|
.map(BufReader::new)
|
|
|
|
|
|
|
|
.map(Stream::Http)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[cfg(all(feature = "tls", feature = "native-certs"))]
|
|
|
|
#[cfg(all(feature = "tls", feature = "native-certs"))]
|
|
|
|
@@ -316,7 +350,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
|
|
|
|
|
|
|
|
|
|
|
|
let stream = rustls::StreamOwned::new(sess, sock);
|
|
|
|
let stream = rustls::StreamOwned::new(sess, sock);
|
|
|
|
|
|
|
|
|
|
|
|
Ok(Stream::Https(stream))
|
|
|
|
Ok(Stream::Https(BufReader::new(stream)))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
|
|
|
|
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
|
|
|
|
@@ -338,7 +372,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
|
|
|
|
_ => Error::BadStatusRead,
|
|
|
|
_ => Error::BadStatusRead,
|
|
|
|
})?;
|
|
|
|
})?;
|
|
|
|
|
|
|
|
|
|
|
|
Ok(Stream::Https(stream))
|
|
|
|
Ok(Stream::Https(BufReader::new(stream)))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
|
|
|
|
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
|
|
|
|
|