diff --git a/src/lib.rs b/src/lib.rs index b65207f..cb86b9c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -121,6 +121,8 @@ pub use serde_json::json; #[cfg(test)] mod test; +#[doc(hidden)] +mod testserver; pub use crate::agent::Agent; pub use crate::agent::AgentBuilder; @@ -137,17 +139,41 @@ pub use cookie::Cookie; #[cfg(feature = "json")] pub use serde_json::{to_value as serde_to_value, Map as SerdeMap, Value as SerdeValue}; +use once_cell::sync::Lazy; +use std::sync::atomic::{AtomicBool, Ordering}; + /// Creates an agent builder. pub fn builder() -> AgentBuilder { AgentBuilder::new() } +// is_test returns false so long as it has only ever been called with false. +// If it has ever been called with true, it will always return true after that. +// This is a public but hidden function used to allow doctests to use the test_agent. +// Note that we use this approach for doctests rather the #[cfg(test)], because +// doctests are run against a copy of the crate build without cfg(test) set. +// We also can't use #[cfg(doctest)] to do this, because cfg(doctest) is only set +// when collecting doctests, not when building the crate. +#[doc(hidden)] +pub fn is_test(is: bool) -> bool { + static IS_TEST: Lazy = Lazy::new(|| AtomicBool::new(false)); + if is { + IS_TEST.store(true, Ordering::SeqCst); + } + let x = IS_TEST.load(Ordering::SeqCst); + return x; +} + /// Agents are used to keep state between requests. pub fn agent() -> Agent { #[cfg(not(test))] - return AgentBuilder::new().build(); + if is_test(false) { + return testserver::test_agent(); + } else { + return AgentBuilder::new().build(); + } #[cfg(test)] - return test::test_agent(); + return testserver::test_agent(); } /// Make a request setting the HTTP method via a string. diff --git a/src/request.rs b/src/request.rs index 00393ca..7315898 100644 --- a/src/request.rs +++ b/src/request.rs @@ -103,14 +103,12 @@ impl Request { /// The `Content-Length` header is implicitly set to the length of the serialized value. /// /// ``` - /// #[macro_use] - /// extern crate ureq; - /// - /// fn main() { - /// let r = ureq::post("/my_page") - /// .send_json(json!({ "name": "martin", "rust": true })); - /// println!("{:?}", r); - /// } + /// # fn main() -> Result<(), ureq::Error> { + /// # ureq::is_test(true); + /// let r = ureq::post("http://example.com/form") + /// .send_json(ureq::json!({ "name": "martin", "rust": true }))?; + /// # Ok(()) + /// # } /// ``` #[cfg(feature = "json")] pub fn send_json(mut self, data: SerdeValue) -> Result { diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index b916bac..3746798 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] -use crate::test::testserver::{read_headers, TestServer}; +use crate::testserver::{read_headers, TestServer}; use std::io::{self, Read, Write}; use std::net::TcpStream; use std::time::Duration; diff --git a/src/test/mod.rs b/src/test/mod.rs index 72506a9..a26c761 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,10 +1,10 @@ use crate::unit::Unit; -use crate::{error::Error, Agent}; -use crate::{stream::Stream, AgentBuilder}; +use crate::{error::Error}; +use crate::{stream::Stream}; use once_cell::sync::Lazy; use std::io::{Cursor, Write}; use std::sync::{Arc, Mutex}; -use std::{collections::HashMap, net::ToSocketAddrs}; +use std::{collections::HashMap}; mod agent_test; mod body_read; @@ -13,44 +13,8 @@ mod query_string; mod range; mod redirect; mod simple; -pub(crate) mod testserver; mod timeout; -// An agent to be installed by default for tests and doctests, such -// that all hostnames resolve to a TestServer on localhost. -pub(crate) fn test_agent() -> Agent { - use std::io; - use std::net::{SocketAddr, TcpStream}; - let testserver = testserver::TestServer::new(|mut stream: TcpStream| -> io::Result<()> { - testserver::read_headers(&stream); - stream.write_all(b"HTTP/1.1 200 OK\r\n")?; - stream.write_all(b"Transfer-Encoding: chunked\r\n")?; - stream.write_all(b"Content-Type: text/html; charset=ISO-8859-1\r\n")?; - stream.write_all(b"\r\n")?; - stream.write_all(b"7\r\n")?; - stream.write_all(b"success\r\n")?; - stream.write_all(b"0\r\n")?; - stream.write_all(b"\r\n")?; - Ok(()) - }); - // Slightly tricky thing here: we want to make sure the TestServer lives - // as long as the agent. This is accomplished by `move`ing it into the - // closure, which becomes owned by the agent. - AgentBuilder::new() - .resolver(move |h: &str| -> io::Result> { - // Don't override resolution for HTTPS requests yet, since we - // don't have a setup for an HTTPS testserver. Also, skip localhost - // resolutions since those may come from a unittest that set up - // its own, specific testserver. - if h.ends_with(":443") || h.starts_with("localhost:") { - return Ok(h.to_socket_addrs()?.collect::>()); - } - let addr: SocketAddr = format!("127.0.0.1:{}", testserver.port).parse().unwrap(); - Ok(vec![addr]) - }) - .build() -} - type RequestHandler = dyn Fn(&Unit) -> Result + Send + 'static; pub(crate) static TEST_HANDLERS: Lazy>>>> = diff --git a/src/test/redirect.rs b/src/test/redirect.rs index 0c4e4f5..95ae558 100644 --- a/src/test/redirect.rs +++ b/src/test/redirect.rs @@ -2,7 +2,7 @@ use std::{ io::{self, Write}, net::TcpStream, }; -use test::testserver::{self, TestServer}; +use testserver::{self, TestServer}; use crate::test; diff --git a/src/test/testserver.rs b/src/test/testserver.rs deleted file mode 100644 index 720bec6..0000000 --- a/src/test/testserver.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::io::{self, BufRead, BufReader}; -use std::net::{TcpListener, TcpStream}; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::thread; - -pub struct TestServer { - pub port: u16, - pub done: Arc, -} - -pub struct TestHeaders(Vec); - -impl TestHeaders { - // Return the path for a request, e.g. /foo from "GET /foo HTTP/1.1" - #[cfg(feature = "cookies")] - pub fn path(&self) -> &str { - if self.0.len() == 0 { - "" - } else { - &self.0[0].split(" ").nth(1).unwrap() - } - } - - #[cfg(feature = "cookies")] - pub fn headers(&self) -> &[String] { - &self.0[1..] - } -} - -// Read a stream until reaching a blank line, in order to consume -// request headers. -pub fn read_headers(stream: &TcpStream) -> TestHeaders { - let mut results = vec![]; - for line in BufReader::new(stream).lines() { - match line { - Err(e) => { - eprintln!("testserver: in read_headers: {}", e); - break; - } - Ok(line) if line == "" => break, - Ok(line) => results.push(line), - }; - } - TestHeaders(results) -} - -impl TestServer { - pub fn new(handler: fn(TcpStream) -> io::Result<()>) -> Self { - let listener = TcpListener::bind("localhost:0").unwrap(); - let port = listener.local_addr().unwrap().port(); - let done = Arc::new(AtomicBool::new(false)); - let done_clone = done.clone(); - thread::spawn(move || { - for stream in listener.incoming() { - if let Err(e) = stream { - eprintln!("testserver: handling just-accepted stream: {}", e); - break; - } - if done.load(Ordering::SeqCst) { - break; - } else { - thread::spawn(move || handler(stream.unwrap())); - } - } - }); - TestServer { - port, - done: done_clone, - } - } -} - -impl Drop for TestServer { - fn drop(&mut self) { - self.done.store(true, Ordering::SeqCst); - // Connect once to unblock the listen loop. - TcpStream::connect(format!("localhost:{}", self.port)).unwrap(); - } -} diff --git a/src/test/timeout.rs b/src/test/timeout.rs index ef3d218..f377418 100644 --- a/src/test/timeout.rs +++ b/src/test/timeout.rs @@ -1,4 +1,4 @@ -use crate::test::testserver::*; +use crate::testserver::*; use std::io::{self, Write}; use std::net::TcpStream; use std::thread; diff --git a/src/testserver.rs b/src/testserver.rs new file mode 100644 index 0000000..326fc54 --- /dev/null +++ b/src/testserver.rs @@ -0,0 +1,137 @@ +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; +use std::time::Duration; +use std::{ + io::{self, BufRead, BufReader, Write}, + net::ToSocketAddrs, +}; + +use crate::{Agent, AgentBuilder}; + +// An agent to be installed by default for tests and doctests, such +// that all hostnames resolve to a TestServer on localhost. +pub(crate) fn test_agent() -> Agent { + let testserver = TestServer::new(|mut stream: TcpStream| -> io::Result<()> { + read_headers(&stream); + stream.write_all(b"HTTP/1.1 200 OK\r\n")?; + stream.write_all(b"Transfer-Encoding: chunked\r\n")?; + stream.write_all(b"Content-Type: text/html; charset=ISO-8859-1\r\n")?; + stream.write_all(b"\r\n")?; + stream.write_all(b"7\r\n")?; + stream.write_all(b"success\r\n")?; + stream.write_all(b"0\r\n")?; + stream.write_all(b"\r\n")?; + Ok(()) + }); + // Slightly tricky thing here: we want to make sure the TestServer lives + // as long as the agent. This is accomplished by `move`ing it into the + // closure, which becomes owned by the agent. + AgentBuilder::new() + .resolver(move |h: &str| -> io::Result> { + // Don't override resolution for HTTPS requests yet, since we + // don't have a setup for an HTTPS testserver. Also, skip localhost + // resolutions since those may come from a unittest that set up + // its own, specific testserver. + if h.ends_with(":443") || h.starts_with("localhost:") { + return Ok(h.to_socket_addrs()?.collect::>()); + } + let addr: SocketAddr = format!("127.0.0.1:{}", testserver.port).parse().unwrap(); + Ok(vec![addr]) + }) + .build() +} + +pub struct TestServer { + pub port: u16, + pub done: Arc, +} + +pub struct TestHeaders(Vec); + +#[allow(dead_code)] +impl TestHeaders { + // Return the path for a request, e.g. /foo from "GET /foo HTTP/1.1" + #[cfg(feature = "cookies")] + pub fn path(&self) -> &str { + if self.0.len() == 0 { + "" + } else { + &self.0[0].split(" ").nth(1).unwrap() + } + } + + #[cfg(feature = "cookies")] + pub fn headers(&self) -> &[String] { + &self.0[1..] + } +} + +// Read a stream until reaching a blank line, in order to consume +// request headers. +pub fn read_headers(stream: &TcpStream) -> TestHeaders { + let mut results = vec![]; + for line in BufReader::new(stream).lines() { + match line { + Err(e) => { + eprintln!("testserver: in read_headers: {}", e); + break; + } + Ok(line) if line == "" => break, + Ok(line) => results.push(line), + }; + } + TestHeaders(results) +} + +impl TestServer { + pub fn new(handler: fn(TcpStream) -> io::Result<()>) -> Self { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + let done = Arc::new(AtomicBool::new(false)); + let done_clone = done.clone(); + thread::spawn(move || { + for stream in listener.incoming() { + if let Err(e) = stream { + eprintln!("testserver: handling just-accepted stream: {}", e); + break; + } + if done.load(Ordering::SeqCst) { + break; + } else { + thread::spawn(move || handler(stream.unwrap())); + } + } + }); + // before returning from new(), ensure the server is ready to accept connections + loop { + if let Err(e) = TcpStream::connect(format!("127.0.0.1:{}", port)) { + match e.kind() { + io::ErrorKind::ConnectionRefused => { + std::thread::sleep(Duration::from_millis(100)); + continue; + } + _ => eprintln!("testserver: pre-connect with error {}", e), + } + } else { + break; + } + } + TestServer { + port, + done: done_clone, + } + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.done.store(true, Ordering::SeqCst); + // Connect once to unblock the listen loop. + match TcpStream::connect(format!("localhost:{}", self.port)) { + Err(e) => eprintln!("error dropping testserver: {}", e), + _ => {} + } + } +}