move connect calls to stream

This commit is contained in:
Martin Algesten
2018-06-14 14:38:00 +02:00
parent ef6f8c6259
commit 7177a99d1f
5 changed files with 85 additions and 86 deletions

View File

@@ -9,6 +9,7 @@ use util::*;
include!("request.rs");
include!("response.rs");
include!("conn.rs");
include!("stream.rs");
#[derive(Debug, Default, Clone)]
pub struct Agent {

View File

@@ -1,11 +1,4 @@
use dns_lookup;
use std::io::Write;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::net::TcpStream;
use std::time::Duration;
use stream::Stream;
use native_tls::TlsConnector;
use url::Url;
const CHUNK_SIZE: usize = 1024 * 1024;
@@ -119,58 +112,6 @@ impl ConnectionPool {
}
}
fn connect_http(request: &Request, url: &Url) -> Result<Stream, Error> {
//
let hostname = url.host_str().unwrap();
let port = url.port().unwrap_or(80);
connect_host(request, hostname, port).map(|tcp| Stream::Http(tcp))
}
fn connect_https(request: &Request, url: &Url) -> Result<Stream, Error> {
//
let hostname = url.host_str().unwrap();
let port = url.port().unwrap_or(443);
let socket = connect_host(request, hostname, port)?;
let connector = TlsConnector::builder()?.build()?;
let stream = connector.connect(hostname, socket)?;
Ok(Stream::Https(stream))
}
fn connect_host(request: &Request, hostname: &str, port: u16) -> Result<TcpStream, Error> {
//
let ips: Vec<IpAddr> =
dns_lookup::lookup_host(hostname).map_err(|e| Error::DnsFailed(format!("{}", e)))?;
if ips.len() == 0 {
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
}
// pick first ip, or should we randomize?
let sock_addr = SocketAddr::new(ips[0], port);
// connect with a configured timeout.
let stream = match request.timeout {
0 => TcpStream::connect(&sock_addr),
_ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(request.timeout as u64)),
}.map_err(|err| Error::ConnectionFailed(format!("{}", err)))?;
// rust's absurd api returns Err if we set 0.
if request.timeout_read > 0 {
stream
.set_read_timeout(Some(Duration::from_millis(request.timeout_read as u64)))
.ok();
}
if request.timeout_write > 0 {
stream
.set_write_timeout(Some(Duration::from_millis(request.timeout_write as u64)))
.ok();
}
Ok(stream)
}
fn send_payload(request: &Request, payload: Payload, stream: &mut Stream) -> IoResult<()> {
//
@@ -242,14 +183,3 @@ fn match_cookies<'a>(jar: &'a CookieJar, domain: &str, path: &str, is_secure: bo
.map(|o| o.unwrap())
.collect()
}
#[cfg(not(test))]
fn connect_test(_request: &Request, url: &Url) -> Result<Stream, Error> {
Err(Error::UnknownScheme(url.scheme().to_string()))
}
#[cfg(test)]
fn connect_test(request: &Request, url: &Url) -> Result<Stream, Error> {
use test;
test::resolve_handler(request, url)
}

View File

@@ -16,7 +16,6 @@ extern crate url;
mod agent;
mod error;
mod header;
mod stream;
mod util;
#[cfg(test)]

View File

@@ -1,15 +1,17 @@
use std::io::Read;
use std::io::Result;
use std::io::Write;
use std::net::TcpStream;
use dns_lookup;
use native_tls::TlsConnector;
use native_tls::TlsStream;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::net::TcpStream;
use std::time::Duration;
pub enum Stream {
Http(TcpStream),
Https(TlsStream<TcpStream>),
Read(Box<Read>),
#[cfg(test)] Test(Box<Read + Send>, Vec<u8>),
#[cfg(test)]
Test(Box<Read + Send>, Vec<u8>),
}
impl Stream {
@@ -17,37 +19,104 @@ impl Stream {
pub fn to_write_vec(&self) -> Vec<u8> {
match self {
Stream::Test(_, writer) => writer.clone(),
_ => panic!("to_write_vec on non Test stream")
_ => panic!("to_write_vec on non Test stream"),
}
}
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match self {
Stream::Http(sock) => sock.read(buf),
Stream::Https(stream) => stream.read(buf),
Stream::Read(read) => read.read(buf),
#[cfg(test)] Stream::Test(reader, _) => reader.read(buf),
#[cfg(test)]
Stream::Test(reader, _) => reader.read(buf),
}
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
match self {
Stream::Http(sock) => sock.write(buf),
Stream::Https(stream) => stream.write(buf),
Stream::Read(_) => panic!("Write to read stream"),
#[cfg(test)] Stream::Test(_, writer) => writer.write(buf),
#[cfg(test)]
Stream::Test(_, writer) => writer.write(buf),
}
}
fn flush(&mut self) -> Result<()> {
fn flush(&mut self) -> IoResult<()> {
match self {
Stream::Http(sock) => sock.flush(),
Stream::Https(stream) => stream.flush(),
Stream::Read(_) => panic!("Flush read stream"),
#[cfg(test)] Stream::Test(_, writer) => writer.flush(),
#[cfg(test)]
Stream::Test(_, writer) => writer.flush(),
}
}
}
fn connect_http(request: &Request, url: &Url) -> Result<Stream, Error> {
//
let hostname = url.host_str().unwrap();
let port = url.port().unwrap_or(80);
connect_host(request, hostname, port).map(|tcp| Stream::Http(tcp))
}
fn connect_https(request: &Request, url: &Url) -> Result<Stream, Error> {
//
let hostname = url.host_str().unwrap();
let port = url.port().unwrap_or(443);
let socket = connect_host(request, hostname, port)?;
let connector = TlsConnector::builder()?.build()?;
let stream = connector.connect(hostname, socket)?;
Ok(Stream::Https(stream))
}
fn connect_host(request: &Request, hostname: &str, port: u16) -> Result<TcpStream, Error> {
//
let ips: Vec<IpAddr> =
dns_lookup::lookup_host(hostname).map_err(|e| Error::DnsFailed(format!("{}", e)))?;
if ips.len() == 0 {
return Err(Error::DnsFailed(format!("No ip address for {}", hostname)));
}
// pick first ip, or should we randomize?
let sock_addr = SocketAddr::new(ips[0], port);
// connect with a configured timeout.
let stream = match request.timeout {
0 => TcpStream::connect(&sock_addr),
_ => TcpStream::connect_timeout(&sock_addr, Duration::from_millis(request.timeout as u64)),
}.map_err(|err| Error::ConnectionFailed(format!("{}", err)))?;
// rust's absurd api returns Err if we set 0.
if request.timeout_read > 0 {
stream
.set_read_timeout(Some(Duration::from_millis(request.timeout_read as u64)))
.ok();
}
if request.timeout_write > 0 {
stream
.set_write_timeout(Some(Duration::from_millis(request.timeout_write as u64)))
.ok();
}
Ok(stream)
}
#[cfg(not(test))]
fn connect_test(_request: &Request, url: &Url) -> Result<Stream, Error> {
Err(Error::UnknownScheme(url.scheme().to_string()))
}
#[cfg(test)]
fn connect_test(request: &Request, url: &Url) -> Result<Stream, Error> {
use test;
test::resolve_handler(request, url)
}

View File

@@ -1,17 +1,17 @@
use agent::Request;
use agent::Stream;
use error::Error;
use header::Header;
use std::collections::HashMap;
use std::io::Write;
use std::sync::{Arc, Mutex};
use stream::Stream;
use url::Url;
use util::vecread::VecRead;
mod agent_test;
mod auth;
mod simple;
mod body_read;
mod simple;
type RequestHandler = Fn(&Request, &Url) -> Result<Stream, Error> + Send + 'static;