From b47f90e773fcd17cc562818c01d31af9f8e9e45d Mon Sep 17 00:00:00 2001 From: Jacob Hoffman-Andrews Date: Tue, 30 Jun 2020 19:16:34 -0700 Subject: [PATCH] Factor out TestServer. (#98) This creates a struct that encapsulates the test server setup, and adds a shutdown method to clean up the thread when done. --- src/test/agent_test.rs | 47 +++++++------------------- src/test/mod.rs | 1 + src/test/testserver.rs | 50 +++++++++++++++++++++++++++ src/test/timeout.rs | 77 +++++++----------------------------------- 4 files changed, 76 insertions(+), 99 deletions(-) create mode 100644 src/test/testserver.rs diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index 55934b9..66356cd 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -1,6 +1,7 @@ use crate::test; -use std::io::{BufRead, BufReader, Read, Write}; -use std::thread; +use crate::test::testserver::{read_headers, TestServer}; +use std::io::{self, Write}; +use std::net::TcpStream; use std::time::Duration; use super::super::*; @@ -56,47 +57,25 @@ fn agent_cookies() { agent.get("test://host/agent_cookies").call(); } -// Start a test server on an available port, that times out idle connections at 2 seconds. -// Return the port this server is listening on. -fn start_idle_timeout_server() -> u16 { - let listener = std::net::TcpListener::bind("localhost:0").unwrap(); - let port = listener.local_addr().unwrap().port(); - thread::spawn(move || { - for stream in listener.incoming() { - thread::spawn(move || { - let stream = stream.unwrap(); - stream - .set_read_timeout(Some(Duration::from_secs(2))) - .unwrap(); - let mut write_stream = stream.try_clone().unwrap(); - for line in BufReader::new(stream).lines() { - let line = match line { - Ok(x) => x, - Err(_) => return, - }; - if line == "" { - write_stream - .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 8\r\n\r\nresponse") - .unwrap(); - } - } - }); - } - }); - port +// Handler that answers with a simple HTTP response, and times +// out idle connections after 2 seconds. +fn idle_timeout_handler(mut stream: TcpStream) -> io::Result<()> { + read_headers(&stream); + stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 8\r\n\r\nresponse")?; + stream.set_read_timeout(Some(Duration::from_secs(2)))?; + Ok(()) } #[test] fn connection_reuse() { - let port = start_idle_timeout_server(); - let url = format!("http://localhost:{}", port); + let testserver = TestServer::new(idle_timeout_handler); + let url = format!("http://localhost:{}", testserver.port); let agent = Agent::default().build(); let resp = agent.get(&url).call(); // use up the connection so it gets returned to the pool assert_eq!(resp.status(), 200); - let mut buf = vec![]; - resp.into_reader().read_to_end(&mut buf).unwrap(); + resp.into_string().unwrap(); { let mut guard_state = agent.state.lock().unwrap(); diff --git a/src/test/mod.rs b/src/test/mod.rs index 2baa08f..801b7c9 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -14,6 +14,7 @@ mod query_string; mod range; mod redirect; mod simple; +mod testserver; mod timeout; type RequestHandler = dyn Fn(&Unit) -> Result + Send + 'static; diff --git a/src/test/testserver.rs b/src/test/testserver.rs new file mode 100644 index 0000000..7c563a1 --- /dev/null +++ b/src/test/testserver.rs @@ -0,0 +1,50 @@ +use std::io::{self, BufRead, BufReader}; +use std::net::{TcpListener, TcpStream}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; + +pub struct TestServer { + pub port: u16, + pub done: Arc, +} + +// Read a stream until reaching a blank line, in order to consume +// request headers. +pub fn read_headers(stream: &TcpStream) { + for line in BufReader::new(stream).lines() { + if line.unwrap() == "" { + break; + } + } +} + +impl TestServer { + pub fn new(handler: fn(TcpStream) -> io::Result<()>) -> Self { + let listener = TcpListener::bind("localhost:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + let done = Arc::new(AtomicBool::new(false)); + let done_clone = done.clone(); + thread::spawn(move || { + for stream in listener.incoming() { + thread::spawn(move || handler(stream.unwrap())); + if done.load(Ordering::Relaxed) { + break; + } + } + println!("testserver on {} exiting", port); + }); + TestServer { + port, + done: done_clone, + } + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.done.store(true, Ordering::Relaxed); + // Connect once to unblock the listen loop. + TcpStream::connect(format!("localhost:{}", self.port)).unwrap(); + } +} diff --git a/src/test/timeout.rs b/src/test/timeout.rs index 4e14df5..0ccafca 100644 --- a/src/test/timeout.rs +++ b/src/test/timeout.rs @@ -1,4 +1,5 @@ -use std::io::{self, BufRead, BufReader, Read, Write}; +use crate::test::testserver::*; +use std::io::{self, Write}; use std::net::TcpStream; use std::thread; use std::time::Duration; @@ -7,7 +8,8 @@ use super::super::*; // Send an HTTP response on the TcpStream at a rate of two bytes every 10 // milliseconds, for a total of 600 bytes. -fn dribble_body_respond(stream: &mut TcpStream) -> io::Result<()> { +fn dribble_body_respond(mut stream: TcpStream) -> io::Result<()> { + read_headers(&stream); let contents = [b'a'; 300]; let headers = format!( "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", @@ -23,49 +25,12 @@ fn dribble_body_respond(stream: &mut TcpStream) -> io::Result<()> { Ok(()) } -// Read a stream until reaching a blank line, in order to consume -// request headers. -fn read_headers(stream: &TcpStream) { - for line in BufReader::new(stream).lines() { - let line = match line { - Ok(x) => x, - Err(_) => return, - }; - if line == "" { - break; - } - } -} - -// Start a test server on an available port, that dribbles out a response at 1 write per 10ms. -// Return the port this server is listening on. -fn start_dribble_body_server() -> u16 { - let listener = std::net::TcpListener::bind("localhost:0").unwrap(); - let port = listener.local_addr().unwrap().port(); - let dribble_handler = |mut stream: TcpStream| { - read_headers(&stream); - if let Err(e) = dribble_body_respond(&mut stream) { - eprintln!("sending dribble repsonse: {}", e); - } - }; - thread::spawn(move || { - for stream in listener.incoming() { - thread::spawn(move || dribble_handler(stream.unwrap())); - } - }); - port -} - fn get_and_expect_timeout(url: String) { let agent = Agent::default().build(); let timeout = Duration::from_millis(500); let resp = agent.get(&url).timeout(timeout).call(); - let mut reader = resp.into_reader(); - let mut bytes = vec![]; - let result = reader.read_to_end(&mut bytes); - - match result { + match resp.into_string() { Err(io_error) => match io_error.kind() { io::ErrorKind::WouldBlock => Ok(()), io::ErrorKind::TimedOut => Ok(()), @@ -78,15 +43,15 @@ fn get_and_expect_timeout(url: String) { #[test] fn overall_timeout_during_body() { - let port = start_dribble_body_server(); - let url = format!("http://localhost:{}/", port); - + // Start a test server on an available port, that dribbles out a response at 1 write per 10ms. + let server = TestServer::new(dribble_body_respond); + let url = format!("http://localhost:{}/", server.port); get_and_expect_timeout(url); } // Send HTTP headers on the TcpStream at a rate of one header every 100 // milliseconds, for a total of 30 headers. -fn dribble_headers_respond(stream: &mut TcpStream) -> io::Result<()> { +fn dribble_headers_respond(mut stream: TcpStream) -> io::Result<()> { stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n")?; for _ in 0..30 { stream.write_all(b"a: b\n")?; @@ -96,28 +61,10 @@ fn dribble_headers_respond(stream: &mut TcpStream) -> io::Result<()> { Ok(()) } -// Start a test server on an available port, that dribbles out response *headers* at 1 write per 10ms. -// Return the port this server is listening on. -fn start_dribble_headers_server() -> u16 { - let listener = std::net::TcpListener::bind("localhost:0").unwrap(); - let port = listener.local_addr().unwrap().port(); - let dribble_handler = |mut stream: TcpStream| { - read_headers(&stream); - if let Err(e) = dribble_headers_respond(&mut stream) { - eprintln!("sending dribble repsonse: {}", e); - } - }; - thread::spawn(move || { - for stream in listener.incoming() { - thread::spawn(move || dribble_handler(stream.unwrap())); - } - }); - port -} - #[test] fn overall_timeout_during_headers() { - let port = start_dribble_headers_server(); - let url = format!("http://localhost:{}/", port); + // Start a test server on an available port, that dribbles out a response at 1 write per 10ms. + let server = TestServer::new(dribble_headers_respond); + let url = format!("http://localhost:{}/", server.port); get_and_expect_timeout(url); }