refactor into unit

This commit is contained in:
Martin Algesten
2018-06-30 13:05:36 +02:00
parent 0334f9608b
commit f5a4c83819
13 changed files with 349 additions and 297 deletions

View File

@@ -2,13 +2,14 @@ use cookie::{Cookie, CookieJar};
use std::str::FromStr; use std::str::FromStr;
use std::sync::Mutex; use std::sync::Mutex;
use header::{add_header, Header}; use header::{add_header, get_header, get_all_headers, has_header, Header};
// to get to share private fields // to get to share private fields
include!("request.rs"); include!("request.rs");
include!("response.rs"); include!("response.rs");
include!("conn.rs"); include!("conn.rs");
include!("stream.rs"); include!("stream.rs");
include!("unit.rs");
/// Agents keep state between requests. /// Agents keep state between requests.
/// ///
@@ -45,7 +46,7 @@ pub struct Agent {
} }
#[derive(Debug)] #[derive(Debug)]
struct AgentState { pub struct AgentState {
pool: ConnectionPool, pool: ConnectionPool,
jar: CookieJar, jar: CookieJar,
} }
@@ -109,7 +110,7 @@ impl Agent {
{ {
let s = format!("{}: {}", header.into(), value.into()); let s = format!("{}: {}", header.into(), value.into());
let header = s.parse::<Header>().expect("Failed to parse header"); let header = s.parse::<Header>().expect("Failed to parse header");
add_header(header, &mut self.headers); add_header(&mut self.headers, header);
self self
} }
@@ -145,7 +146,7 @@ impl Agent {
for (k, v) in headers.into_iter() { for (k, v) in headers.into_iter() {
let s = format!("{}: {}", k.into(), v.into()); let s = format!("{}: {}", k.into(), v.into());
let header = s.parse::<Header>().expect("Failed to parse header"); let header = s.parse::<Header>().expect("Failed to parse header");
add_header(header, &mut self.headers); add_header(&mut self.headers, header);
} }
self self
} }
@@ -192,7 +193,7 @@ impl Agent {
{ {
let s = format!("Authorization: {} {}", kind.into(), pass.into()); let s = format!("Authorization: {} {}", kind.into(), pass.into());
let header = s.parse::<Header>().expect("Failed to parse header"); let header = s.parse::<Header>().expect("Failed to parse header");
add_header(header, &mut self.headers); add_header(&mut self.headers, header);
self self
} }

View File

@@ -7,137 +7,9 @@ const CHUNK_SIZE: usize = 1024 * 1024;
pub struct ConnectionPool {} pub struct ConnectionPool {}
impl ConnectionPool { impl ConnectionPool {
fn new() -> Self { pub fn new() -> Self {
ConnectionPool {} ConnectionPool {}
} }
fn connect(
&mut self,
request: &Request,
method: &str,
url: &Url,
redirects: u32,
mut jar: Option<&mut CookieJar>,
body: SizedReader,
) -> Result<Response, Error> {
//
let do_chunk = request.header("transfer-encoding")
// if the user has set an encoding header, obey that.
.map(|enc| enc.len() > 0)
// otherwise, no chunking.
.unwrap_or(false);
let hostname = url.host_str().unwrap_or("localhost"); // is localhost a good alternative?
let query_string = combine_query(&url, &request.query);
let is_secure = url.scheme().eq_ignore_ascii_case("https");
let is_head = request.method.eq_ignore_ascii_case("head");
let cookie_headers: Vec<_> = {
match jar.as_ref() {
None => vec![],
Some(jar) => match_cookies(jar, hostname, url.path(), is_secure),
}
};
let extra_headers = {
let mut extra = vec![];
// chunking and Content-Length headers are mutually exclusive
// also don't write this if the user has set it themselves
if !do_chunk && !request.has("content-length") {
if let Some(size) = body.size {
extra.push(format!("Content-Length: {}\r\n", size).parse::<Header>()?);
}
}
extra
};
let headers = request
.headers
.iter()
.chain(cookie_headers.iter())
.chain(extra_headers.iter());
// open socket
let mut stream = match url.scheme() {
"http" => connect_http(request, &url),
"https" => connect_https(request, &url),
"test" => connect_test(request, &url),
_ => Err(Error::UnknownScheme(url.scheme().to_string())),
}?;
// send the request start + headers
let mut prelude: Vec<u8> = vec![];
write!(prelude, "{} {}{} HTTP/1.1\r\n", method, url.path(), query_string)?;
if !request.has("host") {
write!(prelude, "Host: {}\r\n", url.host().unwrap())?;
}
for header in headers {
write!(prelude, "{}: {}\r\n", header.name(), header.value())?;
}
write!(prelude, "\r\n")?;
stream.write_all(&mut prelude[..])?;
// start reading the response to process cookies and redirects.
let mut resp = Response::from_read(&mut stream);
// squirrel away cookies
if let Some(add_jar) = jar.as_mut() {
for raw_cookie in resp.all("set-cookie").iter() {
let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
raw_cookie.to_string()
} else {
format!("{}; Domain={}", raw_cookie, hostname)
};
match Cookie::parse_encoded(&to_parse[..]) {
Err(_) => (), // ignore unparseable cookies
Ok(mut cookie) => {
let cookie = cookie.into_owned();
add_jar.add(cookie)
}
}
}
}
// handle redirects
if resp.redirect() {
if redirects == 0 {
return Err(Error::TooManyRedirects);
}
// the location header
let location = resp.header("location");
if let Some(location) = location {
// join location header to current url in case it it relative
let new_url = url.join(location)
.map_err(|_| Error::BadUrl(format!("Bad redirection: {}", location)))?;
// perform the redirect differently depending on 3xx code.
return match resp.status {
301 | 302 | 303 => {
send_body(body, do_chunk, &mut stream)?;
let empty = Payload::Empty.into_read();
self.connect(request, "GET", &new_url, redirects - 1, jar, empty)
}
307 | 308 | _ => {
self.connect(request, method, &new_url, redirects - 1, jar, body)
}
};
}
}
// send the body (which can be empty now depending on redirects)
send_body(body, do_chunk, &mut stream)?;
// since it is not a redirect, give away the incoming stream to the response object
resp.set_stream(stream, is_head);
// release the response
Ok(resp)
}
} }
fn send_body(body: SizedReader, do_chunk: bool, stream: &mut Stream) -> IoResult<()> { fn send_body(body: SizedReader, do_chunk: bool, stream: &mut Stream) -> IoResult<()> {
@@ -165,41 +37,3 @@ where
} }
Ok(()) Ok(())
} }
// TODO check so cookies can't be set for tld:s
fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bool) -> Vec<Header> {
jar.iter()
.filter(|c| {
// if there is a domain, it must be matched. if there is no domain, then ignore cookie
let domain_ok = c.domain()
.map(|cdom| domain.contains(cdom))
.unwrap_or(false);
// a path must match the beginning of request path. no cookie path, we say is ok. is it?!
let path_ok = c.path()
.map(|cpath| path.find(cpath).map(|pos| pos == 0).unwrap_or(false))
.unwrap_or(true);
// either the cookie isnt secure, or we're not doing a secure request.
let secure_ok = !c.secure() || is_secure;
domain_ok && path_ok && secure_ok
})
.map(|c| {
let name = c.name().to_string();
let value = c.value().to_string();
let nameval = Cookie::new(name, value).encoded().to_string();
let head = format!("Cookie: {}", nameval);
head.parse::<Header>().ok()
})
.filter(|o| o.is_some())
.map(|o| o.unwrap())
.collect()
}
fn combine_query(url: &Url, query: &QString) -> String {
match (url.query(), query.len() > 0) {
(Some(urlq), true) => format!("?{}&{}", urlq, query),
(Some(urlq), false) => format!("?{}", urlq),
(None, true) => format!("?{}", query),
(None, false) => "".to_string(),
}
}

View File

@@ -53,6 +53,30 @@ impl Header {
} }
} }
pub fn get_header<'a, 'b>(headers: &'b Vec<Header>, name: &'a str) -> Option<&'b str> {
headers.iter().find(|h| h.is_name(name)).map(|h| h.value())
}
pub fn get_all_headers<'a, 'b>(headers: &'b Vec<Header>, name: &'a str) -> Vec<&'b str> {
headers
.iter()
.filter(|h| h.is_name(name))
.map(|h| h.value())
.collect()
}
pub fn has_header(headers: &Vec<Header>, name: &str) -> bool {
get_header(headers, name).is_some()
}
pub fn add_header(headers: &mut Vec<Header>, header: Header) {
if !header.name().to_lowercase().starts_with("x-") {
let name = header.name();
headers.retain(|h| h.name() != name);
}
headers.push(header);
}
impl FromStr for Header { impl FromStr for Header {
type Err = Error; type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
@@ -68,11 +92,3 @@ impl FromStr for Header {
Ok(Header { line, index }) Ok(Header { line, index })
} }
} }
pub fn add_header(header: Header, headers: &mut Vec<Header>) {
if !header.name().to_lowercase().starts_with("x-") {
let name = header.name();
headers.retain(|h| h.name() != name);
}
headers.push(header);
}

View File

@@ -23,7 +23,7 @@ lazy_static! {
/// ``` /// ```
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct Request { pub struct Request {
state: Arc<Mutex<Option<AgentState>>>, agent: Arc<Mutex<Option<AgentState>>>,
// via agent // via agent
method: String, method: String,
@@ -32,9 +32,9 @@ pub struct Request {
// from request itself // from request itself
headers: Vec<Header>, headers: Vec<Header>,
query: QString, query: QString,
timeout: u32, timeout_connect: u64,
timeout_read: u32, timeout_read: u64,
timeout_write: u32, timeout_write: u64,
redirects: u32, redirects: u32,
} }
@@ -64,7 +64,7 @@ impl Default for Payload {
} }
} }
struct SizedReader { pub struct SizedReader {
size: Option<usize>, size: Option<usize>,
reader: Box<Read + 'static>, reader: Box<Read + 'static>,
} }
@@ -108,7 +108,7 @@ impl Payload {
impl Request { impl Request {
fn new(agent: &Agent, method: String, path: String) -> Request { fn new(agent: &Agent, method: String, path: String) -> Request {
Request { Request {
state: Arc::clone(&agent.state), agent: Arc::clone(&agent.state),
method, method,
path, path,
headers: agent.headers.clone(), headers: agent.headers.clone(),
@@ -132,11 +132,11 @@ impl Request {
/// Executes the request and blocks the caller until done. /// Executes the request and blocks the caller until done.
/// ///
/// Use `.timeout()` and `.timeout_read()` to avoid blocking forever. /// Use `.timeout_connect()` and `.timeout_read()` to avoid blocking forever.
/// ///
/// ``` /// ```
/// let r = ureq::get("/my_page") /// let r = ureq::get("/my_page")
/// .timeout(10_000) // max 10 seconds /// .timeout_connect(10_000) // max 10 seconds
/// .call(); /// .call();
/// ///
/// println!("{:?}", r); /// println!("{:?}", r);
@@ -146,32 +146,11 @@ impl Request {
} }
fn do_call(&mut self, payload: Payload) -> Response { fn do_call(&mut self, payload: Payload) -> Response {
let mut state = self.state.lock().unwrap();
self.to_url() self.to_url()
.and_then(|url| { .and_then(|url| {
match state.as_mut() { let reader = payload.into_read();
None => let mut unit = Unit::new(&self, &url, &reader);
// create a one off pool/jar. unit.connect(url, &self.method, self.redirects, reader)
ConnectionPool::new().connect(
self,
&self.method,
&url,
self.redirects,
None,
payload.into_read(),
),
Some(state) => {
let jar = &mut state.jar;
state.pool.connect(
self,
&self.method,
&url,
self.redirects,
Some(jar),
payload.into_read(),
)
},
}
}) })
.unwrap_or_else(|e| e.into()) .unwrap_or_else(|e| e.into())
} }
@@ -269,7 +248,7 @@ impl Request {
{ {
let s = format!("{}: {}", header.into(), value.into()); let s = format!("{}: {}", header.into(), value.into());
let header = s.parse::<Header>().expect("Failed to parse header"); let header = s.parse::<Header>().expect("Failed to parse header");
add_header(header, &mut self.headers); add_header(&mut self.headers, header);
self self
} }
@@ -282,10 +261,7 @@ impl Request {
/// assert_eq!("foobar", req.header("x-api-Key").unwrap()); /// assert_eq!("foobar", req.header("x-api-Key").unwrap());
/// ``` /// ```
pub fn header<'a>(&self, name: &'a str) -> Option<&str> { pub fn header<'a>(&self, name: &'a str) -> Option<&str> {
self.headers get_header(&self.headers, name)
.iter()
.find(|h| h.is_name(name))
.map(|h| h.value())
} }
/// Tells if the header has been set. /// Tells if the header has been set.
@@ -297,7 +273,7 @@ impl Request {
/// assert_eq!(true, req.has("x-api-Key")); /// assert_eq!(true, req.has("x-api-Key"));
/// ``` /// ```
pub fn has<'a>(&self, name: &'a str) -> bool { pub fn has<'a>(&self, name: &'a str) -> bool {
self.header(name).is_some() has_header(&self.headers, name)
} }
/// All headers corresponding values for the give name, or empty vector. /// All headers corresponding values for the give name, or empty vector.
@@ -313,11 +289,7 @@ impl Request {
/// ]); /// ]);
/// ``` /// ```
pub fn all<'a>(&self, name: &'a str) -> Vec<&str> { pub fn all<'a>(&self, name: &'a str) -> Vec<&str> {
self.headers get_all_headers(&self.headers, name)
.iter()
.filter(|h| h.is_name(name))
.map(|h| h.value())
.collect()
} }
/// Set many headers. /// Set many headers.
@@ -348,7 +320,7 @@ impl Request {
for (k, v) in headers.into_iter() { for (k, v) in headers.into_iter() {
let s = format!("{}: {}", k.into(), v.into()); let s = format!("{}: {}", k.into(), v.into());
let header = s.parse::<Header>().expect("Failed to parse header"); let header = s.parse::<Header>().expect("Failed to parse header");
add_header(header, &mut self.headers); add_header(&mut self.headers, header);
} }
self self
} }
@@ -430,12 +402,12 @@ impl Request {
/// ///
/// ``` /// ```
/// let r = ureq::get("/my_page") /// let r = ureq::get("/my_page")
/// .timeout(1_000) // wait max 1 second to connect /// .timeout_connect(1_000) // wait max 1 second to connect
/// .call(); /// .call();
/// println!("{:?}", r); /// println!("{:?}", r);
/// ``` /// ```
pub fn timeout(&mut self, millis: u32) -> &mut Request { pub fn timeout_connect(&mut self, millis: u64) -> &mut Request {
self.timeout = millis; self.timeout_connect = millis;
self self
} }
@@ -449,7 +421,7 @@ impl Request {
/// .call(); /// .call();
/// println!("{:?}", r); /// println!("{:?}", r);
/// ``` /// ```
pub fn timeout_read(&mut self, millis: u32) -> &mut Request { pub fn timeout_read(&mut self, millis: u64) -> &mut Request {
self.timeout_read = millis; self.timeout_read = millis;
self self
} }
@@ -464,7 +436,7 @@ impl Request {
/// .call(); /// .call();
/// println!("{:?}", r); /// println!("{:?}", r);
/// ``` /// ```
pub fn timeout_write(&mut self, millis: u32) -> &mut Request { pub fn timeout_write(&mut self, millis: u64) -> &mut Request {
self.timeout_write = millis; self.timeout_write = millis;
self self
} }
@@ -508,7 +480,7 @@ impl Request {
{ {
let s = format!("Authorization: {} {}", kind.into(), pass.into()); let s = format!("Authorization: {} {}", kind.into(), pass.into());
let header = s.parse::<Header>().expect("Failed to parse header"); let header = s.parse::<Header>().expect("Failed to parse header");
add_header(header, &mut self.headers); add_header(&mut self.headers, header);
self self
} }

View File

@@ -60,28 +60,28 @@ impl Write for Stream {
} }
} }
fn connect_http(request: &Request, url: &Url) -> Result<Stream, Error> { fn connect_http(unit: &Unit) -> Result<Stream, Error> {
// //
let hostname = url.host_str().unwrap(); let hostname = unit.url.host_str().unwrap();
let port = url.port().unwrap_or(80); let port = unit.url.port().unwrap_or(80);
connect_host(request, hostname, port).map(|tcp| Stream::Http(tcp)) connect_host(unit, hostname, port).map(|tcp| Stream::Http(tcp))
} }
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
fn connect_https(request: &Request, url: &Url) -> Result<Stream, Error> { fn connect_https(unit: &Unit) -> Result<Stream, Error> {
// //
let hostname = url.host_str().unwrap(); let hostname = unit.url.host_str().unwrap();
let port = url.port().unwrap_or(443); let port = unit.url.port().unwrap_or(443);
let socket = connect_host(request, hostname, port)?; let socket = connect_host(unit, hostname, port)?;
let connector = TlsConnector::builder().build()?; let connector = TlsConnector::builder().build()?;
let stream = connector.connect(hostname, socket)?; let stream = connector.connect(hostname, socket)?;
Ok(Stream::Https(stream)) Ok(Stream::Https(stream))
} }
fn connect_host(request: &Request, hostname: &str, port: u16) -> Result<TcpStream, Error> { fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
// //
let ips: Vec<SocketAddr> = format!("{}:{}", hostname, port).to_socket_addrs() let ips: Vec<SocketAddr> = format!("{}:{}", hostname, port).to_socket_addrs()
.map_err(|e| Error::DnsFailed(format!("{}", e)))? .map_err(|e| Error::DnsFailed(format!("{}", e)))?
@@ -95,20 +95,20 @@ fn connect_host(request: &Request, hostname: &str, port: u16) -> Result<TcpStrea
let sock_addr = ips[0]; let sock_addr = ips[0];
// connect with a configured timeout. // connect with a configured timeout.
let stream = match request.timeout { let stream = match unit.timeout_connect {
0 => TcpStream::connect(&sock_addr), 0 => TcpStream::connect(&sock_addr),
_ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(request.timeout as u64)), _ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(unit.timeout_connect as u64)),
}.map_err(|err| Error::ConnectionFailed(format!("{}", err)))?; }.map_err(|err| Error::ConnectionFailed(format!("{}", err)))?;
// rust's absurd api returns Err if we set 0. // rust's absurd api returns Err if we set 0.
if request.timeout_read > 0 { if unit.timeout_read > 0 {
stream stream
.set_read_timeout(Some(Duration::from_millis(request.timeout_read as u64))) .set_read_timeout(Some(Duration::from_millis(unit.timeout_read as u64)))
.ok(); .ok();
} }
if request.timeout_write > 0 { if unit.timeout_write > 0 {
stream stream
.set_write_timeout(Some(Duration::from_millis(request.timeout_write as u64))) .set_write_timeout(Some(Duration::from_millis(unit.timeout_write as u64)))
.ok(); .ok();
} }
@@ -116,17 +116,17 @@ fn connect_host(request: &Request, hostname: &str, port: u16) -> Result<TcpStrea
} }
#[cfg(test)] #[cfg(test)]
fn connect_test(request: &Request, url: &Url) -> Result<Stream, Error> { fn connect_test(unit: &Unit) -> Result<Stream, Error> {
use test; use test;
test::resolve_handler(request, url) test::resolve_handler(unit)
} }
#[cfg(not(test))] #[cfg(not(test))]
fn connect_test(_request: &Request, url: &Url) -> Result<Stream, Error> { fn connect_test(unit: &Unit) -> Result<Stream, Error> {
Err(Error::UnknownScheme(url.scheme().to_string())) Err(Error::UnknownScheme(unit.url.scheme().to_string()))
} }
#[cfg(not(feature = "tls"))] #[cfg(not(feature = "tls"))]
fn connect_https(request: &Request, url: &Url) -> Result<Stream, Error> { fn connect_https(unit: &Unit) -> Result<Stream, Error> {
Err(Error::UnknownScheme(url.scheme().to_string())) Err(Error::UnknownScheme(unit.url.scheme().to_string()))
} }

View File

@@ -6,18 +6,18 @@ use super::super::*;
fn agent_reuse_headers() { fn agent_reuse_headers() {
let agent = agent().set("Authorization", "Foo 12345").build(); let agent = agent().set("Authorization", "Foo 12345").build();
test::set_handler("/agent_reuse_headers", |req, _url| { test::set_handler("/agent_reuse_headers", |unit| {
assert!(req.has("Authorization")); assert!(unit.has("Authorization"));
assert_eq!(req.header("Authorization").unwrap(), "Foo 12345"); assert_eq!(unit.header("Authorization").unwrap(), "Foo 12345");
test::make_response(200, "OK", vec!["X-Call: 1"], vec![]) test::make_response(200, "OK", vec!["X-Call: 1"], vec![])
}); });
let resp = agent.get("test://host/agent_reuse_headers").call(); let resp = agent.get("test://host/agent_reuse_headers").call();
assert_eq!(resp.header("X-Call").unwrap(), "1"); assert_eq!(resp.header("X-Call").unwrap(), "1");
test::set_handler("/agent_reuse_headers", |req, _url| { test::set_handler("/agent_reuse_headers", |unit| {
assert!(req.has("Authorization")); assert!(unit.has("Authorization"));
assert_eq!(req.header("Authorization").unwrap(), "Foo 12345"); assert_eq!(unit.header("Authorization").unwrap(), "Foo 12345");
test::make_response(200, "OK", vec!["X-Call: 2"], vec![]) test::make_response(200, "OK", vec!["X-Call: 2"], vec![])
}); });
@@ -29,7 +29,7 @@ fn agent_reuse_headers() {
fn agent_cookies() { fn agent_cookies() {
let agent = agent().build(); let agent = agent().build();
test::set_handler("/agent_cookies", |_req, _url| { test::set_handler("/agent_cookies", |_unit| {
test::make_response( test::make_response(
200, 200,
"OK", "OK",
@@ -43,7 +43,7 @@ fn agent_cookies() {
assert!(agent.cookie("foo").is_some()); assert!(agent.cookie("foo").is_some());
assert_eq!(agent.cookie("foo").unwrap().value(), "bar baz"); assert_eq!(agent.cookie("foo").unwrap().value(), "bar baz");
test::set_handler("/agent_cookies", |_req, _url| { test::set_handler("/agent_cookies", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });

View File

@@ -4,9 +4,9 @@ use super::super::*;
#[test] #[test]
fn basic_auth() { fn basic_auth() {
test::set_handler("/basic_auth", |req, _url| { test::set_handler("/basic_auth", |unit| {
assert_eq!( assert_eq!(
req.header("Authorization").unwrap(), unit.header("Authorization").unwrap(),
"Basic bWFydGluOnJ1YmJlcm1hc2hndW0=" "Basic bWFydGluOnJ1YmJlcm1hc2hndW0="
); );
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
@@ -19,8 +19,8 @@ fn basic_auth() {
#[test] #[test]
fn kind_auth() { fn kind_auth() {
test::set_handler("/kind_auth", |req, _url| { test::set_handler("/kind_auth", |unit| {
assert_eq!(req.header("Authorization").unwrap(), "Digest abcdefgh123"); assert_eq!(unit.header("Authorization").unwrap(), "Digest abcdefgh123");
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = get("test://host/kind_auth") let resp = get("test://host/kind_auth")

View File

@@ -5,7 +5,7 @@ use super::super::*;
#[test] #[test]
fn transfer_encoding_bogus() { fn transfer_encoding_bogus() {
test::set_handler("/transfer_encoding_bogus", |_req, _url| { test::set_handler("/transfer_encoding_bogus", |_unit| {
test::make_response( test::make_response(
200, 200,
"OK", "OK",
@@ -26,7 +26,7 @@ fn transfer_encoding_bogus() {
#[test] #[test]
fn content_length_limited() { fn content_length_limited() {
test::set_handler("/content_length_limited", |_req, _url| { test::set_handler("/content_length_limited", |_unit| {
test::make_response( test::make_response(
200, 200,
"OK", "OK",
@@ -44,7 +44,7 @@ fn content_length_limited() {
#[test] #[test]
// content-length should be ignored when chunked // content-length should be ignored when chunked
fn ignore_content_length_when_chunked() { fn ignore_content_length_when_chunked() {
test::set_handler("/ignore_content_length_when_chunked", |_req, _url| { test::set_handler("/ignore_content_length_when_chunked", |_unit| {
test::make_response( test::make_response(
200, 200,
"OK", "OK",
@@ -63,7 +63,7 @@ fn ignore_content_length_when_chunked() {
#[test] #[test]
fn no_reader_on_head() { fn no_reader_on_head() {
test::set_handler("/no_reader_on_head", |_req, _url| { test::set_handler("/no_reader_on_head", |_unit| {
// so this is technically illegal, we return a body for the HEAD request. // so this is technically illegal, we return a body for the HEAD request.
test::make_response( test::make_response(
200, 200,

View File

@@ -4,7 +4,7 @@ use super::super::*;
#[test] #[test]
fn content_length_on_str() { fn content_length_on_str() {
test::set_handler("/content_length_on_str", |_req, _url| { test::set_handler("/content_length_on_str", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = post("test://host/content_length_on_str").send_string("Hello World!!!"); let resp = post("test://host/content_length_on_str").send_string("Hello World!!!");
@@ -15,7 +15,7 @@ fn content_length_on_str() {
#[test] #[test]
fn user_set_content_length_on_str() { fn user_set_content_length_on_str() {
test::set_handler("/user_set_content_length_on_str", |_req, _url| { test::set_handler("/user_set_content_length_on_str", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = post("test://host/user_set_content_length_on_str") let resp = post("test://host/user_set_content_length_on_str")
@@ -29,7 +29,7 @@ fn user_set_content_length_on_str() {
#[test] #[test]
#[cfg(feature = "json")] #[cfg(feature = "json")]
fn content_length_on_json() { fn content_length_on_json() {
test::set_handler("/content_length_on_json", |_req, _url| { test::set_handler("/content_length_on_json", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let mut json = SerdeMap::new(); let mut json = SerdeMap::new();
@@ -45,7 +45,7 @@ fn content_length_on_json() {
#[test] #[test]
fn content_length_and_chunked() { fn content_length_and_chunked() {
test::set_handler("/content_length_and_chunked", |_req, _url| { test::set_handler("/content_length_and_chunked", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = post("test://host/content_length_and_chunked") let resp = post("test://host/content_length_and_chunked")
@@ -60,7 +60,7 @@ fn content_length_and_chunked() {
#[test] #[test]
#[cfg(feature = "charset")] #[cfg(feature = "charset")]
fn str_with_encoding() { fn str_with_encoding() {
test::set_handler("/str_with_encoding", |_req, _url| { test::set_handler("/str_with_encoding", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = post("test://host/str_with_encoding") let resp = post("test://host/str_with_encoding")

View File

@@ -1,4 +1,4 @@
use agent::Request; use agent::Unit;
use agent::Stream; use agent::Stream;
use error::Error; use error::Error;
use header::Header; use header::Header;
@@ -6,7 +6,6 @@ use std::collections::HashMap;
use std::io::Cursor; use std::io::Cursor;
use std::io::Write; use std::io::Write;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use url::Url;
mod agent_test; mod agent_test;
mod auth; mod auth;
@@ -16,7 +15,7 @@ mod query_string;
mod range; mod range;
mod simple; mod simple;
type RequestHandler = Fn(&Request, &Url) -> Result<Stream, Error> + Send + 'static; type RequestHandler = Fn(&Unit) -> Result<Stream, Error> + Send + 'static;
lazy_static! { lazy_static! {
pub static ref TEST_HANDLERS: Arc<Mutex<HashMap<String, Box<RequestHandler>>>> = pub static ref TEST_HANDLERS: Arc<Mutex<HashMap<String, Box<RequestHandler>>>> =
@@ -25,7 +24,7 @@ lazy_static! {
pub fn set_handler<H>(path: &str, handler: H) pub fn set_handler<H>(path: &str, handler: H)
where where
H: Fn(&Request, &Url) -> Result<Stream, Error> + Send + 'static, H: Fn(&Unit) -> Result<Stream, Error> + Send + 'static,
{ {
let mut handlers = TEST_HANDLERS.lock().unwrap(); let mut handlers = TEST_HANDLERS.lock().unwrap();
handlers.insert(path.to_string(), Box::new(handler)); handlers.insert(path.to_string(), Box::new(handler));
@@ -50,9 +49,9 @@ pub fn make_response(
Ok(Stream::Test(Box::new(cursor), write)) Ok(Stream::Test(Box::new(cursor), write))
} }
pub fn resolve_handler(req: &Request, url: &Url) -> Result<Stream, Error> { pub fn resolve_handler(unit: &Unit) -> Result<Stream, Error> {
let mut handlers = TEST_HANDLERS.lock().unwrap(); let mut handlers = TEST_HANDLERS.lock().unwrap();
let path = url.path(); let path = unit.url.path();
let handler = handlers.remove(path).unwrap(); let handler = handlers.remove(path).unwrap();
handler(req, url) handler(unit)
} }

View File

@@ -4,7 +4,7 @@ use super::super::*;
#[test] #[test]
fn no_query_string() { fn no_query_string() {
test::set_handler("/no_query_string", |_req, _url| { test::set_handler("/no_query_string", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = get("test://host/no_query_string").call(); let resp = get("test://host/no_query_string").call();
@@ -15,7 +15,7 @@ fn no_query_string() {
#[test] #[test]
fn escaped_query_string() { fn escaped_query_string() {
test::set_handler("/escaped_query_string", |_req, _url| { test::set_handler("/escaped_query_string", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = get("test://host/escaped_query_string") let resp = get("test://host/escaped_query_string")
@@ -29,7 +29,7 @@ fn escaped_query_string() {
#[test] #[test]
fn query_in_path() { fn query_in_path() {
test::set_handler("/query_in_path", |_req, _url| { test::set_handler("/query_in_path", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = get("test://host/query_in_path?foo=bar").call(); let resp = get("test://host/query_in_path?foo=bar").call();
@@ -40,7 +40,7 @@ fn query_in_path() {
#[test] #[test]
fn query_in_path_and_req() { fn query_in_path_and_req() {
test::set_handler("/query_in_path_and_req", |_req, _url| { test::set_handler("/query_in_path_and_req", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = get("test://host/query_in_path_and_req?foo=bar") let resp = get("test://host/query_in_path_and_req?foo=bar")

View File

@@ -5,9 +5,9 @@ use super::super::*;
#[test] #[test]
fn header_passing() { fn header_passing() {
test::set_handler("/header_passing", |req, _url| { test::set_handler("/header_passing", |unit| {
assert!(req.has("X-Foo")); assert!(unit.has("X-Foo"));
assert_eq!(req.header("X-Foo").unwrap(), "bar"); assert_eq!(unit.header("X-Foo").unwrap(), "bar");
test::make_response(200, "OK", vec!["X-Bar: foo"], vec![]) test::make_response(200, "OK", vec!["X-Bar: foo"], vec![])
}); });
let resp = get("test://host/header_passing").set("X-Foo", "bar").call(); let resp = get("test://host/header_passing").set("X-Foo", "bar").call();
@@ -18,9 +18,9 @@ fn header_passing() {
#[test] #[test]
fn repeat_non_x_header() { fn repeat_non_x_header() {
test::set_handler("/repeat_non_x_header", |req, _url| { test::set_handler("/repeat_non_x_header", |unit| {
assert!(req.has("Accept")); assert!(unit.has("Accept"));
assert_eq!(req.header("Accept").unwrap(), "baz"); assert_eq!(unit.header("Accept").unwrap(), "baz");
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = get("test://host/repeat_non_x_header") let resp = get("test://host/repeat_non_x_header")
@@ -32,11 +32,11 @@ fn repeat_non_x_header() {
#[test] #[test]
fn repeat_x_header() { fn repeat_x_header() {
test::set_handler("/repeat_x_header", |req, _url| { test::set_handler("/repeat_x_header", |unit| {
assert!(req.has("X-Forwarded-For")); assert!(unit.has("X-Forwarded-For"));
assert_eq!(req.header("X-Forwarded-For").unwrap(), "130.240.19.2"); assert_eq!(unit.header("X-Forwarded-For").unwrap(), "130.240.19.2");
assert_eq!( assert_eq!(
req.all("X-Forwarded-For"), unit.all("X-Forwarded-For"),
vec!["130.240.19.2", "130.240.19.3"] vec!["130.240.19.2", "130.240.19.3"]
); );
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
@@ -50,7 +50,7 @@ fn repeat_x_header() {
#[test] #[test]
fn body_as_text() { fn body_as_text() {
test::set_handler("/body_as_text", |_req, _url| { test::set_handler("/body_as_text", |_unit| {
test::make_response(200, "OK", vec![], "Hello World!".to_string().into_bytes()) test::make_response(200, "OK", vec![], "Hello World!".to_string().into_bytes())
}); });
let resp = get("test://host/body_as_text").call(); let resp = get("test://host/body_as_text").call();
@@ -61,7 +61,7 @@ fn body_as_text() {
#[test] #[test]
#[cfg(feature = "json")] #[cfg(feature = "json")]
fn body_as_json() { fn body_as_json() {
test::set_handler("/body_as_json", |_req, _url| { test::set_handler("/body_as_json", |_unit| {
test::make_response( test::make_response(
200, 200,
"OK", "OK",
@@ -76,7 +76,7 @@ fn body_as_json() {
#[test] #[test]
fn body_as_reader() { fn body_as_reader() {
test::set_handler("/body_as_reader", |_req, _url| { test::set_handler("/body_as_reader", |_unit| {
test::make_response(200, "OK", vec![], "abcdefgh".to_string().into_bytes()) test::make_response(200, "OK", vec![], "abcdefgh".to_string().into_bytes())
}); });
let resp = get("test://host/body_as_reader").call(); let resp = get("test://host/body_as_reader").call();
@@ -88,7 +88,7 @@ fn body_as_reader() {
#[test] #[test]
fn escape_path() { fn escape_path() {
test::set_handler("/escape_path%20here", |_req, _url| { test::set_handler("/escape_path%20here", |_unit| {
test::make_response(200, "OK", vec![], vec![]) test::make_response(200, "OK", vec![], vec![])
}); });
let resp = get("test://host/escape_path here").call(); let resp = get("test://host/escape_path here").call();

230
src/unit.rs Normal file
View File

@@ -0,0 +1,230 @@
//
pub struct Unit {
pub agent: Arc<Mutex<Option<AgentState>>>,
pub url: Url,
pub is_chunked: bool,
pub is_head: bool,
pub hostname: String,
pub query_string: String,
pub headers: Vec<Header>,
pub timeout_connect: u64,
pub timeout_read: u64,
pub timeout_write: u64,
}
impl Unit {
//
fn new(req: &Request, url: &Url, body: &SizedReader) -> Self {
//
let is_chunked = req.header("transfer-encoding")
// if the user has set an encoding header, obey that.
.map(|enc| enc.len() > 0)
// otherwise, no chunking.
.unwrap_or(false);
let is_secure = url.scheme().eq_ignore_ascii_case("https");
let is_head = req.method.eq_ignore_ascii_case("head");
let hostname = url.host_str().unwrap_or("localhost").to_string();
let query_string = combine_query(&url, &req.query);
let cookie_headers: Vec<_> = {
let mut state = req.agent.lock().unwrap();
match state.as_ref().map(|state| &state.jar) {
None => vec![],
Some(jar) => match_cookies(jar, &hostname, url.path(), is_secure),
}
};
let extra_headers = {
let mut extra = vec![];
// chunking and Content-Length headers are mutually exclusive
// also don't write this if the user has set it themselves
if !is_chunked && !req.has("content-length") {
if let Some(size) = body.size {
extra.push(
format!("Content-Length: {}\r\n", size)
.parse::<Header>()
.unwrap(),
);
}
}
extra
};
let headers: Vec<_> = req
.headers
.iter()
.chain(cookie_headers.iter())
.chain(extra_headers.iter())
.cloned()
.collect();
Unit {
agent: Arc::clone(&req.agent),
url: url.clone(),
is_chunked,
is_head,
hostname,
query_string,
headers,
timeout_connect: req.timeout_connect,
timeout_read: req.timeout_read,
timeout_write: req.timeout_write,
}
}
fn connect(
&mut self,
url: Url,
method: &str,
redirects: u32,
body: SizedReader,
) -> Result<Response, Error> {
//
// open socket
let mut stream = match url.scheme() {
"http" => connect_http(self),
"https" => connect_https(self),
"test" => connect_test(self),
_ => Err(Error::UnknownScheme(url.scheme().to_string())),
}?;
// send the request start + headers
let mut prelude: Vec<u8> = vec![];
write!(
prelude,
"{} {}{} HTTP/1.1\r\n",
method,
url.path(),
self.query_string
)?;
if !has_header(&self.headers, "host") {
write!(prelude, "Host: {}\r\n", url.host().unwrap())?;
}
for header in &self.headers {
write!(prelude, "{}: {}\r\n", header.name(), header.value())?;
}
write!(prelude, "\r\n")?;
stream.write_all(&mut prelude[..])?;
// start reading the response to process cookies and redirects.
let mut resp = Response::from_read(&mut stream);
// squirrel away cookies
{
let mut state = self.agent.lock().unwrap();
if let Some(add_jar) = state.as_mut().map(|state| &mut state.jar) {
for raw_cookie in resp.all("set-cookie").iter() {
let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
raw_cookie.to_string()
} else {
format!("{}; Domain={}", raw_cookie, self.hostname)
};
match Cookie::parse_encoded(&to_parse[..]) {
Err(_) => (), // ignore unparseable cookies
Ok(mut cookie) => {
let cookie = cookie.into_owned();
add_jar.add(cookie)
}
}
}
}
}
// handle redirects
if resp.redirect() {
if redirects == 0 {
return Err(Error::TooManyRedirects);
}
// the location header
let location = resp.header("location");
if let Some(location) = location {
// join location header to current url in case it it relative
let new_url = url
.join(location)
.map_err(|_| Error::BadUrl(format!("Bad redirection: {}", location)))?;
// perform the redirect differently depending on 3xx code.
return match resp.status {
301 | 302 | 303 => {
send_body(body, self.is_chunked, &mut stream)?;
let empty = Payload::Empty.into_read();
self.connect(new_url, "GET", redirects - 1, empty)
}
307 | 308 | _ => self.connect(new_url, method, redirects - 1, body),
};
}
}
// send the body (which can be empty now depending on redirects)
send_body(body, self.is_chunked, &mut stream)?;
// since it is not a redirect, give away the incoming stream to the response object
resp.set_stream(stream, self.is_head);
// release the response
Ok(resp)
}
#[cfg(test)]
pub fn header<'a>(&self, name: &'a str) -> Option<&str> {
get_header(&self.headers, name)
}
#[cfg(test)]
pub fn has<'a>(&self, name: &'a str) -> bool {
has_header(&self.headers, name)
}
#[cfg(test)]
pub fn all<'a>(&self, name: &'a str) -> Vec<&str> {
get_all_headers(&self.headers, name)
}
}
// TODO check so cookies can't be set for tld:s
fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bool) -> Vec<Header> {
jar.iter()
.filter(|c| {
// if there is a domain, it must be matched. if there is no domain, then ignore cookie
let domain_ok = c
.domain()
.map(|cdom| domain.contains(cdom))
.unwrap_or(false);
// a path must match the beginning of request path. no cookie path, we say is ok. is it?!
let path_ok = c
.path()
.map(|cpath| path.find(cpath).map(|pos| pos == 0).unwrap_or(false))
.unwrap_or(true);
// either the cookie isnt secure, or we're not doing a secure request.
let secure_ok = !c.secure() || is_secure;
domain_ok && path_ok && secure_ok
})
.map(|c| {
let name = c.name().to_string();
let value = c.value().to_string();
let nameval = Cookie::new(name, value).encoded().to_string();
let head = format!("Cookie: {}", nameval);
head.parse::<Header>().ok()
})
.filter(|o| o.is_some())
.map(|o| o.unwrap())
.collect()
}
fn combine_query(url: &Url, query: &QString) -> String {
match (url.query(), query.len() > 0) {
(Some(urlq), true) => format!("?{}&{}", urlq, query),
(Some(urlq), false) => format!("?{}", urlq),
(None, true) => format!("?{}", query),
(None, false) => "".to_string(),
}
}