use crate::agent::Unit; use crate::error::Error; use lazy_static::lazy_static; use std::io::{Cursor, Read, Result as IoResult, Write}; use std::net::SocketAddr; use std::net::TcpStream; use std::net::ToSocketAddrs; use std::time::Duration; #[allow(clippy::large_enum_variant)] pub enum Stream { Http(TcpStream), #[cfg(feature = "tls")] Https(rustls::StreamOwned), Cursor(Cursor>), #[cfg(test)] Test(Box, Vec), } impl ::std::fmt::Debug for Stream { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::result::Result<(), ::std::fmt::Error> { write!( f, "Stream[{}]", match self { Stream::Http(_) => "http", #[cfg(feature = "tls")] Stream::Https(_) => "https", Stream::Cursor(_) => "cursor", #[cfg(test)] Stream::Test(_, _) => "test", } ) } } impl Stream { pub fn is_poolable(&self) -> bool { match self { Stream::Http(_) => true, #[cfg(feature = "tls")] Stream::Https(_) => true, _ => false, } } #[cfg(test)] pub fn to_write_vec(&self) -> Vec { match self { Stream::Test(_, writer) => writer.clone(), _ => panic!("to_write_vec on non Test stream"), } } } impl Read for Stream { fn read(&mut self, buf: &mut [u8]) -> IoResult { match self { Stream::Http(sock) => sock.read(buf), #[cfg(feature = "tls")] Stream::Https(stream) => stream.read(buf), Stream::Cursor(read) => read.read(buf), #[cfg(test)] Stream::Test(reader, _) => reader.read(buf), } } } impl Write for Stream { fn write(&mut self, buf: &[u8]) -> IoResult { match self { Stream::Http(sock) => sock.write(buf), #[cfg(feature = "tls")] Stream::Https(stream) => stream.write(buf), Stream::Cursor(_) => panic!("Write to read only stream"), #[cfg(test)] Stream::Test(_, writer) => writer.write(buf), } } fn flush(&mut self) -> IoResult<()> { match self { Stream::Http(sock) => sock.flush(), #[cfg(feature = "tls")] Stream::Https(stream) => stream.flush(), Stream::Cursor(_) => panic!("Flush read only stream"), #[cfg(test)] Stream::Test(_, writer) => writer.flush(), } } } pub fn connect_http(unit: &Unit) -> Result { // let hostname = unit.url.host_str().unwrap(); let port = unit.url.port().unwrap_or(80); connect_host(unit, hostname, port).map(Stream::Http) } #[cfg(feature = "tls")] pub fn connect_https(unit: &Unit) -> Result { use std::sync::Arc; lazy_static! { static ref TLS_CONF: Arc = { let mut config = rustls::ClientConfig::new(); config .root_store .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); Arc::new(config) }; } let hostname = unit.url.host_str().unwrap(); let port = unit.url.port().unwrap_or(443); let sni = webpki::DNSNameRef::try_from_ascii_str(hostname).unwrap(); let sess = rustls::ClientSession::new(&*TLS_CONF, sni); let sock = connect_host(unit, hostname, port)?; let stream = rustls::StreamOwned::new(sess, sock); Ok(Stream::Https(stream)) } pub fn connect_host(unit: &Unit, hostname: &str, port: u16) -> Result { // let ips: Vec = format!("{}:{}", hostname, port) .to_socket_addrs() .map_err(|e| Error::DnsFailed(format!("{}", e)))? .collect(); if ips.is_empty() { return Err(Error::DnsFailed(format!("No ip address for {}", hostname))); } // pick first ip, or should we randomize? let sock_addr = ips[0]; // connect with a configured timeout. let stream = match unit.timeout_connect { 0 => TcpStream::connect(&sock_addr), _ => TcpStream::connect_timeout( &sock_addr, Duration::from_millis(unit.timeout_connect as u64), ), } .map_err(|err| Error::ConnectionFailed(format!("{}", err)))?; // rust's absurd api returns Err if we set 0. if unit.timeout_read > 0 { stream .set_read_timeout(Some(Duration::from_millis(unit.timeout_read as u64))) .ok(); } if unit.timeout_write > 0 { stream .set_write_timeout(Some(Duration::from_millis(unit.timeout_write as u64))) .ok(); } Ok(stream) } #[cfg(test)] pub fn connect_test(unit: &Unit) -> Result { use crate::test; test::resolve_handler(unit) } #[cfg(not(test))] pub fn connect_test(unit: &Unit) -> Result { Err(Error::UnknownScheme(unit.url.scheme().to_string())) } #[cfg(not(feature = "tls"))] pub fn connect_https(unit: &Unit) -> Result { Err(Error::UnknownScheme(unit.url.scheme().to_string())) }