diff --git a/src/agent.rs b/src/agent.rs index a3594ae..ca088a5 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -2,13 +2,14 @@ use cookie::{Cookie, CookieJar}; use std::str::FromStr; use std::sync::Mutex; -use header::{add_header, Header}; +use header::{add_header, get_header, get_all_headers, has_header, Header}; // to get to share private fields include!("request.rs"); include!("response.rs"); include!("conn.rs"); include!("stream.rs"); +include!("unit.rs"); /// Agents keep state between requests. /// @@ -45,7 +46,7 @@ pub struct Agent { } #[derive(Debug)] -struct AgentState { +pub struct AgentState { pool: ConnectionPool, jar: CookieJar, } @@ -109,7 +110,7 @@ impl Agent { { let s = format!("{}: {}", header.into(), value.into()); let header = s.parse::
().expect("Failed to parse header"); - add_header(header, &mut self.headers); + add_header(&mut self.headers, header); self } @@ -145,7 +146,7 @@ impl Agent { for (k, v) in headers.into_iter() { let s = format!("{}: {}", k.into(), v.into()); let header = s.parse::
().expect("Failed to parse header"); - add_header(header, &mut self.headers); + add_header(&mut self.headers, header); } self } @@ -192,7 +193,7 @@ impl Agent { { let s = format!("Authorization: {} {}", kind.into(), pass.into()); let header = s.parse::
().expect("Failed to parse header"); - add_header(header, &mut self.headers); + add_header(&mut self.headers, header); self } diff --git a/src/conn.rs b/src/conn.rs index d352329..deb9b7d 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -7,137 +7,9 @@ const CHUNK_SIZE: usize = 1024 * 1024; pub struct ConnectionPool {} impl ConnectionPool { - fn new() -> Self { + pub fn new() -> Self { ConnectionPool {} } - - fn connect( - &mut self, - request: &Request, - method: &str, - url: &Url, - redirects: u32, - mut jar: Option<&mut CookieJar>, - body: SizedReader, - ) -> Result { - // - - let do_chunk = request.header("transfer-encoding") - // if the user has set an encoding header, obey that. - .map(|enc| enc.len() > 0) - // otherwise, no chunking. - .unwrap_or(false); - - let hostname = url.host_str().unwrap_or("localhost"); // is localhost a good alternative? - - let query_string = combine_query(&url, &request.query); - - let is_secure = url.scheme().eq_ignore_ascii_case("https"); - - let is_head = request.method.eq_ignore_ascii_case("head"); - - let cookie_headers: Vec<_> = { - match jar.as_ref() { - None => vec![], - Some(jar) => match_cookies(jar, hostname, url.path(), is_secure), - } - }; - let extra_headers = { - let mut extra = vec![]; - - // chunking and Content-Length headers are mutually exclusive - // also don't write this if the user has set it themselves - if !do_chunk && !request.has("content-length") { - if let Some(size) = body.size { - extra.push(format!("Content-Length: {}\r\n", size).parse::
()?); - } - } - extra - }; - let headers = request - .headers - .iter() - .chain(cookie_headers.iter()) - .chain(extra_headers.iter()); - - // open socket - let mut stream = match url.scheme() { - "http" => connect_http(request, &url), - "https" => connect_https(request, &url), - "test" => connect_test(request, &url), - _ => Err(Error::UnknownScheme(url.scheme().to_string())), - }?; - - // send the request start + headers - let mut prelude: Vec = vec![]; - write!(prelude, "{} {}{} HTTP/1.1\r\n", method, url.path(), query_string)?; - if !request.has("host") { - write!(prelude, "Host: {}\r\n", url.host().unwrap())?; - } - for header in headers { - write!(prelude, "{}: {}\r\n", header.name(), header.value())?; - } - write!(prelude, "\r\n")?; - - stream.write_all(&mut prelude[..])?; - - // start reading the response to process cookies and redirects. - let mut resp = Response::from_read(&mut stream); - - // squirrel away cookies - if let Some(add_jar) = jar.as_mut() { - for raw_cookie in resp.all("set-cookie").iter() { - let to_parse = if raw_cookie.to_lowercase().contains("domain=") { - raw_cookie.to_string() - } else { - format!("{}; Domain={}", raw_cookie, hostname) - }; - match Cookie::parse_encoded(&to_parse[..]) { - Err(_) => (), // ignore unparseable cookies - Ok(mut cookie) => { - let cookie = cookie.into_owned(); - add_jar.add(cookie) - } - } - } - } - - // handle redirects - if resp.redirect() { - if redirects == 0 { - return Err(Error::TooManyRedirects); - } - - // the location header - let location = resp.header("location"); - if let Some(location) = location { - // join location header to current url in case it it relative - let new_url = url.join(location) - .map_err(|_| Error::BadUrl(format!("Bad redirection: {}", location)))?; - - // perform the redirect differently depending on 3xx code. - return match resp.status { - 301 | 302 | 303 => { - send_body(body, do_chunk, &mut stream)?; - let empty = Payload::Empty.into_read(); - self.connect(request, "GET", &new_url, redirects - 1, jar, empty) - } - 307 | 308 | _ => { - self.connect(request, method, &new_url, redirects - 1, jar, body) - } - }; - } - } - - // send the body (which can be empty now depending on redirects) - send_body(body, do_chunk, &mut stream)?; - - // since it is not a redirect, give away the incoming stream to the response object - resp.set_stream(stream, is_head); - - // release the response - Ok(resp) - } } fn send_body(body: SizedReader, do_chunk: bool, stream: &mut Stream) -> IoResult<()> { @@ -165,41 +37,3 @@ where } Ok(()) } - -// TODO check so cookies can't be set for tld:s -fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bool) -> Vec
{ - jar.iter() - .filter(|c| { - // if there is a domain, it must be matched. if there is no domain, then ignore cookie - let domain_ok = c.domain() - .map(|cdom| domain.contains(cdom)) - .unwrap_or(false); - // a path must match the beginning of request path. no cookie path, we say is ok. is it?! - let path_ok = c.path() - .map(|cpath| path.find(cpath).map(|pos| pos == 0).unwrap_or(false)) - .unwrap_or(true); - // either the cookie isnt secure, or we're not doing a secure request. - let secure_ok = !c.secure() || is_secure; - - domain_ok && path_ok && secure_ok - }) - .map(|c| { - let name = c.name().to_string(); - let value = c.value().to_string(); - let nameval = Cookie::new(name, value).encoded().to_string(); - let head = format!("Cookie: {}", nameval); - head.parse::
().ok() - }) - .filter(|o| o.is_some()) - .map(|o| o.unwrap()) - .collect() -} - -fn combine_query(url: &Url, query: &QString) -> String { - match (url.query(), query.len() > 0) { - (Some(urlq), true) => format!("?{}&{}", urlq, query), - (Some(urlq), false) => format!("?{}", urlq), - (None, true) => format!("?{}", query), - (None, false) => "".to_string(), - } -} diff --git a/src/header.rs b/src/header.rs index a990359..e823e68 100644 --- a/src/header.rs +++ b/src/header.rs @@ -53,6 +53,30 @@ impl Header { } } +pub fn get_header<'a, 'b>(headers: &'b Vec
, name: &'a str) -> Option<&'b str> { + headers.iter().find(|h| h.is_name(name)).map(|h| h.value()) +} + +pub fn get_all_headers<'a, 'b>(headers: &'b Vec
, name: &'a str) -> Vec<&'b str> { + headers + .iter() + .filter(|h| h.is_name(name)) + .map(|h| h.value()) + .collect() +} + +pub fn has_header(headers: &Vec
, name: &str) -> bool { + get_header(headers, name).is_some() +} + +pub fn add_header(headers: &mut Vec
, header: Header) { + if !header.name().to_lowercase().starts_with("x-") { + let name = header.name(); + headers.retain(|h| h.name() != name); + } + headers.push(header); +} + impl FromStr for Header { type Err = Error; fn from_str(s: &str) -> Result { @@ -68,11 +92,3 @@ impl FromStr for Header { Ok(Header { line, index }) } } - -pub fn add_header(header: Header, headers: &mut Vec
) { - if !header.name().to_lowercase().starts_with("x-") { - let name = header.name(); - headers.retain(|h| h.name() != name); - } - headers.push(header); -} diff --git a/src/request.rs b/src/request.rs index a290ff7..8909c0e 100644 --- a/src/request.rs +++ b/src/request.rs @@ -23,7 +23,7 @@ lazy_static! { /// ``` #[derive(Clone, Default)] pub struct Request { - state: Arc>>, + agent: Arc>>, // via agent method: String, @@ -32,9 +32,9 @@ pub struct Request { // from request itself headers: Vec
, query: QString, - timeout: u32, - timeout_read: u32, - timeout_write: u32, + timeout_connect: u64, + timeout_read: u64, + timeout_write: u64, redirects: u32, } @@ -64,7 +64,7 @@ impl Default for Payload { } } -struct SizedReader { +pub struct SizedReader { size: Option, reader: Box, } @@ -108,7 +108,7 @@ impl Payload { impl Request { fn new(agent: &Agent, method: String, path: String) -> Request { Request { - state: Arc::clone(&agent.state), + agent: Arc::clone(&agent.state), method, path, headers: agent.headers.clone(), @@ -132,11 +132,11 @@ impl Request { /// Executes the request and blocks the caller until done. /// - /// Use `.timeout()` and `.timeout_read()` to avoid blocking forever. + /// Use `.timeout_connect()` and `.timeout_read()` to avoid blocking forever. /// /// ``` /// let r = ureq::get("/my_page") - /// .timeout(10_000) // max 10 seconds + /// .timeout_connect(10_000) // max 10 seconds /// .call(); /// /// println!("{:?}", r); @@ -146,32 +146,11 @@ impl Request { } fn do_call(&mut self, payload: Payload) -> Response { - let mut state = self.state.lock().unwrap(); self.to_url() .and_then(|url| { - match state.as_mut() { - None => - // create a one off pool/jar. - ConnectionPool::new().connect( - self, - &self.method, - &url, - self.redirects, - None, - payload.into_read(), - ), - Some(state) => { - let jar = &mut state.jar; - state.pool.connect( - self, - &self.method, - &url, - self.redirects, - Some(jar), - payload.into_read(), - ) - }, - } + let reader = payload.into_read(); + let mut unit = Unit::new(&self, &url, &reader); + unit.connect(url, &self.method, self.redirects, reader) }) .unwrap_or_else(|e| e.into()) } @@ -269,7 +248,7 @@ impl Request { { let s = format!("{}: {}", header.into(), value.into()); let header = s.parse::
().expect("Failed to parse header"); - add_header(header, &mut self.headers); + add_header(&mut self.headers, header); self } @@ -282,10 +261,7 @@ impl Request { /// assert_eq!("foobar", req.header("x-api-Key").unwrap()); /// ``` pub fn header<'a>(&self, name: &'a str) -> Option<&str> { - self.headers - .iter() - .find(|h| h.is_name(name)) - .map(|h| h.value()) + get_header(&self.headers, name) } /// Tells if the header has been set. @@ -297,7 +273,7 @@ impl Request { /// assert_eq!(true, req.has("x-api-Key")); /// ``` pub fn has<'a>(&self, name: &'a str) -> bool { - self.header(name).is_some() + has_header(&self.headers, name) } /// All headers corresponding values for the give name, or empty vector. @@ -313,11 +289,7 @@ impl Request { /// ]); /// ``` pub fn all<'a>(&self, name: &'a str) -> Vec<&str> { - self.headers - .iter() - .filter(|h| h.is_name(name)) - .map(|h| h.value()) - .collect() + get_all_headers(&self.headers, name) } /// Set many headers. @@ -348,7 +320,7 @@ impl Request { for (k, v) in headers.into_iter() { let s = format!("{}: {}", k.into(), v.into()); let header = s.parse::
().expect("Failed to parse header"); - add_header(header, &mut self.headers); + add_header(&mut self.headers, header); } self } @@ -430,12 +402,12 @@ impl Request { /// /// ``` /// let r = ureq::get("/my_page") - /// .timeout(1_000) // wait max 1 second to connect + /// .timeout_connect(1_000) // wait max 1 second to connect /// .call(); /// println!("{:?}", r); /// ``` - pub fn timeout(&mut self, millis: u32) -> &mut Request { - self.timeout = millis; + pub fn timeout_connect(&mut self, millis: u64) -> &mut Request { + self.timeout_connect = millis; self } @@ -449,7 +421,7 @@ impl Request { /// .call(); /// println!("{:?}", r); /// ``` - pub fn timeout_read(&mut self, millis: u32) -> &mut Request { + pub fn timeout_read(&mut self, millis: u64) -> &mut Request { self.timeout_read = millis; self } @@ -464,7 +436,7 @@ impl Request { /// .call(); /// println!("{:?}", r); /// ``` - pub fn timeout_write(&mut self, millis: u32) -> &mut Request { + pub fn timeout_write(&mut self, millis: u64) -> &mut Request { self.timeout_write = millis; self } @@ -508,7 +480,7 @@ impl Request { { let s = format!("Authorization: {} {}", kind.into(), pass.into()); let header = s.parse::
().expect("Failed to parse header"); - add_header(header, &mut self.headers); + add_header(&mut self.headers, header); self } diff --git a/src/stream.rs b/src/stream.rs index 0f55a26..6695de6 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -60,28 +60,28 @@ impl Write for Stream { } } -fn connect_http(request: &Request, url: &Url) -> Result { +fn connect_http(unit: &Unit) -> Result { // - let hostname = url.host_str().unwrap(); - let port = url.port().unwrap_or(80); + let hostname = unit.url.host_str().unwrap(); + let port = unit.url.port().unwrap_or(80); - connect_host(request, hostname, port).map(|tcp| Stream::Http(tcp)) + connect_host(unit, hostname, port).map(|tcp| Stream::Http(tcp)) } #[cfg(feature = "tls")] -fn connect_https(request: &Request, url: &Url) -> Result { +fn connect_https(unit: &Unit) -> Result { // - let hostname = url.host_str().unwrap(); - let port = url.port().unwrap_or(443); + let hostname = unit.url.host_str().unwrap(); + let port = unit.url.port().unwrap_or(443); - let socket = connect_host(request, hostname, port)?; + let socket = connect_host(unit, hostname, port)?; let connector = TlsConnector::builder().build()?; let stream = connector.connect(hostname, socket)?; Ok(Stream::Https(stream)) } -fn connect_host(request: &Request, hostname: &str, port: u16) -> Result { +fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result { // let ips: Vec = format!("{}:{}", hostname, port).to_socket_addrs() .map_err(|e| Error::DnsFailed(format!("{}", e)))? @@ -95,20 +95,20 @@ fn connect_host(request: &Request, hostname: &str, port: u16) -> Result TcpStream::connect(&sock_addr), - _ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(request.timeout as u64)), + _ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(unit.timeout_connect as u64)), }.map_err(|err| Error::ConnectionFailed(format!("{}", err)))?; // rust's absurd api returns Err if we set 0. - if request.timeout_read > 0 { + if unit.timeout_read > 0 { stream - .set_read_timeout(Some(Duration::from_millis(request.timeout_read as u64))) + .set_read_timeout(Some(Duration::from_millis(unit.timeout_read as u64))) .ok(); } - if request.timeout_write > 0 { + if unit.timeout_write > 0 { stream - .set_write_timeout(Some(Duration::from_millis(request.timeout_write as u64))) + .set_write_timeout(Some(Duration::from_millis(unit.timeout_write as u64))) .ok(); } @@ -116,17 +116,17 @@ fn connect_host(request: &Request, hostname: &str, port: u16) -> Result Result { +fn connect_test(unit: &Unit) -> Result { use test; - test::resolve_handler(request, url) + test::resolve_handler(unit) } #[cfg(not(test))] -fn connect_test(_request: &Request, url: &Url) -> Result { - Err(Error::UnknownScheme(url.scheme().to_string())) +fn connect_test(unit: &Unit) -> Result { + Err(Error::UnknownScheme(unit.url.scheme().to_string())) } #[cfg(not(feature = "tls"))] -fn connect_https(request: &Request, url: &Url) -> Result { - Err(Error::UnknownScheme(url.scheme().to_string())) +fn connect_https(unit: &Unit) -> Result { + Err(Error::UnknownScheme(unit.url.scheme().to_string())) } diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index 7fe1b74..762f734 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -6,18 +6,18 @@ use super::super::*; fn agent_reuse_headers() { let agent = agent().set("Authorization", "Foo 12345").build(); - test::set_handler("/agent_reuse_headers", |req, _url| { - assert!(req.has("Authorization")); - assert_eq!(req.header("Authorization").unwrap(), "Foo 12345"); + test::set_handler("/agent_reuse_headers", |unit| { + assert!(unit.has("Authorization")); + assert_eq!(unit.header("Authorization").unwrap(), "Foo 12345"); test::make_response(200, "OK", vec!["X-Call: 1"], vec![]) }); let resp = agent.get("test://host/agent_reuse_headers").call(); assert_eq!(resp.header("X-Call").unwrap(), "1"); - test::set_handler("/agent_reuse_headers", |req, _url| { - assert!(req.has("Authorization")); - assert_eq!(req.header("Authorization").unwrap(), "Foo 12345"); + test::set_handler("/agent_reuse_headers", |unit| { + assert!(unit.has("Authorization")); + assert_eq!(unit.header("Authorization").unwrap(), "Foo 12345"); test::make_response(200, "OK", vec!["X-Call: 2"], vec![]) }); @@ -29,7 +29,7 @@ fn agent_reuse_headers() { fn agent_cookies() { let agent = agent().build(); - test::set_handler("/agent_cookies", |_req, _url| { + test::set_handler("/agent_cookies", |_unit| { test::make_response( 200, "OK", @@ -43,7 +43,7 @@ fn agent_cookies() { assert!(agent.cookie("foo").is_some()); assert_eq!(agent.cookie("foo").unwrap().value(), "bar baz"); - test::set_handler("/agent_cookies", |_req, _url| { + test::set_handler("/agent_cookies", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); diff --git a/src/test/auth.rs b/src/test/auth.rs index 727182b..719f205 100644 --- a/src/test/auth.rs +++ b/src/test/auth.rs @@ -4,9 +4,9 @@ use super::super::*; #[test] fn basic_auth() { - test::set_handler("/basic_auth", |req, _url| { + test::set_handler("/basic_auth", |unit| { assert_eq!( - req.header("Authorization").unwrap(), + unit.header("Authorization").unwrap(), "Basic bWFydGluOnJ1YmJlcm1hc2hndW0=" ); test::make_response(200, "OK", vec![], vec![]) @@ -19,8 +19,8 @@ fn basic_auth() { #[test] fn kind_auth() { - test::set_handler("/kind_auth", |req, _url| { - assert_eq!(req.header("Authorization").unwrap(), "Digest abcdefgh123"); + test::set_handler("/kind_auth", |unit| { + assert_eq!(unit.header("Authorization").unwrap(), "Digest abcdefgh123"); test::make_response(200, "OK", vec![], vec![]) }); let resp = get("test://host/kind_auth") diff --git a/src/test/body_read.rs b/src/test/body_read.rs index 078c86d..fae8a8c 100644 --- a/src/test/body_read.rs +++ b/src/test/body_read.rs @@ -5,7 +5,7 @@ use super::super::*; #[test] fn transfer_encoding_bogus() { - test::set_handler("/transfer_encoding_bogus", |_req, _url| { + test::set_handler("/transfer_encoding_bogus", |_unit| { test::make_response( 200, "OK", @@ -26,7 +26,7 @@ fn transfer_encoding_bogus() { #[test] fn content_length_limited() { - test::set_handler("/content_length_limited", |_req, _url| { + test::set_handler("/content_length_limited", |_unit| { test::make_response( 200, "OK", @@ -44,7 +44,7 @@ fn content_length_limited() { #[test] // content-length should be ignored when chunked fn ignore_content_length_when_chunked() { - test::set_handler("/ignore_content_length_when_chunked", |_req, _url| { + test::set_handler("/ignore_content_length_when_chunked", |_unit| { test::make_response( 200, "OK", @@ -63,7 +63,7 @@ fn ignore_content_length_when_chunked() { #[test] fn no_reader_on_head() { - test::set_handler("/no_reader_on_head", |_req, _url| { + test::set_handler("/no_reader_on_head", |_unit| { // so this is technically illegal, we return a body for the HEAD request. test::make_response( 200, diff --git a/src/test/body_send.rs b/src/test/body_send.rs index bbc7426..1c858ac 100644 --- a/src/test/body_send.rs +++ b/src/test/body_send.rs @@ -4,7 +4,7 @@ use super::super::*; #[test] fn content_length_on_str() { - test::set_handler("/content_length_on_str", |_req, _url| { + test::set_handler("/content_length_on_str", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let resp = post("test://host/content_length_on_str").send_string("Hello World!!!"); @@ -15,7 +15,7 @@ fn content_length_on_str() { #[test] fn user_set_content_length_on_str() { - test::set_handler("/user_set_content_length_on_str", |_req, _url| { + test::set_handler("/user_set_content_length_on_str", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let resp = post("test://host/user_set_content_length_on_str") @@ -29,7 +29,7 @@ fn user_set_content_length_on_str() { #[test] #[cfg(feature = "json")] fn content_length_on_json() { - test::set_handler("/content_length_on_json", |_req, _url| { + test::set_handler("/content_length_on_json", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let mut json = SerdeMap::new(); @@ -45,7 +45,7 @@ fn content_length_on_json() { #[test] fn content_length_and_chunked() { - test::set_handler("/content_length_and_chunked", |_req, _url| { + test::set_handler("/content_length_and_chunked", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let resp = post("test://host/content_length_and_chunked") @@ -60,7 +60,7 @@ fn content_length_and_chunked() { #[test] #[cfg(feature = "charset")] fn str_with_encoding() { - test::set_handler("/str_with_encoding", |_req, _url| { + test::set_handler("/str_with_encoding", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let resp = post("test://host/str_with_encoding") diff --git a/src/test/mod.rs b/src/test/mod.rs index 08d91f9..423ac7f 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,4 +1,4 @@ -use agent::Request; +use agent::Unit; use agent::Stream; use error::Error; use header::Header; @@ -6,7 +6,6 @@ use std::collections::HashMap; use std::io::Cursor; use std::io::Write; use std::sync::{Arc, Mutex}; -use url::Url; mod agent_test; mod auth; @@ -16,7 +15,7 @@ mod query_string; mod range; mod simple; -type RequestHandler = Fn(&Request, &Url) -> Result + Send + 'static; +type RequestHandler = Fn(&Unit) -> Result + Send + 'static; lazy_static! { pub static ref TEST_HANDLERS: Arc>>> = @@ -25,7 +24,7 @@ lazy_static! { pub fn set_handler(path: &str, handler: H) where - H: Fn(&Request, &Url) -> Result + Send + 'static, + H: Fn(&Unit) -> Result + Send + 'static, { let mut handlers = TEST_HANDLERS.lock().unwrap(); handlers.insert(path.to_string(), Box::new(handler)); @@ -50,9 +49,9 @@ pub fn make_response( Ok(Stream::Test(Box::new(cursor), write)) } -pub fn resolve_handler(req: &Request, url: &Url) -> Result { +pub fn resolve_handler(unit: &Unit) -> Result { let mut handlers = TEST_HANDLERS.lock().unwrap(); - let path = url.path(); + let path = unit.url.path(); let handler = handlers.remove(path).unwrap(); - handler(req, url) + handler(unit) } diff --git a/src/test/query_string.rs b/src/test/query_string.rs index 36dcd30..8e2ab01 100644 --- a/src/test/query_string.rs +++ b/src/test/query_string.rs @@ -4,7 +4,7 @@ use super::super::*; #[test] fn no_query_string() { - test::set_handler("/no_query_string", |_req, _url| { + test::set_handler("/no_query_string", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let resp = get("test://host/no_query_string").call(); @@ -15,7 +15,7 @@ fn no_query_string() { #[test] fn escaped_query_string() { - test::set_handler("/escaped_query_string", |_req, _url| { + test::set_handler("/escaped_query_string", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let resp = get("test://host/escaped_query_string") @@ -29,7 +29,7 @@ fn escaped_query_string() { #[test] fn query_in_path() { - test::set_handler("/query_in_path", |_req, _url| { + test::set_handler("/query_in_path", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let resp = get("test://host/query_in_path?foo=bar").call(); @@ -40,7 +40,7 @@ fn query_in_path() { #[test] fn query_in_path_and_req() { - test::set_handler("/query_in_path_and_req", |_req, _url| { + test::set_handler("/query_in_path_and_req", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let resp = get("test://host/query_in_path_and_req?foo=bar") diff --git a/src/test/simple.rs b/src/test/simple.rs index f689bc3..1c8e10f 100644 --- a/src/test/simple.rs +++ b/src/test/simple.rs @@ -5,9 +5,9 @@ use super::super::*; #[test] fn header_passing() { - test::set_handler("/header_passing", |req, _url| { - assert!(req.has("X-Foo")); - assert_eq!(req.header("X-Foo").unwrap(), "bar"); + test::set_handler("/header_passing", |unit| { + assert!(unit.has("X-Foo")); + assert_eq!(unit.header("X-Foo").unwrap(), "bar"); test::make_response(200, "OK", vec!["X-Bar: foo"], vec![]) }); let resp = get("test://host/header_passing").set("X-Foo", "bar").call(); @@ -18,9 +18,9 @@ fn header_passing() { #[test] fn repeat_non_x_header() { - test::set_handler("/repeat_non_x_header", |req, _url| { - assert!(req.has("Accept")); - assert_eq!(req.header("Accept").unwrap(), "baz"); + test::set_handler("/repeat_non_x_header", |unit| { + assert!(unit.has("Accept")); + assert_eq!(unit.header("Accept").unwrap(), "baz"); test::make_response(200, "OK", vec![], vec![]) }); let resp = get("test://host/repeat_non_x_header") @@ -32,11 +32,11 @@ fn repeat_non_x_header() { #[test] fn repeat_x_header() { - test::set_handler("/repeat_x_header", |req, _url| { - assert!(req.has("X-Forwarded-For")); - assert_eq!(req.header("X-Forwarded-For").unwrap(), "130.240.19.2"); + test::set_handler("/repeat_x_header", |unit| { + assert!(unit.has("X-Forwarded-For")); + assert_eq!(unit.header("X-Forwarded-For").unwrap(), "130.240.19.2"); assert_eq!( - req.all("X-Forwarded-For"), + unit.all("X-Forwarded-For"), vec!["130.240.19.2", "130.240.19.3"] ); test::make_response(200, "OK", vec![], vec![]) @@ -50,7 +50,7 @@ fn repeat_x_header() { #[test] fn body_as_text() { - test::set_handler("/body_as_text", |_req, _url| { + test::set_handler("/body_as_text", |_unit| { test::make_response(200, "OK", vec![], "Hello World!".to_string().into_bytes()) }); let resp = get("test://host/body_as_text").call(); @@ -61,7 +61,7 @@ fn body_as_text() { #[test] #[cfg(feature = "json")] fn body_as_json() { - test::set_handler("/body_as_json", |_req, _url| { + test::set_handler("/body_as_json", |_unit| { test::make_response( 200, "OK", @@ -76,7 +76,7 @@ fn body_as_json() { #[test] fn body_as_reader() { - test::set_handler("/body_as_reader", |_req, _url| { + test::set_handler("/body_as_reader", |_unit| { test::make_response(200, "OK", vec![], "abcdefgh".to_string().into_bytes()) }); let resp = get("test://host/body_as_reader").call(); @@ -88,7 +88,7 @@ fn body_as_reader() { #[test] fn escape_path() { - test::set_handler("/escape_path%20here", |_req, _url| { + test::set_handler("/escape_path%20here", |_unit| { test::make_response(200, "OK", vec![], vec![]) }); let resp = get("test://host/escape_path here").call(); diff --git a/src/unit.rs b/src/unit.rs new file mode 100644 index 0000000..c0af75b --- /dev/null +++ b/src/unit.rs @@ -0,0 +1,230 @@ +// + +pub struct Unit { + pub agent: Arc>>, + pub url: Url, + pub is_chunked: bool, + pub is_head: bool, + pub hostname: String, + pub query_string: String, + pub headers: Vec
, + pub timeout_connect: u64, + pub timeout_read: u64, + pub timeout_write: u64, +} + +impl Unit { + // + + fn new(req: &Request, url: &Url, body: &SizedReader) -> Self { + // + + let is_chunked = req.header("transfer-encoding") + // if the user has set an encoding header, obey that. + .map(|enc| enc.len() > 0) + // otherwise, no chunking. + .unwrap_or(false); + + let is_secure = url.scheme().eq_ignore_ascii_case("https"); + + let is_head = req.method.eq_ignore_ascii_case("head"); + + let hostname = url.host_str().unwrap_or("localhost").to_string(); + + let query_string = combine_query(&url, &req.query); + + let cookie_headers: Vec<_> = { + let mut state = req.agent.lock().unwrap(); + match state.as_ref().map(|state| &state.jar) { + None => vec![], + Some(jar) => match_cookies(jar, &hostname, url.path(), is_secure), + } + }; + let extra_headers = { + let mut extra = vec![]; + + // chunking and Content-Length headers are mutually exclusive + // also don't write this if the user has set it themselves + if !is_chunked && !req.has("content-length") { + if let Some(size) = body.size { + extra.push( + format!("Content-Length: {}\r\n", size) + .parse::
() + .unwrap(), + ); + } + } + extra + }; + let headers: Vec<_> = req + .headers + .iter() + .chain(cookie_headers.iter()) + .chain(extra_headers.iter()) + .cloned() + .collect(); + + Unit { + agent: Arc::clone(&req.agent), + url: url.clone(), + is_chunked, + is_head, + hostname, + query_string, + headers, + timeout_connect: req.timeout_connect, + timeout_read: req.timeout_read, + timeout_write: req.timeout_write, + } + } + + fn connect( + &mut self, + url: Url, + method: &str, + redirects: u32, + body: SizedReader, + ) -> Result { + // + + // open socket + let mut stream = match url.scheme() { + "http" => connect_http(self), + "https" => connect_https(self), + "test" => connect_test(self), + _ => Err(Error::UnknownScheme(url.scheme().to_string())), + }?; + + // send the request start + headers + let mut prelude: Vec = vec![]; + write!( + prelude, + "{} {}{} HTTP/1.1\r\n", + method, + url.path(), + self.query_string + )?; + if !has_header(&self.headers, "host") { + write!(prelude, "Host: {}\r\n", url.host().unwrap())?; + } + for header in &self.headers { + write!(prelude, "{}: {}\r\n", header.name(), header.value())?; + } + write!(prelude, "\r\n")?; + + stream.write_all(&mut prelude[..])?; + + // start reading the response to process cookies and redirects. + let mut resp = Response::from_read(&mut stream); + + // squirrel away cookies + { + let mut state = self.agent.lock().unwrap(); + if let Some(add_jar) = state.as_mut().map(|state| &mut state.jar) { + for raw_cookie in resp.all("set-cookie").iter() { + let to_parse = if raw_cookie.to_lowercase().contains("domain=") { + raw_cookie.to_string() + } else { + format!("{}; Domain={}", raw_cookie, self.hostname) + }; + match Cookie::parse_encoded(&to_parse[..]) { + Err(_) => (), // ignore unparseable cookies + Ok(mut cookie) => { + let cookie = cookie.into_owned(); + add_jar.add(cookie) + } + } + } + } + } + + // handle redirects + if resp.redirect() { + if redirects == 0 { + return Err(Error::TooManyRedirects); + } + + // the location header + let location = resp.header("location"); + if let Some(location) = location { + // join location header to current url in case it it relative + let new_url = url + .join(location) + .map_err(|_| Error::BadUrl(format!("Bad redirection: {}", location)))?; + + // perform the redirect differently depending on 3xx code. + return match resp.status { + 301 | 302 | 303 => { + send_body(body, self.is_chunked, &mut stream)?; + let empty = Payload::Empty.into_read(); + self.connect(new_url, "GET", redirects - 1, empty) + } + 307 | 308 | _ => self.connect(new_url, method, redirects - 1, body), + }; + } + } + + // send the body (which can be empty now depending on redirects) + send_body(body, self.is_chunked, &mut stream)?; + + // since it is not a redirect, give away the incoming stream to the response object + resp.set_stream(stream, self.is_head); + + // release the response + Ok(resp) + } + + #[cfg(test)] + pub fn header<'a>(&self, name: &'a str) -> Option<&str> { + get_header(&self.headers, name) + } + #[cfg(test)] + pub fn has<'a>(&self, name: &'a str) -> bool { + has_header(&self.headers, name) + } + #[cfg(test)] + pub fn all<'a>(&self, name: &'a str) -> Vec<&str> { + get_all_headers(&self.headers, name) + } + +} + +// TODO check so cookies can't be set for tld:s +fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bool) -> Vec
{ + jar.iter() + .filter(|c| { + // if there is a domain, it must be matched. if there is no domain, then ignore cookie + let domain_ok = c + .domain() + .map(|cdom| domain.contains(cdom)) + .unwrap_or(false); + // a path must match the beginning of request path. no cookie path, we say is ok. is it?! + let path_ok = c + .path() + .map(|cpath| path.find(cpath).map(|pos| pos == 0).unwrap_or(false)) + .unwrap_or(true); + // either the cookie isnt secure, or we're not doing a secure request. + let secure_ok = !c.secure() || is_secure; + + domain_ok && path_ok && secure_ok + }) + .map(|c| { + let name = c.name().to_string(); + let value = c.value().to_string(); + let nameval = Cookie::new(name, value).encoded().to_string(); + let head = format!("Cookie: {}", nameval); + head.parse::
().ok() + }) + .filter(|o| o.is_some()) + .map(|o| o.unwrap()) + .collect() +} + +fn combine_query(url: &Url, query: &QString) -> String { + match (url.query(), query.len() > 0) { + (Some(urlq), true) => format!("?{}&{}", urlq, query), + (Some(urlq), false) => format!("?{}", urlq), + (None, true) => format!("?{}", query), + (None, false) => "".to_string(), + } +}