connection pooling

This commit is contained in:
Martin Algesten
2018-06-30 16:52:54 +02:00
parent c5fb12a1fe
commit 4a5944443f
8 changed files with 205 additions and 81 deletions

View File

@@ -22,6 +22,18 @@ pub enum Payload {
Reader(Box<Read + 'static>), Reader(Box<Read + 'static>),
} }
impl ::std::fmt::Debug for Payload {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::result::Result<(), ::std::fmt::Error> {
write!(f, "{}", match self {
Payload::Empty => "Empty",
Payload::Text(t, _) => &t,
#[cfg(feature = "json")]
Payload::JSON(_) => "JSON",
Payload::Reader(_) => "Reader",
})
}
}
impl Default for Payload { impl Default for Payload {
fn default() -> Payload { fn default() -> Payload {
Payload::Empty Payload::Empty

View File

@@ -48,8 +48,7 @@ impl Error {
/// For synthetic responses, this is the status text. /// For synthetic responses, this is the status text.
pub fn status_text(&self) -> &str { pub fn status_text(&self) -> &str {
match self { match self {
Error::BadUrl(e) => { Error::BadUrl(_) => {
println!("{}", e);
"Bad URL" "Bad URL"
} }
Error::UnknownScheme(_) => "Unknown Scheme", Error::UnknownScheme(_) => "Unknown Scheme",

View File

@@ -1,10 +1,59 @@
// use agent::Unit;
use std::collections::HashMap;
use std::io::{Read, Result as IoResult};
use stream::Stream;
use url::Url;
#[derive(Debug, Default, Clone)] #[derive(Default, Debug)]
pub struct ConnectionPool {} pub struct ConnectionPool {
recycle: HashMap<Url, Stream>,
}
impl ConnectionPool { impl ConnectionPool {
pub fn new() -> Self { pub fn new() -> Self {
ConnectionPool {} ConnectionPool {
..Default::default()
}
}
pub fn try_get_connection(&mut self, url: &Url) -> Option<Stream> {
self.recycle.remove(url)
}
}
pub struct PoolReturnRead<R: Read + Sized> {
unit: Option<Unit>,
reader: Option<R>,
}
impl<R: Read + Sized> PoolReturnRead<R> {
pub fn new(unit: Option<Unit>, reader: R) -> Self {
PoolReturnRead {
unit,
reader: Some(reader),
}
}
fn return_connection(&mut self) {
if let Some(_unit) = self.unit.take() {}
}
fn do_read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match self.reader.as_mut() {
None => return Ok(0),
Some(reader) => reader.read(buf),
}
}
}
impl<R: Read + Sized> Read for PoolReturnRead<R> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
let amount = self.do_read(buf)?;
// only if the underlying reader is exhausted can we send a new
// request to the same socket. hence, we only return it now.
if amount == 0 {
self.return_connection();
}
Ok(amount)
} }
} }

View File

@@ -6,8 +6,8 @@ use std::sync::Arc;
use super::SerdeValue; use super::SerdeValue;
lazy_static! { lazy_static! {
static ref URL_BASE: Url = { Url::parse("http://localhost/") static ref URL_BASE: Url =
.expect("Failed to parse URL_BASE") }; { Url::parse("http://localhost/").expect("Failed to parse URL_BASE") };
} }
/// Request instances are builders that creates a request. /// Request instances are builders that creates a request.
@@ -43,13 +43,14 @@ impl ::std::fmt::Debug for Request {
write!( write!(
f, f,
"Request({} {}{}, {:?})", "Request({} {}{}, {:?})",
self.method, url.path(), query, self.method,
url.path(),
query,
self.headers self.headers
) )
} }
} }
impl Request { impl Request {
fn new(agent: &Agent, method: String, path: String) -> Request { fn new(agent: &Agent, method: String, path: String) -> Request {
Request { Request {
@@ -95,7 +96,7 @@ impl Request {
.and_then(|url| { .and_then(|url| {
let reader = payload.into_read(); let reader = payload.into_read();
let unit = Unit::new(&self, &url, &reader); let unit = Unit::new(&self, &url, &reader);
connect(unit, url, &self.method, self.redirects, reader) connect(unit, &self.method, true, self.redirects, reader)
}) })
.unwrap_or_else(|e| e.into()) .unwrap_or_else(|e| e.into())
} }

View File

@@ -1,12 +1,9 @@
use agent::Unit; use agent::Unit;
use ascii::AsciiString; use ascii::AsciiString;
use chunked_transfer; use chunked_transfer::Decoder as ChunkDecoder;
use header::Header; use header::Header;
use std::io::Cursor; use pool::PoolReturnRead;
use std::io::Error as IoError; use std::io::{Cursor, Error as IoError, ErrorKind, Read, Result as IoResult};
use std::io::ErrorKind;
use std::io::Read;
use std::io::Result as IoResult;
use std::str::FromStr; use std::str::FromStr;
use stream::Stream; use stream::Stream;
@@ -251,28 +248,34 @@ impl Response {
/// assert_eq!(bytes.len(), len); /// assert_eq!(bytes.len(), len);
/// ``` /// ```
pub fn into_reader(self) -> impl Read { pub fn into_reader(self) -> impl Read {
//
let is_chunked = self.header("transfer-encoding") let is_chunked = self.header("transfer-encoding")
.map(|enc| enc.len() > 0) // whatever it says, do chunked .map(|enc| enc.len() > 0) // whatever it says, do chunked
.unwrap_or(false); .unwrap_or(false);
let len = self.header("content-length") let is_head = (&self.unit).as_ref().map(|u| u.is_head).unwrap_or(false);
.and_then(|l| l.parse::<usize>().ok());
let len = if is_head {
// head requests never have a body
Some(0)
} else {
self.header("content-length")
.and_then(|l| l.parse::<usize>().ok())
};
let reader = self.stream.expect("No reader in response?!"); let reader = self.stream.expect("No reader in response?!");
let unit = self.unit;
// head requests never have a body
let is_head = self.unit.map(|u| u.is_head).unwrap_or(false);
if is_head {
return Box::new(LimitedRead::new(reader, 0)) as Box<Read>;
}
// figure out how to make a reader // figure out how to make a reader
match is_chunked { match (is_chunked && !is_head, len) {
true => Box::new(chunked_transfer::Decoder::new(reader)), (true, _) => {
false => match len { Box::new(PoolReturnRead::new(unit, ChunkDecoder::new(reader))) as Box<Read>
Some(len) => Box::new(LimitedRead::new(reader, len)), }
None => Box::new(reader) as Box<Read>, (false, Some(len)) => {
}, Box::new(PoolReturnRead::new(unit, LimitedRead::new(reader, len)))
}
(false, None) => Box::new(PoolReturnRead::new(unit, reader)) as Box<Read>,
} }
} }

View File

@@ -18,6 +18,22 @@ pub enum Stream {
Test(Box<Read + Send>, Vec<u8>), Test(Box<Read + Send>, Vec<u8>),
} }
impl ::std::fmt::Debug for Stream {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::result::Result<(), ::std::fmt::Error> {
write!(
f,
"Stream[{}]",
match self {
Stream::Http(_) => "http",
Stream::Https(_) => "https",
Stream::Cursor(_) => "cursor",
#[cfg(test)]
Stream::Test(_, _) => "test",
}
)
}
}
impl Stream { impl Stream {
#[cfg(test)] #[cfg(test)]
pub fn to_write_vec(&self) -> Vec<u8> { pub fn to_write_vec(&self) -> Vec<u8> {

View File

@@ -48,6 +48,5 @@ fn query_in_path_and_req() {
.call(); .call();
let vec = resp.to_write_vec(); let vec = resp.to_write_vec();
let s = String::from_utf8_lossy(&vec); let s = String::from_utf8_lossy(&vec);
println!("{}", s);
assert!(s.contains("GET /query_in_path_and_req?foo=bar&baz=1%202%203 HTTP/1.1")) assert!(s.contains("GET /query_in_path_and_req?foo=bar&baz=1%202%203 HTTP/1.1"))
} }

View File

@@ -1,9 +1,11 @@
use body::{send_body, Payload, SizedReader}; use body::{send_body, Payload, SizedReader};
use std::io::Write; use std::io::{Result as IoResult, Write};
use stream::{connect_http, connect_https, connect_test}; use stream::{connect_http, connect_https, connect_test, Stream};
use url::Url; use url::Url;
// //
/// It's a "unit of work". Maybe a bad name for it?
#[derive(Debug)]
pub struct Unit { pub struct Unit {
pub agent: Arc<Mutex<Option<AgentState>>>, pub agent: Arc<Mutex<Option<AgentState>>>,
pub url: Url, pub url: Url,
@@ -97,64 +99,35 @@ impl Unit {
} }
pub fn connect( pub fn connect(
unit: Unit, mut unit: Unit,
url: Url,
method: &str, method: &str,
use_pooled: bool,
redirects: u32, redirects: u32,
body: SizedReader, body: SizedReader,
) -> Result<Response, Error> { ) -> Result<Response, Error> {
// //
// open socket // open socket
let mut stream = match url.scheme() { let (mut stream, is_recycled) = connect_socket(&unit, use_pooled)?;
"http" => connect_http(&unit),
"https" => connect_https(&unit),
"test" => connect_test(&unit),
_ => Err(Error::UnknownScheme(url.scheme().to_string())),
}?;
// send the request start + headers let send_result = send_prelude(&unit, method, &mut stream);
let mut prelude: Vec<u8> = vec![];
write!(
prelude,
"{} {}{} HTTP/1.1\r\n",
method,
url.path(),
&unit.query_string
)?;
if !has_header(&unit.headers, "host") {
write!(prelude, "Host: {}\r\n", url.host().unwrap())?;
}
for header in &unit.headers {
write!(prelude, "{}: {}\r\n", header.name(), header.value())?;
}
write!(prelude, "\r\n")?;
stream.write_all(&mut prelude[..])?; if send_result.is_err() {
if is_recycled {
// we try open a new connection, this time there will be
// no connection in the pool. don't use it.
return connect(unit, method, false, redirects, body);
} else {
// not a pooled connection, propagate the error.
return Err(send_result.unwrap_err().into());
}
}
// start reading the response to process cookies and redirects. // start reading the response to process cookies and redirects.
let mut resp = Response::from_read(&mut stream); let mut resp = Response::from_read(&mut stream);
// squirrel away cookies // squirrel away cookies
{ save_cookies(&unit, &resp);
let state = &mut unit.agent.lock().unwrap();
if let Some(add_jar) = state.as_mut().map(|state| &mut state.jar) {
for raw_cookie in resp.all("set-cookie").iter() {
let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
raw_cookie.to_string()
} else {
format!("{}; Domain={}", raw_cookie, &unit.hostname)
};
match Cookie::parse_encoded(&to_parse[..]) {
Err(_) => (), // ignore unparseable cookies
Ok(mut cookie) => {
let cookie = cookie.into_owned();
add_jar.add(cookie)
}
}
}
}
}
// handle redirects // handle redirects
if resp.redirect() { if resp.redirect() {
@@ -166,18 +139,22 @@ pub fn connect(
let location = resp.header("location"); let location = resp.header("location");
if let Some(location) = location { if let Some(location) = location {
// join location header to current url in case it it relative // join location header to current url in case it it relative
let new_url = url let new_url = unit
.url
.join(location) .join(location)
.map_err(|_| Error::BadUrl(format!("Bad redirection: {}", location)))?; .map_err(|_| Error::BadUrl(format!("Bad redirection: {}", location)))?;
// change this for every redirect since it is used when connection pooling.
unit.url = new_url;
// perform the redirect differently depending on 3xx code. // perform the redirect differently depending on 3xx code.
return match resp.status() { return match resp.status() {
301 | 302 | 303 => { 301 | 302 | 303 => {
send_body(body, unit.is_chunked, &mut stream)?; send_body(body, unit.is_chunked, &mut stream)?;
let empty = Payload::Empty.into_read(); let empty = Payload::Empty.into_read();
connect(unit, new_url, "GET", redirects - 1, empty) connect(unit, "GET", use_pooled, redirects - 1, empty)
} }
307 | 308 | _ => connect(unit, new_url, method, redirects - 1, body), 307 | 308 | _ => connect(unit, method, use_pooled, redirects - 1, body),
}; };
} }
} }
@@ -233,3 +210,71 @@ fn combine_query(url: &Url, query: &QString) -> String {
(None, false) => "".to_string(), (None, false) => "".to_string(),
} }
} }
fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error> {
if use_pooled {
let state = &mut unit.agent.lock().unwrap();
if let Some(agent) = state.as_mut() {
if let Some(stream) = agent.pool.try_get_connection(&unit.url) {
return Ok((stream, true));
}
}
}
let stream = match unit.url.scheme() {
"http" => connect_http(&unit),
"https" => connect_https(&unit),
"test" => connect_test(&unit),
_ => Err(Error::UnknownScheme(unit.url.scheme().to_string())),
};
Ok((stream?, false))
}
fn send_prelude(unit: &Unit, method: &str, stream: &mut Stream) -> IoResult<()> {
// send the request start + headers
let mut prelude: Vec<u8> = vec![];
write!(
prelude,
"{} {}{} HTTP/1.1\r\n",
method,
unit.url.path(),
&unit.query_string
)?;
if !has_header(&unit.headers, "host") {
write!(prelude, "Host: {}\r\n", unit.url.host().unwrap())?;
}
for header in &unit.headers {
write!(prelude, "{}: {}\r\n", header.name(), header.value())?;
}
write!(prelude, "\r\n")?;
stream.write_all(&mut prelude[..])?;
Ok(())
}
fn save_cookies(unit: &Unit, resp: &Response) {
//
let cookies = resp.all("set-cookie");
if cookies.is_empty() {
return;
}
let state = &mut unit.agent.lock().unwrap();
if let Some(add_jar) = state.as_mut().map(|state| &mut state.jar) {
for raw_cookie in cookies.iter() {
let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
raw_cookie.to_string()
} else {
format!("{}; Domain={}", raw_cookie, &unit.hostname)
};
match Cookie::parse_encoded(&to_parse[..]) {
Err(_) => (), // ignore unparseable cookies
Ok(mut cookie) => {
let cookie = cookie.into_owned();
add_jar.add(cookie)
}
}
}
}
}