Factor out TestServer. (#98)

This creates a struct that encapsulates the test server setup, and adds
a shutdown method to clean up the thread when done.
This commit is contained in:
Jacob Hoffman-Andrews
2020-06-30 19:16:34 -07:00
committed by GitHub
parent dafdf6a718
commit b47f90e773
4 changed files with 76 additions and 99 deletions

View File

@@ -1,6 +1,7 @@
use crate::test; use crate::test;
use std::io::{BufRead, BufReader, Read, Write}; use crate::test::testserver::{read_headers, TestServer};
use std::thread; use std::io::{self, Write};
use std::net::TcpStream;
use std::time::Duration; use std::time::Duration;
use super::super::*; use super::super::*;
@@ -56,47 +57,25 @@ fn agent_cookies() {
agent.get("test://host/agent_cookies").call(); agent.get("test://host/agent_cookies").call();
} }
// Start a test server on an available port, that times out idle connections at 2 seconds. // Handler that answers with a simple HTTP response, and times
// Return the port this server is listening on. // out idle connections after 2 seconds.
fn start_idle_timeout_server() -> u16 { fn idle_timeout_handler(mut stream: TcpStream) -> io::Result<()> {
let listener = std::net::TcpListener::bind("localhost:0").unwrap(); read_headers(&stream);
let port = listener.local_addr().unwrap().port(); stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 8\r\n\r\nresponse")?;
thread::spawn(move || { stream.set_read_timeout(Some(Duration::from_secs(2)))?;
for stream in listener.incoming() { Ok(())
thread::spawn(move || {
let stream = stream.unwrap();
stream
.set_read_timeout(Some(Duration::from_secs(2)))
.unwrap();
let mut write_stream = stream.try_clone().unwrap();
for line in BufReader::new(stream).lines() {
let line = match line {
Ok(x) => x,
Err(_) => return,
};
if line == "" {
write_stream
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 8\r\n\r\nresponse")
.unwrap();
}
}
});
}
});
port
} }
#[test] #[test]
fn connection_reuse() { fn connection_reuse() {
let port = start_idle_timeout_server(); let testserver = TestServer::new(idle_timeout_handler);
let url = format!("http://localhost:{}", port); let url = format!("http://localhost:{}", testserver.port);
let agent = Agent::default().build(); let agent = Agent::default().build();
let resp = agent.get(&url).call(); let resp = agent.get(&url).call();
// use up the connection so it gets returned to the pool // use up the connection so it gets returned to the pool
assert_eq!(resp.status(), 200); assert_eq!(resp.status(), 200);
let mut buf = vec![]; resp.into_string().unwrap();
resp.into_reader().read_to_end(&mut buf).unwrap();
{ {
let mut guard_state = agent.state.lock().unwrap(); let mut guard_state = agent.state.lock().unwrap();

View File

@@ -14,6 +14,7 @@ mod query_string;
mod range; mod range;
mod redirect; mod redirect;
mod simple; mod simple;
mod testserver;
mod timeout; mod timeout;
type RequestHandler = dyn Fn(&Unit) -> Result<Stream, Error> + Send + 'static; type RequestHandler = dyn Fn(&Unit) -> Result<Stream, Error> + Send + 'static;

50
src/test/testserver.rs Normal file
View File

@@ -0,0 +1,50 @@
use std::io::{self, BufRead, BufReader};
use std::net::{TcpListener, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
pub struct TestServer {
pub port: u16,
pub done: Arc<AtomicBool>,
}
// Read a stream until reaching a blank line, in order to consume
// request headers.
pub fn read_headers(stream: &TcpStream) {
for line in BufReader::new(stream).lines() {
if line.unwrap() == "" {
break;
}
}
}
impl TestServer {
pub fn new(handler: fn(TcpStream) -> io::Result<()>) -> Self {
let listener = TcpListener::bind("localhost:0").unwrap();
let port = listener.local_addr().unwrap().port();
let done = Arc::new(AtomicBool::new(false));
let done_clone = done.clone();
thread::spawn(move || {
for stream in listener.incoming() {
thread::spawn(move || handler(stream.unwrap()));
if done.load(Ordering::Relaxed) {
break;
}
}
println!("testserver on {} exiting", port);
});
TestServer {
port,
done: done_clone,
}
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.done.store(true, Ordering::Relaxed);
// Connect once to unblock the listen loop.
TcpStream::connect(format!("localhost:{}", self.port)).unwrap();
}
}

View File

@@ -1,4 +1,5 @@
use std::io::{self, BufRead, BufReader, Read, Write}; use crate::test::testserver::*;
use std::io::{self, Write};
use std::net::TcpStream; use std::net::TcpStream;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
@@ -7,7 +8,8 @@ use super::super::*;
// Send an HTTP response on the TcpStream at a rate of two bytes every 10 // Send an HTTP response on the TcpStream at a rate of two bytes every 10
// milliseconds, for a total of 600 bytes. // milliseconds, for a total of 600 bytes.
fn dribble_body_respond(stream: &mut TcpStream) -> io::Result<()> { fn dribble_body_respond(mut stream: TcpStream) -> io::Result<()> {
read_headers(&stream);
let contents = [b'a'; 300]; let contents = [b'a'; 300];
let headers = format!( let headers = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n",
@@ -23,49 +25,12 @@ fn dribble_body_respond(stream: &mut TcpStream) -> io::Result<()> {
Ok(()) Ok(())
} }
// Read a stream until reaching a blank line, in order to consume
// request headers.
fn read_headers(stream: &TcpStream) {
for line in BufReader::new(stream).lines() {
let line = match line {
Ok(x) => x,
Err(_) => return,
};
if line == "" {
break;
}
}
}
// Start a test server on an available port, that dribbles out a response at 1 write per 10ms.
// Return the port this server is listening on.
fn start_dribble_body_server() -> u16 {
let listener = std::net::TcpListener::bind("localhost:0").unwrap();
let port = listener.local_addr().unwrap().port();
let dribble_handler = |mut stream: TcpStream| {
read_headers(&stream);
if let Err(e) = dribble_body_respond(&mut stream) {
eprintln!("sending dribble repsonse: {}", e);
}
};
thread::spawn(move || {
for stream in listener.incoming() {
thread::spawn(move || dribble_handler(stream.unwrap()));
}
});
port
}
fn get_and_expect_timeout(url: String) { fn get_and_expect_timeout(url: String) {
let agent = Agent::default().build(); let agent = Agent::default().build();
let timeout = Duration::from_millis(500); let timeout = Duration::from_millis(500);
let resp = agent.get(&url).timeout(timeout).call(); let resp = agent.get(&url).timeout(timeout).call();
let mut reader = resp.into_reader(); match resp.into_string() {
let mut bytes = vec![];
let result = reader.read_to_end(&mut bytes);
match result {
Err(io_error) => match io_error.kind() { Err(io_error) => match io_error.kind() {
io::ErrorKind::WouldBlock => Ok(()), io::ErrorKind::WouldBlock => Ok(()),
io::ErrorKind::TimedOut => Ok(()), io::ErrorKind::TimedOut => Ok(()),
@@ -78,15 +43,15 @@ fn get_and_expect_timeout(url: String) {
#[test] #[test]
fn overall_timeout_during_body() { fn overall_timeout_during_body() {
let port = start_dribble_body_server(); // Start a test server on an available port, that dribbles out a response at 1 write per 10ms.
let url = format!("http://localhost:{}/", port); let server = TestServer::new(dribble_body_respond);
let url = format!("http://localhost:{}/", server.port);
get_and_expect_timeout(url); get_and_expect_timeout(url);
} }
// Send HTTP headers on the TcpStream at a rate of one header every 100 // Send HTTP headers on the TcpStream at a rate of one header every 100
// milliseconds, for a total of 30 headers. // milliseconds, for a total of 30 headers.
fn dribble_headers_respond(stream: &mut TcpStream) -> io::Result<()> { fn dribble_headers_respond(mut stream: TcpStream) -> io::Result<()> {
stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n")?; stream.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n")?;
for _ in 0..30 { for _ in 0..30 {
stream.write_all(b"a: b\n")?; stream.write_all(b"a: b\n")?;
@@ -96,28 +61,10 @@ fn dribble_headers_respond(stream: &mut TcpStream) -> io::Result<()> {
Ok(()) Ok(())
} }
// Start a test server on an available port, that dribbles out response *headers* at 1 write per 10ms.
// Return the port this server is listening on.
fn start_dribble_headers_server() -> u16 {
let listener = std::net::TcpListener::bind("localhost:0").unwrap();
let port = listener.local_addr().unwrap().port();
let dribble_handler = |mut stream: TcpStream| {
read_headers(&stream);
if let Err(e) = dribble_headers_respond(&mut stream) {
eprintln!("sending dribble repsonse: {}", e);
}
};
thread::spawn(move || {
for stream in listener.incoming() {
thread::spawn(move || dribble_handler(stream.unwrap()));
}
});
port
}
#[test] #[test]
fn overall_timeout_during_headers() { fn overall_timeout_during_headers() {
let port = start_dribble_headers_server(); // Start a test server on an available port, that dribbles out a response at 1 write per 10ms.
let url = format!("http://localhost:{}/", port); let server = TestServer::new(dribble_headers_respond);
let url = format!("http://localhost:{}/", server.port);
get_and_expect_timeout(url); get_and_expect_timeout(url);
} }