diff --git a/src/stream.rs b/src/stream.rs index a88a925..223dd9e 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -151,6 +151,15 @@ impl BufRead for DeadlineStream { impl Read for DeadlineStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { + // If the stream's BufReader has any buffered bytes, return those first. + // This avoids calling `fill_buf()` on DeadlineStream unnecessarily, + // since that call always does a syscall. This ensures DeadlineStream + // can pass through the efficiency we gain by using a BufReader in Stream. + if !self.stream.inner.buffer().is_empty() { + let n = self.stream.inner.buffer().read(buf)?; + self.stream.inner.consume(n); + return Ok(n); + } // All reads on a DeadlineStream use the BufRead impl. This ensures // that we have a chance to set the correct timeout before each recv // syscall. @@ -642,3 +651,71 @@ pub(crate) fn connect_test(unit: &Unit) -> Result { pub(crate) fn connect_test(unit: &Unit) -> Result { Err(ErrorKind::UnknownScheme.msg(format!("unknown scheme '{}'", unit.url.scheme()))) } + +#[cfg(test)] +mod tests { + use super::*; + use std::{ + io::Read, + sync::{Arc, Mutex}, + }; + + // Returns all zeroes to `.read()` and logs how many times it's called + struct ReadRecorder { + reads: Arc>>, + } + + impl Read for ReadRecorder { + fn read(&mut self, buf: &mut [u8]) -> std::result::Result { + self.reads.lock().unwrap().push(buf.len()); + buf.fill(0); + Ok(buf.len()) + } + } + + impl Write for ReadRecorder { + fn write(&mut self, _: &[u8]) -> io::Result { + unimplemented!() + } + + fn flush(&mut self) -> io::Result<()> { + unimplemented!() + } + } + + impl fmt::Debug for ReadRecorder { + fn fmt(&self, _: &mut fmt::Formatter<'_>) -> fmt::Result { + unimplemented!() + } + } + + impl ReadWrite for ReadRecorder { + fn socket(&self) -> Option<&TcpStream> { + unimplemented!() + } + + fn is_poolable(&self) -> bool { + unimplemented!() + } + } + + // Test that when a DeadlineStream wraps a Stream, and the user performs a series of + // tiny read_exacts, Stream's BufReader is used appropriately. + #[test] + fn test_deadline_stream_buffering() { + let reads = Arc::new(Mutex::new(vec![])); + let recorder = ReadRecorder { + reads: reads.clone(), + }; + let stream = Stream::new(recorder); + let mut deadline_stream = DeadlineStream::new(stream, None); + let mut buf = [0u8; 1]; + for _ in 0..8193 { + deadline_stream.read(&mut buf).unwrap(); + } + let reads = reads.lock().unwrap(); + assert_eq!(reads.len(), 2); + assert_eq!(reads[0], 8192); + assert_eq!(reads[1], 8192); + } +}