Read buffer to avoid byte-by-byte syscalls (#141)

Fixes #140
This commit is contained in:
Martin Algesten
2020-09-13 03:27:15 +02:00
committed by GitHub
parent 960c0ff43b
commit 50c19c5484
3 changed files with 58 additions and 22 deletions

View File

@@ -608,7 +608,7 @@ impl Request {
/// ///
/// Example: /// Example:
/// ``` /// ```
/// let tls_connector = std::sync::Arc::new(native_tls::TlsConnector::new()); /// let tls_connector = std::sync::Arc::new(native_tls::TlsConnector::new().unwrap());
/// let req = ureq::post("https://cool.server") /// let req = ureq::post("https://cool.server")
/// .set_tls_connector(tls_connector.clone()); /// .set_tls_connector(tls_connector.clone());
/// ``` /// ```

View File

@@ -593,14 +593,16 @@ pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option<Unit>, s
fn read_next_line<R: Read>(reader: &mut R) -> IoResult<String> { fn read_next_line<R: Read>(reader: &mut R) -> IoResult<String> {
let mut buf = Vec::new(); let mut buf = Vec::new();
let mut prev_byte_was_cr = false; let mut prev_byte_was_cr = false;
let mut one = [0_u8];
loop { loop {
let byte = reader.bytes().next(); let amt = reader.read(&mut one[..])?;
let byte = match byte { if amt == 0 {
Some(b) => b?, return Err(IoError::new(ErrorKind::ConnectionAborted, "Unexpected EOF"));
None => return Err(IoError::new(ErrorKind::ConnectionAborted, "Unexpected EOF")), }
};
let byte = one[0];
if byte == b'\n' && prev_byte_was_cr { if byte == b'\n' && prev_byte_was_cr {
buf.pop(); // removing the '\r' buf.pop(); // removing the '\r'

View File

@@ -1,4 +1,6 @@
use std::io::{Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write}; use std::io::{
BufRead, BufReader, Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write,
};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::net::TcpStream; use std::net::TcpStream;
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
@@ -25,14 +27,14 @@ use crate::unit::Unit;
#[allow(clippy::large_enum_variant)] #[allow(clippy::large_enum_variant)]
pub enum Stream { pub enum Stream {
Http(TcpStream), Http(BufReader<TcpStream>),
#[cfg(all(feature = "tls", not(feature = "native-tls")))] #[cfg(all(feature = "tls", not(feature = "native-tls")))]
Https(rustls::StreamOwned<rustls::ClientSession, TcpStream>), Https(BufReader<rustls::StreamOwned<rustls::ClientSession, TcpStream>>),
#[cfg(all(feature = "native-tls", not(feature = "tls")))] #[cfg(all(feature = "native-tls", not(feature = "tls")))]
Https(TlsStream<TcpStream>), Https(BufReader<TlsStream<TcpStream>>),
Cursor(Cursor<Vec<u8>>), Cursor(Cursor<Vec<u8>>),
#[cfg(test)] #[cfg(test)]
Test(Box<dyn Read + Send>, Vec<u8>), Test(Box<dyn BufRead + Send>, Vec<u8>),
} }
// DeadlineStream wraps a stream such that read() will return an error // DeadlineStream wraps a stream such that read() will return an error
@@ -161,9 +163,9 @@ impl Stream {
pub(crate) fn socket(&self) -> Option<&TcpStream> { pub(crate) fn socket(&self) -> Option<&TcpStream> {
match self { match self {
Stream::Http(tcpstream) => Some(tcpstream), Stream::Http(b) => Some(b.get_ref()),
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
Stream::Https(rustls_stream) => Some(&rustls_stream.sock), Stream::Https(b) => Some(&b.get_ref().sock),
_ => None, _ => None,
} }
} }
@@ -193,6 +195,36 @@ impl Read for Stream {
} }
} }
impl BufRead for Stream {
fn fill_buf(&mut self) -> IoResult<&[u8]> {
match self {
Stream::Http(r) => r.fill_buf(),
#[cfg(any(
all(feature = "tls", not(feature = "native-tls")),
all(feature = "native-tls", not(feature = "tls")),
))]
Stream::Https(r) => r.fill_buf(),
Stream::Cursor(r) => r.fill_buf(),
#[cfg(test)]
Stream::Test(r, _) => r.fill_buf(),
}
}
fn consume(&mut self, amt: usize) {
match self {
Stream::Http(r) => r.consume(amt),
#[cfg(any(
all(feature = "tls", not(feature = "native-tls")),
all(feature = "native-tls", not(feature = "tls")),
))]
Stream::Https(r) => r.consume(amt),
Stream::Cursor(r) => r.consume(amt),
#[cfg(test)]
Stream::Test(r, _) => r.consume(amt),
}
}
}
impl<R: Read> From<ChunkDecoder<R>> for Stream impl<R: Read> From<ChunkDecoder<R>> for Stream
where where
R: Read, R: Read,
@@ -205,7 +237,7 @@ where
#[cfg(all(feature = "tls", not(feature = "native-tls")))] #[cfg(all(feature = "tls", not(feature = "native-tls")))]
fn read_https( fn read_https(
stream: &mut StreamOwned<ClientSession, TcpStream>, stream: &mut BufReader<StreamOwned<ClientSession, TcpStream>>,
buf: &mut [u8], buf: &mut [u8],
) -> IoResult<usize> { ) -> IoResult<usize> {
match stream.read(buf) { match stream.read(buf) {
@@ -216,7 +248,7 @@ fn read_https(
} }
#[cfg(all(feature = "native-tls", not(feature = "tls")))] #[cfg(all(feature = "native-tls", not(feature = "tls")))]
fn read_https(stream: &mut TlsStream<TcpStream>, buf: &mut [u8]) -> IoResult<usize> { fn read_https(stream: &mut BufReader<TlsStream<TcpStream>>, buf: &mut [u8]) -> IoResult<usize> {
match stream.read(buf) { match stream.read(buf) {
Ok(size) => Ok(size), Ok(size) => Ok(size),
Err(ref e) if is_close_notify(e) => Ok(0), Err(ref e) if is_close_notify(e) => Ok(0),
@@ -243,12 +275,12 @@ fn is_close_notify(e: &std::io::Error) -> bool {
impl Write for Stream { impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> { fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match self { match self {
Stream::Http(sock) => sock.write(buf), Stream::Http(sock) => sock.get_mut().write(buf),
#[cfg(any( #[cfg(any(
all(feature = "tls", not(feature = "native-tls")), all(feature = "tls", not(feature = "native-tls")),
all(feature = "native-tls", not(feature = "tls")), all(feature = "native-tls", not(feature = "tls")),
))] ))]
Stream::Https(stream) => stream.write(buf), Stream::Https(stream) => stream.get_mut().write(buf),
Stream::Cursor(_) => panic!("Write to read only stream"), Stream::Cursor(_) => panic!("Write to read only stream"),
#[cfg(test)] #[cfg(test)]
Stream::Test(_, writer) => writer.write(buf), Stream::Test(_, writer) => writer.write(buf),
@@ -256,12 +288,12 @@ impl Write for Stream {
} }
fn flush(&mut self) -> IoResult<()> { fn flush(&mut self) -> IoResult<()> {
match self { match self {
Stream::Http(sock) => sock.flush(), Stream::Http(sock) => sock.get_mut().flush(),
#[cfg(any( #[cfg(any(
all(feature = "tls", not(feature = "native-tls")), all(feature = "tls", not(feature = "native-tls")),
all(feature = "native-tls", not(feature = "tls")), all(feature = "native-tls", not(feature = "tls")),
))] ))]
Stream::Https(stream) => stream.flush(), Stream::Https(stream) => stream.get_mut().flush(),
Stream::Cursor(_) => panic!("Flush read only stream"), Stream::Cursor(_) => panic!("Flush read only stream"),
#[cfg(test)] #[cfg(test)]
Stream::Test(_, writer) => writer.flush(), Stream::Test(_, writer) => writer.flush(),
@@ -274,7 +306,9 @@ pub(crate) fn connect_http(unit: &Unit) -> Result<Stream, Error> {
let hostname = unit.url.host_str().unwrap(); let hostname = unit.url.host_str().unwrap();
let port = unit.url.port().unwrap_or(80); let port = unit.url.port().unwrap_or(80);
connect_host(unit, hostname, port).map(Stream::Http) connect_host(unit, hostname, port)
.map(BufReader::new)
.map(Stream::Http)
} }
#[cfg(all(feature = "tls", feature = "native-certs"))] #[cfg(all(feature = "tls", feature = "native-certs"))]
@@ -316,7 +350,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
let stream = rustls::StreamOwned::new(sess, sock); let stream = rustls::StreamOwned::new(sess, sock);
Ok(Stream::Https(stream)) Ok(Stream::Https(BufReader::new(stream)))
} }
#[cfg(all(feature = "native-tls", not(feature = "tls")))] #[cfg(all(feature = "native-tls", not(feature = "tls")))]
@@ -338,7 +372,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
_ => Error::BadStatusRead, _ => Error::BadStatusRead,
})?; })?;
Ok(Stream::Https(stream)) Ok(Stream::Https(BufReader::new(stream)))
} }
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> { pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {