diff --git a/src/request.rs b/src/request.rs index 02fb7b2..ddc5ca3 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,7 +1,7 @@ +use lazy_static::lazy_static; use qstring::QString; use std::io::Read; use std::sync::Arc; -use lazy_static::lazy_static; #[cfg(feature = "json")] use super::SerdeValue; @@ -39,19 +39,17 @@ pub struct Request { impl ::std::fmt::Debug for Request { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::result::Result<(), ::std::fmt::Error> { - let (path, query) = self.to_url() + let (path, query) = self + .to_url() .map(|u| { - let query = combine_query(&u, &self.query, true); - (u.path().to_string(), query) - }) + let query = combine_query(&u, &self.query, true); + (u.path().to_string(), query) + }) .unwrap_or_else(|_| ("BAD_URL".to_string(), "BAD_URL".to_string())); write!( f, "Request({} {}{}, {:?})", - self.method, - path, - query, - self.headers + self.method, path, query, self.headers ) } } @@ -101,7 +99,7 @@ impl Request { .and_then(|url| { let reader = payload.into_read(); let unit = Unit::new(&self, &url, true, &reader); - connect(&self, unit, &self.method, true, 0, reader) + connect(&self, unit, true, 0, reader, false) }) .unwrap_or_else(|e| e.into()) } @@ -148,8 +146,7 @@ impl Request { /// .send_string("Hällo Wörld!"); /// println!("{:?}", r); /// ``` - pub fn send_string(&mut self, data: &str) -> Response - { + pub fn send_string(&mut self, data: &str) -> Response { let text = data.into(); let charset = response::charset_from_content_type(self.header("content-type")).to_string(); self.do_call(Payload::Text(text, charset)) @@ -169,8 +166,7 @@ impl Request { /// .set("Content-Type", "text/plain") /// .send(read); /// ``` - pub fn send(&mut self, reader: impl Read + 'static) -> Response - { + pub fn send(&mut self, reader: impl Read + 'static) -> Response { self.do_call(Payload::Reader(Box::new(reader))) } @@ -188,8 +184,7 @@ impl Request { /// println!("Oh no error!"); /// } /// ``` - pub fn set(&mut self, header: &str, value: &str) -> &mut Request - { + pub fn set(&mut self, header: &str, value: &str) -> &mut Request { add_header(&mut self.headers, Header::new(header, value)); self } @@ -246,8 +241,7 @@ impl Request { /// /// println!("{:?}", r); /// ``` - pub fn query(&mut self, param: &str, value: &str) -> &mut Request - { + pub fn query(&mut self, param: &str, value: &str) -> &mut Request { self.query.add_pair((param, value)); self } @@ -262,8 +256,7 @@ impl Request { /// .call(); /// println!("{:?}", r); /// ``` - pub fn query_str(&mut self, query: &str) -> &mut Request - { + pub fn query_str(&mut self, query: &str) -> &mut Request { self.query.add_str(query); self } @@ -326,8 +319,7 @@ impl Request { /// let r2 = ureq::get("http://martin:rubbermashgum@localhost/my_page").call(); /// println!("{:?}", r2); /// ``` - pub fn auth(&mut self, user: &str, pass: &str) -> &mut Request - { + pub fn auth(&mut self, user: &str, pass: &str) -> &mut Request { let pass = basic_auth(user, pass); self.auth_kind("Basic", &pass) } @@ -340,8 +332,7 @@ impl Request { /// .call(); /// println!("{:?}", r); /// ``` - pub fn auth_kind(&mut self, kind: &str, pass: &str) -> &mut Request - { + pub fn auth_kind(&mut self, kind: &str, pass: &str) -> &mut Request { let value = format!("{} {}", kind, pass); self.set("Authorization", &value); self @@ -445,8 +436,7 @@ impl Request { /// assert_eq!(req.get_scheme().unwrap(), "https"); /// ``` pub fn get_scheme(&self) -> Result { - self.to_url() - .map(|u| u.scheme().to_string()) + self.to_url().map(|u| u.scheme().to_string()) } /// The complete query for this request. @@ -459,8 +449,7 @@ impl Request { /// assert_eq!(req.get_query().unwrap(), "?foo=bar&format=json"); /// ``` pub fn get_query(&self) -> Result { - self.to_url() - .map(|u| combine_query(&u, &self.query, true)) + self.to_url().map(|u| combine_query(&u, &self.query, true)) } /// The normalized path of this request. @@ -472,8 +461,7 @@ impl Request { /// assert_eq!(req.get_path().unwrap(), "/innit"); /// ``` pub fn get_path(&self) -> Result { - self.to_url() - .map(|u| u.path().to_string()) + self.to_url().map(|u| u.path().to_string()) } fn to_url(&self) -> Result { diff --git a/src/response.rs b/src/response.rs index 5cc4f50..85acd4c 100644 --- a/src/response.rs +++ b/src/response.rs @@ -256,7 +256,7 @@ impl Response { .map(|c| c.eq_ignore_ascii_case("close")) .unwrap_or(false); - let is_head = (&self.unit).as_ref().map(|u| u.is_head).unwrap_or(false); + let is_head = (&self.unit).as_ref().map(|u| u.is_head()).unwrap_or(false); let is_chunked = self .header("transfer-encoding") diff --git a/src/test/mod.rs b/src/test/mod.rs index 7df408c..9984362 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -12,6 +12,7 @@ mod body_read; mod body_send; mod query_string; mod range; +mod redirect; mod simple; type RequestHandler = Fn(&Unit) -> Result + Send + 'static; diff --git a/src/test/redirect.rs b/src/test/redirect.rs new file mode 100644 index 0000000..3ea329c --- /dev/null +++ b/src/test/redirect.rs @@ -0,0 +1,90 @@ +use crate::test; + +use super::super::*; + +#[test] +fn redirect_on() { + test::set_handler("/redirect_on1", |_| { + test::make_response(302, "Go here", vec!["Location: /redirect_on2"], vec![]) + }); + test::set_handler("/redirect_on2", |_| { + test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) + }); + let resp = get("test://host/redirect_on1").call(); + assert_eq!(resp.status(), 200); + assert!(resp.has("x-foo")); + assert_eq!(resp.header("x-foo").unwrap(), "bar"); +} + +#[test] +fn redirect_many() { + test::set_handler("/redirect_many1", |_| { + test::make_response(302, "Go here", vec!["Location: /redirect_many2"], vec![]) + }); + test::set_handler("/redirect_many2", |_| { + test::make_response(302, "Go here", vec!["Location: /redirect_many3"], vec![]) + }); + let resp = get("test://host/redirect_many1").redirects(1).call(); + assert_eq!(resp.status(), 500); + assert_eq!(resp.status_text(), "Too Many Redirects"); +} + +#[test] +fn redirect_off() { + test::set_handler("/redirect_off", |_| { + test::make_response(302, "Go here", vec!["Location: somewhere.else"], vec![]) + }); + let resp = get("test://host/redirect_off").redirects(0).call(); + assert_eq!(resp.status(), 302); + assert!(resp.has("Location")); + assert_eq!(resp.header("Location").unwrap(), "somewhere.else"); +} + +#[test] +fn redirect_head() { + test::set_handler("/redirect_head1", |_| { + test::make_response(302, "Go here", vec!["Location: /redirect_head2"], vec![]) + }); + test::set_handler("/redirect_head2", |unit| { + assert_eq!(unit.method, "HEAD"); + test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) + }); + let resp = head("test://host/redirect_head1").call(); + assert_eq!(resp.status(), 200); + assert!(resp.has("x-foo")); + assert_eq!(resp.header("x-foo").unwrap(), "bar"); +} + +#[test] +fn redirect_get() { + test::set_handler("/redirect_get1", |_| { + test::make_response(302, "Go here", vec!["Location: /redirect_get2"], vec![]) + }); + test::set_handler("/redirect_get2", |unit| { + assert_eq!(unit.method, "GET"); + assert!(unit.has("Range")); + assert_eq!(unit.header("Range").unwrap(), "bytes=10-50"); + test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) + }); + let resp = get("test://host/redirect_get1") + .set("Range", "bytes=10-50") + .call(); + assert_eq!(resp.status(), 200); + assert!(resp.has("x-foo")); + assert_eq!(resp.header("x-foo").unwrap(), "bar"); +} + +#[test] +fn redirect_post() { + test::set_handler("/redirect_post1", |_| { + test::make_response(302, "Go here", vec!["Location: /redirect_post2"], vec![]) + }); + test::set_handler("/redirect_post2", |unit| { + assert_eq!(unit.method, "GET"); + test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) + }); + let resp = post("test://host/redirect_post1").call(); + assert_eq!(resp.status(), 200); + assert!(resp.has("x-foo")); + assert_eq!(resp.header("x-foo").unwrap(), "bar"); +} diff --git a/src/test/simple.rs b/src/test/simple.rs index 2f1c0ca..000f0f3 100644 --- a/src/test/simple.rs +++ b/src/test/simple.rs @@ -138,41 +138,3 @@ fn non_ascii_header() { assert_eq!(resp.status(), 500); assert_eq!(resp.status_text(), "Bad Header"); } - -#[test] -fn redirect_on() { - test::set_handler("/redirect_on1", |_| { - test::make_response(302, "Go here", vec!["Location: /redirect_on2"], vec![]) - }); - test::set_handler("/redirect_on2", |_| { - test::make_response(200, "OK", vec!["x-foo: bar"], vec![]) - }); - let resp = get("test://host/redirect_on1").call(); - assert_eq!(resp.status(), 200); - assert!(resp.has("x-foo")); - assert_eq!(resp.header("x-foo").unwrap(), "bar"); -} - -#[test] -fn redirect_many() { - test::set_handler("/redirect_many1", |_| { - test::make_response(302, "Go here", vec!["Location: /redirect_many2"], vec![]) - }); - test::set_handler("/redirect_many2", |_| { - test::make_response(302, "Go here", vec!["Location: /redirect_many3"], vec![]) - }); - let resp = get("test://host/redirect_many1").redirects(1).call(); - assert_eq!(resp.status(), 500); - assert_eq!(resp.status_text(), "Too Many Redirects"); -} - -#[test] -fn redirect_off() { - test::set_handler("/redirect_off", |_| { - test::make_response(302, "Go here", vec!["Location: somewhere.else"], vec![]) - }); - let resp = get("test://host/redirect_off").redirects(0).call(); - assert_eq!(resp.status(), 302); - assert!(resp.has("Location")); - assert_eq!(resp.header("Location").unwrap(), "somewhere.else"); -} diff --git a/src/unit.rs b/src/unit.rs index 536d0e8..c0e25eb 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -1,7 +1,7 @@ -use base64; use crate::body::{send_body, Payload, SizedReader}; -use std::io::{Result as IoResult, Write}; use crate::stream::{connect_http, connect_https, connect_test, Stream}; +use base64; +use std::io::{Result as IoResult, Write}; use url::Url; // @@ -15,12 +15,12 @@ pub struct Unit { pub agent: Arc>>, pub url: Url, pub is_chunked: bool, - pub is_head: bool, pub query_string: String, pub headers: Vec
, pub timeout_connect: u64, pub timeout_read: u64, pub timeout_write: u64, + pub method: String, } impl Unit { @@ -29,7 +29,8 @@ impl Unit { fn new(req: &Request, url: &Url, mix_queries: bool, body: &SizedReader) -> Self { // - let is_chunked = req.header("transfer-encoding") + let is_chunked = req + .header("transfer-encoding") // if the user has set an encoding header, obey that. .map(|enc| !enc.is_empty()) // otherwise, no chunking. @@ -37,8 +38,6 @@ impl Unit { let is_secure = url.scheme().eq_ignore_ascii_case("https"); - let is_head = req.method.eq_ignore_ascii_case("head"); - let hostname = url.host_str().unwrap_or(DEFAULT_HOST).to_string(); let query_string = combine_query(&url, &req.query, mix_queries); @@ -82,15 +81,19 @@ impl Unit { agent: Arc::clone(&req.agent), url: url.clone(), is_chunked, - is_head, query_string, headers, timeout_connect: req.timeout_connect, timeout_read: req.timeout_read, timeout_write: req.timeout_write, + method: req.method.clone(), } } + pub fn is_head(&self) -> bool { + self.method.eq_ignore_ascii_case("head") + } + #[cfg(test)] pub fn header<'a>(&self, name: &'a str) -> Option<&str> { get_header(&self.headers, name) @@ -109,23 +112,23 @@ impl Unit { pub fn connect( req: &Request, unit: Unit, - method: &str, use_pooled: bool, redirect_count: u32, body: SizedReader, + redir: bool, ) -> Result { // // open socket let (mut stream, is_recycled) = connect_socket(&unit, use_pooled)?; - let send_result = send_prelude(&unit, method, &mut stream); + let send_result = send_prelude(&unit, &mut stream, redir); if send_result.is_err() { if is_recycled { // we try open a new connection, this time there will be // no connection in the pool. don't use it. - return connect(req, unit, method, false, redirect_count, body); + return connect(req, unit, false, redirect_count, body, redir); } else { // not a pooled connection, propagate the error. return Err(send_result.unwrap_err().into()); @@ -161,10 +164,16 @@ pub fn connect( 301 | 302 | 303 => { let empty = Payload::Empty.into_read(); // recreate the unit to get a new hostname and cookies for the new host. - let new_unit = Unit::new(req, &new_url, false, &empty); - return connect(req, new_unit, "GET", use_pooled, redirect_count + 1, empty); + let mut new_unit = Unit::new(req, &new_url, false, &empty); + // this is to follow how curl does it. POST, PUT etc change + // to GET on a redirect. + new_unit.method = match &unit.method[..] { + "GET" | "HEAD" => unit.method, + _ => "GET".into(), + }; + return connect(req, new_unit, use_pooled, redirect_count + 1, empty, true); } - , _ => (), + _ => (), // reinstate this with expect-100 // 307 | 308 | _ => connect(unit, method, use_pooled, redirects - 1, body), }; @@ -239,7 +248,7 @@ fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error } /// Send request line + headers (all up until the body). -fn send_prelude(unit: &Unit, method: &str, stream: &mut Stream) -> IoResult<()> { +fn send_prelude(unit: &Unit, stream: &mut Stream, redir: bool) -> IoResult<()> { // // build into a buffer and send in one go. @@ -249,7 +258,7 @@ fn send_prelude(unit: &Unit, method: &str, stream: &mut Stream) -> IoResult<()> write!( prelude, "{} {}{} HTTP/1.1\r\n", - method, + unit.method, unit.url.path(), &unit.query_string )?; @@ -267,7 +276,9 @@ fn send_prelude(unit: &Unit, method: &str, stream: &mut Stream) -> IoResult<()> // other headers for header in &unit.headers { - write!(prelude, "{}: {}\r\n", header.name(), header.value())?; + if !redir || !header.is_name("Authorization") { + write!(prelude, "{}: {}\r\n", header.name(), header.value())?; + } } // finish