Implement middleware function
This commit is contained in:
67
examples/count-bytes.rs
Normal file
67
examples/count-bytes.rs
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use ureq::{Error, Middleware, MiddlewareNext, Request, Response};
|
||||||
|
|
||||||
|
// Some state that could be shared with the main application.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct CounterState {
|
||||||
|
request_count: u64,
|
||||||
|
total_bytes: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware wrapper working off the shared state.
|
||||||
|
struct CounterMiddleware(Arc<Mutex<CounterState>>);
|
||||||
|
|
||||||
|
pub fn main() -> Result<(), Error> {
|
||||||
|
// Shared state for counters.
|
||||||
|
let shared_state = Arc::new(Mutex::new(CounterState::default()));
|
||||||
|
|
||||||
|
let agent = ureq::builder()
|
||||||
|
// Clone the state into the middleware
|
||||||
|
.middleware(CounterMiddleware(shared_state.clone()))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
agent.get("https://httpbin.org/bytes/123").call()?;
|
||||||
|
agent.get("https://httpbin.org/bytes/123").call()?;
|
||||||
|
|
||||||
|
{
|
||||||
|
let state = shared_state.lock().unwrap();
|
||||||
|
|
||||||
|
println!("State after requests:\n\n{:?}\n", state);
|
||||||
|
|
||||||
|
assert_eq!(state.request_count, 2);
|
||||||
|
assert_eq!(state.total_bytes, 246);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Middleware for CounterMiddleware {
|
||||||
|
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||||
|
// Get state before request to increase request counter.
|
||||||
|
// Extra brackets to release the lock while continuing the chain.
|
||||||
|
{
|
||||||
|
let mut state = self.0.lock().unwrap();
|
||||||
|
|
||||||
|
state.request_count += 1;
|
||||||
|
} // release lock
|
||||||
|
|
||||||
|
// Continue the middleware chain
|
||||||
|
let response = next.handle(request)?;
|
||||||
|
|
||||||
|
// Get state after response to increase byte count.
|
||||||
|
// Extra brackets not necessary, but there for symmetry with first lock.
|
||||||
|
{
|
||||||
|
let mut state = self.0.lock().unwrap();
|
||||||
|
|
||||||
|
let len = response
|
||||||
|
.header("Content-Length")
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
state.total_bytes += len;
|
||||||
|
} // release lock
|
||||||
|
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
}
|
||||||
17
src/agent.rs
17
src/agent.rs
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
|
use crate::middleware::Middleware;
|
||||||
use crate::pool::ConnectionPool;
|
use crate::pool::ConnectionPool;
|
||||||
use crate::proxy::Proxy;
|
use crate::proxy::Proxy;
|
||||||
use crate::request::Request;
|
use crate::request::Request;
|
||||||
@@ -42,6 +43,7 @@ pub struct AgentBuilder {
|
|||||||
#[cfg(feature = "cookies")]
|
#[cfg(feature = "cookies")]
|
||||||
cookie_store: Option<CookieStore>,
|
cookie_store: Option<CookieStore>,
|
||||||
resolver: ArcResolver,
|
resolver: ArcResolver,
|
||||||
|
middleware: Option<Vec<Arc<dyn Middleware + Send + Sync>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Config as built by AgentBuilder and then static for the lifetime of the Agent.
|
/// Config as built by AgentBuilder and then static for the lifetime of the Agent.
|
||||||
@@ -109,6 +111,7 @@ pub(crate) struct AgentState {
|
|||||||
#[cfg(feature = "cookies")]
|
#[cfg(feature = "cookies")]
|
||||||
pub(crate) cookie_tin: CookieTin,
|
pub(crate) cookie_tin: CookieTin,
|
||||||
pub(crate) resolver: ArcResolver,
|
pub(crate) resolver: ArcResolver,
|
||||||
|
pub(crate) middleware: Option<Vec<Arc<dyn Middleware + Send + Sync>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Agent {
|
impl Agent {
|
||||||
@@ -245,6 +248,7 @@ impl AgentBuilder {
|
|||||||
resolver: StdResolver.into(),
|
resolver: StdResolver.into(),
|
||||||
#[cfg(feature = "cookies")]
|
#[cfg(feature = "cookies")]
|
||||||
cookie_store: None,
|
cookie_store: None,
|
||||||
|
middleware: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,6 +268,7 @@ impl AgentBuilder {
|
|||||||
#[cfg(feature = "cookies")]
|
#[cfg(feature = "cookies")]
|
||||||
cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)),
|
cookie_tin: CookieTin::new(self.cookie_store.unwrap_or_else(CookieStore::default)),
|
||||||
resolver: self.resolver,
|
resolver: self.resolver,
|
||||||
|
middleware: self.middleware,
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -591,6 +596,18 @@ impl AgentBuilder {
|
|||||||
self.cookie_store = Some(cookie_store);
|
self.cookie_store = Some(cookie_store);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add middleware handler to this agent.
|
||||||
|
///
|
||||||
|
/// All requests made by the agent will use this middleware. Middleware is invoked
|
||||||
|
/// in the order they are added to the builder.
|
||||||
|
pub fn middleware(mut self, m: impl Middleware + Send + Sync + 'static) -> Self {
|
||||||
|
if self.middleware.is_none() {
|
||||||
|
self.middleware = Some(vec![]);
|
||||||
|
}
|
||||||
|
self.middleware.as_mut().unwrap().push(Arc::new(m));
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
|
|||||||
@@ -327,6 +327,7 @@ mod agent;
|
|||||||
mod body;
|
mod body;
|
||||||
mod error;
|
mod error;
|
||||||
mod header;
|
mod header;
|
||||||
|
mod middleware;
|
||||||
mod pool;
|
mod pool;
|
||||||
mod proxy;
|
mod proxy;
|
||||||
mod request;
|
mod request;
|
||||||
@@ -392,6 +393,7 @@ pub use crate::agent::AgentBuilder;
|
|||||||
pub use crate::agent::RedirectAuthHeaders;
|
pub use crate::agent::RedirectAuthHeaders;
|
||||||
pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport};
|
pub use crate::error::{Error, ErrorKind, OrAnyStatus, Transport};
|
||||||
pub use crate::header::Header;
|
pub use crate::header::Header;
|
||||||
|
pub use crate::middleware::{Middleware, MiddlewareNext};
|
||||||
pub use crate::proxy::Proxy;
|
pub use crate::proxy::Proxy;
|
||||||
pub use crate::request::{Request, RequestUrl};
|
pub use crate::request::{Request, RequestUrl};
|
||||||
pub use crate::resolve::Resolver;
|
pub use crate::resolve::Resolver;
|
||||||
|
|||||||
169
src/middleware.rs
Normal file
169
src/middleware.rs
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::{Error, Request, Response};
|
||||||
|
|
||||||
|
/// Chained processing of request (and response).
|
||||||
|
///
|
||||||
|
/// # Middleware as `fn`
|
||||||
|
///
|
||||||
|
/// The middleware trait is implemented for all functions that have the signature
|
||||||
|
///
|
||||||
|
/// `Fn(Request, MiddlewareNext) -> Result<Response, Error>`
|
||||||
|
///
|
||||||
|
/// That means the easiest way to implement middleware is by providing a `fn`, like so
|
||||||
|
///
|
||||||
|
/// ```no_run
|
||||||
|
/// # use ureq::{Request, Response, MiddlewareNext, Error};
|
||||||
|
/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||||
|
/// // do middleware things
|
||||||
|
///
|
||||||
|
/// // continue the middleware chain
|
||||||
|
/// next.handle(req)
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # Adding headers
|
||||||
|
///
|
||||||
|
/// A common use case is to add headers to the outgoing request. Here an example of how.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # #[cfg(feature = "json")]
|
||||||
|
/// # fn main() -> Result<(), ureq::Error> {
|
||||||
|
/// # use ureq::{Request, Response, MiddlewareNext, Error};
|
||||||
|
/// # ureq::is_test(true);
|
||||||
|
/// fn my_middleware(req: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||||
|
/// // set my bespoke header and continue the chain
|
||||||
|
/// next.handle(req.set("X-My-Header", "value_42"))
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// let agent = ureq::builder()
|
||||||
|
/// .middleware(my_middleware)
|
||||||
|
/// .build();
|
||||||
|
///
|
||||||
|
/// let result: serde_json::Value =
|
||||||
|
/// agent.get("http://httpbin.org/headers").call()?.into_json()?;
|
||||||
|
///
|
||||||
|
/// assert_eq!(&result["headers"]["X-My-Header"], "value_42");
|
||||||
|
///
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// # #[cfg(not(feature = "json"))]
|
||||||
|
/// # fn main() {}
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// # State
|
||||||
|
///
|
||||||
|
/// To maintain state between middleware invocations, we need to do something more elaborate than
|
||||||
|
/// the simple `fn` and implement the `Middleware` trait directly.
|
||||||
|
///
|
||||||
|
/// ## Example with mutex lock
|
||||||
|
///
|
||||||
|
/// In the `examples` directory there is an additional example `count-bytes.rs` which uses
|
||||||
|
/// a mutex lock like shown below.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
|
||||||
|
/// # use std::sync::{Arc, Mutex};
|
||||||
|
/// struct MyState {
|
||||||
|
/// // whatever is needed
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// struct MyMiddleware(Arc<Mutex<MyState>>);
|
||||||
|
///
|
||||||
|
/// impl Middleware for MyMiddleware {
|
||||||
|
/// fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||||
|
/// // These extra brackets ensures we release the Mutex lock before continuing the
|
||||||
|
/// // chain. There could also be scenarios where we want to maintain the lock through
|
||||||
|
/// // the invocation, which would block other requests from proceeding concurrently
|
||||||
|
/// // through the middleware.
|
||||||
|
/// {
|
||||||
|
/// let mut state = self.0.lock().unwrap();
|
||||||
|
/// // do stuff with state
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// // continue middleware chain
|
||||||
|
/// next.handle(request)
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ## Example with atomic
|
||||||
|
///
|
||||||
|
/// This example shows how we can increase a counter for each request going
|
||||||
|
/// through the agent.
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// # fn main() -> Result<(), ureq::Error> {
|
||||||
|
/// # ureq::is_test(true);
|
||||||
|
/// use ureq::{Request, Response, Middleware, MiddlewareNext, Error};
|
||||||
|
/// use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
/// use std::sync::Arc;
|
||||||
|
///
|
||||||
|
/// // Middleware that stores a counter state. This example uses an AtomicU64
|
||||||
|
/// // since the middleware is potentially shared by multiple threads running
|
||||||
|
/// // requests at the same time.
|
||||||
|
/// struct MyCounter(Arc<AtomicU64>);
|
||||||
|
///
|
||||||
|
/// impl Middleware for MyCounter {
|
||||||
|
/// fn handle(&self, req: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||||
|
/// // increase the counter for each invocation
|
||||||
|
/// self.0.fetch_add(1, Ordering::SeqCst);
|
||||||
|
///
|
||||||
|
/// // continue the middleware chain
|
||||||
|
/// next.handle(req)
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// let shared_counter = Arc::new(AtomicU64::new(0));
|
||||||
|
///
|
||||||
|
/// let agent = ureq::builder()
|
||||||
|
/// // Add our middleware
|
||||||
|
/// .middleware(MyCounter(shared_counter.clone()))
|
||||||
|
/// .build();
|
||||||
|
///
|
||||||
|
/// agent.get("http://httpbin.org/get").call()?;
|
||||||
|
/// agent.get("http://httpbin.org/get").call()?;
|
||||||
|
///
|
||||||
|
/// // Check we did indeed increase the counter twice.
|
||||||
|
/// assert_eq!(shared_counter.load(Ordering::SeqCst), 2);
|
||||||
|
///
|
||||||
|
/// # Ok(()) }
|
||||||
|
/// ```
|
||||||
|
pub trait Middleware {
|
||||||
|
/// Handle of the middleware logic.
|
||||||
|
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Continuation of a [`Middleware`] chain.
|
||||||
|
pub struct MiddlewareNext<'a>(Next<'a>);
|
||||||
|
|
||||||
|
impl<'a> MiddlewareNext<'a> {
|
||||||
|
pub(crate) fn new(n: Next<'a>) -> Self {
|
||||||
|
MiddlewareNext(n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) enum Next<'a> {
|
||||||
|
/// Chained middleware. The Box around the next MiddlewareNext is to break the recursive datatype.
|
||||||
|
Chain(Arc<dyn Middleware>, Box<MiddlewareNext<'a>>),
|
||||||
|
/// End of the middleware chain doing the actual request invocation.
|
||||||
|
End(Box<dyn FnOnce(Request) -> Result<Response, Error> + 'a>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> MiddlewareNext<'a> {
|
||||||
|
/// Continue the middleware chain by providing (a possibly amended) [`Request`].
|
||||||
|
pub fn handle(self, request: Request) -> Result<Response, Error> {
|
||||||
|
match self.0 {
|
||||||
|
Next::Chain(mw, next) => mw.handle(request, *next),
|
||||||
|
Next::End(request_fn) => request_fn(request),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> Middleware for F
|
||||||
|
where
|
||||||
|
F: Fn(Request, MiddlewareNext) -> Result<Response, Error>,
|
||||||
|
{
|
||||||
|
fn handle(&self, request: Request, next: MiddlewareNext) -> Result<Response, Error> {
|
||||||
|
(self)(request, next)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ use url::{form_urlencoded, ParseError, Url};
|
|||||||
|
|
||||||
use crate::body::Payload;
|
use crate::body::Payload;
|
||||||
use crate::header::{self, Header};
|
use crate::header::{self, Header};
|
||||||
|
use crate::middleware::{MiddlewareNext, Next};
|
||||||
use crate::unit::{self, Unit};
|
use crate::unit::{self, Unit};
|
||||||
use crate::Response;
|
use crate::Response;
|
||||||
use crate::{agent::Agent, error::Error};
|
use crate::{agent::Agent, error::Error};
|
||||||
@@ -117,6 +118,8 @@ impl Request {
|
|||||||
#[cfg(any(feature = "gzip", feature = "brotli"))]
|
#[cfg(any(feature = "gzip", feature = "brotli"))]
|
||||||
self.add_accept_encoding();
|
self.add_accept_encoding();
|
||||||
|
|
||||||
|
let agent = &self.agent;
|
||||||
|
|
||||||
let deadline = match self.timeout.or(self.agent.config.timeout) {
|
let deadline = match self.timeout.or(self.agent.config.timeout) {
|
||||||
None => None,
|
None => None,
|
||||||
Some(timeout) => {
|
Some(timeout) => {
|
||||||
@@ -125,16 +128,38 @@ impl Request {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let reader = payload.into_read();
|
let request_fn = |req: Request| {
|
||||||
let unit = Unit::new(
|
let reader = payload.into_read();
|
||||||
&self.agent,
|
let unit = Unit::new(
|
||||||
&self.method,
|
&req.agent,
|
||||||
&url,
|
&req.method,
|
||||||
self.headers,
|
&url,
|
||||||
&reader,
|
req.headers,
|
||||||
deadline,
|
&reader,
|
||||||
);
|
deadline,
|
||||||
let response = unit::connect(unit, true, reader).map_err(|e| e.url(url.clone()))?;
|
);
|
||||||
|
|
||||||
|
unit::connect(unit, true, reader).map_err(|e| e.url(url.clone()))
|
||||||
|
};
|
||||||
|
|
||||||
|
// This clone is quite cheap since either we are cloning the Optional::None or a Vec<Arc<dyn Middleware>>.
|
||||||
|
let maybe_middleware = agent.state.middleware.clone();
|
||||||
|
|
||||||
|
let response = if let Some(middleware) = maybe_middleware {
|
||||||
|
// The request_fn is the final target in the middleware chain doing the actual invocation.
|
||||||
|
let mut chain = MiddlewareNext::new(Next::End(Box::new(request_fn)));
|
||||||
|
|
||||||
|
// Build middleware in reverse order.
|
||||||
|
for mw in middleware.into_iter().rev() {
|
||||||
|
chain = MiddlewareNext::new(Next::Chain(mw, Box::new(chain)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run middleware chain
|
||||||
|
chain.handle(self)?
|
||||||
|
} else {
|
||||||
|
// Run the request_fn without any further indirection.
|
||||||
|
request_fn(self)?
|
||||||
|
};
|
||||||
|
|
||||||
if response.status() >= 400 {
|
if response.status() >= 400 {
|
||||||
Err(Error::Status(response.status(), response))
|
Err(Error::Status(response.status(), response))
|
||||||
|
|||||||
Reference in New Issue
Block a user