Implement middleware function
This commit is contained in:
67
examples/count-bytes.rs
Normal file
67
examples/count-bytes.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use ureq::{Error, Middleware, MiddlewareNext, Request, Response};
|
||||
|
||||
// Some state that could be shared with the main application.
|
||||
#[derive(Debug, Default)]
|
||||
struct CounterState {
|
||||
request_count: u64,
|
||||
total_bytes: u64,
|
||||
}
|
||||
|
||||
// Middleware wrapper working off the shared state.
|
||||
struct CounterMiddleware(Arc<Mutex<CounterState>>);
|
||||
|
||||
pub fn main() -> Result<(), Error> {
|
||||
// Shared state for counters.
|
||||
let shared_state = Arc::new(Mutex::new(CounterState::default()));
|
||||
|
||||
let agent = ureq::builder()
|
||||
// Clone the state into the middleware
|
||||
.middleware(CounterMiddleware(shared_state.clone()))
|
||||
.build();
|
||||
|
||||
agent.get("https://httpbin.org/bytes/123").call()?;
|
||||
agent.get("https://httpbin.org/bytes/123").call()?;
|
||||
|
||||
{
|
||||
let state = shared_state.lock().unwrap();
|
||||
|
||||
println!("State after requests:\n\n{:?}\n", state);
|
||||
|
||||
assert_eq!(state.request_count, 2);
|
||||
assert_eq!(state.total_bytes, 246);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Middleware for CounterMiddleware {
|
||||
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||
// Get state before request to increase request counter.
|
||||
// Extra brackets to release the lock while continuing the chain.
|
||||
{
|
||||
let mut state = self.0.lock().unwrap();
|
||||
|
||||
state.request_count += 1;
|
||||
} // release lock
|
||||
|
||||
// Continue the middleware chain
|
||||
let response = next.handle(request)?;
|
||||
|
||||
// Get state after response to increase byte count.
|
||||
// Extra brackets not necessary, but there for symmetry with first lock.
|
||||
{
|
||||
let mut state = self.0.lock().unwrap();
|
||||
|
||||
let len = response
|
||||
.header("Content-Length")
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap();
|
||||
|
||||
state.total_bytes += len;
|
||||
} // release lock
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
17
src/agent.rs
17
src/agent.rs
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use url::Url;
|
||||
|
||||
use crate::middleware::Middleware;
|
||||
use crate::pool::ConnectionPool;
|
||||
use crate::proxy::Proxy;
|
||||
use crate::request::Request;
|
||||
@@ -42,6 +43,7 @@ pub struct AgentBuilder {
|
||||
#[cfg(feature = "cookies")]
|
||||
cookie_store: Option<CookieStore>,
|
||||
resolver: ArcResolver,
|
||||
middleware: Option<Vec<Arc<dyn Middleware + Send + Sync>>>,
|
||||
}
|
||||
|
||||
/// Config as built by AgentBuilder and then static for the lifetime of the Agent.
|
||||
@@ -109,6 +111,7 @@ pub(crate) struct AgentState {
|
||||
#[cfg(feature = "cookies")]
|
||||
pub(crate) cookie_tin: CookieTin,
|
||||
pub(crate) resolver: ArcResolver,
|
||||
pub(crate) middleware: Option<Vec<Arc<dyn Middleware + Send + Sync>>>,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
@@ -245,6 +248,7 @@ impl AgentBuilder {
|
||||
resolver: StdResolver.into(),
|
||||
#[cfg(feature = "cookies")]
|
||||
cookie_store: None,
|
||||
middleware: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,6 +268,7 @@ impl AgentBuilder {
|
||||
#[cfg(feature = "cookies")]
|
||||
cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)),
|
||||
resolver: self.resolver,
|
||||
middleware: self.middleware,
|
||||
}),
|
||||
}
|
||||
}
|
||||
@@ -591,6 +596,18 @@ impl AgentBuilder {
|
||||
self.cookie_store = Some(cookie_store);
|
||||
self
|
||||
}
|
||||
|
||||
/// Add middleware handler to this agent.
|
||||
///
|
||||
/// All requests made by the agent will use this middleware. Middleware is invoked
|
||||
/// in the order they are added to the builder.
|
||||
pub fn middleware(mut self, m: impl Middleware + Send + Sync + 'static) -> Self {
|
||||
if self.middleware.is_none() {
|
||||
self.middleware = Some(vec![]);
|
||||
}
|
||||
self.middleware.as_mut().unwrap().push(Arc::new(m));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
|
||||
@@ -327,6 +327,7 @@ mod agent;
|
||||
mod body;
|
||||
mod error;
|
||||
mod header;
|
||||
mod middleware;
|
||||
mod pool;
|
||||
mod proxy;
|
||||
mod request;
|
||||
@@ -392,6 +393,7 @@ pub use crate::agent::AgentBuilder;
|
||||
pub use crate::agent::RedirectAuthHeaders;
|
||||
pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport};
|
||||
pub use crate::header::Header;
|
||||
pub use crate::middleware::{Middleware, MiddlewareNext};
|
||||
pub use crate::proxy::Proxy;
|
||||
pub use crate::request::{Request, RequestUrl};
|
||||
pub use crate::resolve::Resolver;
|
||||
|
||||
169
src/middleware.rs
Normal file
169
src/middleware.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{Error, Request, Response};
|
||||
|
||||
/// Chained processing of request (and response).
|
||||
///
|
||||
/// # Middleware as `fn`
|
||||
///
|
||||
/// The middleware trait is implemented for all functions that have the signature
|
||||
///
|
||||
/// `Fn(Request, MiddlewareNext) -> Result<Response, Error>`
|
||||
///
|
||||
/// That means the easiest way to implement middleware is by providing a `fn`, like so
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ureq::{Request, Response, MiddlewareNext, Error};
|
||||
/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||
/// // do middleware things
|
||||
///
|
||||
/// // continue the middleware chain
|
||||
/// next.handle(req)
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// # Adding headers
|
||||
///
|
||||
/// A common use case is to add headers to the outgoing request. Here an example of how.
|
||||
///
|
||||
/// ```
|
||||
/// # #[cfg(feature = "json")]
|
||||
/// # fn main() -> Result<(), ureq::Error> {
|
||||
/// # use ureq::{Request, Response, MiddlewareNext, Error};
|
||||
/// # ureq::is_test(true);
|
||||
/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||
/// // set my bespoke header and continue the chain
|
||||
/// next.handle(req.set("X-My-Header", "value_42"))
|
||||
/// }
|
||||
///
|
||||
/// let agent = ureq::builder()
|
||||
/// .middleware(my_middleware)
|
||||
/// .build();
|
||||
///
|
||||
/// let result: serde_json::Value =
|
||||
/// agent.get("http://httpbin.org/headers").call()?.into_json()?;
|
||||
///
|
||||
/// assert_eq!(&result["headers"]["X-My-Header"], "value_42");
|
||||
///
|
||||
/// # Ok(()) }
|
||||
/// # #[cfg(not(feature = "json"))]
|
||||
/// # fn main() {}
|
||||
/// ```
|
||||
///
|
||||
/// # State
|
||||
///
|
||||
/// To maintain state between middleware invocations, we need to do something more elaborate than
|
||||
/// the simple `fn` and implement the `Middleware` trait directly.
|
||||
///
|
||||
/// ## Example with mutex lock
|
||||
///
|
||||
/// In the `examples` directory there is an additional example `count-bytes.rs` which uses
|
||||
/// a mutex lock like shown below.
|
||||
///
|
||||
/// ```
|
||||
/// # use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
|
||||
/// # use std::sync::{Arc, Mutex};
|
||||
/// struct MyState {
|
||||
/// // whatever is needed
|
||||
/// }
|
||||
///
|
||||
/// struct MyMiddleware(Arc<Mutex<MyState>>);
|
||||
///
|
||||
/// impl Middleware for MyMiddleware {
|
||||
/// fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||
/// // These extra brackets ensures we release the Mutex lock before continuing the
|
||||
/// // chain. There could also be scenarios where we want to maintain the lock through
|
||||
/// // the invocation, which would block other requests from proceeding concurrently
|
||||
/// // through the middleware.
|
||||
/// {
|
||||
/// let mut state = self.0.lock().unwrap();
|
||||
/// // do stuff with state
|
||||
/// }
|
||||
///
|
||||
/// // continue middleware chain
|
||||
/// next.handle(request)
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// ## Example with atomic
|
||||
///
|
||||
/// This example shows how we can increase a counter for each request going
|
||||
/// through the agent.
|
||||
///
|
||||
/// ```
|
||||
/// # fn main() -> Result<(), ureq::Error> {
|
||||
/// # ureq::is_test(true);
|
||||
/// use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
|
||||
/// use std::sync::atomic::{AtomicU64, Ordering};
|
||||
/// use std::sync::Arc;
|
||||
///
|
||||
/// // Middleware that stores a counter state. This example uses an AtomicU64
|
||||
/// // since the middleware is potentially shared by multiple threads running
|
||||
/// // requests at the same time.
|
||||
/// struct MyCounter(Arc<AtomicU64>);
|
||||
///
|
||||
/// impl Middleware for MyCounter {
|
||||
/// fn handle(&self, req: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||
/// // increase the counter for each invocation
|
||||
/// self.0.fetch_add(1, Ordering::SeqCst);
|
||||
///
|
||||
/// // continue the middleware chain
|
||||
/// next.handle(req)
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// let shared_counter = Arc::new(AtomicU64::new(0));
|
||||
///
|
||||
/// let agent = ureq::builder()
|
||||
/// // Add our middleware
|
||||
/// .middleware(MyCounter(shared_counter.clone()))
|
||||
/// .build();
|
||||
///
|
||||
/// agent.get("http://httpbin.org/get").call()?;
|
||||
/// agent.get("http://httpbin.org/get").call()?;
|
||||
///
|
||||
/// // Check we did indeed increase the counter twice.
|
||||
/// assert_eq!(shared_counter.load(Ordering::SeqCst), 2);
|
||||
///
|
||||
/// # Ok(()) }
|
||||
/// ```
|
||||
pub trait Middleware {
|
||||
/// Handle of the middleware logic.
|
||||
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error>;
|
||||
}
|
||||
|
||||
/// Continuation of a [`Middleware`] chain.
|
||||
pub struct MiddlewareNext<'a>(Next<'a>);
|
||||
|
||||
impl<'a> MiddlewareNext<'a> {
|
||||
pub(crate) fn new(n: Next<'a>) -> Self {
|
||||
MiddlewareNext(n)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum Next<'a> {
|
||||
/// Chained middleware. The Box around the next MiddlewareNext is to break the recursive datatype.
|
||||
Chain(Arc<dyn Middleware>, Box<MiddlewareNext<'a>>),
|
||||
/// End of the middleware chain doing the actual request invocation.
|
||||
End(Box<dyn FnOnce(Request) -> Result<Response, Error> + 'a>),
|
||||
}
|
||||
|
||||
impl<'a> MiddlewareNext<'a> {
|
||||
/// Continue the middleware chain by providing (a possibly amended) [`Request`].
|
||||
pub fn handle(self, request: Request) -> Result<Response, Error> {
|
||||
match self.0 {
|
||||
Next::Chain(mw, next) => mw.handle(request, *next),
|
||||
Next::End(request_fn) => request_fn(request),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Middleware for F
|
||||
where
|
||||
F: Fn(Request, MiddlewareNext) -> Result<Response, Error>,
|
||||
{
|
||||
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||
(self)(request, next)
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ use url::{form_urlencoded, ParseError, Url};
|
||||
|
||||
use crate::body::Payload;
|
||||
use crate::header::{self, Header};
|
||||
use crate::middleware::{MiddlewareNext, Next};
|
||||
use crate::unit::{self, Unit};
|
||||
use crate::Response;
|
||||
use crate::{agent::Agent, error::Error};
|
||||
@@ -117,6 +118,8 @@ impl Request {
|
||||
#[cfg(any(feature = "gzip", feature = "brotli"))]
|
||||
self.add_accept_encoding();
|
||||
|
||||
let agent = &self.agent;
|
||||
|
||||
let deadline = match self.timeout.or(self.agent.config.timeout) {
|
||||
None => None,
|
||||
Some(timeout) => {
|
||||
@@ -125,16 +128,38 @@ impl Request {
|
||||
}
|
||||
};
|
||||
|
||||
let request_fn = |req: Request| {
|
||||
let reader = payload.into_read();
|
||||
let unit = Unit::new(
|
||||
&self.agent,
|
||||
&self.method,
|
||||
&req.agent,
|
||||
&req.method,
|
||||
&url,
|
||||
self.headers,
|
||||
req.headers,
|
||||
&reader,
|
||||
deadline,
|
||||
);
|
||||
let response = unit::connect(unit, true, reader).map_err(|e| e.url(url.clone()))?;
|
||||
|
||||
unit::connect(unit, true, reader).map_err(|e| e.url(url.clone()))
|
||||
};
|
||||
|
||||
// This clone is quite cheap since either we are cloning the Optional::None or a Vec<Arc<dyn Middleware>>.
|
||||
let maybe_middleware = agent.state.middleware.clone();
|
||||
|
||||
let response = if let Some(middleware) = maybe_middleware {
|
||||
// The request_fn is the final target in the middleware chain doing the actual invocation.
|
||||
let mut chain = MiddlewareNext::new(Next::End(Box::new(request_fn)));
|
||||
|
||||
// Build middleware in reverse order.
|
||||
for mw in middleware.into_iter().rev() {
|
||||
chain = MiddlewareNext::new(Next::Chain(mw, Box::new(chain)));
|
||||
}
|
||||
|
||||
// Run middleware chain
|
||||
chain.handle(self)?
|
||||
} else {
|
||||
// Run the request_fn without any further indirection.
|
||||
request_fn(self)?
|
||||
};
|
||||
|
||||
if response.status() >= 400 {
|
||||
Err(Error::Status(response.status(), response))
|
||||
|
||||
Reference in New Issue
Block a user