From f395a726e3c9073211312883ecb20aa17a22f35f Mon Sep 17 00:00:00 2001 From: Kade Robertson Date: Fri, 10 Feb 2023 16:28:00 -0500 Subject: [PATCH] feat: bidrectional http response conversion --- .github/workflows/test.yml | 1 + Cargo.toml | 2 + src/response.rs | 143 +++++++++++++++++++++++++++++++++++++ test.sh | 2 +- 4 files changed, 147 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 139f29f..e3c54ed 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -58,6 +58,7 @@ jobs: - native-certs - gzip - brotli + - http env: RUST_BACKTRACE: "1" RUSTFLAGS: "-D dead_code -D unused-variables -D unused" diff --git a/Cargo.toml b/Cargo.toml index 418ae58..f6f0375 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ cookies = ["cookie", "cookie_store"] socks-proxy = ["socks"] gzip = ["flate2"] brotli = ["brotli-decompressor"] +http = ["dep:http"] [dependencies] base64 = "0.21" @@ -43,6 +44,7 @@ rustls-native-certs = { version = "0.6", optional = true } native-tls = { version = "0.2", optional = true } flate2 = { version = "1.0.22", optional = true } brotli-decompressor = { version = "2.3.2", optional = true } +http = { version = "0.2", optional = true } [dev-dependencies] serde = { version = "1", features = ["derive"] } diff --git a/src/response.rs b/src/response.rs index d5535d4..3f5f67e 100644 --- a/src/response.rs +++ b/src/response.rs @@ -28,6 +28,9 @@ use flate2::read::MultiGzDecoder; #[cfg(feature = "brotli")] use brotli_decompressor::Decompressor as BrotliDecoder; +#[cfg(feature = "http")] +use std::net::{IpAddr, Ipv4Addr}; + pub const DEFAULT_CONTENT_TYPE: &str = "text/plain"; pub const DEFAULT_CHARACTER_SET: &str = "utf-8"; const INTO_STRING_LIMIT: usize = 10 * 1_024 * 1_024; @@ -872,6 +875,82 @@ impl Read for ErrorReader { } } +#[cfg(feature = "http")] +impl + Send + Sync + 'static> From> for Response { + fn from(value: http::Response) -> Self { + let version_str = format!("{:?}", value.version()); + let status_line = format!("{} {}", version_str, value.status()); + let status_num = u16::from(value.status()); + Response { + url: "https://example.com/".parse().unwrap(), + status_line, + index: ResponseStatusIndex { + http_version: version_str.len(), + response_code: version_str.len() + status_num.to_string().len(), + }, + status: status_num, + headers: value + .headers() + .iter() + .filter_map(|(name, value)| { + let mut raw_header: Vec = name.to_string().into_bytes(); + raw_header.extend([0x3a, 0x20]); // ": " + raw_header.extend(value.as_bytes()); + + HeaderLine::from(raw_header).into_header().ok() + }) + .collect::>(), + reader: Box::new(std::io::Cursor::new(value.into_body())), + remote_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 80), + history: vec![], + } + } +} + +#[cfg(feature = "http")] +fn create_builder(response: &Response) -> http::response::Builder { + let http_version = match response.http_version() { + "HTTP/0.9" => http::Version::HTTP_09, + "HTTP/1.0" => http::Version::HTTP_10, + "HTTP/1.1" => http::Version::HTTP_11, + "HTTP/2.0" => http::Version::HTTP_2, + "HTTP/3.0" => http::Version::HTTP_3, + _ => unreachable!(), + }; + + let response_builder = response + .headers + .iter() + .filter_map(|header| { + header + .value() + .map(|safe_value| (header.name().to_owned(), safe_value.to_owned())) + }) + .fold(http::Response::builder(), |builder, header| { + builder.header(header.0, header.1) + }) + .status(response.status()) + .version(http_version); + + response_builder +} + +#[cfg(feature = "http")] +impl From for http::Response> { + fn from(value: Response) -> Self { + create_builder(&value).body(value.into_reader()).unwrap() + } +} + +#[cfg(feature = "http")] +impl From for http::Response { + fn from(value: Response) -> Self { + create_builder(&value) + .body(value.into_string().unwrap()) + .unwrap() + } +} + #[cfg(test)] mod tests { use std::io::Cursor; @@ -1210,4 +1289,68 @@ mod tests { let body = resp.into_string().unwrap(); assert_eq!(body, "hi\n"); } + + #[test] + #[cfg(feature = "http")] + fn convert_http_response() { + use http::{Response, StatusCode, Version}; + + let http_response_body = (0..10240).into_iter().map(|_| 0xaa).collect::>(); + let http_response = Response::builder() + .version(Version::HTTP_2) + .header("Custom-Header", "custom value") + .header("Content-Type", "application/octet-stream") + .status(StatusCode::IM_A_TEAPOT) + .body(http_response_body.clone()) + .unwrap(); + + let response: super::Response = http_response.into(); + assert_eq!(response.get_url(), "https://example.com/"); + assert_eq!(response.http_version(), "HTTP/2.0"); + assert_eq!(response.status(), u16::from(StatusCode::IM_A_TEAPOT)); + assert_eq!(response.status_text(), "I'm a teapot"); + assert_eq!(response.remote_addr().to_string().as_str(), "127.0.0.1:80"); + assert_eq!(response.header("Custom-Header"), Some("custom value")); + assert_eq!(response.content_type(), "application/octet-stream"); + + let mut body_buf: Vec = vec![]; + response.into_reader().read_to_end(&mut body_buf).unwrap(); + assert_eq!(body_buf, http_response_body); + } + + #[test] + #[cfg(feature = "http")] + fn convert_http_response_string() { + use http::{Response, StatusCode, Version}; + + let http_response_body = "Some body string".to_string(); + let http_response = Response::builder() + .version(Version::HTTP_11) + .status(StatusCode::OK) + .body(http_response_body.clone()) + .unwrap(); + + let response: super::Response = http_response.into(); + assert_eq!(response.get_url(), "https://example.com/"); + assert_eq!(response.content_type(), "text/plain"); + assert_eq!(response.into_string().unwrap(), http_response_body); + } + + #[test] + #[cfg(feature = "http")] + fn convert_http_response_bad_header() { + use http::{Response, StatusCode, Version}; + + let http_response = Response::builder() + .version(Version::HTTP_11) + .status(StatusCode::OK) + .header("Some-Invalid-Header", vec![0xde, 0xad, 0xbe, 0xef]) + .header("Some-Valid-Header", vec![0x48, 0x45, 0x4c, 0x4c, 0x4f]) + .body(vec![]) + .unwrap(); + + let response: super::Response = http_response.into(); + assert_eq!(response.header("Some-Invalid-Header"), None); + assert_eq!(response.header("Some-Valid-Header"), Some("HELLO")); + } } diff --git a/test.sh b/test.sh index 48dbb03..76bc205 100755 --- a/test.sh +++ b/test.sh @@ -4,7 +4,7 @@ set -eu export RUST_BACKTRACE=1 export RUSTFLAGS="-D dead_code -D unused-variables -D unused" -for feature in "" tls json charset cookies socks-proxy "tls native-certs" native-tls gzip brotli; do +for feature in "" tls json charset cookies socks-proxy "tls native-certs" native-tls gzip brotli http; do if ! cargo test --no-default-features --features "${feature}" ; then echo Command failed: cargo test --no-default-features --features \"${feature}\" exit 1