refactor into unit
This commit is contained in:
11
src/agent.rs
11
src/agent.rs
@@ -2,13 +2,14 @@ use cookie::{Cookie, CookieJar};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use header::{add_header, Header};
|
||||
use header::{add_header, get_header, get_all_headers, has_header, Header};
|
||||
|
||||
// to get to share private fields
|
||||
include!("request.rs");
|
||||
include!("response.rs");
|
||||
include!("conn.rs");
|
||||
include!("stream.rs");
|
||||
include!("unit.rs");
|
||||
|
||||
/// Agents keep state between requests.
|
||||
///
|
||||
@@ -45,7 +46,7 @@ pub struct Agent {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AgentState {
|
||||
pub struct AgentState {
|
||||
pool: ConnectionPool,
|
||||
jar: CookieJar,
|
||||
}
|
||||
@@ -109,7 +110,7 @@ impl Agent {
|
||||
{
|
||||
let s = format!("{}: {}", header.into(), value.into());
|
||||
let header = s.parse::<Header>().expect("Failed to parse header");
|
||||
add_header(header, &mut self.headers);
|
||||
add_header(&mut self.headers, header);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -145,7 +146,7 @@ impl Agent {
|
||||
for (k, v) in headers.into_iter() {
|
||||
let s = format!("{}: {}", k.into(), v.into());
|
||||
let header = s.parse::<Header>().expect("Failed to parse header");
|
||||
add_header(header, &mut self.headers);
|
||||
add_header(&mut self.headers, header);
|
||||
}
|
||||
self
|
||||
}
|
||||
@@ -192,7 +193,7 @@ impl Agent {
|
||||
{
|
||||
let s = format!("Authorization: {} {}", kind.into(), pass.into());
|
||||
let header = s.parse::<Header>().expect("Failed to parse header");
|
||||
add_header(header, &mut self.headers);
|
||||
add_header(&mut self.headers, header);
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
168
src/conn.rs
168
src/conn.rs
@@ -7,137 +7,9 @@ const CHUNK_SIZE: usize = 1024 * 1024;
|
||||
pub struct ConnectionPool {}
|
||||
|
||||
impl ConnectionPool {
|
||||
fn new() -> Self {
|
||||
pub fn new() -> Self {
|
||||
ConnectionPool {}
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&mut self,
|
||||
request: &Request,
|
||||
method: &str,
|
||||
url: &Url,
|
||||
redirects: u32,
|
||||
mut jar: Option<&mut CookieJar>,
|
||||
body: SizedReader,
|
||||
) -> Result<Response, Error> {
|
||||
//
|
||||
|
||||
let do_chunk = request.header("transfer-encoding")
|
||||
// if the user has set an encoding header, obey that.
|
||||
.map(|enc| enc.len() > 0)
|
||||
// otherwise, no chunking.
|
||||
.unwrap_or(false);
|
||||
|
||||
let hostname = url.host_str().unwrap_or("localhost"); // is localhost a good alternative?
|
||||
|
||||
let query_string = combine_query(&url, &request.query);
|
||||
|
||||
let is_secure = url.scheme().eq_ignore_ascii_case("https");
|
||||
|
||||
let is_head = request.method.eq_ignore_ascii_case("head");
|
||||
|
||||
let cookie_headers: Vec<_> = {
|
||||
match jar.as_ref() {
|
||||
None => vec![],
|
||||
Some(jar) => match_cookies(jar, hostname, url.path(), is_secure),
|
||||
}
|
||||
};
|
||||
let extra_headers = {
|
||||
let mut extra = vec![];
|
||||
|
||||
// chunking and Content-Length headers are mutually exclusive
|
||||
// also don't write this if the user has set it themselves
|
||||
if !do_chunk && !request.has("content-length") {
|
||||
if let Some(size) = body.size {
|
||||
extra.push(format!("Content-Length: {}\r\n", size).parse::<Header>()?);
|
||||
}
|
||||
}
|
||||
extra
|
||||
};
|
||||
let headers = request
|
||||
.headers
|
||||
.iter()
|
||||
.chain(cookie_headers.iter())
|
||||
.chain(extra_headers.iter());
|
||||
|
||||
// open socket
|
||||
let mut stream = match url.scheme() {
|
||||
"http" => connect_http(request, &url),
|
||||
"https" => connect_https(request, &url),
|
||||
"test" => connect_test(request, &url),
|
||||
_ => Err(Error::UnknownScheme(url.scheme().to_string())),
|
||||
}?;
|
||||
|
||||
// send the request start + headers
|
||||
let mut prelude: Vec<u8> = vec![];
|
||||
write!(prelude, "{} {}{} HTTP/1.1\r\n", method, url.path(), query_string)?;
|
||||
if !request.has("host") {
|
||||
write!(prelude, "Host: {}\r\n", url.host().unwrap())?;
|
||||
}
|
||||
for header in headers {
|
||||
write!(prelude, "{}: {}\r\n", header.name(), header.value())?;
|
||||
}
|
||||
write!(prelude, "\r\n")?;
|
||||
|
||||
stream.write_all(&mut prelude[..])?;
|
||||
|
||||
// start reading the response to process cookies and redirects.
|
||||
let mut resp = Response::from_read(&mut stream);
|
||||
|
||||
// squirrel away cookies
|
||||
if let Some(add_jar) = jar.as_mut() {
|
||||
for raw_cookie in resp.all("set-cookie").iter() {
|
||||
let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
|
||||
raw_cookie.to_string()
|
||||
} else {
|
||||
format!("{}; Domain={}", raw_cookie, hostname)
|
||||
};
|
||||
match Cookie::parse_encoded(&to_parse[..]) {
|
||||
Err(_) => (), // ignore unparseable cookies
|
||||
Ok(mut cookie) => {
|
||||
let cookie = cookie.into_owned();
|
||||
add_jar.add(cookie)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle redirects
|
||||
if resp.redirect() {
|
||||
if redirects == 0 {
|
||||
return Err(Error::TooManyRedirects);
|
||||
}
|
||||
|
||||
// the location header
|
||||
let location = resp.header("location");
|
||||
if let Some(location) = location {
|
||||
// join location header to current url in case it it relative
|
||||
let new_url = url.join(location)
|
||||
.map_err(|_| Error::BadUrl(format!("Bad redirection: {}", location)))?;
|
||||
|
||||
// perform the redirect differently depending on 3xx code.
|
||||
return match resp.status {
|
||||
301 | 302 | 303 => {
|
||||
send_body(body, do_chunk, &mut stream)?;
|
||||
let empty = Payload::Empty.into_read();
|
||||
self.connect(request, "GET", &new_url, redirects - 1, jar, empty)
|
||||
}
|
||||
307 | 308 | _ => {
|
||||
self.connect(request, method, &new_url, redirects - 1, jar, body)
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// send the body (which can be empty now depending on redirects)
|
||||
send_body(body, do_chunk, &mut stream)?;
|
||||
|
||||
// since it is not a redirect, give away the incoming stream to the response object
|
||||
resp.set_stream(stream, is_head);
|
||||
|
||||
// release the response
|
||||
Ok(resp)
|
||||
}
|
||||
}
|
||||
|
||||
fn send_body(body: SizedReader, do_chunk: bool, stream: &mut Stream) -> IoResult<()> {
|
||||
@@ -165,41 +37,3 @@ where
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO check so cookies can't be set for tld:s
|
||||
fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bool) -> Vec<Header> {
|
||||
jar.iter()
|
||||
.filter(|c| {
|
||||
// if there is a domain, it must be matched. if there is no domain, then ignore cookie
|
||||
let domain_ok = c.domain()
|
||||
.map(|cdom| domain.contains(cdom))
|
||||
.unwrap_or(false);
|
||||
// a path must match the beginning of request path. no cookie path, we say is ok. is it?!
|
||||
let path_ok = c.path()
|
||||
.map(|cpath| path.find(cpath).map(|pos| pos == 0).unwrap_or(false))
|
||||
.unwrap_or(true);
|
||||
// either the cookie isnt secure, or we're not doing a secure request.
|
||||
let secure_ok = !c.secure() || is_secure;
|
||||
|
||||
domain_ok && path_ok && secure_ok
|
||||
})
|
||||
.map(|c| {
|
||||
let name = c.name().to_string();
|
||||
let value = c.value().to_string();
|
||||
let nameval = Cookie::new(name, value).encoded().to_string();
|
||||
let head = format!("Cookie: {}", nameval);
|
||||
head.parse::<Header>().ok()
|
||||
})
|
||||
.filter(|o| o.is_some())
|
||||
.map(|o| o.unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn combine_query(url: &Url, query: &QString) -> String {
|
||||
match (url.query(), query.len() > 0) {
|
||||
(Some(urlq), true) => format!("?{}&{}", urlq, query),
|
||||
(Some(urlq), false) => format!("?{}", urlq),
|
||||
(None, true) => format!("?{}", query),
|
||||
(None, false) => "".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,6 +53,30 @@ impl Header {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_header<'a, 'b>(headers: &'b Vec<Header>, name: &'a str) -> Option<&'b str> {
|
||||
headers.iter().find(|h| h.is_name(name)).map(|h| h.value())
|
||||
}
|
||||
|
||||
pub fn get_all_headers<'a, 'b>(headers: &'b Vec<Header>, name: &'a str) -> Vec<&'b str> {
|
||||
headers
|
||||
.iter()
|
||||
.filter(|h| h.is_name(name))
|
||||
.map(|h| h.value())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn has_header(headers: &Vec<Header>, name: &str) -> bool {
|
||||
get_header(headers, name).is_some()
|
||||
}
|
||||
|
||||
pub fn add_header(headers: &mut Vec<Header>, header: Header) {
|
||||
if !header.name().to_lowercase().starts_with("x-") {
|
||||
let name = header.name();
|
||||
headers.retain(|h| h.name() != name);
|
||||
}
|
||||
headers.push(header);
|
||||
}
|
||||
|
||||
impl FromStr for Header {
|
||||
type Err = Error;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
@@ -68,11 +92,3 @@ impl FromStr for Header {
|
||||
Ok(Header { line, index })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_header(header: Header, headers: &mut Vec<Header>) {
|
||||
if !header.name().to_lowercase().starts_with("x-") {
|
||||
let name = header.name();
|
||||
headers.retain(|h| h.name() != name);
|
||||
}
|
||||
headers.push(header);
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ lazy_static! {
|
||||
/// ```
|
||||
#[derive(Clone, Default)]
|
||||
pub struct Request {
|
||||
state: Arc<Mutex<Option<AgentState>>>,
|
||||
agent: Arc<Mutex<Option<AgentState>>>,
|
||||
|
||||
// via agent
|
||||
method: String,
|
||||
@@ -32,9 +32,9 @@ pub struct Request {
|
||||
// from request itself
|
||||
headers: Vec<Header>,
|
||||
query: QString,
|
||||
timeout: u32,
|
||||
timeout_read: u32,
|
||||
timeout_write: u32,
|
||||
timeout_connect: u64,
|
||||
timeout_read: u64,
|
||||
timeout_write: u64,
|
||||
redirects: u32,
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ impl Default for Payload {
|
||||
}
|
||||
}
|
||||
|
||||
struct SizedReader {
|
||||
pub struct SizedReader {
|
||||
size: Option<usize>,
|
||||
reader: Box<Read + 'static>,
|
||||
}
|
||||
@@ -108,7 +108,7 @@ impl Payload {
|
||||
impl Request {
|
||||
fn new(agent: &Agent, method: String, path: String) -> Request {
|
||||
Request {
|
||||
state: Arc::clone(&agent.state),
|
||||
agent: Arc::clone(&agent.state),
|
||||
method,
|
||||
path,
|
||||
headers: agent.headers.clone(),
|
||||
@@ -132,11 +132,11 @@ impl Request {
|
||||
|
||||
/// Executes the request and blocks the caller until done.
|
||||
///
|
||||
/// Use `.timeout()` and `.timeout_read()` to avoid blocking forever.
|
||||
/// Use `.timeout_connect()` and `.timeout_read()` to avoid blocking forever.
|
||||
///
|
||||
/// ```
|
||||
/// let r = ureq::get("/my_page")
|
||||
/// .timeout(10_000) // max 10 seconds
|
||||
/// .timeout_connect(10_000) // max 10 seconds
|
||||
/// .call();
|
||||
///
|
||||
/// println!("{:?}", r);
|
||||
@@ -146,32 +146,11 @@ impl Request {
|
||||
}
|
||||
|
||||
fn do_call(&mut self, payload: Payload) -> Response {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
self.to_url()
|
||||
.and_then(|url| {
|
||||
match state.as_mut() {
|
||||
None =>
|
||||
// create a one off pool/jar.
|
||||
ConnectionPool::new().connect(
|
||||
self,
|
||||
&self.method,
|
||||
&url,
|
||||
self.redirects,
|
||||
None,
|
||||
payload.into_read(),
|
||||
),
|
||||
Some(state) => {
|
||||
let jar = &mut state.jar;
|
||||
state.pool.connect(
|
||||
self,
|
||||
&self.method,
|
||||
&url,
|
||||
self.redirects,
|
||||
Some(jar),
|
||||
payload.into_read(),
|
||||
)
|
||||
},
|
||||
}
|
||||
let reader = payload.into_read();
|
||||
let mut unit = Unit::new(&self, &url, &reader);
|
||||
unit.connect(url, &self.method, self.redirects, reader)
|
||||
})
|
||||
.unwrap_or_else(|e| e.into())
|
||||
}
|
||||
@@ -269,7 +248,7 @@ impl Request {
|
||||
{
|
||||
let s = format!("{}: {}", header.into(), value.into());
|
||||
let header = s.parse::<Header>().expect("Failed to parse header");
|
||||
add_header(header, &mut self.headers);
|
||||
add_header(&mut self.headers, header);
|
||||
self
|
||||
}
|
||||
|
||||
@@ -282,10 +261,7 @@ impl Request {
|
||||
/// assert_eq!("foobar", req.header("x-api-Key").unwrap());
|
||||
/// ```
|
||||
pub fn header<'a>(&self, name: &'a str) -> Option<&str> {
|
||||
self.headers
|
||||
.iter()
|
||||
.find(|h| h.is_name(name))
|
||||
.map(|h| h.value())
|
||||
get_header(&self.headers, name)
|
||||
}
|
||||
|
||||
/// Tells if the header has been set.
|
||||
@@ -297,7 +273,7 @@ impl Request {
|
||||
/// assert_eq!(true, req.has("x-api-Key"));
|
||||
/// ```
|
||||
pub fn has<'a>(&self, name: &'a str) -> bool {
|
||||
self.header(name).is_some()
|
||||
has_header(&self.headers, name)
|
||||
}
|
||||
|
||||
/// All headers corresponding values for the give name, or empty vector.
|
||||
@@ -313,11 +289,7 @@ impl Request {
|
||||
/// ]);
|
||||
/// ```
|
||||
pub fn all<'a>(&self, name: &'a str) -> Vec<&str> {
|
||||
self.headers
|
||||
.iter()
|
||||
.filter(|h| h.is_name(name))
|
||||
.map(|h| h.value())
|
||||
.collect()
|
||||
get_all_headers(&self.headers, name)
|
||||
}
|
||||
|
||||
/// Set many headers.
|
||||
@@ -348,7 +320,7 @@ impl Request {
|
||||
for (k, v) in headers.into_iter() {
|
||||
let s = format!("{}: {}", k.into(), v.into());
|
||||
let header = s.parse::<Header>().expect("Failed to parse header");
|
||||
add_header(header, &mut self.headers);
|
||||
add_header(&mut self.headers, header);
|
||||
}
|
||||
self
|
||||
}
|
||||
@@ -430,12 +402,12 @@ impl Request {
|
||||
///
|
||||
/// ```
|
||||
/// let r = ureq::get("/my_page")
|
||||
/// .timeout(1_000) // wait max 1 second to connect
|
||||
/// .timeout_connect(1_000) // wait max 1 second to connect
|
||||
/// .call();
|
||||
/// println!("{:?}", r);
|
||||
/// ```
|
||||
pub fn timeout(&mut self, millis: u32) -> &mut Request {
|
||||
self.timeout = millis;
|
||||
pub fn timeout_connect(&mut self, millis: u64) -> &mut Request {
|
||||
self.timeout_connect = millis;
|
||||
self
|
||||
}
|
||||
|
||||
@@ -449,7 +421,7 @@ impl Request {
|
||||
/// .call();
|
||||
/// println!("{:?}", r);
|
||||
/// ```
|
||||
pub fn timeout_read(&mut self, millis: u32) -> &mut Request {
|
||||
pub fn timeout_read(&mut self, millis: u64) -> &mut Request {
|
||||
self.timeout_read = millis;
|
||||
self
|
||||
}
|
||||
@@ -464,7 +436,7 @@ impl Request {
|
||||
/// .call();
|
||||
/// println!("{:?}", r);
|
||||
/// ```
|
||||
pub fn timeout_write(&mut self, millis: u32) -> &mut Request {
|
||||
pub fn timeout_write(&mut self, millis: u64) -> &mut Request {
|
||||
self.timeout_write = millis;
|
||||
self
|
||||
}
|
||||
@@ -508,7 +480,7 @@ impl Request {
|
||||
{
|
||||
let s = format!("Authorization: {} {}", kind.into(), pass.into());
|
||||
let header = s.parse::<Header>().expect("Failed to parse header");
|
||||
add_header(header, &mut self.headers);
|
||||
add_header(&mut self.headers, header);
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
@@ -60,28 +60,28 @@ impl Write for Stream {
|
||||
}
|
||||
}
|
||||
|
||||
fn connect_http(request: &Request, url: &Url) -> Result<Stream, Error> {
|
||||
fn connect_http(unit: &Unit) -> Result<Stream, Error> {
|
||||
//
|
||||
let hostname = url.host_str().unwrap();
|
||||
let port = url.port().unwrap_or(80);
|
||||
let hostname = unit.url.host_str().unwrap();
|
||||
let port = unit.url.port().unwrap_or(80);
|
||||
|
||||
connect_host(request, hostname, port).map(|tcp| Stream::Http(tcp))
|
||||
connect_host(unit, hostname, port).map(|tcp| Stream::Http(tcp))
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
fn connect_https(request: &Request, url: &Url) -> Result<Stream, Error> {
|
||||
fn connect_https(unit: &Unit) -> Result<Stream, Error> {
|
||||
//
|
||||
let hostname = url.host_str().unwrap();
|
||||
let port = url.port().unwrap_or(443);
|
||||
let hostname = unit.url.host_str().unwrap();
|
||||
let port = unit.url.port().unwrap_or(443);
|
||||
|
||||
let socket = connect_host(request, hostname, port)?;
|
||||
let socket = connect_host(unit, hostname, port)?;
|
||||
let connector = TlsConnector::builder().build()?;
|
||||
let stream = connector.connect(hostname, socket)?;
|
||||
|
||||
Ok(Stream::Https(stream))
|
||||
}
|
||||
|
||||
fn connect_host(request: &Request, hostname: &str, port: u16) -> Result<TcpStream, Error> {
|
||||
fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result<TcpStream, Error> {
|
||||
//
|
||||
let ips: Vec<SocketAddr> = format!("{}:{}", hostname, port).to_socket_addrs()
|
||||
.map_err(|e| Error::DnsFailed(format!("{}", e)))?
|
||||
@@ -95,20 +95,20 @@ fn connect_host(request: &Request, hostname: &str, port: u16) -> Result<TcpStrea
|
||||
let sock_addr = ips[0];
|
||||
|
||||
// connect with a configured timeout.
|
||||
let stream = match request.timeout {
|
||||
let stream = match unit.timeout_connect {
|
||||
0 => TcpStream::connect(&sock_addr),
|
||||
_ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(request.timeout as u64)),
|
||||
_ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(unit.timeout_connect as u64)),
|
||||
}.map_err(|err| Error::ConnectionFailed(format!("{}", err)))?;
|
||||
|
||||
// rust's absurd api returns Err if we set 0.
|
||||
if request.timeout_read > 0 {
|
||||
if unit.timeout_read > 0 {
|
||||
stream
|
||||
.set_read_timeout(Some(Duration::from_millis(request.timeout_read as u64)))
|
||||
.set_read_timeout(Some(Duration::from_millis(unit.timeout_read as u64)))
|
||||
.ok();
|
||||
}
|
||||
if request.timeout_write > 0 {
|
||||
if unit.timeout_write > 0 {
|
||||
stream
|
||||
.set_write_timeout(Some(Duration::from_millis(request.timeout_write as u64)))
|
||||
.set_write_timeout(Some(Duration::from_millis(unit.timeout_write as u64)))
|
||||
.ok();
|
||||
}
|
||||
|
||||
@@ -116,17 +116,17 @@ fn connect_host(request: &Request, hostname: &str, port: u16) -> Result<TcpStrea
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn connect_test(request: &Request, url: &Url) -> Result<Stream, Error> {
|
||||
fn connect_test(unit: &Unit) -> Result<Stream, Error> {
|
||||
use test;
|
||||
test::resolve_handler(request, url)
|
||||
test::resolve_handler(unit)
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
fn connect_test(_request: &Request, url: &Url) -> Result<Stream, Error> {
|
||||
Err(Error::UnknownScheme(url.scheme().to_string()))
|
||||
fn connect_test(unit: &Unit) -> Result<Stream, Error> {
|
||||
Err(Error::UnknownScheme(unit.url.scheme().to_string()))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tls"))]
|
||||
fn connect_https(request: &Request, url: &Url) -> Result<Stream, Error> {
|
||||
Err(Error::UnknownScheme(url.scheme().to_string()))
|
||||
fn connect_https(unit: &Unit) -> Result<Stream, Error> {
|
||||
Err(Error::UnknownScheme(unit.url.scheme().to_string()))
|
||||
}
|
||||
|
||||
@@ -6,18 +6,18 @@ use super::super::*;
|
||||
fn agent_reuse_headers() {
|
||||
let agent = agent().set("Authorization", "Foo 12345").build();
|
||||
|
||||
test::set_handler("/agent_reuse_headers", |req, _url| {
|
||||
assert!(req.has("Authorization"));
|
||||
assert_eq!(req.header("Authorization").unwrap(), "Foo 12345");
|
||||
test::set_handler("/agent_reuse_headers", |unit| {
|
||||
assert!(unit.has("Authorization"));
|
||||
assert_eq!(unit.header("Authorization").unwrap(), "Foo 12345");
|
||||
test::make_response(200, "OK", vec!["X-Call: 1"], vec![])
|
||||
});
|
||||
|
||||
let resp = agent.get("test://host/agent_reuse_headers").call();
|
||||
assert_eq!(resp.header("X-Call").unwrap(), "1");
|
||||
|
||||
test::set_handler("/agent_reuse_headers", |req, _url| {
|
||||
assert!(req.has("Authorization"));
|
||||
assert_eq!(req.header("Authorization").unwrap(), "Foo 12345");
|
||||
test::set_handler("/agent_reuse_headers", |unit| {
|
||||
assert!(unit.has("Authorization"));
|
||||
assert_eq!(unit.header("Authorization").unwrap(), "Foo 12345");
|
||||
test::make_response(200, "OK", vec!["X-Call: 2"], vec![])
|
||||
});
|
||||
|
||||
@@ -29,7 +29,7 @@ fn agent_reuse_headers() {
|
||||
fn agent_cookies() {
|
||||
let agent = agent().build();
|
||||
|
||||
test::set_handler("/agent_cookies", |_req, _url| {
|
||||
test::set_handler("/agent_cookies", |_unit| {
|
||||
test::make_response(
|
||||
200,
|
||||
"OK",
|
||||
@@ -43,7 +43,7 @@ fn agent_cookies() {
|
||||
assert!(agent.cookie("foo").is_some());
|
||||
assert_eq!(agent.cookie("foo").unwrap().value(), "bar baz");
|
||||
|
||||
test::set_handler("/agent_cookies", |_req, _url| {
|
||||
test::set_handler("/agent_cookies", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
|
||||
|
||||
@@ -4,9 +4,9 @@ use super::super::*;
|
||||
|
||||
#[test]
|
||||
fn basic_auth() {
|
||||
test::set_handler("/basic_auth", |req, _url| {
|
||||
test::set_handler("/basic_auth", |unit| {
|
||||
assert_eq!(
|
||||
req.header("Authorization").unwrap(),
|
||||
unit.header("Authorization").unwrap(),
|
||||
"Basic bWFydGluOnJ1YmJlcm1hc2hndW0="
|
||||
);
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
@@ -19,8 +19,8 @@ fn basic_auth() {
|
||||
|
||||
#[test]
|
||||
fn kind_auth() {
|
||||
test::set_handler("/kind_auth", |req, _url| {
|
||||
assert_eq!(req.header("Authorization").unwrap(), "Digest abcdefgh123");
|
||||
test::set_handler("/kind_auth", |unit| {
|
||||
assert_eq!(unit.header("Authorization").unwrap(), "Digest abcdefgh123");
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = get("test://host/kind_auth")
|
||||
|
||||
@@ -5,7 +5,7 @@ use super::super::*;
|
||||
|
||||
#[test]
|
||||
fn transfer_encoding_bogus() {
|
||||
test::set_handler("/transfer_encoding_bogus", |_req, _url| {
|
||||
test::set_handler("/transfer_encoding_bogus", |_unit| {
|
||||
test::make_response(
|
||||
200,
|
||||
"OK",
|
||||
@@ -26,7 +26,7 @@ fn transfer_encoding_bogus() {
|
||||
|
||||
#[test]
|
||||
fn content_length_limited() {
|
||||
test::set_handler("/content_length_limited", |_req, _url| {
|
||||
test::set_handler("/content_length_limited", |_unit| {
|
||||
test::make_response(
|
||||
200,
|
||||
"OK",
|
||||
@@ -44,7 +44,7 @@ fn content_length_limited() {
|
||||
#[test]
|
||||
// content-length should be ignored when chunked
|
||||
fn ignore_content_length_when_chunked() {
|
||||
test::set_handler("/ignore_content_length_when_chunked", |_req, _url| {
|
||||
test::set_handler("/ignore_content_length_when_chunked", |_unit| {
|
||||
test::make_response(
|
||||
200,
|
||||
"OK",
|
||||
@@ -63,7 +63,7 @@ fn ignore_content_length_when_chunked() {
|
||||
|
||||
#[test]
|
||||
fn no_reader_on_head() {
|
||||
test::set_handler("/no_reader_on_head", |_req, _url| {
|
||||
test::set_handler("/no_reader_on_head", |_unit| {
|
||||
// so this is technically illegal, we return a body for the HEAD request.
|
||||
test::make_response(
|
||||
200,
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::super::*;
|
||||
|
||||
#[test]
|
||||
fn content_length_on_str() {
|
||||
test::set_handler("/content_length_on_str", |_req, _url| {
|
||||
test::set_handler("/content_length_on_str", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = post("test://host/content_length_on_str").send_string("Hello World!!!");
|
||||
@@ -15,7 +15,7 @@ fn content_length_on_str() {
|
||||
|
||||
#[test]
|
||||
fn user_set_content_length_on_str() {
|
||||
test::set_handler("/user_set_content_length_on_str", |_req, _url| {
|
||||
test::set_handler("/user_set_content_length_on_str", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = post("test://host/user_set_content_length_on_str")
|
||||
@@ -29,7 +29,7 @@ fn user_set_content_length_on_str() {
|
||||
#[test]
|
||||
#[cfg(feature = "json")]
|
||||
fn content_length_on_json() {
|
||||
test::set_handler("/content_length_on_json", |_req, _url| {
|
||||
test::set_handler("/content_length_on_json", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let mut json = SerdeMap::new();
|
||||
@@ -45,7 +45,7 @@ fn content_length_on_json() {
|
||||
|
||||
#[test]
|
||||
fn content_length_and_chunked() {
|
||||
test::set_handler("/content_length_and_chunked", |_req, _url| {
|
||||
test::set_handler("/content_length_and_chunked", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = post("test://host/content_length_and_chunked")
|
||||
@@ -60,7 +60,7 @@ fn content_length_and_chunked() {
|
||||
#[test]
|
||||
#[cfg(feature = "charset")]
|
||||
fn str_with_encoding() {
|
||||
test::set_handler("/str_with_encoding", |_req, _url| {
|
||||
test::set_handler("/str_with_encoding", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = post("test://host/str_with_encoding")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use agent::Request;
|
||||
use agent::Unit;
|
||||
use agent::Stream;
|
||||
use error::Error;
|
||||
use header::Header;
|
||||
@@ -6,7 +6,6 @@ use std::collections::HashMap;
|
||||
use std::io::Cursor;
|
||||
use std::io::Write;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use url::Url;
|
||||
|
||||
mod agent_test;
|
||||
mod auth;
|
||||
@@ -16,7 +15,7 @@ mod query_string;
|
||||
mod range;
|
||||
mod simple;
|
||||
|
||||
type RequestHandler = Fn(&Request, &Url) -> Result<Stream, Error> + Send + 'static;
|
||||
type RequestHandler = Fn(&Unit) -> Result<Stream, Error> + Send + 'static;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref TEST_HANDLERS: Arc<Mutex<HashMap<String, Box<RequestHandler>>>> =
|
||||
@@ -25,7 +24,7 @@ lazy_static! {
|
||||
|
||||
pub fn set_handler<H>(path: &str, handler: H)
|
||||
where
|
||||
H: Fn(&Request, &Url) -> Result<Stream, Error> + Send + 'static,
|
||||
H: Fn(&Unit) -> Result<Stream, Error> + Send + 'static,
|
||||
{
|
||||
let mut handlers = TEST_HANDLERS.lock().unwrap();
|
||||
handlers.insert(path.to_string(), Box::new(handler));
|
||||
@@ -50,9 +49,9 @@ pub fn make_response(
|
||||
Ok(Stream::Test(Box::new(cursor), write))
|
||||
}
|
||||
|
||||
pub fn resolve_handler(req: &Request, url: &Url) -> Result<Stream, Error> {
|
||||
pub fn resolve_handler(unit: &Unit) -> Result<Stream, Error> {
|
||||
let mut handlers = TEST_HANDLERS.lock().unwrap();
|
||||
let path = url.path();
|
||||
let path = unit.url.path();
|
||||
let handler = handlers.remove(path).unwrap();
|
||||
handler(req, url)
|
||||
handler(unit)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use super::super::*;
|
||||
|
||||
#[test]
|
||||
fn no_query_string() {
|
||||
test::set_handler("/no_query_string", |_req, _url| {
|
||||
test::set_handler("/no_query_string", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = get("test://host/no_query_string").call();
|
||||
@@ -15,7 +15,7 @@ fn no_query_string() {
|
||||
|
||||
#[test]
|
||||
fn escaped_query_string() {
|
||||
test::set_handler("/escaped_query_string", |_req, _url| {
|
||||
test::set_handler("/escaped_query_string", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = get("test://host/escaped_query_string")
|
||||
@@ -29,7 +29,7 @@ fn escaped_query_string() {
|
||||
|
||||
#[test]
|
||||
fn query_in_path() {
|
||||
test::set_handler("/query_in_path", |_req, _url| {
|
||||
test::set_handler("/query_in_path", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = get("test://host/query_in_path?foo=bar").call();
|
||||
@@ -40,7 +40,7 @@ fn query_in_path() {
|
||||
|
||||
#[test]
|
||||
fn query_in_path_and_req() {
|
||||
test::set_handler("/query_in_path_and_req", |_req, _url| {
|
||||
test::set_handler("/query_in_path_and_req", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = get("test://host/query_in_path_and_req?foo=bar")
|
||||
|
||||
@@ -5,9 +5,9 @@ use super::super::*;
|
||||
|
||||
#[test]
|
||||
fn header_passing() {
|
||||
test::set_handler("/header_passing", |req, _url| {
|
||||
assert!(req.has("X-Foo"));
|
||||
assert_eq!(req.header("X-Foo").unwrap(), "bar");
|
||||
test::set_handler("/header_passing", |unit| {
|
||||
assert!(unit.has("X-Foo"));
|
||||
assert_eq!(unit.header("X-Foo").unwrap(), "bar");
|
||||
test::make_response(200, "OK", vec!["X-Bar: foo"], vec![])
|
||||
});
|
||||
let resp = get("test://host/header_passing").set("X-Foo", "bar").call();
|
||||
@@ -18,9 +18,9 @@ fn header_passing() {
|
||||
|
||||
#[test]
|
||||
fn repeat_non_x_header() {
|
||||
test::set_handler("/repeat_non_x_header", |req, _url| {
|
||||
assert!(req.has("Accept"));
|
||||
assert_eq!(req.header("Accept").unwrap(), "baz");
|
||||
test::set_handler("/repeat_non_x_header", |unit| {
|
||||
assert!(unit.has("Accept"));
|
||||
assert_eq!(unit.header("Accept").unwrap(), "baz");
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = get("test://host/repeat_non_x_header")
|
||||
@@ -32,11 +32,11 @@ fn repeat_non_x_header() {
|
||||
|
||||
#[test]
|
||||
fn repeat_x_header() {
|
||||
test::set_handler("/repeat_x_header", |req, _url| {
|
||||
assert!(req.has("X-Forwarded-For"));
|
||||
assert_eq!(req.header("X-Forwarded-For").unwrap(), "130.240.19.2");
|
||||
test::set_handler("/repeat_x_header", |unit| {
|
||||
assert!(unit.has("X-Forwarded-For"));
|
||||
assert_eq!(unit.header("X-Forwarded-For").unwrap(), "130.240.19.2");
|
||||
assert_eq!(
|
||||
req.all("X-Forwarded-For"),
|
||||
unit.all("X-Forwarded-For"),
|
||||
vec!["130.240.19.2", "130.240.19.3"]
|
||||
);
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
@@ -50,7 +50,7 @@ fn repeat_x_header() {
|
||||
|
||||
#[test]
|
||||
fn body_as_text() {
|
||||
test::set_handler("/body_as_text", |_req, _url| {
|
||||
test::set_handler("/body_as_text", |_unit| {
|
||||
test::make_response(200, "OK", vec![], "Hello World!".to_string().into_bytes())
|
||||
});
|
||||
let resp = get("test://host/body_as_text").call();
|
||||
@@ -61,7 +61,7 @@ fn body_as_text() {
|
||||
#[test]
|
||||
#[cfg(feature = "json")]
|
||||
fn body_as_json() {
|
||||
test::set_handler("/body_as_json", |_req, _url| {
|
||||
test::set_handler("/body_as_json", |_unit| {
|
||||
test::make_response(
|
||||
200,
|
||||
"OK",
|
||||
@@ -76,7 +76,7 @@ fn body_as_json() {
|
||||
|
||||
#[test]
|
||||
fn body_as_reader() {
|
||||
test::set_handler("/body_as_reader", |_req, _url| {
|
||||
test::set_handler("/body_as_reader", |_unit| {
|
||||
test::make_response(200, "OK", vec![], "abcdefgh".to_string().into_bytes())
|
||||
});
|
||||
let resp = get("test://host/body_as_reader").call();
|
||||
@@ -88,7 +88,7 @@ fn body_as_reader() {
|
||||
|
||||
#[test]
|
||||
fn escape_path() {
|
||||
test::set_handler("/escape_path%20here", |_req, _url| {
|
||||
test::set_handler("/escape_path%20here", |_unit| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
let resp = get("test://host/escape_path here").call();
|
||||
|
||||
230
src/unit.rs
Normal file
230
src/unit.rs
Normal file
@@ -0,0 +1,230 @@
|
||||
//
|
||||
|
||||
pub struct Unit {
|
||||
pub agent: Arc<Mutex<Option<AgentState>>>,
|
||||
pub url: Url,
|
||||
pub is_chunked: bool,
|
||||
pub is_head: bool,
|
||||
pub hostname: String,
|
||||
pub query_string: String,
|
||||
pub headers: Vec<Header>,
|
||||
pub timeout_connect: u64,
|
||||
pub timeout_read: u64,
|
||||
pub timeout_write: u64,
|
||||
}
|
||||
|
||||
impl Unit {
|
||||
//
|
||||
|
||||
fn new(req: &Request, url: &Url, body: &SizedReader) -> Self {
|
||||
//
|
||||
|
||||
let is_chunked = req.header("transfer-encoding")
|
||||
// if the user has set an encoding header, obey that.
|
||||
.map(|enc| enc.len() > 0)
|
||||
// otherwise, no chunking.
|
||||
.unwrap_or(false);
|
||||
|
||||
let is_secure = url.scheme().eq_ignore_ascii_case("https");
|
||||
|
||||
let is_head = req.method.eq_ignore_ascii_case("head");
|
||||
|
||||
let hostname = url.host_str().unwrap_or("localhost").to_string();
|
||||
|
||||
let query_string = combine_query(&url, &req.query);
|
||||
|
||||
let cookie_headers: Vec<_> = {
|
||||
let mut state = req.agent.lock().unwrap();
|
||||
match state.as_ref().map(|state| &state.jar) {
|
||||
None => vec![],
|
||||
Some(jar) => match_cookies(jar, &hostname, url.path(), is_secure),
|
||||
}
|
||||
};
|
||||
let extra_headers = {
|
||||
let mut extra = vec![];
|
||||
|
||||
// chunking and Content-Length headers are mutually exclusive
|
||||
// also don't write this if the user has set it themselves
|
||||
if !is_chunked && !req.has("content-length") {
|
||||
if let Some(size) = body.size {
|
||||
extra.push(
|
||||
format!("Content-Length: {}\r\n", size)
|
||||
.parse::<Header>()
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
extra
|
||||
};
|
||||
let headers: Vec<_> = req
|
||||
.headers
|
||||
.iter()
|
||||
.chain(cookie_headers.iter())
|
||||
.chain(extra_headers.iter())
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
Unit {
|
||||
agent: Arc::clone(&req.agent),
|
||||
url: url.clone(),
|
||||
is_chunked,
|
||||
is_head,
|
||||
hostname,
|
||||
query_string,
|
||||
headers,
|
||||
timeout_connect: req.timeout_connect,
|
||||
timeout_read: req.timeout_read,
|
||||
timeout_write: req.timeout_write,
|
||||
}
|
||||
}
|
||||
|
||||
fn connect(
|
||||
&mut self,
|
||||
url: Url,
|
||||
method: &str,
|
||||
redirects: u32,
|
||||
body: SizedReader,
|
||||
) -> Result<Response, Error> {
|
||||
//
|
||||
|
||||
// open socket
|
||||
let mut stream = match url.scheme() {
|
||||
"http" => connect_http(self),
|
||||
"https" => connect_https(self),
|
||||
"test" => connect_test(self),
|
||||
_ => Err(Error::UnknownScheme(url.scheme().to_string())),
|
||||
}?;
|
||||
|
||||
// send the request start + headers
|
||||
let mut prelude: Vec<u8> = vec![];
|
||||
write!(
|
||||
prelude,
|
||||
"{} {}{} HTTP/1.1\r\n",
|
||||
method,
|
||||
url.path(),
|
||||
self.query_string
|
||||
)?;
|
||||
if !has_header(&self.headers, "host") {
|
||||
write!(prelude, "Host: {}\r\n", url.host().unwrap())?;
|
||||
}
|
||||
for header in &self.headers {
|
||||
write!(prelude, "{}: {}\r\n", header.name(), header.value())?;
|
||||
}
|
||||
write!(prelude, "\r\n")?;
|
||||
|
||||
stream.write_all(&mut prelude[..])?;
|
||||
|
||||
// start reading the response to process cookies and redirects.
|
||||
let mut resp = Response::from_read(&mut stream);
|
||||
|
||||
// squirrel away cookies
|
||||
{
|
||||
let mut state = self.agent.lock().unwrap();
|
||||
if let Some(add_jar) = state.as_mut().map(|state| &mut state.jar) {
|
||||
for raw_cookie in resp.all("set-cookie").iter() {
|
||||
let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
|
||||
raw_cookie.to_string()
|
||||
} else {
|
||||
format!("{}; Domain={}", raw_cookie, self.hostname)
|
||||
};
|
||||
match Cookie::parse_encoded(&to_parse[..]) {
|
||||
Err(_) => (), // ignore unparseable cookies
|
||||
Ok(mut cookie) => {
|
||||
let cookie = cookie.into_owned();
|
||||
add_jar.add(cookie)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle redirects
|
||||
if resp.redirect() {
|
||||
if redirects == 0 {
|
||||
return Err(Error::TooManyRedirects);
|
||||
}
|
||||
|
||||
// the location header
|
||||
let location = resp.header("location");
|
||||
if let Some(location) = location {
|
||||
// join location header to current url in case it it relative
|
||||
let new_url = url
|
||||
.join(location)
|
||||
.map_err(|_| Error::BadUrl(format!("Bad redirection: {}", location)))?;
|
||||
|
||||
// perform the redirect differently depending on 3xx code.
|
||||
return match resp.status {
|
||||
301 | 302 | 303 => {
|
||||
send_body(body, self.is_chunked, &mut stream)?;
|
||||
let empty = Payload::Empty.into_read();
|
||||
self.connect(new_url, "GET", redirects - 1, empty)
|
||||
}
|
||||
307 | 308 | _ => self.connect(new_url, method, redirects - 1, body),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// send the body (which can be empty now depending on redirects)
|
||||
send_body(body, self.is_chunked, &mut stream)?;
|
||||
|
||||
// since it is not a redirect, give away the incoming stream to the response object
|
||||
resp.set_stream(stream, self.is_head);
|
||||
|
||||
// release the response
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn header<'a>(&self, name: &'a str) -> Option<&str> {
|
||||
get_header(&self.headers, name)
|
||||
}
|
||||
#[cfg(test)]
|
||||
pub fn has<'a>(&self, name: &'a str) -> bool {
|
||||
has_header(&self.headers, name)
|
||||
}
|
||||
#[cfg(test)]
|
||||
pub fn all<'a>(&self, name: &'a str) -> Vec<&str> {
|
||||
get_all_headers(&self.headers, name)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO check so cookies can't be set for tld:s
|
||||
fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bool) -> Vec<Header> {
|
||||
jar.iter()
|
||||
.filter(|c| {
|
||||
// if there is a domain, it must be matched. if there is no domain, then ignore cookie
|
||||
let domain_ok = c
|
||||
.domain()
|
||||
.map(|cdom| domain.contains(cdom))
|
||||
.unwrap_or(false);
|
||||
// a path must match the beginning of request path. no cookie path, we say is ok. is it?!
|
||||
let path_ok = c
|
||||
.path()
|
||||
.map(|cpath| path.find(cpath).map(|pos| pos == 0).unwrap_or(false))
|
||||
.unwrap_or(true);
|
||||
// either the cookie isnt secure, or we're not doing a secure request.
|
||||
let secure_ok = !c.secure() || is_secure;
|
||||
|
||||
domain_ok && path_ok && secure_ok
|
||||
})
|
||||
.map(|c| {
|
||||
let name = c.name().to_string();
|
||||
let value = c.value().to_string();
|
||||
let nameval = Cookie::new(name, value).encoded().to_string();
|
||||
let head = format!("Cookie: {}", nameval);
|
||||
head.parse::<Header>().ok()
|
||||
})
|
||||
.filter(|o| o.is_some())
|
||||
.map(|o| o.unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn combine_query(url: &Url, query: &QString) -> String {
|
||||
match (url.query(), query.len() > 0) {
|
||||
(Some(urlq), true) => format!("?{}&{}", urlq, query),
|
||||
(Some(urlq), false) => format!("?{}", urlq),
|
||||
(None, true) => format!("?{}", query),
|
||||
(None, false) => "".to_string(),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user