Remove DEFAULT_HOST (#153)

In a few places we relied on "localhost" as a default if a URL's host
was not set, but I think it's better to error out in these cases.

In general, there are a few places in Unit that assumed there is a
host as part of the URL. I've made that explicit by doing a check
at the beginning of `connect()`. I've also tried to plumb through
the semantics of "host is always present" by changing the parameter
types of some of the functions that use the hostname.

I considered a more thorough way to express this with types - for
instance implementing an `HttpUrl` struct that embeds a `Url`, and
exports most of the same methods, but guarantees that host is always
present. However, that was more invasive than this so I did a smaller
change to start.
This commit is contained in:
Jacob Hoffman-Andrews
2020-09-27 10:07:13 -07:00
committed by GitHub
parent fec79dcef3
commit e8c3403f7b
4 changed files with 44 additions and 25 deletions

View File

@@ -8,7 +8,6 @@ use crate::Proxy;
use url::Url; use url::Url;
pub const DEFAULT_HOST: &str = "localhost";
const DEFAULT_MAX_IDLE_CONNECTIONS: usize = 100; const DEFAULT_MAX_IDLE_CONNECTIONS: usize = 100;
const DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST: usize = 1; const DEFAULT_MAX_IDLE_CONNECTIONS_PER_HOST: usize = 1;

View File

@@ -12,7 +12,6 @@ use crate::body::BodySize;
use crate::body::{Payload, SizedReader}; use crate::body::{Payload, SizedReader};
use crate::error::Error; use crate::error::Error;
use crate::header::{self, Header}; use crate::header::{self, Header};
use crate::pool;
use crate::unit::{self, Unit}; use crate::unit::{self, Unit};
use crate::Response; use crate::Response;
@@ -500,8 +499,13 @@ impl Request {
/// assert_eq!(req2.get_host().unwrap(), "localhost"); /// assert_eq!(req2.get_host().unwrap(), "localhost");
/// ``` /// ```
pub fn get_host(&self) -> Result<String, Error> { pub fn get_host(&self) -> Result<String, Error> {
self.to_url() match self.to_url() {
.map(|u| u.host_str().unwrap_or(pool::DEFAULT_HOST).to_string()) Ok(u) => match u.host_str() {
Some(host) => Ok(host.to_string()),
None => Err(Error::BadUrl("No hostname in URL".into())),
},
Err(e) => Err(e),
}
} }
/// Returns the scheme for this request. /// Returns the scheme for this request.
@@ -637,3 +641,13 @@ impl fmt::Debug for TLSConnector {
f.debug_struct("TLSConnector").finish() f.debug_struct("TLSConnector").finish()
} }
} }
#[test]
fn no_hostname() {
let req = Request::new(
&Agent::default(),
"GET".to_string(),
"unix:/run/foo.socket".to_string(),
);
assert!(req.get_host().is_err());
}

View File

@@ -301,9 +301,8 @@ impl Write for Stream {
} }
} }
pub(crate) fn connect_http(unit: &Unit) -> Result<Stream, Error> { pub(crate) fn connect_http(unit: &Unit, hostname: &str) -> Result<Stream, Error> {
// //
let hostname = unit.url.host_str().unwrap();
let port = unit.url.port().unwrap_or(80); let port = unit.url.port().unwrap_or(80);
connect_host(unit, hostname, port) connect_host(unit, hostname, port)
@@ -325,7 +324,7 @@ fn configure_certs(config: &mut rustls::ClientConfig) {
} }
#[cfg(all(feature = "tls", not(feature = "native-tls")))] #[cfg(all(feature = "tls", not(feature = "native-tls")))]
pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> { pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error> {
use lazy_static::lazy_static; use lazy_static::lazy_static;
use std::sync::Arc; use std::sync::Arc;
@@ -337,7 +336,6 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
}; };
} }
let hostname = unit.url.host_str().unwrap();
let port = unit.url.port().unwrap_or(443); let port = unit.url.port().unwrap_or(443);
let sni = webpki::DNSNameRef::try_from_ascii_str(hostname) let sni = webpki::DNSNameRef::try_from_ascii_str(hostname)
@@ -358,10 +356,9 @@ pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> {
} }
#[cfg(all(feature = "native-tls", not(feature = "tls")))] #[cfg(all(feature = "native-tls", not(feature = "tls")))]
pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> { pub(crate) fn connect_https(unit: &Unit, hostname: &str) -> Result<Stream, Error> {
use std::sync::Arc; use std::sync::Arc;
let hostname = unit.url.host_str().unwrap();
let port = unit.url.port().unwrap_or(443); let port = unit.url.port().unwrap_or(443);
let sock = connect_host(unit, hostname, port)?; let sock = connect_host(unit, hostname, port)?;
@@ -657,6 +654,6 @@ pub(crate) fn connect_test(unit: &Unit) -> Result<Stream, Error> {
} }
#[cfg(not(any(feature = "tls", feature = "native-tls")))] #[cfg(not(any(feature = "tls", feature = "native-tls")))]
pub(crate) fn connect_https(unit: &Unit) -> Result<Stream, Error> { pub(crate) fn connect_https(unit: &Unit, _hostname: &str) -> Result<Stream, Error> {
Err(Error::UnknownScheme(unit.url.scheme().to_string())) Err(Error::UnknownScheme(unit.url.scheme().to_string()))
} }

View File

@@ -14,9 +14,6 @@ use crate::resolve::ArcResolver;
use crate::stream::{self, connect_test, Stream}; use crate::stream::{self, connect_test, Stream};
use crate::{Error, Header, Request, Response}; use crate::{Error, Header, Request, Response};
#[cfg(feature = "cookie")]
use crate::pool::DEFAULT_HOST;
/// It's a "unit of work". Maybe a bad name for it? /// It's a "unit of work". Maybe a bad name for it?
/// ///
/// *Internal API* /// *Internal API*
@@ -51,7 +48,9 @@ impl Unit {
let query_string = combine_query(&url, &req.query, mix_queries); let query_string = combine_query(&url, &req.query, mix_queries);
let cookie_header: Option<Header> = extract_cookies(&req.agent, &url); let cookie_header: Option<Header> = url
.host_str()
.and_then(|host_str| extract_cookies(&req.agent, &url.scheme(), host_str, &url.path()));
let extra_headers = { let extra_headers = {
let mut extra = vec![]; let mut extra = vec![];
@@ -145,8 +144,9 @@ pub(crate) fn connect(
) -> Result<Response, Error> { ) -> Result<Response, Error> {
// //
let host = req.get_host()?;
// open socket // open socket
let (mut stream, is_recycled) = connect_socket(&unit, use_pooled)?; let (mut stream, is_recycled) = connect_socket(&unit, &host, use_pooled)?;
let send_result = send_prelude(&unit, &mut stream, redir); let send_result = send_prelude(&unit, &mut stream, redir);
@@ -238,16 +238,25 @@ pub(crate) fn connect(
} }
#[cfg(feature = "cookie")] #[cfg(feature = "cookie")]
fn extract_cookies(state: &std::sync::Mutex<AgentState>, url: &Url) -> Option<Header> { fn extract_cookies(
state: &std::sync::Mutex<AgentState>,
scheme: &str,
host: &str,
path: &str,
) -> Option<Header> {
let state = state.lock().unwrap(); let state = state.lock().unwrap();
let is_secure = url.scheme().eq_ignore_ascii_case("https"); let is_secure = scheme.eq_ignore_ascii_case("https");
let hostname = url.host_str().unwrap_or(DEFAULT_HOST).to_string();
match_cookies(&state.jar, &hostname, url.path(), is_secure) match_cookies(&state.jar, host, path, is_secure)
} }
#[cfg(not(feature = "cookie"))] #[cfg(not(feature = "cookie"))]
fn extract_cookies(_state: &std::sync::Mutex<AgentState>, _url: &Url) -> Option<Header> { fn extract_cookies(
_state: &std::sync::Mutex<AgentState>,
_scheme: &str,
_host: &str,
_path: &str,
) -> Option<Header> {
None None
} }
@@ -298,7 +307,7 @@ pub(crate) fn combine_query(url: &Url, query: &QString, mix_queries: bool) -> St
} }
/// Connect the socket, either by using the pool or grab a new one. /// Connect the socket, either by using the pool or grab a new one.
fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error> { fn connect_socket(unit: &Unit, hostname: &str, use_pooled: bool) -> Result<(Stream, bool), Error> {
match unit.url.scheme() { match unit.url.scheme() {
"http" | "https" | "test" => (), "http" | "https" | "test" => (),
_ => return Err(Error::UnknownScheme(unit.url.scheme().to_string())), _ => return Err(Error::UnknownScheme(unit.url.scheme().to_string())),
@@ -316,8 +325,8 @@ fn connect_socket(unit: &Unit, use_pooled: bool) -> Result<(Stream, bool), Error
} }
} }
let stream = match unit.url.scheme() { let stream = match unit.url.scheme() {
"http" => stream::connect_http(&unit), "http" => stream::connect_http(&unit, hostname),
"https" => stream::connect_https(&unit), "https" => stream::connect_https(&unit, hostname),
"test" => connect_test(&unit), "test" => connect_test(&unit),
_ => Err(Error::UnknownScheme(unit.url.scheme().to_string())), _ => Err(Error::UnknownScheme(unit.url.scheme().to_string())),
}; };
@@ -408,7 +417,7 @@ fn save_cookies(unit: &Unit, resp: &Response) {
let to_parse = if raw_cookie.to_lowercase().contains("domain=") { let to_parse = if raw_cookie.to_lowercase().contains("domain=") {
(*raw_cookie).to_string() (*raw_cookie).to_string()
} else { } else {
let host = &unit.url.host_str().unwrap_or(DEFAULT_HOST).to_string(); let host = &unit.url.host_str().unwrap().to_string();
format!("{}; Domain={}", raw_cookie, host) format!("{}; Domain={}", raw_cookie, host)
}; };
match Cookie::parse_encoded(&to_parse[..]) { match Cookie::parse_encoded(&to_parse[..]) {