diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index 66356cd..58bdd31 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -99,3 +99,50 @@ fn connection_reuse() { } assert_eq!(resp.status(), 200); } + +#[cfg(test)] +fn cookie_and_redirect(mut stream: TcpStream) -> io::Result<()> { + let headers = read_headers(&stream); + match headers.path() { + "/first" => { + stream.write_all(b"HTTP/1.1 302 Found\r\n")?; + stream.write_all(b"Location: /second\r\n")?; + stream.write_all(b"Set-Cookie: first=true\r\n")?; + stream.write_all(b"Content-Length: 0\r\n\r\n")?; + }, + "/second" => { + if headers.headers().iter().find(|&x| x.contains("Cookie: first")).is_none() { + panic!("request did not contain cookie 'first'"); + } + stream.write_all(b"HTTP/1.1 302 Found\r\n")?; + stream.write_all(b"Location: /third\r\n")?; + stream.write_all(b"Set-Cookie: second=true\r\n")?; + stream.write_all(b"Content-Length: 0\r\n\r\n")?; + }, + "/third" => { + if headers.headers().iter().find(|&x| x.contains("Cookie: first")).is_none() { + panic!("request did not contain cookie 'second'"); + } + stream.write_all(b"HTTP/1.1 200 OK\r\n")?; + stream.write_all(b"Set-Cookie: third=true\r\n")?; + stream.write_all(b"Content-Length: 0\r\n\r\n")?; + }, + _ => {}, + } + Ok(()) +} + +#[cfg(feature = "cookie")] +#[test] +fn test_cookies_on_redirect() { + let testserver = TestServer::new(cookie_and_redirect); + let url = format!("http://localhost:{}/first", testserver.port); + let agent = Agent::default().build(); + let resp = agent.post(&url).call(); + if resp.error() { + panic!("error: {} {}", resp.status(), resp.into_string().unwrap()); + } + assert!(agent.cookie("first").is_some()); + assert!(agent.cookie("second").is_some()); + assert!(agent.cookie("third").is_some()); +} diff --git a/src/test/testserver.rs b/src/test/testserver.rs index 7c563a1..45b70a8 100644 --- a/src/test/testserver.rs +++ b/src/test/testserver.rs @@ -9,14 +9,35 @@ pub struct TestServer { pub done: Arc, } -// Read a stream until reaching a blank line, in order to consume -// request headers. -pub fn read_headers(stream: &TcpStream) { - for line in BufReader::new(stream).lines() { - if line.unwrap() == "" { - break; +pub struct TestHeaders(Vec); + +impl TestHeaders { + // Return the path for a request, e.g. /foo from "GET /foo HTTP/1.1" + pub fn path(&self) -> &str { + if self.0.len() == 0 { + "" + } else { + &self.0[0].split(" ").nth(1).unwrap() } } + + pub fn headers(&self) -> &[String] { + &self.0[1..] + } +} + +// Read a stream until reaching a blank line, in order to consume +// request headers. +pub fn read_headers(stream: &TcpStream) -> TestHeaders { + let mut results = vec![]; + for line in BufReader::new(stream).lines() { + match line { + Err(e) => panic!(e), + Ok(line) if line == "" => break, + Ok(line) => results.push(line), + }; + } + TestHeaders(results) } impl TestServer {