From d0bd2d5ea9db57f7991406e3ad9baab4cf399cbd Mon Sep 17 00:00:00 2001 From: Joshua Nelson Date: Tue, 5 Jan 2021 16:55:26 -0500 Subject: [PATCH] Use iteration instead of recursion for `connect` (#291) This allows handling larger redirect chains. Fixes #290 --- src/error.rs | 27 +++++---- src/request.rs | 2 +- src/response.rs | 71 ++++++------------------ src/test/redirect.rs | 35 ++++++++++++ src/unit.rs | 129 +++++++++++++++++++++++-------------------- 5 files changed, 137 insertions(+), 127 deletions(-) diff --git a/src/error.rs b/src/error.rs index 5d1ab2c..4cc2f7c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -103,8 +103,8 @@ impl Display for Error { match self { Error::Status(status, response) => { write!(f, "{}: status code {}", response.get_url(), status)?; - if let Some(original) = response.history().last() { - write!(f, " (redirected from {})", original.get_url())?; + if let Some(original) = response.history.get(0) { + write!(f, " (redirected from {})", original)?; } } Error::Transport(err) => { @@ -300,19 +300,22 @@ fn status_code_error() { #[test] fn status_code_error_redirect() { - use std::sync::Arc; - let mut response0 = Response::new(302, "Found", "").unwrap(); - response0.set_url("http://example.org/".parse().unwrap()); - let mut response1 = Response::new(302, "Found", "").unwrap(); - response1.set_previous(Arc::new(response0)); - let mut response2 = Response::new(500, "Internal Server Error", "server overloaded").unwrap(); - response2.set_previous(Arc::new(response1)); - response2.set_url("http://example.com/".parse().unwrap()); - let err = Error::Status(response2.status(), response2); + use crate::{get, test}; + test::set_handler("/redirect_a", |unit| { + assert_eq!(unit.method, "GET"); + test::make_response(302, "Go here", vec!["Location: test://example.edu/redirect_b"], vec![]) + }); + test::set_handler("/redirect_b", |unit| { + assert_eq!(unit.method, "GET"); + test::make_response(302, "Go here", vec!["Location: http://example.com/status/500"], vec![]) + }); + + let err = get("test://example.org/redirect_a").call().unwrap_err(); + assert_eq!(err.kind(), ErrorKind::HTTP, "{:?}", err); assert_eq!( err.to_string(), - "http://example.com/: status code 500 (redirected from http://example.org/)" + "http://example.com/status/500: status code 500 (redirected from test://example.org/redirect_a)" ); } diff --git a/src/request.rs b/src/request.rs index 8bf0991..b28fcdd 100644 --- a/src/request.rs +++ b/src/request.rs @@ -118,7 +118,7 @@ impl Request { } let reader = payload.into_read(); let unit = Unit::new(&self.agent, &self.method, &url, &self.headers, &reader); - let response = unit::connect(unit, true, reader, None).map_err(|e| e.url(url.clone()))?; + let response = unit::connect(unit, true, reader).map_err(|e| e.url(url.clone()))?; if response.status() >= 400 { Err(Error::Status(response.status(), response)) diff --git a/src/response.rs b/src/response.rs index 07df455..5e4cbb0 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,9 +1,6 @@ +use std::io::{self, Read}; use std::str::FromStr; use std::{fmt, io::BufRead}; -use std::{ - io::{self, Read}, - sync::Arc, -}; use chunked_transfer::Decoder as ChunkDecoder; use url::Url; @@ -52,9 +49,12 @@ pub struct Response { headers: Vec
, unit: Option, stream: Stream, - // If this Response resulted from a redirect, the Response containing - // that redirect. - previous: Option>, + /// The redirect history of this response, if any. The history starts with + /// the first response received and ends with the response immediately + /// previous to this one. + /// + /// If this response was not redirected, the history is empty. + pub(crate) history: Vec, } /// index into status_line where we split: HTTP/1.1 200 OK @@ -390,13 +390,6 @@ impl Response { }) } - // Returns an iterator across the redirect history of this response, - // if any. The iterator starts with the response before this one. - // If this response was not redirected, the iterator is empty. - pub(crate) fn history(&self) -> Hist { - Hist::new(self.previous.as_deref()) - } - /// Create a response from a Read trait impl. /// /// This is hopefully useful for unit tests. @@ -438,18 +431,13 @@ impl Response { headers, unit, stream: stream.into(), - previous: None, + history: vec![], }) } - pub(crate) fn do_from_request( - unit: Unit, - stream: Stream, - previous: Option>, - ) -> Result { + pub(crate) fn do_from_request(unit: Unit, stream: Stream) -> Result { let url = Some(unit.url.clone()); let mut resp = Response::do_from_stream(stream, Some(unit))?; - resp.previous = previous; resp.url = url; Ok(resp) } @@ -465,8 +453,10 @@ impl Response { } #[cfg(test)] - pub fn set_previous(&mut self, previous: Arc) { - self.previous = Some(previous); + pub fn history_from_previous(&mut self, previous: Response) { + let previous_url = previous.get_url().to_string(); + self.history = previous.history; + self.history.push(previous_url); } } @@ -538,31 +528,6 @@ impl FromStr for Response { } } -// Hist is an iterator over the history of a redirected response. It -// yields the URLs that were requested in backwards order, from most recent -// to least recent. -pub(crate) struct Hist<'a> { - response: Option<&'a Response>, -} - -impl<'a> Hist<'a> { - fn new(response: Option<&'a Response>) -> Hist<'a> { - Hist { response } - } -} -impl<'a> Iterator for Hist<'a> { - type Item = &'a Response; - fn next(&mut self) -> Option<&'a Response> { - let response = match self.response { - None => return None, - Some(r) => r, - }; - - self.response = response.previous.as_deref(); - Some(response) - } -} - fn read_next_line(reader: &mut impl BufRead) -> io::Result { let mut s = String::new(); if reader.read_line(&mut s)? == 0 { @@ -774,18 +739,18 @@ mod tests { fn history() { let mut response0 = Response::new(302, "Found", "").unwrap(); response0.set_url("http://1.example.com/".parse().unwrap()); - assert_eq!(response0.history().count(), 0); + assert!(response0.history.is_empty()); let mut response1 = Response::new(302, "Found", "").unwrap(); response1.set_url("http://2.example.com/".parse().unwrap()); - response1.set_previous(Arc::new(response0)); + response1.history_from_previous(response0); let mut response2 = Response::new(404, "NotFound", "").unwrap(); response2.set_url("http://2.example.com/".parse().unwrap()); - response2.set_previous(Arc::new(response1)); + response2.history_from_previous(response1); - let hist: Vec<&str> = response2.history().map(|r| r.get_url()).collect(); - assert_eq!(hist, ["http://2.example.com/", "http://1.example.com/"]) + let hist: Vec<&str> = response2.history.iter().map(|r| &**r).collect(); + assert_eq!(hist, ["http://1.example.com/", "http://2.example.com/"]) } } diff --git a/src/test/redirect.rs b/src/test/redirect.rs index d360702..83a5e0e 100644 --- a/src/test/redirect.rs +++ b/src/test/redirect.rs @@ -35,6 +35,22 @@ fn redirect_many() { .get("test://host/redirect_many1") .call(); assert!(matches!(result, Err(e) if e.kind() == ErrorKind::TooManyRedirects)); + + test::set_handler("/redirect_many1", |_| { + test::make_response(302, "Go here", vec!["Location: /redirect_many2"], vec![]) + }); + test::set_handler("/redirect_many2", |_| { + test::make_response(302, "Go here", vec!["Location: /redirect_many3"], vec![]) + }); + test::set_handler("/redirect_many3", |_| { + test::make_response(302, "Go here", vec!["Location: /redirect_many4"], vec![]) + }); + let result = builder() + .redirects(2) + .build() + .get("test://host/redirect_many1") + .call(); + assert!(matches!(result, Err(e) if e.kind() == ErrorKind::TooManyRedirects)); } #[test] @@ -141,3 +157,22 @@ fn redirect_308() { assert_eq!(resp.status(), 200); assert_eq!(resp.get_url(), "test://host/valid_response"); } + +#[test] +fn too_many_redirects() { + for i in 0..10_000 { + test::set_handler(&format!("/malicious_redirect_{}", i), move |_| { + let location = format!("Location: /malicious_redirect_{}", i + 1); + test::make_response(302, "Go here", vec![&location], vec![]) + }); + } + + test::set_handler("/malicious_redirect_10000", |unit| { + assert_eq!(unit.method, "GET"); + test::make_response(200, "OK", vec![], vec![]) + }); + + let req = crate::builder().redirects(10001).build(); + let resp = req.get("test://host/malicious_redirect_0").call().unwrap(); + assert_eq!(resp.get_url(), "test://host/malicious_redirect_10000"); +} diff --git a/src/unit.rs b/src/unit.rs index 2fa1fa2..e42de6c 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -1,8 +1,5 @@ +use std::io::{self, Write}; use std::time; -use std::{ - io::{self, Write}, - sync::Arc, -}; use log::{debug, info}; use url::Url; @@ -163,12 +160,70 @@ impl Unit { } } -/// Perform a connection. Used recursively for redirects. +/// Perform a connection. Follows redirects. pub(crate) fn connect( - unit: Unit, + mut unit: Unit, + use_pooled: bool, + mut body: SizedReader, +) -> Result { + let mut history = vec![]; + let mut resp = loop { + let resp = connect_inner(&unit, use_pooled, body, &history)?; + + // handle redirects + if !(300..399).contains(&resp.status()) || unit.agent.config.redirects == 0 { + break resp; + } + if history.len() + 1 >= unit.agent.config.redirects as usize { + return Err(ErrorKind::TooManyRedirects.new()); + } + // the location header + let location = match resp.header("location") { + Some(l) => l, + None => break resp, + }; + + let url = &unit.url; + let method = &unit.method; + // join location header to current url in case it is relative + let new_url = url.join(location).map_err(|e| { + ErrorKind::InvalidUrl + .msg(&format!("Bad redirection: {}", location)) + .src(e) + })?; + + // perform the redirect differently depending on 3xx code. + let new_method = match resp.status() { + // this is to follow how curl does it. POST, PUT etc change + // to GET on a redirect. + 301 | 302 | 303 => match &method[..] { + "GET" | "HEAD" => unit.method, + _ => "GET".into(), + }, + // never change the method for 307/308 + // only resend the request if it cannot have a body + // NOTE: DELETE is intentionally excluded: https://stackoverflow.com/questions/299628 + 307 | 308 if ["GET", "HEAD", "OPTIONS", "TRACE"].contains(&method.as_str()) => { + unit.method + } + _ => break resp, + }; + debug!("redirect {} {} -> {}", resp.status(), url, new_url); + history.push(unit.url.to_string()); + body = Payload::Empty.into_read(); + // recreate the unit to get a new hostname and cookies for the new host. + unit = Unit::new(&unit.agent, &new_method, &new_url, &unit.headers, &body); + }; + resp.history = history; + Ok(resp) +} + +/// Perform a connection. Does not follow redirects. +fn connect_inner( + unit: &Unit, use_pooled: bool, body: SizedReader, - previous: Option>, + previous: &[String], ) -> Result { let host = unit .url @@ -185,14 +240,15 @@ pub(crate) fn connect( info!("sending request {} {}", method, url); } - let send_result = send_prelude(&unit, &mut stream, previous.is_some()); + let send_result = send_prelude(&unit, &mut stream, !previous.is_empty()); if let Err(err) = send_result { if is_recycled { debug!("retrying request early {} {}: {}", method, url, err); // we try open a new connection, this time there will be // no connection in the pool. don't use it. - return connect(unit, false, body, previous); + // NOTE: this recurses at most once because `use_pooled` is `false`. + return connect_inner(unit, false, body, previous); } else { // not a pooled connection, propagate the error. return Err(err.into()); @@ -204,7 +260,7 @@ pub(crate) fn connect( body::send_body(body, unit.is_chunked, &mut stream)?; // start reading the response to process cookies and redirects. - let result = Response::do_from_request(unit.clone(), stream, previous.clone()); + let result = Response::do_from_request(unit.clone(), stream); // https://tools.ietf.org/html/rfc7230#section-6.3.1 // When an inbound connection is closed prematurely, a client MAY @@ -220,7 +276,8 @@ pub(crate) fn connect( Err(err) if err.connection_closed() && retryable && is_recycled => { debug!("retrying request {} {}: {}", method, url, err); let empty = Payload::Empty.into_read(); - return connect(unit, false, empty, previous); + // NOTE: this recurses at most once because `use_pooled` is `false`. + return connect_inner(unit, false, empty, previous); } Err(e) => return Err(e), Ok(resp) => resp, @@ -230,56 +287,6 @@ pub(crate) fn connect( #[cfg(feature = "cookies")] save_cookies(&unit, &resp); - // handle redirects - if (300..399).contains(&resp.status()) && unit.agent.config.redirects > 0 { - if let Some(previous) = previous { - if previous.history().count() + 1 >= unit.agent.config.redirects as usize { - return Err(ErrorKind::TooManyRedirects.new()); - } - } - - // the location header - let location = resp.header("location"); - if let Some(location) = location { - // join location header to current url in case it it relative - let new_url = url.join(location).map_err(|e| { - ErrorKind::InvalidUrl - .msg(&format!("Bad redirection: {}", location)) - .src(e) - })?; - - // perform the redirect differently depending on 3xx code. - match resp.status() { - 301 | 302 | 303 => { - let empty = Payload::Empty.into_read(); - // this is to follow how curl does it. POST, PUT etc change - // to GET on a redirect. - let new_method = match &method[..] { - "GET" | "HEAD" => method.to_string(), - _ => "GET".into(), - }; - // recreate the unit to get a new hostname and cookies for the new host. - let new_unit = - Unit::new(&unit.agent, &new_method, &new_url, &unit.headers, &empty); - - debug!("redirect {} {} -> {}", resp.status(), url, new_url); - return connect(new_unit, use_pooled, empty, Some(Arc::new(resp))); - } - // never change the method for 307/308 - // only resend the request if it cannot have a body - // NOTE: DELETE is intentionally excluded: https://stackoverflow.com/questions/299628 - 307 | 308 if ["GET", "HEAD", "OPTIONS", "TRACE"].contains(&method.as_str()) => { - let empty = Payload::Empty.into_read(); - debug!("redirect {} {} -> {}", resp.status(), url, new_url); - // recreate the unit to get a new hostname and cookies for the new host. - let new_unit = Unit::new(&unit.agent, &unit.method, &new_url, &unit.headers, &empty); - return connect(new_unit, use_pooled, empty, Some(Arc::new(resp))); - } - _ => (), - }; - } - } - debug!("response {} to {} {}", resp.status(), method, url); // release the response