diff --git a/src/pool.rs b/src/pool.rs index 14b0da7..b8a7e7e 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -8,7 +8,6 @@ use crate::Proxy; use url::Url; -pub const DEFAULT_HOST: &str = "localhost"; const DEFAULT_MAX_IDLE_CONNECTIONS: usize = 100; const DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST: usize = 1; diff --git a/src/request.rs b/src/request.rs index 31478cc..85bce0c 100644 --- a/src/request.rs +++ b/src/request.rs @@ -12,7 +12,6 @@ use crate::body::BodySize; use crate::body::{Payload, SizedReader}; use crate::error::Error; use crate::header::{self, Header}; -use crate::pool; use crate::unit::{self, Unit}; use crate::Response; @@ -500,8 +499,13 @@ impl Request { /// assert_eq!(req2.get_host().unwrap(), "localhost"); /// ``` pub fn get_host(&self) -> Result { - self.to_url() - .map(|u| u.host_str().unwrap_or(pool::DEFAULT_HOST).to_string()) + match self.to_url() { + Ok(u) => match u.host_str() { + Some(host) => Ok(host.to_string()), + None => Err(Error::BadUrl("No hostname in URL".into())), + }, + Err(e) => Err(e), + } } /// Returns the scheme for this request. @@ -637,3 +641,13 @@ impl fmt::Debug for TLSConnector { f.debug_struct("TLSConnector").finish() } } + +#[test] +fn no_hostname() { + let req = Request::new( + &Agent::default(), + "GET".to_string(), + "unix:/run/foo.socket".to_string(), + ); + assert!(req.get_host().is_err()); +} diff --git a/src/stream.rs b/src/stream.rs index c51520f..a38899c 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -301,9 +301,8 @@ impl Write for Stream { } } -pub(crate) fn connect_http(unit: &Unit) -> Result { +pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result { // - let hostname = unit.url.host_str().unwrap(); let port = unit.url.port().unwrap_or(80); connect_host(unit, hostname, port) @@ -325,7 +324,7 @@ fn configure_certs(config: &mut rustls::ClientConfig) { } #[cfg(all(feature = "tls", not(feature = "native-tls")))] -pub(crate) fn connect_https(unit: &Unit) -> Result { +pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result { use lazy_static::lazy_static; use std::sync::Arc; @@ -337,7 +336,6 @@ pub(crate) fn connect_https(unit: &Unit) -> Result { }; } - let hostname = unit.url.host_str().unwrap(); let port = unit.url.port().unwrap_or(443); let sni = webpki::DNSNameRef::try_from_ascii_str(hostname) @@ -358,10 +356,9 @@ pub(crate) fn connect_https(unit: &Unit) -> Result { } #[cfg(all(feature = "native-tls", not(feature = "tls")))] -pub(crate) fn connect_https(unit: &Unit) -> Result { +pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result { use std::sync::Arc; - let hostname = unit.url.host_str().unwrap(); let port = unit.url.port().unwrap_or(443); let sock = connect_host(unit, hostname, port)?; @@ -657,6 +654,6 @@ pub(crate) fn connect_test(unit: &Unit) -> Result { } #[cfg(not(any(feature = "tls", feature = "native-tls")))] -pub(crate) fn connect_https(unit: &Unit) -> Result { +pub(crate) fn connect_https(unit: &Unit, _hostname: &str) -> Result { Err(Error::UnknownScheme(unit.url.scheme().to_string())) } diff --git a/src/unit.rs b/src/unit.rs index 6ca617b..58df934 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -14,9 +14,6 @@ use crate::resolve::ArcResolver; use crate::stream::{self, connect_test, Stream}; use crate::{Error, Header, Request, Response}; -#[cfg(feature = "cookie")] -use crate::pool::DEFAULT_HOST; - /// It's a "unit of work". Maybe a bad name for it? /// /// *Internal API* @@ -51,7 +48,9 @@ impl Unit { let query_string = combine_query(&url, &req.query, mix_queries); - let cookie_header: Option
= extract_cookies(&req.agent, &url); + let cookie_header: Option
= url + .host_str() + .and_then(|host_str| extract_cookies(&req.agent, &url.scheme(), host_str, &url.path())); let extra_headers = { let mut extra = vec![]; @@ -145,8 +144,9 @@ pub(crate) fn connect( ) -> Result { // + let host = req.get_host()?; // open socket - let (mut stream, is_recycled) = connect_socket(&unit, use_pooled)?; + let (mut stream, is_recycled) = connect_socket(&unit, &host, use_pooled)?; let send_result = send_prelude(&unit, &mut stream, redir); @@ -238,16 +238,25 @@ pub(crate) fn connect( } #[cfg(feature = "cookie")] -fn extract_cookies(state: &std::sync::Mutex, url: &Url) -> Option
{ +fn extract_cookies( + state: &std::sync::Mutex, + scheme: &str, + host: &str, + path: &str, +) -> Option
{ let state = state.lock().unwrap(); - let is_secure = url.scheme().eq_ignore_ascii_case("https"); - let hostname = url.host_str().unwrap_or(DEFAULT_HOST).to_string(); + let is_secure = scheme.eq_ignore_ascii_case("https"); - match_cookies(&state.jar, &hostname, url.path(), is_secure) + match_cookies(&state.jar, host, path, is_secure) } #[cfg(not(feature = "cookie"))] -fn extract_cookies(_state: &std::sync::Mutex, _url: &Url) -> Option
{ +fn extract_cookies( + _state: &std::sync::Mutex, + _scheme: &str, + _host: &str, + _path: &str, +) -> Option
{ None } @@ -298,7 +307,7 @@ pub(crate) fn combine_query(url: &Url, query: &QString, mix_queries: bool) -> St } /// Connect the socket, either by using the pool or grab a new one. -fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error> { +fn connect_socket(unit: &Unit, hostname: &str, use_pooled: bool) -> Result<(Stream, bool), Error> { match unit.url.scheme() { "http" | "https" | "test" => (), _ => return Err(Error::UnknownScheme(unit.url.scheme().to_string())), @@ -316,8 +325,8 @@ fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error } } let stream = match unit.url.scheme() { - "http" => stream::connect_http(&unit), - "https" => stream::connect_https(&unit), + "http" => stream::connect_http(&unit, hostname), + "https" => stream::connect_https(&unit, hostname), "test" => connect_test(&unit), _ => Err(Error::UnknownScheme(unit.url.scheme().to_string())), }; @@ -408,7 +417,7 @@ fn save_cookies(unit: &Unit, resp: &Response) { let to_parse = if raw_cookie.to_lowercase().contains("domain=") { (*raw_cookie).to_string() } else { - let host = &unit.url.host_str().unwrap_or(DEFAULT_HOST).to_string(); + let host = &unit.url.host_str().unwrap().to_string(); format!("{}; Domain={}", raw_cookie, host) }; match Cookie::parse_encoded(&to_parse[..]) {