diff --git a/examples/count-bytes.rs b/examples/count-bytes.rs new file mode 100644 index 0000000..22e8ab4 --- /dev/null +++ b/examples/count-bytes.rs @@ -0,0 +1,67 @@ +use std::sync::{Arc, Mutex}; + +use ureq::{Error, Middleware, MiddlewareNext, Request, Response}; + +// Some state that could be shared with the main application. +#[derive(Debug, Default)] +struct CounterState { + request_count: u64, + total_bytes: u64, +} + +// Middleware wrapper working off the shared state. +struct CounterMiddleware(Arc>); + +pub fn main() -> Result<(), Error> { + // Shared state for counters. + let shared_state = Arc::new(Mutex::new(CounterState::default())); + + let agent = ureq::builder() + // Clone the state into the middleware + .middleware(CounterMiddleware(shared_state.clone())) + .build(); + + agent.get("https://httpbin.org/bytes/123").call()?; + agent.get("https://httpbin.org/bytes/123").call()?; + + { + let state = shared_state.lock().unwrap(); + + println!("State after requests:\n\n{:?}\n", state); + + assert_eq!(state.request_count, 2); + assert_eq!(state.total_bytes, 246); + } + + Ok(()) +} + +impl Middleware for CounterMiddleware { + fn handle(&self, request: Request, next: MiddlewareNext) -> Result { + // Get state before request to increase request counter. + // Extra brackets to release the lock while continuing the chain. + { + let mut state = self.0.lock().unwrap(); + + state.request_count += 1; + } // release lock + + // Continue the middleware chain + let response = next.handle(request)?; + + // Get state after response to increase byte count. + // Extra brackets not necessary, but there for symmetry with first lock. + { + let mut state = self.0.lock().unwrap(); + + let len = response + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .unwrap(); + + state.total_bytes += len; + } // release lock + + Ok(response) + } +} diff --git a/src/agent.rs b/src/agent.rs index 94c4140..371693d 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use std::time::Duration; use url::Url; +use crate::middleware::Middleware; use crate::pool::ConnectionPool; use crate::proxy::Proxy; use crate::request::Request; @@ -42,6 +43,7 @@ pub struct AgentBuilder { #[cfg(feature = "cookies")] cookie_store: Option, resolver: ArcResolver, + middleware: Option>>, } /// Config as built by AgentBuilder and then static for the lifetime of the Agent. @@ -109,6 +111,7 @@ pub(crate) struct AgentState { #[cfg(feature = "cookies")] pub(crate) cookie_tin: CookieTin, pub(crate) resolver: ArcResolver, + pub(crate) middleware: Option>>, } impl Agent { @@ -245,6 +248,7 @@ impl AgentBuilder { resolver: StdResolver.into(), #[cfg(feature = "cookies")] cookie_store: None, + middleware: None, } } @@ -264,6 +268,7 @@ impl AgentBuilder { #[cfg(feature = "cookies")] cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)), resolver: self.resolver, + middleware: self.middleware, }), } } @@ -591,6 +596,18 @@ impl AgentBuilder { self.cookie_store = Some(cookie_store); self } + + /// Add middleware handler to this agent. + /// + /// All requests made by the agent will use this middleware. Middleware is invoked + /// in the order they are added to the builder. + pub fn middleware(mut self, m: impl Middleware + Send + Sync + 'static) -> Self { + if self.middleware.is_none() { + self.middleware = Some(vec![]); + } + self.middleware.as_mut().unwrap().push(Arc::new(m)); + self + } } #[cfg(feature = "tls")] diff --git a/src/lib.rs b/src/lib.rs index 5449b06..731845c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -327,6 +327,7 @@ mod agent; mod body; mod error; mod header; +mod middleware; mod pool; mod proxy; mod request; @@ -392,6 +393,7 @@ pub use crate::agent::AgentBuilder; pub use crate::agent::RedirectAuthHeaders; pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport}; pub use crate::header::Header; +pub use crate::middleware::{Middleware, MiddlewareNext}; pub use crate::proxy::Proxy; pub use crate::request::{Request, RequestUrl}; pub use crate::resolve::Resolver; diff --git a/src/middleware.rs b/src/middleware.rs new file mode 100644 index 0000000..4bcdce8 --- /dev/null +++ b/src/middleware.rs @@ -0,0 +1,169 @@ +use std::sync::Arc; + +use crate::{Error, Request, Response}; + +/// Chained processing of request (and response). +/// +/// # Middleware as `fn` +/// +/// The middleware trait is implemented for all functions that have the signature +/// +/// `Fn(Request, MiddlewareNext) -> Result` +/// +/// That means the easiest way to implement middleware is by providing a `fn`, like so +/// +/// ```no_run +/// # use ureq::{Request, Response, MiddlewareNext, Error}; +/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result { +/// // do middleware things +/// +/// // continue the middleware chain +/// next.handle(req) +/// } +/// ``` +/// +/// # Adding headers +/// +/// A common use case is to add headers to the outgoing request. Here an example of how. +/// +/// ``` +/// # #[cfg(feature = "json")] +/// # fn main() -> Result<(), ureq::Error> { +/// # use ureq::{Request, Response, MiddlewareNext, Error}; +/// # ureq::is_test(true); +/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result { +/// // set my bespoke header and continue the chain +/// next.handle(req.set("X-My-Header", "value_42")) +/// } +/// +/// let agent = ureq::builder() +/// .middleware(my_middleware) +/// .build(); +/// +/// let result: serde_json::Value = +/// agent.get("http://httpbin.org/headers").call()?.into_json()?; +/// +/// assert_eq!(&result["headers"]["X-My-Header"], "value_42"); +/// +/// # Ok(()) } +/// # #[cfg(not(feature = "json"))] +/// # fn main() {} +/// ``` +/// +/// # State +/// +/// To maintain state between middleware invocations, we need to do something more elaborate than +/// the simple `fn` and implement the `Middleware` trait directly. +/// +/// ## Example with mutex lock +/// +/// In the `examples` directory there is an additional example `count-bytes.rs` which uses +/// a mutex lock like shown below. +/// +/// ``` +/// # use ureq::{Request, Response, Middleware, MiddlewareNext, Error}; +/// # use std::sync::{Arc, Mutex}; +/// struct MyState { +/// // whatever is needed +/// } +/// +/// struct MyMiddleware(Arc>); +/// +/// impl Middleware for MyMiddleware { +/// fn handle(&self, request: Request, next: MiddlewareNext) -> Result { +/// // These extra brackets ensures we release the Mutex lock before continuing the +/// // chain. There could also be scenarios where we want to maintain the lock through +/// // the invocation, which would block other requests from proceeding concurrently +/// // through the middleware. +/// { +/// let mut state = self.0.lock().unwrap(); +/// // do stuff with state +/// } +/// +/// // continue middleware chain +/// next.handle(request) +/// } +/// } +/// ``` +/// +/// ## Example with atomic +/// +/// This example shows how we can increase a counter for each request going +/// through the agent. +/// +/// ``` +/// # fn main() -> Result<(), ureq::Error> { +/// # ureq::is_test(true); +/// use ureq::{Request, Response, Middleware, MiddlewareNext, Error}; +/// use std::sync::atomic::{AtomicU64, Ordering}; +/// use std::sync::Arc; +/// +/// // Middleware that stores a counter state. This example uses an AtomicU64 +/// // since the middleware is potentially shared by multiple threads running +/// // requests at the same time. +/// struct MyCounter(Arc); +/// +/// impl Middleware for MyCounter { +/// fn handle(&self, req: Request, next: MiddlewareNext) -> Result { +/// // increase the counter for each invocation +/// self.0.fetch_add(1, Ordering::SeqCst); +/// +/// // continue the middleware chain +/// next.handle(req) +/// } +/// } +/// +/// let shared_counter = Arc::new(AtomicU64::new(0)); +/// +/// let agent = ureq::builder() +/// // Add our middleware +/// .middleware(MyCounter(shared_counter.clone())) +/// .build(); +/// +/// agent.get("http://httpbin.org/get").call()?; +/// agent.get("http://httpbin.org/get").call()?; +/// +/// // Check we did indeed increase the counter twice. +/// assert_eq!(shared_counter.load(Ordering::SeqCst), 2); +/// +/// # Ok(()) } +/// ``` +pub trait Middleware { + /// Handle of the middleware logic. + fn handle(&self, request: Request, next: MiddlewareNext) -> Result; +} + +/// Continuation of a [`Middleware`] chain. +pub struct MiddlewareNext<'a>(Next<'a>); + +impl<'a> MiddlewareNext<'a> { + pub(crate) fn new(n: Next<'a>) -> Self { + MiddlewareNext(n) + } +} + +pub(crate) enum Next<'a> { + /// Chained middleware. The Box around the next MiddlewareNext is to break the recursive datatype. + Chain(Arc, Box>), + /// End of the middleware chain doing the actual request invocation. + End(Box Result + 'a>), +} + +impl<'a> MiddlewareNext<'a> { + /// Continue the middleware chain by providing (a possibly amended) [`Request`]. + pub fn handle(self, request: Request) -> Result { + match self.0 { + Next::Chain(mw, next) => mw.handle(request, *next), + Next::End(request_fn) => request_fn(request), + } + } +} + +impl Middleware for F +where + F: Fn(Request, MiddlewareNext) -> Result, +{ + fn handle(&self, request: Request, next: MiddlewareNext) -> Result { + (self)(request, next) + } +} diff --git a/src/request.rs b/src/request.rs index c43b50d..8f3925e 100644 --- a/src/request.rs +++ b/src/request.rs @@ -5,6 +5,7 @@ use url::{form_urlencoded, ParseError, Url}; use crate::body::Payload; use crate::header::{self, Header}; +use crate::middleware::{MiddlewareNext, Next}; use crate::unit::{self, Unit}; use crate::Response; use crate::{agent::Agent, error::Error}; @@ -117,6 +118,8 @@ impl Request { #[cfg(any(feature = "gzip", feature = "brotli"))] self.add_accept_encoding(); + let agent = &self.agent; + let deadline = match self.timeout.or(self.agent.config.timeout) { None => None, Some(timeout) => { @@ -125,16 +128,38 @@ impl Request { } }; - let reader = payload.into_read(); - let unit = Unit::new( - &self.agent, - &self.method, - &url, - self.headers, - &reader, - deadline, - ); - let response = unit::connect(unit, true, reader).map_err(|e| e.url(url.clone()))?; + let request_fn = |req: Request| { + let reader = payload.into_read(); + let unit = Unit::new( + &req.agent, + &req.method, + &url, + req.headers, + &reader, + deadline, + ); + + unit::connect(unit, true, reader).map_err(|e| e.url(url.clone())) + }; + + // This clone is quite cheap since either we are cloning the Optional::None or a Vec>. + let maybe_middleware = agent.state.middleware.clone(); + + let response = if let Some(middleware) = maybe_middleware { + // The request_fn is the final target in the middleware chain doing the actual invocation. + let mut chain = MiddlewareNext::new(Next::End(Box::new(request_fn))); + + // Build middleware in reverse order. + for mw in middleware.into_iter().rev() { + chain = MiddlewareNext::new(Next::Chain(mw, Box::new(chain))); + } + + // Run middleware chain + chain.handle(self)? + } else { + // Run the request_fn without any further indirection. + request_fn(self)? + }; if response.status() >= 400 { Err(Error::Status(response.status(), response))