Implement middleware function

This commit is contained in:
Martin Algesten
2021-12-20 10:21:19 +01:00
parent 7b2f28bbc2
commit 09ecb6ffd6
5 changed files with 290 additions and 10 deletions

67
examples/count-bytes.rs Normal file
View File

@@ -0,0 +1,67 @@
use std::sync::{Arc, Mutex};
use ureq::{Error, Middleware, MiddlewareNext, Request, Response};
// Some state that could be shared with the main application.
#[derive(Debug, Default)]
struct CounterState {
request_count: u64,
total_bytes: u64,
}
// Middleware wrapper working off the shared state.
struct CounterMiddleware(Arc<Mutex<CounterState>>);
pub fn main() -> Result<(), Error> {
// Shared state for counters.
let shared_state = Arc::new(Mutex::new(CounterState::default()));
let agent = ureq::builder()
// Clone the state into the middleware
.middleware(CounterMiddleware(shared_state.clone()))
.build();
agent.get("https://httpbin.org/bytes/123").call()?;
agent.get("https://httpbin.org/bytes/123").call()?;
{
let state = shared_state.lock().unwrap();
println!("State after requests:\n\n{:?}\n", state);
assert_eq!(state.request_count, 2);
assert_eq!(state.total_bytes, 246);
}
Ok(())
}
impl Middleware for CounterMiddleware {
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
// Get state before request to increase request counter.
// Extra brackets to release the lock while continuing the chain.
{
let mut state = self.0.lock().unwrap();
state.request_count += 1;
} // release lock
// Continue the middleware chain
let response = next.handle(request)?;
// Get state after response to increase byte count.
// Extra brackets not necessary, but there for symmetry with first lock.
{
let mut state = self.0.lock().unwrap();
let len = response
.header("Content-Length")
.and_then(|s| s.parse::<u64>().ok())
.unwrap();
state.total_bytes += len;
} // release lock
Ok(response)
}
}

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use url::Url; use url::Url;
use crate::middleware::Middleware;
use crate::pool::ConnectionPool; use crate::pool::ConnectionPool;
use crate::proxy::Proxy; use crate::proxy::Proxy;
use crate::request::Request; use crate::request::Request;
@@ -42,6 +43,7 @@ pub struct AgentBuilder {
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
cookie_store: Option<CookieStore>, cookie_store: Option<CookieStore>,
resolver: ArcResolver, resolver: ArcResolver,
middleware: Option<Vec<Arc<dyn Middleware + Send + Sync>>>,
} }
/// Config as built by AgentBuilder and then static for the lifetime of the Agent. /// Config as built by AgentBuilder and then static for the lifetime of the Agent.
@@ -109,6 +111,7 @@ pub(crate) struct AgentState {
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
pub(crate) cookie_tin: CookieTin, pub(crate) cookie_tin: CookieTin,
pub(crate) resolver: ArcResolver, pub(crate) resolver: ArcResolver,
pub(crate) middleware: Option<Vec<Arc<dyn Middleware + Send + Sync>>>,
} }
impl Agent { impl Agent {
@@ -245,6 +248,7 @@ impl AgentBuilder {
resolver: StdResolver.into(), resolver: StdResolver.into(),
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
cookie_store: None, cookie_store: None,
middleware: None,
} }
} }
@@ -264,6 +268,7 @@ impl AgentBuilder {
#[cfg(feature = "cookies")] #[cfg(feature = "cookies")]
cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)), cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)),
resolver: self.resolver, resolver: self.resolver,
middleware: self.middleware,
}), }),
} }
} }
@@ -591,6 +596,18 @@ impl AgentBuilder {
self.cookie_store = Some(cookie_store); self.cookie_store = Some(cookie_store);
self self
} }
/// Add middleware handler to this agent.
///
/// All requests made by the agent will use this middleware. Middleware is invoked
/// in the order they are added to the builder.
pub fn middleware(mut self, m: impl Middleware + Send + Sync + 'static) -> Self {
if self.middleware.is_none() {
self.middleware = Some(vec![]);
}
self.middleware.as_mut().unwrap().push(Arc::new(m));
self
}
} }
#[cfg(feature = "tls")] #[cfg(feature = "tls")]

View File

@@ -327,6 +327,7 @@ mod agent;
mod body; mod body;
mod error; mod error;
mod header; mod header;
mod middleware;
mod pool; mod pool;
mod proxy; mod proxy;
mod request; mod request;
@@ -392,6 +393,7 @@ pub use crate::agent::AgentBuilder;
pub use crate::agent::RedirectAuthHeaders; pub use crate::agent::RedirectAuthHeaders;
pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport}; pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport};
pub use crate::header::Header; pub use crate::header::Header;
pub use crate::middleware::{Middleware, MiddlewareNext};
pub use crate::proxy::Proxy; pub use crate::proxy::Proxy;
pub use crate::request::{Request, RequestUrl}; pub use crate::request::{Request, RequestUrl};
pub use crate::resolve::Resolver; pub use crate::resolve::Resolver;

169
src/middleware.rs Normal file
View File

@@ -0,0 +1,169 @@
use std::sync::Arc;
use crate::{Error, Request, Response};
/// Chained processing of request (and response).
///
/// # Middleware as `fn`
///
/// The middleware trait is implemented for all functions that have the signature
///
/// `Fn(Request, MiddlewareNext) -> Result<Response, Error>`
///
/// That means the easiest way to implement middleware is by providing a `fn`, like so
///
/// ```no_run
/// # use ureq::{Request, Response, MiddlewareNext, Error};
/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
/// // do middleware things
///
/// // continue the middleware chain
/// next.handle(req)
/// }
/// ```
///
/// # Adding headers
///
/// A common use case is to add headers to the outgoing request. Here an example of how.
///
/// ```
/// # #[cfg(feature = "json")]
/// # fn main() -> Result<(), ureq::Error> {
/// # use ureq::{Request, Response, MiddlewareNext, Error};
/// # ureq::is_test(true);
/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
/// // set my bespoke header and continue the chain
/// next.handle(req.set("X-My-Header", "value_42"))
/// }
///
/// let agent = ureq::builder()
/// .middleware(my_middleware)
/// .build();
///
/// let result: serde_json::Value =
/// agent.get("http://httpbin.org/headers").call()?.into_json()?;
///
/// assert_eq!(&result["headers"]["X-My-Header"], "value_42");
///
/// # Ok(()) }
/// # #[cfg(not(feature = "json"))]
/// # fn main() {}
/// ```
///
/// # State
///
/// To maintain state between middleware invocations, we need to do something more elaborate than
/// the simple `fn` and implement the `Middleware` trait directly.
///
/// ## Example with mutex lock
///
/// In the `examples` directory there is an additional example `count-bytes.rs` which uses
/// a mutex lock like shown below.
///
/// ```
/// # use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
/// # use std::sync::{Arc, Mutex};
/// struct MyState {
/// // whatever is needed
/// }
///
/// struct MyMiddleware(Arc<Mutex<MyState>>);
///
/// impl Middleware for MyMiddleware {
/// fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
/// // These extra brackets ensures we release the Mutex lock before continuing the
/// // chain. There could also be scenarios where we want to maintain the lock through
/// // the invocation, which would block other requests from proceeding concurrently
/// // through the middleware.
/// {
/// let mut state = self.0.lock().unwrap();
/// // do stuff with state
/// }
///
/// // continue middleware chain
/// next.handle(request)
/// }
/// }
/// ```
///
/// ## Example with atomic
///
/// This example shows how we can increase a counter for each request going
/// through the agent.
///
/// ```
/// # fn main() -> Result<(), ureq::Error> {
/// # ureq::is_test(true);
/// use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
/// use std::sync::atomic::{AtomicU64, Ordering};
/// use std::sync::Arc;
///
/// // Middleware that stores a counter state. This example uses an AtomicU64
/// // since the middleware is potentially shared by multiple threads running
/// // requests at the same time.
/// struct MyCounter(Arc<AtomicU64>);
///
/// impl Middleware for MyCounter {
/// fn handle(&self, req: Request, next: MiddlewareNext) -> Result<Response, Error> {
/// // increase the counter for each invocation
/// self.0.fetch_add(1, Ordering::SeqCst);
///
/// // continue the middleware chain
/// next.handle(req)
/// }
/// }
///
/// let shared_counter = Arc::new(AtomicU64::new(0));
///
/// let agent = ureq::builder()
/// // Add our middleware
/// .middleware(MyCounter(shared_counter.clone()))
/// .build();
///
/// agent.get("http://httpbin.org/get").call()?;
/// agent.get("http://httpbin.org/get").call()?;
///
/// // Check we did indeed increase the counter twice.
/// assert_eq!(shared_counter.load(Ordering::SeqCst), 2);
///
/// # Ok(()) }
/// ```
pub trait Middleware {
/// Handle of the middleware logic.
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error>;
}
/// Continuation of a [`Middleware`] chain.
pub struct MiddlewareNext<'a>(Next<'a>);
impl<'a> MiddlewareNext<'a> {
pub(crate) fn new(n: Next<'a>) -> Self {
MiddlewareNext(n)
}
}
pub(crate) enum Next<'a> {
/// Chained middleware. The Box around the next MiddlewareNext is to break the recursive datatype.
Chain(Arc<dyn Middleware>, Box<MiddlewareNext<'a>>),
/// End of the middleware chain doing the actual request invocation.
End(Box<dyn FnOnce(Request) -> Result<Response, Error> + 'a>),
}
impl<'a> MiddlewareNext<'a> {
/// Continue the middleware chain by providing (a possibly amended) [`Request`].
pub fn handle(self, request: Request) -> Result<Response, Error> {
match self.0 {
Next::Chain(mw, next) => mw.handle(request, *next),
Next::End(request_fn) => request_fn(request),
}
}
}
impl<F> Middleware for F
where
F: Fn(Request, MiddlewareNext) -> Result<Response, Error>,
{
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
(self)(request, next)
}
}

View File

@@ -5,6 +5,7 @@ use url::{form_urlencoded, ParseError, Url};
use crate::body::Payload; use crate::body::Payload;
use crate::header::{self, Header}; use crate::header::{self, Header};
use crate::middleware::{MiddlewareNext, Next};
use crate::unit::{self, Unit}; use crate::unit::{self, Unit};
use crate::Response; use crate::Response;
use crate::{agent::Agent, error::Error}; use crate::{agent::Agent, error::Error};
@@ -117,6 +118,8 @@ impl Request {
#[cfg(any(feature = "gzip", feature = "brotli"))] #[cfg(any(feature = "gzip", feature = "brotli"))]
self.add_accept_encoding(); self.add_accept_encoding();
let agent = &self.agent;
let deadline = match self.timeout.or(self.agent.config.timeout) { let deadline = match self.timeout.or(self.agent.config.timeout) {
None => None, None => None,
Some(timeout) => { Some(timeout) => {
@@ -125,16 +128,38 @@ impl Request {
} }
}; };
let request_fn = |req: Request| {
let reader = payload.into_read(); let reader = payload.into_read();
let unit = Unit::new( let unit = Unit::new(
&self.agent, &req.agent,
&self.method, &req.method,
&url, &url,
self.headers, req.headers,
&reader, &reader,
deadline, deadline,
); );
let response = unit::connect(unit, true, reader).map_err(|e| e.url(url.clone()))?;
unit::connect(unit, true, reader).map_err(|e| e.url(url.clone()))
};
// This clone is quite cheap since either we are cloning the Optional::None or a Vec<Arc<dyn Middleware>>.
let maybe_middleware = agent.state.middleware.clone();
let response = if let Some(middleware) = maybe_middleware {
// The request_fn is the final target in the middleware chain doing the actual invocation.
let mut chain = MiddlewareNext::new(Next::End(Box::new(request_fn)));
// Build middleware in reverse order.
for mw in middleware.into_iter().rev() {
chain = MiddlewareNext::new(Next::Chain(mw, Box::new(chain)));
}
// Run middleware chain
chain.handle(self)?
} else {
// Run the request_fn without any further indirection.
request_fn(self)?
};
if response.status() >= 400 { if response.status() >= 400 {
Err(Error::Status(response.status(), response)) Err(Error::Status(response.status(), response))