From 50c19c54847a1a031d0ae15f6a2b91e19d8d15d2 Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Sun, 13 Sep 2020 03:27:15 +0200 Subject: [PATCH] Read buffer to avoid byte-by-byte syscalls (#141) Fixes #140 --- src/request.rs | 2 +- src/response.rs | 12 +++++---- src/stream.rs | 66 +++++++++++++++++++++++++++++++++++++------------ 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/src/request.rs b/src/request.rs index 76b67d7..036ec73 100644 --- a/src/request.rs +++ b/src/request.rs @@ -608,7 +608,7 @@ impl Request { /// /// Example: /// ``` - /// let tls_connector = std::sync::Arc::new(native_tls::TlsConnector::new()); + /// let tls_connector = std::sync::Arc::new(native_tls::TlsConnector::new().unwrap()); /// let req = ureq::post("https://cool.server") /// .set_tls_connector(tls_connector.clone()); /// ``` diff --git a/src/response.rs b/src/response.rs index 2b384d3..60c41d1 100644 --- a/src/response.rs +++ b/src/response.rs @@ -593,14 +593,16 @@ pub(crate) fn set_stream(resp: &mut Response, url: String, unit: Option, s fn read_next_line(reader: &mut R) -> IoResult { let mut buf = Vec::new(); let mut prev_byte_was_cr = false; + let mut one = [0_u8]; loop { - let byte = reader.bytes().next(); + let amt = reader.read(&mut one[..])?; - let byte = match byte { - Some(b) => b?, - None => return Err(IoError::new(ErrorKind::ConnectionAborted, "Unexpected EOF")), - }; + if amt == 0 { + return Err(IoError::new(ErrorKind::ConnectionAborted, "Unexpected EOF")); + } + + let byte = one[0]; if byte == b'\n' && prev_byte_was_cr { buf.pop(); // removing the '\r' diff --git a/src/stream.rs b/src/stream.rs index 34a685f..599f7e5 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,4 +1,6 @@ -use std::io::{Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write}; +use std::io::{ + BufRead, BufReader, Cursor, Error as IoError, ErrorKind, Read, Result as IoResult, Write, +}; use std::net::SocketAddr; use std::net::TcpStream; use std::net::ToSocketAddrs; @@ -25,14 +27,14 @@ use crate::unit::Unit; #[allow(clippy::large_enum_variant)] pub enum Stream { - Http(TcpStream), + Http(BufReader), #[cfg(all(feature = "tls", not(feature = "native-tls")))] - Https(rustls::StreamOwned), + Https(BufReader>), #[cfg(all(feature = "native-tls", not(feature = "tls")))] - Https(TlsStream), + Https(BufReader>), Cursor(Cursor>), #[cfg(test)] - Test(Box, Vec), + Test(Box, Vec), } // DeadlineStream wraps a stream such that read() will return an error @@ -161,9 +163,9 @@ impl Stream { pub(crate) fn socket(&self) -> Option<&TcpStream> { match self { - Stream::Http(tcpstream) => Some(tcpstream), + Stream::Http(b) => Some(b.get_ref()), #[cfg(feature = "tls")] - Stream::Https(rustls_stream) => Some(&rustls_stream.sock), + Stream::Https(b) => Some(&b.get_ref().sock), _ => None, } } @@ -193,6 +195,36 @@ impl Read for Stream { } } +impl BufRead for Stream { + fn fill_buf(&mut self) -> IoResult<&[u8]> { + match self { + Stream::Http(r) => r.fill_buf(), + #[cfg(any( + all(feature = "tls", not(feature = "native-tls")), + all(feature = "native-tls", not(feature = "tls")), + ))] + Stream::Https(r) => r.fill_buf(), + Stream::Cursor(r) => r.fill_buf(), + #[cfg(test)] + Stream::Test(r, _) => r.fill_buf(), + } + } + + fn consume(&mut self, amt: usize) { + match self { + Stream::Http(r) => r.consume(amt), + #[cfg(any( + all(feature = "tls", not(feature = "native-tls")), + all(feature = "native-tls", not(feature = "tls")), + ))] + Stream::Https(r) => r.consume(amt), + Stream::Cursor(r) => r.consume(amt), + #[cfg(test)] + Stream::Test(r, _) => r.consume(amt), + } + } +} + impl From> for Stream where R: Read, @@ -205,7 +237,7 @@ where #[cfg(all(feature = "tls", not(feature = "native-tls")))] fn read_https( - stream: &mut StreamOwned, + stream: &mut BufReader>, buf: &mut [u8], ) -> IoResult { match stream.read(buf) { @@ -216,7 +248,7 @@ fn read_https( } #[cfg(all(feature = "native-tls", not(feature = "tls")))] -fn read_https(stream: &mut TlsStream, buf: &mut [u8]) -> IoResult { +fn read_https(stream: &mut BufReader>, buf: &mut [u8]) -> IoResult { match stream.read(buf) { Ok(size) => Ok(size), Err(ref e) if is_close_notify(e) => Ok(0), @@ -243,12 +275,12 @@ fn is_close_notify(e: &std::io::Error) -> bool { impl Write for Stream { fn write(&mut self, buf: &[u8]) -> IoResult { match self { - Stream::Http(sock) => sock.write(buf), + Stream::Http(sock) => sock.get_mut().write(buf), #[cfg(any( all(feature = "tls", not(feature = "native-tls")), all(feature = "native-tls", not(feature = "tls")), ))] - Stream::Https(stream) => stream.write(buf), + Stream::Https(stream) => stream.get_mut().write(buf), Stream::Cursor(_) => panic!("Write to read only stream"), #[cfg(test)] Stream::Test(_, writer) => writer.write(buf), @@ -256,12 +288,12 @@ impl Write for Stream { } fn flush(&mut self) -> IoResult<()> { match self { - Stream::Http(sock) => sock.flush(), + Stream::Http(sock) => sock.get_mut().flush(), #[cfg(any( all(feature = "tls", not(feature = "native-tls")), all(feature = "native-tls", not(feature = "tls")), ))] - Stream::Https(stream) => stream.flush(), + Stream::Https(stream) => stream.get_mut().flush(), Stream::Cursor(_) => panic!("Flush read only stream"), #[cfg(test)] Stream::Test(_, writer) => writer.flush(), @@ -274,7 +306,9 @@ pub(crate) fn connect_http(unit: &Unit) -> Result { let hostname = unit.url.host_str().unwrap(); let port = unit.url.port().unwrap_or(80); - connect_host(unit, hostname, port).map(Stream::Http) + connect_host(unit, hostname, port) + .map(BufReader::new) + .map(Stream::Http) } #[cfg(all(feature = "tls", feature = "native-certs"))] @@ -316,7 +350,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result { let stream = rustls::StreamOwned::new(sess, sock); - Ok(Stream::Https(stream)) + Ok(Stream::Https(BufReader::new(stream))) } #[cfg(all(feature = "native-tls", not(feature = "tls")))] @@ -338,7 +372,7 @@ pub(crate) fn connect_https(unit: &Unit) -> Result { _ => Error::BadStatusRead, })?; - Ok(Stream::Https(stream)) + Ok(Stream::Https(BufReader::new(stream))) } pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result {