diff --git a/Cargo.toml b/Cargo.toml index c6767aa..7b9d182 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,7 @@ rustls-native-certs = { version = "0.6", optional = true } native-tls = { version = "0.2", optional = true } flate2 = { version = "1.0.22", optional = true } brotli-decompressor = { version = "2.3.2", optional = true } +mbedtls = { version = "0.8.1", optional = true } [dev-dependencies] serde = { version = "1", features = ["derive"] } @@ -55,6 +56,11 @@ rustls-pemfile = { version = "0.2" } [[example]] name = "smoke-test" +[[example]] +name = "mbedtls-req" +required-features = ["mbedtls"] + + [[example]] name = "cureq" required-features = ["charset", "cookies", "socks-proxy", "native-tls"] diff --git a/examples/mbedtls-req/main.rs b/examples/mbedtls-req/main.rs new file mode 100644 index 0000000..ae63f0c --- /dev/null +++ b/examples/mbedtls-req/main.rs @@ -0,0 +1,72 @@ +use std::io::{self, Read}; +use std::sync::{Arc}; +use std::time::Duration; +use std::{env, error, fmt, result}; + +pub mod mbedtls_connector; + +use log::{error, info}; +use ureq; + +#[derive(Debug)] +struct Oops(String); + +impl From for Oops { + fn from(e: io::Error) -> Oops { + Oops(e.to_string()) + } +} + +impl From for Oops { + fn from(e: ureq::Error) -> Oops { + Oops(e.to_string()) + } +} + +impl error::Error for Oops {} + +impl fmt::Display for Oops { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +type Result = result::Result; + +fn get(agent: &ureq::Agent, url: &str) -> Result> { + let response = agent.get(url).call()?; + let mut reader = response.into_reader(); + let mut bytes = vec![]; + reader.read_to_end(&mut bytes)?; + Ok(bytes) +} + +fn get_and_write(agent: &ureq::Agent, url: &str) { + info!("🕷️ {}", url); + match get(agent, url) { + Ok(_) => info!("Good: ✔️ {}\n", url), + Err(e) => error!("Bad: ⚠️ {} {}\n", url, e), + } +} + +fn main() -> Result<()> { + let _args = env::args(); + env_logger::init(); + + let agent = ureq::builder() + .tls_connector(Arc::new(mbedtls_connector::MbedTlsConnector::new(mbedtls::ssl::config::AuthMode::None))) + .timeout_connect(Duration::from_secs(5)) + .timeout(Duration::from_secs(20)) + .build(); + + get_and_write(&agent, "https://example.com/"); + + Ok(()) +} + +/* + * Local Variables: + * compile-command: "cargo build --example mbedtls-req --features=\"mbedtls\"" + * mode: rust + * End: + */ diff --git a/examples/mbedtls-req/mbedtls_connector.rs b/examples/mbedtls-req/mbedtls_connector.rs new file mode 100644 index 0000000..8ba5000 --- /dev/null +++ b/examples/mbedtls-req/mbedtls_connector.rs @@ -0,0 +1,116 @@ +use std::fmt; +use std::io; +use ureq::{Error, ReadWrite, TlsConnector}; + +use std::net::TcpStream; +use std::sync::{Arc, Mutex}; + +use mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use mbedtls::ssl::{Config, Context}; +use mbedtls::rng::CtrDrbg; + +fn entropy_new() -> mbedtls::rng::OsEntropy { + mbedtls::rng::OsEntropy::new() +} + +pub struct MbedTlsConnector { + context: Arc> +} + +#[derive(Debug)] +struct MbedTlsError; +impl fmt::Display for MbedTlsError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MedTLS handshake failed") + } +} + +impl std::error::Error for MbedTlsError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + None + } +} + +#[allow(dead_code)] +pub(crate) fn default_tls_config() -> std::sync::Arc { + Arc::new(MbedTlsConnector::new(mbedtls::ssl::config::AuthMode::Required)) +} + +impl MbedTlsConnector { + pub fn new(mode: mbedtls::ssl::config::AuthMode) -> MbedTlsConnector { + let entropy = Arc::new(entropy_new()); + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + let rng = Arc::new(CtrDrbg::new(entropy, None).unwrap()); + config.set_rng(rng); + config.set_authmode(mode); + let ctx = Context::new(Arc::new(config)); + MbedTlsConnector { + context: Arc::new(Mutex::new(ctx)) + } + } +} + +impl TlsConnector for MbedTlsConnector { + fn connect( + &self, + _dns_name: &str, + tcp_stream: TcpStream, + ) -> Result, Error> { + + let mut ctx = self.context.lock().unwrap(); + match ctx.establish(tcp_stream, None) { + Err(_) => { + let io_err = io::Error::new(io::ErrorKind::InvalidData, MbedTlsError); + return Err(io_err.into()); + } + Ok(()) => Ok(MbedTlsStream::new(self)) + } + } +} + +struct MbedTlsStream { + context: Arc> + //tcp_stream: TcpStream, +} + +impl MbedTlsStream { + pub fn new(mtc: &MbedTlsConnector) -> Box { + Box::new(MbedTlsStream { + context: mtc.context.clone() + }) + } +} + + +impl ReadWrite for MbedTlsStream { + fn socket(&self) -> Option<&TcpStream> { + None + } +} + +impl io::Read for MbedTlsStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let mut ctx = self.context.lock().unwrap(); + ctx.read(buf) + } +} + +impl io::Write for MbedTlsStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + let mut ctx = self.context.lock().unwrap(); + ctx.write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + let mut ctx = self.context.lock().unwrap(); + ctx.flush() + } +} + + +/* + * Local Variables: + * compile-command: "cd ../.. && cargo build --example mbedtls-req --features=\"mbedtls\"" + * mode: rust + * End: + */