From 4a5944443f2de651e2979015a35f6b743f49b8f8 Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Sat, 30 Jun 2018 16:52:54 +0200 Subject: [PATCH] connection pooling --- src/body.rs | 12 ++++ src/error.rs | 3 +- src/pool.rs | 57 ++++++++++++++-- src/request.rs | 11 +-- src/response.rs | 43 ++++++------ src/stream.rs | 16 +++++ src/test/query_string.rs | 1 - src/unit.rs | 143 +++++++++++++++++++++++++-------------- 8 files changed, 205 insertions(+), 81 deletions(-) diff --git a/src/body.rs b/src/body.rs index e575954..5d427a2 100644 --- a/src/body.rs +++ b/src/body.rs @@ -22,6 +22,18 @@ pub enum Payload { Reader(Box), } +impl ::std::fmt::Debug for Payload { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::result::Result<(), ::std::fmt::Error> { + write!(f, "{}", match self { + Payload::Empty => "Empty", + Payload::Text(t, _) => &t, + #[cfg(feature = "json")] + Payload::JSON(_) => "JSON", + Payload::Reader(_) => "Reader", + }) + } +} + impl Default for Payload { fn default() -> Payload { Payload::Empty diff --git a/src/error.rs b/src/error.rs index b510303..38394b5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,8 +48,7 @@ impl Error { /// For synthetic responses, this is the status text. pub fn status_text(&self) -> &str { match self { - Error::BadUrl(e) => { - println!("{}", e); + Error::BadUrl(_) => { "Bad URL" } Error::UnknownScheme(_) => "Unknown Scheme", diff --git a/src/pool.rs b/src/pool.rs index eb3b13f..c6e8984 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,10 +1,59 @@ -// +use agent::Unit; +use std::collections::HashMap; +use std::io::{Read, Result as IoResult}; +use stream::Stream; +use url::Url; -#[derive(Debug, Default, Clone)] -pub struct ConnectionPool {} +#[derive(Default, Debug)] +pub struct ConnectionPool { + recycle: HashMap, +} impl ConnectionPool { pub fn new() -> Self { - ConnectionPool {} + ConnectionPool { + ..Default::default() + } + } + + pub fn try_get_connection(&mut self, url: &Url) -> Option { + self.recycle.remove(url) + } +} + +pub struct PoolReturnRead { + unit: Option, + reader: Option, +} + +impl PoolReturnRead { + pub fn new(unit: Option, reader: R) -> Self { + PoolReturnRead { + unit, + reader: Some(reader), + } + } + + fn return_connection(&mut self) { + if let Some(_unit) = self.unit.take() {} + } + + fn do_read(&mut self, buf: &mut [u8]) -> IoResult { + match self.reader.as_mut() { + None => return Ok(0), + Some(reader) => reader.read(buf), + } + } +} + +impl Read for PoolReturnRead { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + let amount = self.do_read(buf)?; + // only if the underlying reader is exhausted can we send a new + // request to the same socket. hence, we only return it now. + if amount == 0 { + self.return_connection(); + } + Ok(amount) } } diff --git a/src/request.rs b/src/request.rs index b86eab3..b7b75bb 100644 --- a/src/request.rs +++ b/src/request.rs @@ -6,8 +6,8 @@ use std::sync::Arc; use super::SerdeValue; lazy_static! { - static ref URL_BASE: Url = { Url::parse("http://localhost/") - .expect("Failed to parse URL_BASE") }; + static ref URL_BASE: Url = + { Url::parse("http://localhost/").expect("Failed to parse URL_BASE") }; } /// Request instances are builders that creates a request. @@ -43,13 +43,14 @@ impl ::std::fmt::Debug for Request { write!( f, "Request({} {}{}, {:?})", - self.method, url.path(), query, + self.method, + url.path(), + query, self.headers ) } } - impl Request { fn new(agent: &Agent, method: String, path: String) -> Request { Request { @@ -95,7 +96,7 @@ impl Request { .and_then(|url| { let reader = payload.into_read(); let unit = Unit::new(&self, &url, &reader); - connect(unit, url, &self.method, self.redirects, reader) + connect(unit, &self.method, true, self.redirects, reader) }) .unwrap_or_else(|e| e.into()) } diff --git a/src/response.rs b/src/response.rs index a8c3b98..4009e19 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,12 +1,9 @@ use agent::Unit; use ascii::AsciiString; -use chunked_transfer; +use chunked_transfer::Decoder as ChunkDecoder; use header::Header; -use std::io::Cursor; -use std::io::Error as IoError; -use std::io::ErrorKind; -use std::io::Read; -use std::io::Result as IoResult; +use pool::PoolReturnRead; +use std::io::{Cursor, Error as IoError, ErrorKind, Read, Result as IoResult}; use std::str::FromStr; use stream::Stream; @@ -251,28 +248,34 @@ impl Response { /// assert_eq!(bytes.len(), len); /// ``` pub fn into_reader(self) -> impl Read { + // + let is_chunked = self.header("transfer-encoding") .map(|enc| enc.len() > 0) // whatever it says, do chunked .unwrap_or(false); - let len = self.header("content-length") - .and_then(|l| l.parse::().ok()); + let is_head = (&self.unit).as_ref().map(|u| u.is_head).unwrap_or(false); + + let len = if is_head { + // head requests never have a body + Some(0) + } else { + self.header("content-length") + .and_then(|l| l.parse::().ok()) + }; let reader = self.stream.expect("No reader in response?!"); - - // head requests never have a body - let is_head = self.unit.map(|u| u.is_head).unwrap_or(false); - if is_head { - return Box::new(LimitedRead::new(reader, 0)) as Box; - } + let unit = self.unit; // figure out how to make a reader - match is_chunked { - true => Box::new(chunked_transfer::Decoder::new(reader)), - false => match len { - Some(len) => Box::new(LimitedRead::new(reader, len)), - None => Box::new(reader) as Box, - }, + match (is_chunked && !is_head, len) { + (true, _) => { + Box::new(PoolReturnRead::new(unit, ChunkDecoder::new(reader))) as Box + } + (false, Some(len)) => { + Box::new(PoolReturnRead::new(unit, LimitedRead::new(reader, len))) + } + (false, None) => Box::new(PoolReturnRead::new(unit, reader)) as Box, } } diff --git a/src/stream.rs b/src/stream.rs index 4da1e87..514d8ca 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -18,6 +18,22 @@ pub enum Stream { Test(Box, Vec), } +impl ::std::fmt::Debug for Stream { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::result::Result<(), ::std::fmt::Error> { + write!( + f, + "Stream[{}]", + match self { + Stream::Http(_) => "http", + Stream::Https(_) => "https", + Stream::Cursor(_) => "cursor", + #[cfg(test)] + Stream::Test(_, _) => "test", + } + ) + } +} + impl Stream { #[cfg(test)] pub fn to_write_vec(&self) -> Vec { diff --git a/src/test/query_string.rs b/src/test/query_string.rs index 8e2ab01..a0f9dbb 100644 --- a/src/test/query_string.rs +++ b/src/test/query_string.rs @@ -48,6 +48,5 @@ fn query_in_path_and_req() { .call(); let vec = resp.to_write_vec(); let s = String::from_utf8_lossy(&vec); - println!("{}", s); assert!(s.contains("GET /query_in_path_and_req?foo=bar&baz=1%202%203 HTTP/1.1")) } diff --git a/src/unit.rs b/src/unit.rs index 9f95c9c..423bfe8 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -1,9 +1,11 @@ use body::{send_body, Payload, SizedReader}; -use std::io::Write; -use stream::{connect_http, connect_https, connect_test}; +use std::io::{Result as IoResult, Write}; +use stream::{connect_http, connect_https, connect_test, Stream}; use url::Url; // +/// It's a "unit of work". Maybe a bad name for it? +#[derive(Debug)] pub struct Unit { pub agent: Arc>>, pub url: Url, @@ -97,64 +99,35 @@ impl Unit { } pub fn connect( - unit: Unit, - url: Url, + mut unit: Unit, method: &str, + use_pooled: bool, redirects: u32, body: SizedReader, ) -> Result { // // open socket - let mut stream = match url.scheme() { - "http" => connect_http(&unit), - "https" => connect_https(&unit), - "test" => connect_test(&unit), - _ => Err(Error::UnknownScheme(url.scheme().to_string())), - }?; + let (mut stream, is_recycled) = connect_socket(&unit, use_pooled)?; - // send the request start + headers - let mut prelude: Vec = vec![]; - write!( - prelude, - "{} {}{} HTTP/1.1\r\n", - method, - url.path(), - &unit.query_string - )?; - if !has_header(&unit.headers, "host") { - write!(prelude, "Host: {}\r\n", url.host().unwrap())?; - } - for header in &unit.headers { - write!(prelude, "{}: {}\r\n", header.name(), header.value())?; - } - write!(prelude, "\r\n")?; + let send_result = send_prelude(&unit, method, &mut stream); - stream.write_all(&mut prelude[..])?; + if send_result.is_err() { + if is_recycled { + // we try open a new connection, this time there will be + // no connection in the pool. don't use it. + return connect(unit, method, false, redirects, body); + } else { + // not a pooled connection, propagate the error. + return Err(send_result.unwrap_err().into()); + } + } // start reading the response to process cookies and redirects. let mut resp = Response::from_read(&mut stream); // squirrel away cookies - { - let state = &mut unit.agent.lock().unwrap(); - if let Some(add_jar) = state.as_mut().map(|state| &mut state.jar) { - for raw_cookie in resp.all("set-cookie").iter() { - let to_parse = if raw_cookie.to_lowercase().contains("domain=") { - raw_cookie.to_string() - } else { - format!("{}; Domain={}", raw_cookie, &unit.hostname) - }; - match Cookie::parse_encoded(&to_parse[..]) { - Err(_) => (), // ignore unparseable cookies - Ok(mut cookie) => { - let cookie = cookie.into_owned(); - add_jar.add(cookie) - } - } - } - } - } + save_cookies(&unit, &resp); // handle redirects if resp.redirect() { @@ -166,18 +139,22 @@ pub fn connect( let location = resp.header("location"); if let Some(location) = location { // join location header to current url in case it it relative - let new_url = url + let new_url = unit + .url .join(location) .map_err(|_| Error::BadUrl(format!("Bad redirection: {}", location)))?; + // change this for every redirect since it is used when connection pooling. + unit.url = new_url; + // perform the redirect differently depending on 3xx code. return match resp.status() { 301 | 302 | 303 => { send_body(body, unit.is_chunked, &mut stream)?; let empty = Payload::Empty.into_read(); - connect(unit, new_url, "GET", redirects - 1, empty) + connect(unit, "GET", use_pooled, redirects - 1, empty) } - 307 | 308 | _ => connect(unit, new_url, method, redirects - 1, body), + 307 | 308 | _ => connect(unit, method, use_pooled, redirects - 1, body), }; } } @@ -233,3 +210,71 @@ fn combine_query(url: &Url, query: &QString) -> String { (None, false) => "".to_string(), } } + +fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error> { + if use_pooled { + let state = &mut unit.agent.lock().unwrap(); + if let Some(agent) = state.as_mut() { + if let Some(stream) = agent.pool.try_get_connection(&unit.url) { + return Ok((stream, true)); + } + } + } + let stream = match unit.url.scheme() { + "http" => connect_http(&unit), + "https" => connect_https(&unit), + "test" => connect_test(&unit), + _ => Err(Error::UnknownScheme(unit.url.scheme().to_string())), + }; + Ok((stream?, false)) +} + +fn send_prelude(unit: &Unit, method: &str, stream: &mut Stream) -> IoResult<()> { + // send the request start + headers + let mut prelude: Vec = vec![]; + write!( + prelude, + "{} {}{} HTTP/1.1\r\n", + method, + unit.url.path(), + &unit.query_string + )?; + if !has_header(&unit.headers, "host") { + write!(prelude, "Host: {}\r\n", unit.url.host().unwrap())?; + } + for header in &unit.headers { + write!(prelude, "{}: {}\r\n", header.name(), header.value())?; + } + write!(prelude, "\r\n")?; + + stream.write_all(&mut prelude[..])?; + + Ok(()) +} + +fn save_cookies(unit: &Unit, resp: &Response) { + // + + let cookies = resp.all("set-cookie"); + if cookies.is_empty() { + return; + } + + let state = &mut unit.agent.lock().unwrap(); + if let Some(add_jar) = state.as_mut().map(|state| &mut state.jar) { + for raw_cookie in cookies.iter() { + let to_parse = if raw_cookie.to_lowercase().contains("domain=") { + raw_cookie.to_string() + } else { + format!("{}; Domain={}", raw_cookie, &unit.hostname) + }; + match Cookie::parse_encoded(&to_parse[..]) { + Err(_) => (), // ignore unparseable cookies + Ok(mut cookie) => { + let cookie = cookie.into_owned(); + add_jar.add(cookie) + } + } + } + } +}