From d4126027c86d95984863dd27933775107fa188c7 Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Tue, 12 Jun 2018 23:09:17 +0200 Subject: [PATCH] cookie jar --- Cargo.lock | 22 +++++++++++++ Cargo.toml | 1 + README.md | 2 +- src/agent.rs | 61 ++++++++++++++++++++++++++++++++--- src/conn.rs | 72 ++++++++++++++++++++++++++++++++++++++---- src/lib.rs | 14 +++++--- src/request.rs | 33 +++++++++++-------- src/response.rs | 23 ++++++++------ src/stream.rs | 16 +++++++++- src/test/agent_test.rs | 9 ++++-- src/test/mod.rs | 2 +- 11 files changed, 213 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e31c8f4..d692fb5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -37,6 +37,15 @@ name = "chunked_transfer" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" +[[package]] +name = "cookie" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "time 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", + "url 1.7.0 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "dns-lookup" version = "0.9.1" @@ -320,6 +329,16 @@ dependencies = [ "winapi 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", ] +[[package]] +name = "time" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +dependencies = [ + "libc 0.2.42 (registry+https://github.com/rust-lang/crates.io-index)", + "redox_syscall 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)", + "winapi 0.3.5 (registry+https://github.com/rust-lang/crates.io-index)", +] + [[package]] name = "unicase" version = "1.4.2" @@ -353,6 +372,7 @@ dependencies = [ "ascii 0.9.0 (registry+https://github.com/rust-lang/crates.io-index)", "base64 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)", "chunked_transfer 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", + "cookie 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)", "dns-lookup 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)", "encoding 0.2.33 (registry+https://github.com/rust-lang/crates.io-index)", "lazy_static 1.0.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -425,6 +445,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum cc 1.0.17 (registry+https://github.com/rust-lang/crates.io-index)" = "49ec142f5768efb5b7622aebc3fdbdbb8950a4b9ba996393cb76ef7466e8747d" "checksum cfg-if 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "405216fd8fe65f718daa7102ea808a946b6ce40c742998fbfd3463645552de18" "checksum chunked_transfer 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)" = "498d20a7aaf62625b9bf26e637cf7736417cde1d0c99f1d04d1170229a85cf87" +"checksum cookie 0.10.1 (registry+https://github.com/rust-lang/crates.io-index)" = "746858cae4eae40fff37e1998320068df317bc247dc91a67c6cfa053afdc2abb" "checksum dns-lookup 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)" = "54810764899241c707428f4a1989351f30c0c2bda5ea07ff2e43148f8935039f" "checksum dtoa 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "09c3753c3db574d215cba4ea76018483895d7bff25a31b49ba45db21c48e50ab" "checksum encoding 0.2.33 (registry+https://github.com/rust-lang/crates.io-index)" = "6b0d943856b990d12d3b55b359144ff341533e516d94098b1d3fc1ac666d36ec" @@ -461,6 +482,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" "checksum serde_json 1.0.20 (registry+https://github.com/rust-lang/crates.io-index)" = "fc97cccc2959f39984524026d760c08ef0dd5f0f5948c8d31797dbfae458c875" "checksum siphasher 0.2.2 (registry+https://github.com/rust-lang/crates.io-index)" = "0df90a788073e8d0235a67e50441d47db7c8ad9debd91cbf43736a2a92d36537" "checksum socket2 0.3.6 (registry+https://github.com/rust-lang/crates.io-index)" = "06dc9f86ee48652b7c80f3d254e3b9accb67a928c562c64d10d7b016d3d98dab" +"checksum time 0.1.40 (registry+https://github.com/rust-lang/crates.io-index)" = "d825be0eb33fda1a7e68012d51e9c7f451dc1a69391e7fdc197060bb8c56667b" "checksum unicase 1.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "7f4765f83163b74f957c797ad9253caf97f103fb064d3999aea9568d09fc8a33" "checksum unicode-bidi 0.3.4 (registry+https://github.com/rust-lang/crates.io-index)" = "49f2bd0c6468a8230e1db229cff8029217cf623c767ea5d60bfbd42729ea54d5" "checksum unicode-normalization 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)" = "6a0180bc61fc5a987082bfa111f4cc95c4caff7f9799f3e46df09163a937aa25" diff --git a/Cargo.toml b/Cargo.toml index f535353..4bcc7ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ authors = ["Martin Algesten "] ascii = "0.9" base64 = "*" chunked_transfer = "0.3" +cookie = { version = "0.10", features = ["percent-encode"] } dns-lookup = "0.9.1" encoding = "0.2" lazy_static = "1" diff --git a/README.md b/README.md index cc8db7b..75721d8 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ - [x] Limit read length on Content-Size - [x] Auth headers - [x] Repeated headers -- [ ] Cookie jar in agent +- [x] Cookie jar in agent - [ ] Forms with application/x-www-form-urlencoded - [ ] multipart/form-data - [ ] Connection reuse/keep-alive with pool diff --git a/src/agent.rs b/src/agent.rs index 43bc5ba..db2cccb 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -1,7 +1,8 @@ +use cookie::{Cookie, CookieJar}; use std::str::FromStr; use std::sync::Mutex; -use header::{Header, add_header}; +use header::{add_header, Header}; use util::*; // to get to share private fields @@ -11,8 +12,23 @@ include!("conn.rs"); #[derive(Debug, Default, Clone)] pub struct Agent { - pub headers: Vec
, - pub pool: Arc>>, + headers: Vec
, + state: Arc>>, +} + +#[derive(Debug)] +struct AgentState { + pool: ConnectionPool, + jar: CookieJar, +} + +impl AgentState { + fn new() -> Self { + AgentState { + pool: ConnectionPool::new(), + jar: CookieJar::new(), + } + } } impl Agent { @@ -26,7 +42,7 @@ impl Agent { pub fn build(&self) -> Self { Agent { headers: self.headers.clone(), - pool: Arc::new(Mutex::new(Some(ConnectionPool::new()))), + state: Arc::new(Mutex::new(Some(AgentState::new()))), } } @@ -160,6 +176,43 @@ impl Agent { Request::new(&self, method.into(), path.into()) } + /// Gets a cookie in this agent by name. Cookies are available + /// either by setting it in the agent, or by making requests + /// that `Set-Cookie` in the agent. + /// + /// ``` + /// let agent = ureq::agent().build(); + /// + /// agent.get("http://www.google.com").call(); + /// + /// assert!(agent.cookie("NID").is_some()); + /// ``` + pub fn cookie(&self, name: &str) -> Option> { + let state = self.state.lock().unwrap(); + state + .as_ref() + .and_then(|state| state.jar.get(name)) + .map(|c| c.clone()) + } + + /// Set a cookie in this agent. + /// + /// ``` + /// let agent = ureq::agent().build(); + /// + /// let cookie = ureq::Cookie::new("name", "value"); + /// agent.set_cookie(cookie); + /// ``` + pub fn set_cookie(&self, cookie: Cookie<'static>) { + let mut state = self.state.lock().unwrap(); + match state.as_mut() { + None => (), + Some(state) => { + state.jar.add_original(cookie); + } + } + } + pub fn get(&self, path: S) -> Request where S: Into, diff --git a/src/conn.rs b/src/conn.rs index 423aedf..f1e1bd4 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -26,10 +26,23 @@ impl ConnectionPool { method: &str, url: &Url, redirects: u32, + mut jar: Option<&mut CookieJar>, payload: Payload, ) -> Result { // - // open connection + + let hostname = url.host_str().unwrap_or("localhost"); // is localhost a good alternative? + let is_secure = url.scheme().eq_ignore_ascii_case("https"); + + let cookie_headers: Vec<_> = { + match jar.as_ref() { + None => vec![], + Some(jar) => match_cookies(jar, hostname, url.path(), is_secure), + } + }; + let headers = request.headers.iter().chain(cookie_headers.iter()); + + // open socket let mut stream = match url.scheme() { "http" => connect_http(request, &url), "https" => connect_https(request, &url), @@ -43,16 +56,34 @@ impl ConnectionPool { if !request.has("host") { write!(prelude, "Host: {}\r\n", url.host().unwrap())?; } - for header in request.headers.iter() { + for header in headers { write!(prelude, "{}: {}\r\n", header.name(), header.value())?; } write!(prelude, "\r\n")?; stream.write_all(&mut prelude[..])?; - // start reading the response to check it it's a redirect + // start reading the response to process cookies and redirects. let mut resp = Response::from_read(&mut stream); + // squirrel away cookies + if let Some(add_jar) = jar.as_mut() { + for raw_cookie in resp.all("set-cookie").iter() { + let to_parse = if raw_cookie.to_lowercase().contains("domain=") { + raw_cookie.to_string() + } else { + format!("{}; Domain={}", raw_cookie, hostname) + }; + match Cookie::parse_encoded(&to_parse[..]) { + Err(_) => (), // ignore unparseable cookies + Ok(mut cookie) => { + let cookie = cookie.into_owned(); + add_jar.add(cookie) + } + } + } + } + // handle redirects if resp.redirect() { if redirects == 0 { @@ -70,10 +101,10 @@ impl ConnectionPool { return match resp.status { 301 | 302 | 303 => { send_payload(&request, payload, &mut stream)?; - self.connect(request, "GET", &new_url, redirects - 1, Payload::Empty) + self.connect(request, "GET", &new_url, redirects - 1, jar, Payload::Empty) } 307 | 308 | _ => { - self.connect(request, method, &new_url, redirects - 1, payload) + self.connect(request, method, &new_url, redirects - 1, jar, payload) } }; } @@ -83,7 +114,7 @@ impl ConnectionPool { send_payload(&request, payload, &mut stream)?; // since it is not a redirect, give away the incoming stream to the response object - resp.set_reader(stream); + resp.set_stream(stream); // release the response Ok(resp) @@ -193,6 +224,35 @@ where Ok(()) } +// TODO check so cookies can't be set for tld:s +fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bool) -> Vec
{ + jar.iter() + .filter(|c| { + // if there is a domain, it must be matched. if there is no domain, then ignore cookie + let domain_ok = c.domain() + .map(|cdom| domain.contains(cdom)) + .unwrap_or(false); + // a path must match the beginning of request path. no cookie path, we say is ok. is it?! + let path_ok = c.path() + .map(|cpath| path.find(cpath).map(|pos| pos == 0).unwrap_or(false)) + .unwrap_or(true); + // either the cookie isnt secure, or we're not doing a secure request. + let secure_ok = !c.secure() || is_secure; + + domain_ok && path_ok && secure_ok + }) + .map(|c| { + let name = c.name().to_string(); + let value = c.value().to_string(); + let nameval = Cookie::new(name, value).encoded().to_string(); + let head = format!("Cookie: {}", nameval); + head.parse::
().ok() + }) + .filter(|o| o.is_some()) + .map(|o| o.unwrap()) + .collect() +} + #[cfg(not(test))] fn connect_test(_request: &Request, url: &Url) -> Result { Err(Error::UnknownScheme(url.scheme().to_string())) diff --git a/src/lib.rs b/src/lib.rs index cc107c6..db3e3fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ extern crate ascii; extern crate base64; extern crate chunked_transfer; +extern crate cookie; extern crate dns_lookup; extern crate encoding; #[macro_use] @@ -28,6 +29,7 @@ pub use header::Header; // re-export pub use serde_json::{to_value, Map, Value}; +pub use cookie::Cookie; /// Agents keep state between requests. /// @@ -131,16 +133,20 @@ mod tests { #[test] fn connect_http_google() { let resp = get("http://www.google.com/").call(); - println!("{:?}", resp); - assert_eq!("text/html; charset=ISO-8859-1", resp.header("content-type").unwrap()); + assert_eq!( + "text/html; charset=ISO-8859-1", + resp.header("content-type").unwrap() + ); assert_eq!("text/html", resp.content_type()); } #[test] fn connect_https_google() { let resp = get("https://www.google.com/").call(); - println!("{:?}", resp); - assert_eq!("text/html; charset=ISO-8859-1", resp.header("content-type").unwrap()); + assert_eq!( + "text/html; charset=ISO-8859-1", + resp.header("content-type").unwrap() + ); assert_eq!("text/html", resp.content_type()); } } diff --git a/src/request.rs b/src/request.rs index 15af5d1..875f37b 100644 --- a/src/request.rs +++ b/src/request.rs @@ -8,7 +8,7 @@ lazy_static! { #[derive(Clone, Default)] pub struct Request { - pool: Arc>>, + state: Arc>>, // via agent method: String, @@ -57,7 +57,7 @@ impl Payload { impl Request { fn new(agent: &Agent, method: String, path: String) -> Request { Request { - pool: Arc::clone(&agent.pool), + state: Arc::clone(&agent.state), method, path, headers: agent.headers.clone(), @@ -90,26 +90,31 @@ impl Request { /// /// println!("{:?}", r); /// ``` - pub fn call(&self) -> Response { + pub fn call(&mut self) -> Response { self.do_call(Payload::Empty) } - fn do_call(&self, payload: Payload) -> Response { - let mut lock = self.pool.lock().unwrap(); + fn do_call(&mut self, payload: Payload) -> Response { + let mut state = self.state.lock().unwrap(); self.to_url() .and_then(|url| { - if lock.is_none() { - // create a one off pool. - ConnectionPool::new().connect(self, &self.method, &url, self.redirects, payload) - } else { - // reuse connection pool. - lock.as_mut().unwrap().connect( + if state.is_none() { + // create a one off pool/jar. + ConnectionPool::new().connect( self, &self.method, &url, self.redirects, + None, payload, ) + } else { + // reuse connection pool. + let state = state.as_mut().unwrap(); + let jar = &mut state.jar; + state + .pool + .connect(self, &self.method, &url, self.redirects, Some(jar), payload) } }) .unwrap_or_else(|e| e.into()) @@ -127,7 +132,7 @@ impl Request { /// println!("{:?}", r); /// } /// ``` - pub fn send_json(&self, data: serde_json::Value) -> Response { + pub fn send_json(&mut self, data: serde_json::Value) -> Response { self.do_call(Payload::JSON(data)) } @@ -139,7 +144,7 @@ impl Request { /// .send_str("Hello World!"); /// println!("{:?}", r); /// ``` - pub fn send_str(&self, data: S) -> Response + pub fn send_str(&mut self, data: S) -> Response where S: Into, { @@ -151,7 +156,7 @@ impl Request { /// /// /// - pub fn send(&self, reader: R) -> Response + pub fn send(&mut self, reader: R) -> Response where R: Read + Send + 'static, { diff --git a/src/response.rs b/src/response.rs index 08c992e..ef0b487 100644 --- a/src/response.rs +++ b/src/response.rs @@ -17,7 +17,7 @@ pub struct Response { index: (usize, usize), // index into status_line where we split: HTTP/1.1 200 OK status: u16, headers: Vec
, - reader: Option>, + stream: Option, } impl ::std::fmt::Debug for Response { @@ -135,13 +135,13 @@ impl Response { .map(|enc| enc.len() > 0) // whatever it says, do chunked .unwrap_or(false); let len = self.header("content-length").and_then(|l| l.parse::().ok()); - let reader = self.reader.expect("No reader in response?!"); + let reader = self.stream.expect("No reader in response?!"); match is_chunked { true => Box::new(chunked_transfer::Decoder::new(reader)), false => { match len { Some(len) => Box::new(LimitedRead::new(reader, len)), - None => reader, + None => Box::new(reader) as Box, } }, } @@ -202,12 +202,17 @@ impl Response { index, status, headers, - reader: None, + stream: None, }) } - fn set_reader(&mut self, reader: R) where R: Read + Send + 'static { - self.reader = Some(Box::new(reader)); + fn set_stream(&mut self, stream: Stream) { + self.stream = Some(stream); + } + + #[cfg(test)] + pub fn to_write_vec(&self) -> Vec { + self.stream.as_ref().unwrap().to_write_vec() } } @@ -243,7 +248,7 @@ impl FromStr for Response { fn from_str(s: &str) -> Result { let mut read = VecRead::from_str(s); let mut resp = Self::do_from_read(&mut read)?; - resp.set_reader(read); + resp.set_stream(Stream::Read(Box::new(read))); Ok(resp) } } @@ -281,13 +286,13 @@ fn read_next_line(reader: &mut R) -> IoResult { } struct LimitedRead { - reader: Box, + reader: Stream, limit: usize, position: usize, } impl LimitedRead { - fn new(reader: Box, limit: usize) -> Self { + fn new(reader: Stream, limit: usize) -> Self { LimitedRead { reader, limit, diff --git a/src/stream.rs b/src/stream.rs index 01680ab..267a86a 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -7,7 +7,18 @@ use std::net::TcpStream; pub enum Stream { Http(TcpStream), Https(rustls::ClientSession, TcpStream), - #[cfg(test)] Test(Box, Box), + Read(Box), + #[cfg(test)] Test(Box, Vec), +} + +impl Stream { + #[cfg(test)] + pub fn to_write_vec(&self) -> Vec { + match self { + Stream::Test(_, writer) => writer.clone(), + _ => panic!("to_write_vec on non Test stream") + } + } } impl Read for Stream { @@ -15,6 +26,7 @@ impl Read for Stream { match self { Stream::Http(sock) => sock.read(buf), Stream::Https(sess, sock) => rustls::Stream::new(sess, sock).read(buf), + Stream::Read(read) => read.read(buf), #[cfg(test)] Stream::Test(reader, _) => reader.read(buf), } } @@ -25,6 +37,7 @@ impl Write for Stream { match self { Stream::Http(sock) => sock.write(buf), Stream::Https(sess, sock) => rustls::Stream::new(sess, sock).write(buf), + Stream::Read(_) => panic!("Write to read stream"), #[cfg(test)] Stream::Test(_, writer) => writer.write(buf), } } @@ -32,6 +45,7 @@ impl Write for Stream { match self { Stream::Http(sock) => sock.flush(), Stream::Https(sess, sock) => rustls::Stream::new(sess, sock).flush(), + Stream::Read(_) => panic!("Flush read stream"), #[cfg(test)] Stream::Test(_, writer) => writer.flush(), } } diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index 55d352c..c45a305 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -42,10 +42,15 @@ fn agent_cookies() { assert!(agent.cookie("foo").is_some()); assert_eq!(agent.cookie("foo").unwrap().value(), "bar baz"); - test::set_handler("/agent_cookies", |req, _url| { + test::set_handler("/agent_cookies", |_req, _url| { test::make_response(200, "OK", vec![], vec![]) }); - agent.get("test://host/agent_cookies").call(); + let resp = agent.get("test://host/agent_cookies").call(); + + let vec = resp.to_write_vec(); + let s = String::from_utf8_lossy(&vec); + + assert!(s.contains("Cookie: foo=bar%20baz\r\n")); } diff --git a/src/test/mod.rs b/src/test/mod.rs index a013574..ad32355 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -44,7 +44,7 @@ pub fn make_response( buf.append(&mut body); let read = VecRead::from_vec(buf); let write: Vec = vec![]; - Ok(Stream::Test(Box::new(read), Box::new(write))) + Ok(Stream::Test(Box::new(read), write)) } pub fn resolve_handler(req: &Request, url: &Url) -> Result {