diff --git a/src/response.rs b/src/response.rs index fc6ca2f..9287627 100644 --- a/src/response.rs +++ b/src/response.rs @@ -20,6 +20,11 @@ use encoding_rs::Encoding; pub const DEFAULT_CONTENT_TYPE: &str = "text/plain"; pub const DEFAULT_CHARACTER_SET: &str = "utf-8"; +const INTO_STRING_LIMIT: usize = 10 * 1_024 * 1_024; +// Follow the example of curl and limit a single header to 100kB: +// https://curl.se/libcurl/c/CURLOPT_HEADERFUNCTION.html +const MAX_HEADER_SIZE: usize = 100 * 1_024; +const MAX_HEADER_COUNT: usize = 100; /// Response instances are created as results of firing off requests. /// @@ -211,6 +216,11 @@ impl Response { /// length regardless of how many bytes the server sends. /// 3. If no length header, the reader is until server stream end. /// + /// Note: If you use `read_to_end()` on the resulting reader, a malicious + /// server might return enough bytes to exhaust available memroy. If you're + /// making requests to untrusted servers, you should use `.take()` to + /// limit the response bytes read. + /// /// Example: /// /// ``` @@ -226,6 +236,7 @@ impl Response { /// /// let mut bytes: Vec = Vec::with_capacity(len); /// resp.into_reader() + /// .take(10_000_000) /// .read_to_end(&mut bytes)?; /// /// assert_eq!(bytes.len(), len); @@ -291,6 +302,8 @@ impl Response { /// implementation first reads the reader to end into a `Vec` and then /// attempts to decode it using the charset. /// + /// If the response is larger than 10 megabytes, this will return an error. + /// /// Example: /// /// ``` @@ -321,14 +334,30 @@ impl Response { .or_else(|| Encoding::for_label(DEFAULT_CHARACTER_SET.as_bytes())) .unwrap(); let mut buf: Vec = vec![]; - self.into_reader().read_to_end(&mut buf)?; + self.into_reader() + .take((INTO_STRING_LIMIT + 1) as u64) + .read_to_end(&mut buf)?; + if buf.len() > INTO_STRING_LIMIT { + return Err(io::Error::new( + io::ErrorKind::Other, + "response too big for into_string", + )); + } let (text, _, _) = encoding.decode(&buf); Ok(text.into_owned()) } #[cfg(not(feature = "charset"))] { let mut buf: Vec = vec![]; - self.into_reader().read_to_end(&mut buf)?; + self.into_reader() + .take((INTO_STRING_LIMIT + 1) as u64) + .read_to_end(&mut buf)?; + if buf.len() > INTO_STRING_LIMIT { + return Err(io::Error::new( + io::ErrorKind::Other, + "response too big for into_string", + )); + } Ok(String::from_utf8_lossy(&buf).to_string()) } } @@ -427,7 +456,7 @@ impl Response { let (index, status) = parse_status_line(status_line.as_str())?; let mut headers: Vec
= Vec::new(); - loop { + while headers.len() <= MAX_HEADER_COUNT { let line = read_next_line(&mut stream, "a header")?; if line.is_empty() { break; @@ -437,6 +466,12 @@ impl Response { } } + if headers.len() > MAX_HEADER_COUNT { + return Err(ErrorKind::BadHeader.msg( + format!("more than {} header fields in response", MAX_HEADER_COUNT).as_str(), + )); + } + Ok(Response { url: None, status_line, @@ -549,27 +584,33 @@ impl FromStr for Response { fn read_next_line(reader: &mut impl BufRead, context: &str) -> io::Result { let mut buf = Vec::new(); - let result = reader.read_until(b'\n', &mut buf); + let result = reader + .take((MAX_HEADER_SIZE + 1) as u64) + .read_until(b'\n', &mut buf); - if let Err(e) = result { - // Provide context to errors encountered while reading the line. - let reason = format!("Error encountered in {}", context); - - let kind = e.kind(); - - // Use an intermediate wrapper type which carries the error message - // as well as a .source() reference to the original error. - let wrapper = Error::new(ErrorKind::Io, Some(reason)).src(e); - - return Err(io::Error::new(kind, wrapper)); - } - - if result? == 0 { - return Err(io::Error::new( + match result { + Ok(0) => Err(io::Error::new( io::ErrorKind::ConnectionAborted, "Unexpected EOF", - )); - } + )), + Ok(n) if n > MAX_HEADER_SIZE => Err(io::Error::new( + io::ErrorKind::Other, + format!("header field longer than {} bytes", MAX_HEADER_SIZE), + )), + Ok(_) => Ok(()), + Err(e) => { + // Provide context to errors encountered while reading the line. + let reason = format!("Error encountered in {}", context); + + let kind = e.kind(); + + // Use an intermediate wrapper type which carries the error message + // as well as a .source() reference to the original error. + let wrapper = Error::new(ErrorKind::Io, Some(reason)).src(e); + + Err(io::Error::new(kind, wrapper)) + } + }?; if !buf.ends_with(b"\n") { return Err(io::Error::new( @@ -672,6 +713,8 @@ impl Read for ErrorReader { #[cfg(test)] mod tests { + use std::io::Cursor; + use super::*; #[test] @@ -755,6 +798,25 @@ mod tests { assert_eq!("hello world!!!", resp.into_string().unwrap()); } + #[test] + fn into_string_large() { + const LEN: usize = INTO_STRING_LIMIT + 1; + let s = format!( + "HTTP/1.1 200 OK\r\n\ + Content-Length: {}\r\n + \r\n + {}", + LEN, + "A".repeat(LEN), + ); + let result = s.parse::().unwrap(); + let err = result + .into_string() + .expect_err("didn't error with too-long body"); + assert_eq!(err.to_string(), "response too big for into_string"); + assert_eq!(err.kind(), io::ErrorKind::Other); + } + #[test] #[cfg(feature = "json")] fn parse_simple_json() { @@ -801,6 +863,43 @@ mod tests { assert_eq!(resp.status_text(), ""); } + #[test] + fn read_next_line_large() { + const LEN: usize = MAX_HEADER_SIZE + 1; + let s = format!("Long-Header: {}\r\n", "A".repeat(LEN),); + let mut cursor = Cursor::new(s); + let result = read_next_line(&mut cursor, "some context"); + let err = result.expect_err("did not error on too-large header"); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert_eq!( + err.to_string(), + format!("header field longer than {} bytes", MAX_HEADER_SIZE) + ); + } + + #[test] + fn too_many_headers() { + const LEN: usize = MAX_HEADER_COUNT + 1; + let s = format!( + "HTTP/1.1 200 OK\r\n\ + {} + \r\n + hi", + "Header: value\r\n".repeat(LEN), + ); + let err = s + .parse::() + .expect_err("did not error on too many headers"); + assert_eq!(err.kind(), ErrorKind::BadHeader); + assert_eq!( + err.to_string(), + format!( + "Bad Header: more than {} header fields in response", + MAX_HEADER_COUNT + ) + ); + } + #[test] #[cfg(feature = "charset")] fn read_next_line_non_ascii_reason() {