Read buffer to avoid byte-by-byte syscalls (#141)

Fixes #140
This commit is contained in:
Martin Algesten
2020-09-13 03:27:15 +02:00
committed by GitHub
parent 960c0ff43b
commit 50c19c5484
3 changed files with 58 additions and 22 deletions

View File

@@ -1,4 +1,6 @@
use std::io::{Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write};
use std::io::{
BufRead, BufReader, Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write,
};
use std::net::SocketAddr;
use std::net::TcpStream;
use std::net::ToSocketAddrs;
@@ -25,14 +27,14 @@ use crate::unit::Unit;
#[allow(clippy::large_enum_variant)]
pub enum Stream {
Http(TcpStream),
Http(BufReader<TcpStream>),
#[cfg(all(feature = "tls", not(feature = "native-tls")))]
Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>),
Https(BufReader<rustls::StreamOwned<rustls::ClientSession, TcpStream>>),
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
Https(TlsStream<TcpStream>),
Https(BufReader<TlsStream<TcpStream>>),
Cursor(Cursor<Vec<u8>>),
#[cfg(test)]
Test(Box<dyn Read + Send>, Vec<u8>),
Test(Box<dyn BufRead + Send>, Vec<u8>),
}
// DeadlineStream wraps a stream such that read() will return an error
@@ -161,9 +163,9 @@ impl Stream {
pub(crate) fn socket(&self) -> Option<&TcpStream> {
match self {
Stream::Http(tcpstream) => Some(tcpstream),
Stream::Http(b) => Some(b.get_ref()),
#[cfg(feature = "tls")]
Stream::Https(rustls_stream) => Some(&rustls_stream.sock),
Stream::Https(b) => Some(&b.get_ref().sock),
_ => None,
}
}
@@ -193,6 +195,36 @@ impl Read for Stream {
}
}
impl BufRead for Stream {
fn fill_buf(&mut self) -> IoResult<&[u8]> {
match self {
Stream::Http(r) => r.fill_buf(),
#[cfg(any(
all(feature = "tls", not(feature = "native-tls")),
all(feature = "native-tls", not(feature = "tls")),
))]
Stream::Https(r) => r.fill_buf(),
Stream::Cursor(r) => r.fill_buf(),
#[cfg(test)]
Stream::Test(r, _) => r.fill_buf(),
}
}
fn consume(&mut self, amt: usize) {
match self {
Stream::Http(r) => r.consume(amt),
#[cfg(any(
all(feature = "tls", not(feature = "native-tls")),
all(feature = "native-tls", not(feature = "tls")),
))]
Stream::Https(r) => r.consume(amt),
Stream::Cursor(r) => r.consume(amt),
#[cfg(test)]
Stream::Test(r, _) => r.consume(amt),
}
}
}
impl<R: Read> From<ChunkDecoder<R>> for Stream
where
R: Read,
@@ -205,7 +237,7 @@ where
#[cfg(all(feature = "tls", not(feature = "native-tls")))]
fn read_https(
stream: &mut StreamOwned<ClientSession, TcpStream>,
stream: &mut BufReader<StreamOwned<ClientSession, TcpStream>>,
buf: &mut [u8],
) -> IoResult<usize> {
match stream.read(buf) {
@@ -216,7 +248,7 @@ fn read_https(
}
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
fn read_https(stream: &mut TlsStream<TcpStream>, buf: &mut [u8]) -> IoResult<usize> {
fn read_https(stream: &mut BufReader<TlsStream<TcpStream>>, buf: &mut [u8]) -> IoResult<usize> {
match stream.read(buf) {
Ok(size) => Ok(size),
Err(ref e) if is_close_notify(e) => Ok(0),
@@ -243,12 +275,12 @@ fn is_close_notify(e: &std::io::Error) -> bool {
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match self {
Stream::Http(sock) => sock.write(buf),
Stream::Http(sock) => sock.get_mut().write(buf),
#[cfg(any(
all(feature = "tls", not(feature = "native-tls")),
all(feature = "native-tls", not(feature = "tls")),
))]
Stream::Https(stream) => stream.write(buf),
Stream::Https(stream) => stream.get_mut().write(buf),
Stream::Cursor(_) => panic!("Write to read only stream"),
#[cfg(test)]
Stream::Test(_, writer) => writer.write(buf),
@@ -256,12 +288,12 @@ impl Write for Stream {
}
fn flush(&mut self) -> IoResult<()> {
match self {
Stream::Http(sock) => sock.flush(),
Stream::Http(sock) => sock.get_mut().flush(),
#[cfg(any(
all(feature = "tls", not(feature = "native-tls")),
all(feature = "native-tls", not(feature = "tls")),
))]
Stream::Https(stream) => stream.flush(),
Stream::Https(stream) => stream.get_mut().flush(),
Stream::Cursor(_) => panic!("Flush read only stream"),
#[cfg(test)]
Stream::Test(_, writer) => writer.flush(),
@@ -274,7 +306,9 @@ pub(crate) fn connect_http(unit: &Unit) -> Result<Stream, Error> {
let hostname = unit.url.host_str().unwrap();
let port = unit.url.port().unwrap_or(80);
connect_host(unit, hostname, port).map(Stream::Http)
connect_host(unit, hostname, port)
.map(BufReader::new)
.map(Stream::Http)
}
#[cfg(all(feature = "tls", feature = "native-certs"))]
@@ -316,7 +350,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
let stream = rustls::StreamOwned::new(sess, sock);
Ok(Stream::Https(stream))
Ok(Stream::Https(BufReader::new(stream)))
}
#[cfg(all(feature = "native-tls", not(feature = "tls")))]
@@ -338,7 +372,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
_ => Error::BadStatusRead,
})?;
Ok(Stream::Https(stream))
Ok(Stream::Https(BufReader::new(stream)))
}
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {