From 7177a99d1f47161bb8518275413de120e854432f Mon Sep 17 00:00:00 2001 From: Martin Algesten Date: Thu, 14 Jun 2018 14:38:00 +0200 Subject: [PATCH] move connect calls to stream --- src/agent.rs | 1 + src/conn.rs | 70 ------------------------------------ src/lib.rs | 1 - src/stream.rs | 95 ++++++++++++++++++++++++++++++++++++++++++------- src/test/mod.rs | 4 +-- 5 files changed, 85 insertions(+), 86 deletions(-) diff --git a/src/agent.rs b/src/agent.rs index db2cccb..f1c0d34 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -9,6 +9,7 @@ use util::*; include!("request.rs"); include!("response.rs"); include!("conn.rs"); +include!("stream.rs"); #[derive(Debug, Default, Clone)] pub struct Agent { diff --git a/src/conn.rs b/src/conn.rs index b538eb8..e6dd4cc 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -1,11 +1,4 @@ -use dns_lookup; use std::io::Write; -use std::net::IpAddr; -use std::net::SocketAddr; -use std::net::TcpStream; -use std::time::Duration; -use stream::Stream; -use native_tls::TlsConnector; use url::Url; const CHUNK_SIZE: usize = 1024 * 1024; @@ -119,58 +112,6 @@ impl ConnectionPool { } } -fn connect_http(request: &Request, url: &Url) -> Result { - // - let hostname = url.host_str().unwrap(); - let port = url.port().unwrap_or(80); - - connect_host(request, hostname, port).map(|tcp| Stream::Http(tcp)) -} - -fn connect_https(request: &Request, url: &Url) -> Result { - // - let hostname = url.host_str().unwrap(); - let port = url.port().unwrap_or(443); - - let socket = connect_host(request, hostname, port)?; - let connector = TlsConnector::builder()?.build()?; - let stream = connector.connect(hostname, socket)?; - - Ok(Stream::Https(stream)) -} - -fn connect_host(request: &Request, hostname: &str, port: u16) -> Result { - // - let ips: Vec = - dns_lookup::lookup_host(hostname).map_err(|e| Error::DnsFailed(format!("{}", e)))?; - - if ips.len() == 0 { - return Err(Error::DnsFailed(format!("No ip address for {}", hostname))); - } - - // pick first ip, or should we randomize? - let sock_addr = SocketAddr::new(ips[0], port); - - // connect with a configured timeout. - let stream = match request.timeout { - 0 => TcpStream::connect(&sock_addr), - _ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(request.timeout as u64)), - }.map_err(|err| Error::ConnectionFailed(format!("{}", err)))?; - - // rust's absurd api returns Err if we set 0. - if request.timeout_read > 0 { - stream - .set_read_timeout(Some(Duration::from_millis(request.timeout_read as u64))) - .ok(); - } - if request.timeout_write > 0 { - stream - .set_write_timeout(Some(Duration::from_millis(request.timeout_write as u64))) - .ok(); - } - - Ok(stream) -} fn send_payload(request: &Request, payload: Payload, stream: &mut Stream) -> IoResult<()> { // @@ -242,14 +183,3 @@ fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bo .map(|o| o.unwrap()) .collect() } - -#[cfg(not(test))] -fn connect_test(_request: &Request, url: &Url) -> Result { - Err(Error::UnknownScheme(url.scheme().to_string())) -} - -#[cfg(test)] -fn connect_test(request: &Request, url: &Url) -> Result { - use test; - test::resolve_handler(request, url) -} diff --git a/src/lib.rs b/src/lib.rs index dc28d55..6faa920 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,7 +16,6 @@ extern crate url; mod agent; mod error; mod header; -mod stream; mod util; #[cfg(test)] diff --git a/src/stream.rs b/src/stream.rs index 83dd9bf..7bd12dd 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,15 +1,17 @@ -use std::io::Read; -use std::io::Result; -use std::io::Write; -use std::net::TcpStream; +use dns_lookup; +use native_tls::TlsConnector; use native_tls::TlsStream; - +use std::net::IpAddr; +use std::net::SocketAddr; +use std::net::TcpStream; +use std::time::Duration; pub enum Stream { Http(TcpStream), Https(TlsStream), Read(Box), - #[cfg(test)] Test(Box, Vec), + #[cfg(test)] + Test(Box, Vec), } impl Stream { @@ -17,37 +19,104 @@ impl Stream { pub fn to_write_vec(&self) -> Vec { match self { Stream::Test(_, writer) => writer.clone(), - _ => panic!("to_write_vec on non Test stream") + _ => panic!("to_write_vec on non Test stream"), } } } impl Read for Stream { - fn read(&mut self, buf: &mut [u8]) -> Result { + fn read(&mut self, buf: &mut [u8]) -> IoResult { match self { Stream::Http(sock) => sock.read(buf), Stream::Https(stream) => stream.read(buf), Stream::Read(read) => read.read(buf), - #[cfg(test)] Stream::Test(reader, _) => reader.read(buf), + #[cfg(test)] + Stream::Test(reader, _) => reader.read(buf), } } } impl Write for Stream { - fn write(&mut self, buf: &[u8]) -> Result { + fn write(&mut self, buf: &[u8]) -> IoResult { match self { Stream::Http(sock) => sock.write(buf), Stream::Https(stream) => stream.write(buf), Stream::Read(_) => panic!("Write to read stream"), - #[cfg(test)] Stream::Test(_, writer) => writer.write(buf), + #[cfg(test)] + Stream::Test(_, writer) => writer.write(buf), } } - fn flush(&mut self) -> Result<()> { + fn flush(&mut self) -> IoResult<()> { match self { Stream::Http(sock) => sock.flush(), Stream::Https(stream) => stream.flush(), Stream::Read(_) => panic!("Flush read stream"), - #[cfg(test)] Stream::Test(_, writer) => writer.flush(), + #[cfg(test)] + Stream::Test(_, writer) => writer.flush(), } } } + +fn connect_http(request: &Request, url: &Url) -> Result { + // + let hostname = url.host_str().unwrap(); + let port = url.port().unwrap_or(80); + + connect_host(request, hostname, port).map(|tcp| Stream::Http(tcp)) +} + +fn connect_https(request: &Request, url: &Url) -> Result { + // + let hostname = url.host_str().unwrap(); + let port = url.port().unwrap_or(443); + + let socket = connect_host(request, hostname, port)?; + let connector = TlsConnector::builder()?.build()?; + let stream = connector.connect(hostname, socket)?; + + Ok(Stream::Https(stream)) +} + +fn connect_host(request: &Request, hostname: &str, port: u16) -> Result { + // + let ips: Vec = + dns_lookup::lookup_host(hostname).map_err(|e| Error::DnsFailed(format!("{}", e)))?; + + if ips.len() == 0 { + return Err(Error::DnsFailed(format!("No ip address for {}", hostname))); + } + + // pick first ip, or should we randomize? + let sock_addr = SocketAddr::new(ips[0], port); + + // connect with a configured timeout. + let stream = match request.timeout { + 0 => TcpStream::connect(&sock_addr), + _ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(request.timeout as u64)), + }.map_err(|err| Error::ConnectionFailed(format!("{}", err)))?; + + // rust's absurd api returns Err if we set 0. + if request.timeout_read > 0 { + stream + .set_read_timeout(Some(Duration::from_millis(request.timeout_read as u64))) + .ok(); + } + if request.timeout_write > 0 { + stream + .set_write_timeout(Some(Duration::from_millis(request.timeout_write as u64))) + .ok(); + } + + Ok(stream) +} + +#[cfg(not(test))] +fn connect_test(_request: &Request, url: &Url) -> Result { + Err(Error::UnknownScheme(url.scheme().to_string())) +} + +#[cfg(test)] +fn connect_test(request: &Request, url: &Url) -> Result { + use test; + test::resolve_handler(request, url) +} diff --git a/src/test/mod.rs b/src/test/mod.rs index ad32355..ce78344 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -1,17 +1,17 @@ use agent::Request; +use agent::Stream; use error::Error; use header::Header; use std::collections::HashMap; use std::io::Write; use std::sync::{Arc, Mutex}; -use stream::Stream; use url::Url; use util::vecread::VecRead; mod agent_test; mod auth; -mod simple; mod body_read; +mod simple; type RequestHandler = Fn(&Request, &Url) -> Result + Send + 'static;