Fix buffered DeadlineStream

This commit is contained in:
Jacob Hoffman-Andrews
2020-11-28 17:47:17 -08:00
parent 131a0264d1
commit 50cb5cecd1
2 changed files with 68 additions and 44 deletions

View File

@@ -54,6 +54,32 @@ impl From<DeadlineStream> for Stream {
}
}
impl BufRead for DeadlineStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
if let Some(deadline) = self.deadline {
let timeout = time_until_deadline(deadline)?;
if let Some(socket) = self.stream.socket() {
socket.set_read_timeout(Some(timeout))?;
socket.set_write_timeout(Some(timeout))?;
}
}
self.stream.fill_buf().map_err(|e| {
// On unix-y platforms set_read_timeout and set_write_timeout
// causes ErrorKind::WouldBlock instead of ErrorKind::TimedOut.
// Since the socket most definitely not set_nonblocking(true),
// we can safely normalize WouldBlock to TimedOut
if e.kind() == io::ErrorKind::WouldBlock {
return io_err_timeout("timed out reading response".to_string());
}
e
})
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
impl Read for DeadlineStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if let Some(deadline) = self.deadline {
@@ -109,6 +135,19 @@ impl Stream {
inner: BufReader::new(Inner::Test(Box::new(Cursor::new(v)), vec![])),
}
}
fn from_tcp_stream(t: TcpStream) -> Stream {
Stream {
inner: BufReader::with_capacity(1000, Inner::Http(t)),
}
}
fn from_tls_stream(t: StreamOwned<ClientSession, TcpStream>) -> Stream {
Stream {
inner: BufReader::with_capacity(1000, Inner::Https(t)),
}
}
// Check if the server has closed a stream by performing a one-byte
// non-blocking read. If this returns EOF, the server has closed the
// connection: return true. If this returns WouldBlock (aka EAGAIN),
@@ -198,16 +237,6 @@ impl Read for Inner {
}
}
impl BufRead for DeadlineStream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.stream.fill_buf()
}
fn consume(&mut self, amt: usize) {
self.stream.consume(amt)
}
}
impl BufRead for Stream {
fn fill_buf(&mut self) -> io::Result<&[u8]> {
self.inner.fill_buf()
@@ -279,11 +308,7 @@ pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error>
//
let port = unit.url.port().unwrap_or(80);
connect_host(unit, hostname, port)
.map(Inner::Http)
.map(|h| Stream {
inner: BufReader::new(h),
})
connect_host(unit, hostname, port).map(Stream::from_tcp_stream)
}
#[cfg(all(feature = "tls", feature = "native-certs"))]
@@ -327,9 +352,7 @@ pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error
let stream = rustls::StreamOwned::new(sess, sock);
Ok(Stream {
inner: BufReader::new(Inner::Https(stream)),
})
Ok(Stream::from_tls_stream(stream))
}
pub(crate) fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {