cookie jar
This commit is contained in:
61
src/agent.rs
61
src/agent.rs
@@ -1,7 +1,8 @@
|
||||
use cookie::{Cookie, CookieJar};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use header::{Header, add_header};
|
||||
use header::{add_header, Header};
|
||||
use util::*;
|
||||
|
||||
// to get to share private fields
|
||||
@@ -11,8 +12,23 @@ include!("conn.rs");
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct Agent {
|
||||
pub headers: Vec<Header>,
|
||||
pub pool: Arc<Mutex<Option<ConnectionPool>>>,
|
||||
headers: Vec<Header>,
|
||||
state: Arc<Mutex<Option<AgentState>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct AgentState {
|
||||
pool: ConnectionPool,
|
||||
jar: CookieJar,
|
||||
}
|
||||
|
||||
impl AgentState {
|
||||
fn new() -> Self {
|
||||
AgentState {
|
||||
pool: ConnectionPool::new(),
|
||||
jar: CookieJar::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
@@ -26,7 +42,7 @@ impl Agent {
|
||||
pub fn build(&self) -> Self {
|
||||
Agent {
|
||||
headers: self.headers.clone(),
|
||||
pool: Arc::new(Mutex::new(Some(ConnectionPool::new()))),
|
||||
state: Arc::new(Mutex::new(Some(AgentState::new()))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,6 +176,43 @@ impl Agent {
|
||||
Request::new(&self, method.into(), path.into())
|
||||
}
|
||||
|
||||
/// Gets a cookie in this agent by name. Cookies are available
|
||||
/// either by setting it in the agent, or by making requests
|
||||
/// that `Set-Cookie` in the agent.
|
||||
///
|
||||
/// ```
|
||||
/// let agent = ureq::agent().build();
|
||||
///
|
||||
/// agent.get("http://www.google.com").call();
|
||||
///
|
||||
/// assert!(agent.cookie("NID").is_some());
|
||||
/// ```
|
||||
pub fn cookie(&self, name: &str) -> Option<Cookie<'static>> {
|
||||
let state = self.state.lock().unwrap();
|
||||
state
|
||||
.as_ref()
|
||||
.and_then(|state| state.jar.get(name))
|
||||
.map(|c| c.clone())
|
||||
}
|
||||
|
||||
/// Set a cookie in this agent.
|
||||
///
|
||||
/// ```
|
||||
/// let agent = ureq::agent().build();
|
||||
///
|
||||
/// let cookie = ureq::Cookie::new("name", "value");
|
||||
/// agent.set_cookie(cookie);
|
||||
/// ```
|
||||
pub fn set_cookie(&self, cookie: Cookie<'static>) {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
match state.as_mut() {
|
||||
None => (),
|
||||
Some(state) => {
|
||||
state.jar.add_original(cookie);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get<S>(&self, path: S) -> Request
|
||||
where
|
||||
S: Into<String>,
|
||||
|
||||
72
src/conn.rs
72
src/conn.rs
@@ -26,10 +26,23 @@ impl ConnectionPool {
|
||||
method: &str,
|
||||
url: &Url,
|
||||
redirects: u32,
|
||||
mut jar: Option<&mut CookieJar>,
|
||||
payload: Payload,
|
||||
) -> Result<Response, Error> {
|
||||
//
|
||||
// open connection
|
||||
|
||||
let hostname = url.host_str().unwrap_or("localhost"); // is localhost a good alternative?
|
||||
let is_secure = url.scheme().eq_ignore_ascii_case("https");
|
||||
|
||||
let cookie_headers: Vec<_> = {
|
||||
match jar.as_ref() {
|
||||
None => vec![],
|
||||
Some(jar) => match_cookies(jar, hostname, url.path(), is_secure),
|
||||
}
|
||||
};
|
||||
let headers = request.headers.iter().chain(cookie_headers.iter());
|
||||
|
||||
// open socket
|
||||
let mut stream = match url.scheme() {
|
||||
"http" => connect_http(request, &url),
|
||||
"https" => connect_https(request, &url),
|
||||
@@ -43,16 +56,34 @@ impl ConnectionPool {
|
||||
if !request.has("host") {
|
||||
write!(prelude, "Host: {}\r\n", url.host().unwrap())?;
|
||||
}
|
||||
for header in request.headers.iter() {
|
||||
for header in headers {
|
||||
write!(prelude, "{}: {}\r\n", header.name(), header.value())?;
|
||||
}
|
||||
write!(prelude, "\r\n")?;
|
||||
|
||||
stream.write_all(&mut prelude[..])?;
|
||||
|
||||
// start reading the response to check it it's a redirect
|
||||
// start reading the response to process cookies and redirects.
|
||||
let mut resp = Response::from_read(&mut stream);
|
||||
|
||||
// squirrel away cookies
|
||||
if let Some(add_jar) = jar.as_mut() {
|
||||
for raw_cookie in resp.all("set-cookie").iter() {
|
||||
let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
|
||||
raw_cookie.to_string()
|
||||
} else {
|
||||
format!("{}; Domain={}", raw_cookie, hostname)
|
||||
};
|
||||
match Cookie::parse_encoded(&to_parse[..]) {
|
||||
Err(_) => (), // ignore unparseable cookies
|
||||
Ok(mut cookie) => {
|
||||
let cookie = cookie.into_owned();
|
||||
add_jar.add(cookie)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle redirects
|
||||
if resp.redirect() {
|
||||
if redirects == 0 {
|
||||
@@ -70,10 +101,10 @@ impl ConnectionPool {
|
||||
return match resp.status {
|
||||
301 | 302 | 303 => {
|
||||
send_payload(&request, payload, &mut stream)?;
|
||||
self.connect(request, "GET", &new_url, redirects - 1, Payload::Empty)
|
||||
self.connect(request, "GET", &new_url, redirects - 1, jar, Payload::Empty)
|
||||
}
|
||||
307 | 308 | _ => {
|
||||
self.connect(request, method, &new_url, redirects - 1, payload)
|
||||
self.connect(request, method, &new_url, redirects - 1, jar, payload)
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -83,7 +114,7 @@ impl ConnectionPool {
|
||||
send_payload(&request, payload, &mut stream)?;
|
||||
|
||||
// since it is not a redirect, give away the incoming stream to the response object
|
||||
resp.set_reader(stream);
|
||||
resp.set_stream(stream);
|
||||
|
||||
// release the response
|
||||
Ok(resp)
|
||||
@@ -193,6 +224,35 @@ where
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO check so cookies can't be set for tld:s
|
||||
fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bool) -> Vec<Header> {
|
||||
jar.iter()
|
||||
.filter(|c| {
|
||||
// if there is a domain, it must be matched. if there is no domain, then ignore cookie
|
||||
let domain_ok = c.domain()
|
||||
.map(|cdom| domain.contains(cdom))
|
||||
.unwrap_or(false);
|
||||
// a path must match the beginning of request path. no cookie path, we say is ok. is it?!
|
||||
let path_ok = c.path()
|
||||
.map(|cpath| path.find(cpath).map(|pos| pos == 0).unwrap_or(false))
|
||||
.unwrap_or(true);
|
||||
// either the cookie isnt secure, or we're not doing a secure request.
|
||||
let secure_ok = !c.secure() || is_secure;
|
||||
|
||||
domain_ok && path_ok && secure_ok
|
||||
})
|
||||
.map(|c| {
|
||||
let name = c.name().to_string();
|
||||
let value = c.value().to_string();
|
||||
let nameval = Cookie::new(name, value).encoded().to_string();
|
||||
let head = format!("Cookie: {}", nameval);
|
||||
head.parse::<Header>().ok()
|
||||
})
|
||||
.filter(|o| o.is_some())
|
||||
.map(|o| o.unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
fn connect_test(_request: &Request, url: &Url) -> Result<Stream, Error> {
|
||||
Err(Error::UnknownScheme(url.scheme().to_string()))
|
||||
|
||||
14
src/lib.rs
14
src/lib.rs
@@ -1,6 +1,7 @@
|
||||
extern crate ascii;
|
||||
extern crate base64;
|
||||
extern crate chunked_transfer;
|
||||
extern crate cookie;
|
||||
extern crate dns_lookup;
|
||||
extern crate encoding;
|
||||
#[macro_use]
|
||||
@@ -28,6 +29,7 @@ pub use header::Header;
|
||||
|
||||
// re-export
|
||||
pub use serde_json::{to_value, Map, Value};
|
||||
pub use cookie::Cookie;
|
||||
|
||||
/// Agents keep state between requests.
|
||||
///
|
||||
@@ -131,16 +133,20 @@ mod tests {
|
||||
#[test]
|
||||
fn connect_http_google() {
|
||||
let resp = get("http://www.google.com/").call();
|
||||
println!("{:?}", resp);
|
||||
assert_eq!("text/html; charset=ISO-8859-1", resp.header("content-type").unwrap());
|
||||
assert_eq!(
|
||||
"text/html; charset=ISO-8859-1",
|
||||
resp.header("content-type").unwrap()
|
||||
);
|
||||
assert_eq!("text/html", resp.content_type());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn connect_https_google() {
|
||||
let resp = get("https://www.google.com/").call();
|
||||
println!("{:?}", resp);
|
||||
assert_eq!("text/html; charset=ISO-8859-1", resp.header("content-type").unwrap());
|
||||
assert_eq!(
|
||||
"text/html; charset=ISO-8859-1",
|
||||
resp.header("content-type").unwrap()
|
||||
);
|
||||
assert_eq!("text/html", resp.content_type());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ lazy_static! {
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct Request {
|
||||
pool: Arc<Mutex<Option<ConnectionPool>>>,
|
||||
state: Arc<Mutex<Option<AgentState>>>,
|
||||
|
||||
// via agent
|
||||
method: String,
|
||||
@@ -57,7 +57,7 @@ impl Payload {
|
||||
impl Request {
|
||||
fn new(agent: &Agent, method: String, path: String) -> Request {
|
||||
Request {
|
||||
pool: Arc::clone(&agent.pool),
|
||||
state: Arc::clone(&agent.state),
|
||||
method,
|
||||
path,
|
||||
headers: agent.headers.clone(),
|
||||
@@ -90,26 +90,31 @@ impl Request {
|
||||
///
|
||||
/// println!("{:?}", r);
|
||||
/// ```
|
||||
pub fn call(&self) -> Response {
|
||||
pub fn call(&mut self) -> Response {
|
||||
self.do_call(Payload::Empty)
|
||||
}
|
||||
|
||||
fn do_call(&self, payload: Payload) -> Response {
|
||||
let mut lock = self.pool.lock().unwrap();
|
||||
fn do_call(&mut self, payload: Payload) -> Response {
|
||||
let mut state = self.state.lock().unwrap();
|
||||
self.to_url()
|
||||
.and_then(|url| {
|
||||
if lock.is_none() {
|
||||
// create a one off pool.
|
||||
ConnectionPool::new().connect(self, &self.method, &url, self.redirects, payload)
|
||||
} else {
|
||||
// reuse connection pool.
|
||||
lock.as_mut().unwrap().connect(
|
||||
if state.is_none() {
|
||||
// create a one off pool/jar.
|
||||
ConnectionPool::new().connect(
|
||||
self,
|
||||
&self.method,
|
||||
&url,
|
||||
self.redirects,
|
||||
None,
|
||||
payload,
|
||||
)
|
||||
} else {
|
||||
// reuse connection pool.
|
||||
let state = state.as_mut().unwrap();
|
||||
let jar = &mut state.jar;
|
||||
state
|
||||
.pool
|
||||
.connect(self, &self.method, &url, self.redirects, Some(jar), payload)
|
||||
}
|
||||
})
|
||||
.unwrap_or_else(|e| e.into())
|
||||
@@ -127,7 +132,7 @@ impl Request {
|
||||
/// println!("{:?}", r);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn send_json(&self, data: serde_json::Value) -> Response {
|
||||
pub fn send_json(&mut self, data: serde_json::Value) -> Response {
|
||||
self.do_call(Payload::JSON(data))
|
||||
}
|
||||
|
||||
@@ -139,7 +144,7 @@ impl Request {
|
||||
/// .send_str("Hello World!");
|
||||
/// println!("{:?}", r);
|
||||
/// ```
|
||||
pub fn send_str<S>(&self, data: S) -> Response
|
||||
pub fn send_str<S>(&mut self, data: S) -> Response
|
||||
where
|
||||
S: Into<String>,
|
||||
{
|
||||
@@ -151,7 +156,7 @@ impl Request {
|
||||
///
|
||||
///
|
||||
///
|
||||
pub fn send<R>(&self, reader: R) -> Response
|
||||
pub fn send<R>(&mut self, reader: R) -> Response
|
||||
where
|
||||
R: Read + Send + 'static,
|
||||
{
|
||||
|
||||
@@ -17,7 +17,7 @@ pub struct Response {
|
||||
index: (usize, usize), // index into status_line where we split: HTTP/1.1 200 OK
|
||||
status: u16,
|
||||
headers: Vec<Header>,
|
||||
reader: Option<Box<Read + Send + 'static>>,
|
||||
stream: Option<Stream>,
|
||||
}
|
||||
|
||||
impl ::std::fmt::Debug for Response {
|
||||
@@ -135,13 +135,13 @@ impl Response {
|
||||
.map(|enc| enc.len() > 0) // whatever it says, do chunked
|
||||
.unwrap_or(false);
|
||||
let len = self.header("content-length").and_then(|l| l.parse::<usize>().ok());
|
||||
let reader = self.reader.expect("No reader in response?!");
|
||||
let reader = self.stream.expect("No reader in response?!");
|
||||
match is_chunked {
|
||||
true => Box::new(chunked_transfer::Decoder::new(reader)),
|
||||
false => {
|
||||
match len {
|
||||
Some(len) => Box::new(LimitedRead::new(reader, len)),
|
||||
None => reader,
|
||||
None => Box::new(reader) as Box<Read>,
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -202,12 +202,17 @@ impl Response {
|
||||
index,
|
||||
status,
|
||||
headers,
|
||||
reader: None,
|
||||
stream: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_reader<R>(&mut self, reader: R) where R: Read + Send + 'static {
|
||||
self.reader = Some(Box::new(reader));
|
||||
fn set_stream(&mut self, stream: Stream) {
|
||||
self.stream = Some(stream);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn to_write_vec(&self) -> Vec<u8> {
|
||||
self.stream.as_ref().unwrap().to_write_vec()
|
||||
}
|
||||
|
||||
}
|
||||
@@ -243,7 +248,7 @@ impl FromStr for Response {
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let mut read = VecRead::from_str(s);
|
||||
let mut resp = Self::do_from_read(&mut read)?;
|
||||
resp.set_reader(read);
|
||||
resp.set_stream(Stream::Read(Box::new(read)));
|
||||
Ok(resp)
|
||||
}
|
||||
}
|
||||
@@ -281,13 +286,13 @@ fn read_next_line<R: Read>(reader: &mut R) -> IoResult<AsciiString> {
|
||||
}
|
||||
|
||||
struct LimitedRead {
|
||||
reader: Box<Read + Send>,
|
||||
reader: Stream,
|
||||
limit: usize,
|
||||
position: usize,
|
||||
}
|
||||
|
||||
impl LimitedRead {
|
||||
fn new(reader: Box<Read + Send>, limit: usize) -> Self {
|
||||
fn new(reader: Stream, limit: usize) -> Self {
|
||||
LimitedRead {
|
||||
reader,
|
||||
limit,
|
||||
|
||||
@@ -7,7 +7,18 @@ use std::net::TcpStream;
|
||||
pub enum Stream {
|
||||
Http(TcpStream),
|
||||
Https(rustls::ClientSession, TcpStream),
|
||||
#[cfg(test)] Test(Box<Read + Send>, Box<Write + Send>),
|
||||
Read(Box<Read>),
|
||||
#[cfg(test)] Test(Box<Read + Send>, Vec<u8>),
|
||||
}
|
||||
|
||||
impl Stream {
|
||||
#[cfg(test)]
|
||||
pub fn to_write_vec(&self) -> Vec<u8> {
|
||||
match self {
|
||||
Stream::Test(_, writer) => writer.clone(),
|
||||
_ => panic!("to_write_vec on non Test stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for Stream {
|
||||
@@ -15,6 +26,7 @@ impl Read for Stream {
|
||||
match self {
|
||||
Stream::Http(sock) => sock.read(buf),
|
||||
Stream::Https(sess, sock) => rustls::Stream::new(sess, sock).read(buf),
|
||||
Stream::Read(read) => read.read(buf),
|
||||
#[cfg(test)] Stream::Test(reader, _) => reader.read(buf),
|
||||
}
|
||||
}
|
||||
@@ -25,6 +37,7 @@ impl Write for Stream {
|
||||
match self {
|
||||
Stream::Http(sock) => sock.write(buf),
|
||||
Stream::Https(sess, sock) => rustls::Stream::new(sess, sock).write(buf),
|
||||
Stream::Read(_) => panic!("Write to read stream"),
|
||||
#[cfg(test)] Stream::Test(_, writer) => writer.write(buf),
|
||||
}
|
||||
}
|
||||
@@ -32,6 +45,7 @@ impl Write for Stream {
|
||||
match self {
|
||||
Stream::Http(sock) => sock.flush(),
|
||||
Stream::Https(sess, sock) => rustls::Stream::new(sess, sock).flush(),
|
||||
Stream::Read(_) => panic!("Flush read stream"),
|
||||
#[cfg(test)] Stream::Test(_, writer) => writer.flush(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,10 +42,15 @@ fn agent_cookies() {
|
||||
assert!(agent.cookie("foo").is_some());
|
||||
assert_eq!(agent.cookie("foo").unwrap().value(), "bar baz");
|
||||
|
||||
test::set_handler("/agent_cookies", |req, _url| {
|
||||
test::set_handler("/agent_cookies", |_req, _url| {
|
||||
test::make_response(200, "OK", vec![], vec![])
|
||||
});
|
||||
|
||||
agent.get("test://host/agent_cookies").call();
|
||||
let resp = agent.get("test://host/agent_cookies").call();
|
||||
|
||||
let vec = resp.to_write_vec();
|
||||
let s = String::from_utf8_lossy(&vec);
|
||||
|
||||
assert!(s.contains("Cookie: foo=bar%20baz\r\n"));
|
||||
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ pub fn make_response(
|
||||
buf.append(&mut body);
|
||||
let read = VecRead::from_vec(buf);
|
||||
let write: Vec<u8> = vec![];
|
||||
Ok(Stream::Test(Box::new(read), Box::new(write)))
|
||||
Ok(Stream::Test(Box::new(read), write))
|
||||
}
|
||||
|
||||
pub fn resolve_handler(req: &Request, url: &Url) -> Result<Stream, Error> {
|
||||
|
||||
Reference in New Issue
Block a user