Skip to content

Commit 09c56c7

Browse files
committed
Major improvements
- Catchers now carry `TypeId` and type name for collision detection - Transient updated to 0.3, with new derive macro - Added `Transient` or `Static` implementations for error types - CI should now pass
1 parent 99e2109 commit 09c56c7

File tree

17 files changed

+157
-51
lines changed

17 files changed

+157
-51
lines changed

core/codegen/src/attribute/catch/mod.rs

+18-5
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
mod parse;
22

33
use devise::ext::SpanDiagnosticExt;
4-
use devise::{Spanned, Result};
4+
use devise::{Diagnostic, Level, Result, Spanned};
55
use proc_macro2::{TokenStream, Span};
66

77
use crate::http_codegen::Optional;
88
use crate::syn_ext::ReturnTypeExt;
99
use crate::exports::*;
1010

11+
fn arg_ty(arg: &syn::FnArg) -> Result<&syn::Type> {
12+
match arg {
13+
syn::FnArg::Receiver(_) => Err(Diagnostic::spanned(
14+
arg.span(),
15+
Level::Error,
16+
"Catcher cannot have self as a parameter"
17+
)),
18+
syn::FnArg::Typed(syn::PatType {ty, ..})=> Ok(ty.as_ref()),
19+
}
20+
}
21+
1122
pub fn _catch(
1223
args: proc_macro::TokenStream,
1324
input: proc_macro::TokenStream
@@ -45,16 +56,17 @@ pub fn _catch(
4556
syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()),
4657
syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span())
4758
}).rev();
48-
let make_error = if catch.function.sig.inputs.len() >= 3 {
59+
let (make_error, error_type) = if catch.function.sig.inputs.len() >= 3 {
4960
let arg = catch.function.sig.inputs.first().unwrap();
50-
quote_spanned!(arg.span() =>
61+
let ty = arg_ty(arg)?;
62+
(quote_spanned!(arg.span() =>
5163
let #__error = match ::rocket::catcher::downcast(__error_init.as_ref()) {
5264
Some(v) => v,
5365
None => return #_Result::Err((#__status, __error_init)),
5466
};
55-
)
67+
), quote! {Some((#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>()))})
5668
} else {
57-
quote! {}
69+
(quote! {}, quote! {None})
5870
};
5971

6072
// We append `.await` to the function call if this is `async`.
@@ -99,6 +111,7 @@ pub fn _catch(
99111
#_catcher::StaticInfo {
100112
name: ::core::stringify!(#user_catcher_fn_name),
101113
code: #status_code,
114+
error_type: #error_type,
102115
handler: monomorphized_function,
103116
location: (::core::file!(), ::core::line!(), ::core::column!()),
104117
}

core/codegen/src/attribute/catch/parse.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,6 @@ impl Attribute {
5656
.map_err(|diag| diag.help("`#[catch]` expects a status code int or `default`: \
5757
`#[catch(404)]` or `#[catch(default)]`"))?;
5858

59-
Ok(Attribute { status: status.code.0, function, error: status.error })
59+
Ok(Attribute { status: status.code.0, function, error: None })
6060
}
6161
}

core/codegen/src/attribute/route/mod.rs

+14-5
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,15 @@ fn query_decls(route: &Route) -> Option<TokenStream> {
115115
);
116116
::rocket::trace::info!(
117117
target: concat!("rocket::codegen::route::", module_path!()),
118-
error_type = ::std::any::type_name_of_val(&__error),
118+
error_type = ::std::any::type_name_of_val(&__e),
119119
"Forwarding error"
120120
);
121121

122-
return #Outcome::Forward((#__data, #Status::UnprocessableEntity, #resolve_error!(__e)));
122+
return #Outcome::Forward((
123+
#__data,
124+
#Status::UnprocessableEntity,
125+
#resolve_error!(__e)
126+
));
123127
}
124128

125129
(#(#ident.unwrap()),*)
@@ -207,7 +211,11 @@ fn param_guard_decl(guard: &Guard) -> TokenStream {
207211
#i
208212
);
209213

210-
return #Outcome::Forward((#__data, #Status::InternalServerError, #resolve_error!()));
214+
return #Outcome::Forward((
215+
#__data,
216+
#Status::InternalServerError,
217+
#resolve_error!()
218+
));
211219
}
212220
}
213221
},
@@ -226,7 +234,8 @@ fn param_guard_decl(guard: &Guard) -> TokenStream {
226234

227235
fn data_guard_decl(guard: &Guard) -> TokenStream {
228236
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
229-
define_spanned_export!(ty.span() => __req, __data, display_hack, FromData, Outcome, resolve_error);
237+
define_spanned_export!(ty.span() =>
238+
__req, __data, display_hack, FromData, Outcome, resolve_error);
230239

231240
quote_spanned! { ty.span() =>
232241
let #ident: #ty = match <#ty as #FromData>::from_data(#__req, #__data).await {
@@ -251,7 +260,7 @@ fn data_guard_decl(guard: &Guard) -> TokenStream {
251260
parameter = stringify!(#ident),
252261
type_name = stringify!(#ty),
253262
reason = %#display_hack!(&__e),
254-
error_type = ::std::any::type_name_of_val(&__error),
263+
error_type = ::std::any::type_name_of_val(&__e),
255264
"data guard failed"
256265
);
257266

core/codegen/tests/catcher.rs

+20
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,23 @@ fn test_status_param() {
5858
assert_eq!(response.into_string().unwrap(), code.to_string());
5959
}
6060
}
61+
62+
#[catch(404)]
63+
fn bad_req_untyped(_: Status, _: &Request<'_>) -> &'static str { "404" }
64+
#[catch(404)]
65+
fn bad_req_string(_: &String, _: Status, _: &Request<'_>) -> &'static str { "404 String" }
66+
#[catch(404)]
67+
fn bad_req_tuple(_: &(), _: Status, _: &Request<'_>) -> &'static str { "404 ()" }
68+
69+
#[test]
70+
fn test_typed_catchers() {
71+
fn rocket() -> Rocket<Build> {
72+
rocket::build()
73+
.register("/", catchers![bad_req_untyped, bad_req_string, bad_req_tuple])
74+
}
75+
76+
// Assert the catchers do not collide. They are only differentiated by their error type.
77+
let client = Client::debug(rocket()).unwrap();
78+
let response = client.get("/").dispatch();
79+
assert_eq!(response.status(), Status::NotFound);
80+
}

core/http/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ memchr = "2"
3636
stable-pattern = "0.1"
3737
cookie = { version = "0.18", features = ["percent-encode"] }
3838
state = "0.6"
39-
transient = { version = "0.2.1" }
39+
transient = { version = "0.3" }
4040

4141
[dependencies.serde]
4242
version = "1.0"

core/lib/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ tokio-stream = { version = "0.1.6", features = ["signal", "time"] }
7474
cookie = { version = "0.18", features = ["percent-encode"] }
7575
futures = { version = "0.3.30", default-features = false, features = ["std"] }
7676
state = "0.6"
77-
transient = { version = "0.2.1" }
77+
transient = { version = "0.3" }
7878

7979
# tracing
8080
tracing = { version = "0.1.40", default-features = false, features = ["std", "attributes"] }

core/lib/src/catcher/catcher.rs

+44-18
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::fmt;
22
use std::io::Cursor;
33

4+
use transient::TypeId;
5+
46
use crate::http::uri::Path;
57
use crate::http::ext::IntoOwned;
68
use crate::response::Response;
@@ -122,6 +124,9 @@ pub struct Catcher {
122124
/// The catcher's associated error handler.
123125
pub handler: Box<dyn Handler>,
124126

127+
/// Catcher error type
128+
pub(crate) error_type: Option<(TypeId, &'static str)>,
129+
125130
/// The mount point.
126131
pub(crate) base: uri::Origin<'static>,
127132

@@ -134,10 +139,11 @@ pub struct Catcher {
134139
pub(crate) location: Option<(&'static str, u32, u32)>,
135140
}
136141

137-
// The rank is computed as -(number of nonempty segments in base) => catchers
142+
// The rank is computed as -(number of nonempty segments in base) *2 => catchers
138143
// with more nonempty segments have lower ranks => higher precedence.
144+
// Doubled to provide space between for typed catchers.
139145
fn rank(base: Path<'_>) -> isize {
140-
-(base.segments().filter(|s| !s.is_empty()).count() as isize)
146+
-(base.segments().filter(|s| !s.is_empty()).count() as isize) * 2
141147
}
142148

143149
impl Catcher {
@@ -149,22 +155,26 @@ impl Catcher {
149155
///
150156
/// ```rust
151157
/// use rocket::request::Request;
152-
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
158+
/// use rocket::catcher::{Catcher, BoxFuture, ErasedError};
153159
/// use rocket::response::Responder;
154160
/// use rocket::http::Status;
155161
///
156-
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
162+
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
163+
/// -> BoxFuture<'r>
164+
/// {
157165
/// let res = (status, format!("404: {}", req.uri()));
158-
/// Box::pin(async move { res.respond_to(req) })
166+
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
159167
/// }
160168
///
161-
/// fn handle_500<'r>(_: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
162-
/// Box::pin(async move{ "Whoops, we messed up!".respond_to(req) })
169+
/// fn handle_500<'r>(_: Status, req: &'r Request<'_>, _e: ErasedError<'r>) -> BoxFuture<'r> {
170+
/// Box::pin(async move{ "Whoops, we messed up!".respond_to(req).map_err(|s| (s, _e)) })
163171
/// }
164172
///
165-
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
173+
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
174+
/// -> BoxFuture<'r>
175+
/// {
166176
/// let res = (status, format!("{}: {}", status, req.uri()));
167-
/// Box::pin(async move { res.respond_to(req) })
177+
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
168178
/// }
169179
///
170180
/// let not_found_catcher = Catcher::new(404, handle_404);
@@ -189,6 +199,7 @@ impl Catcher {
189199
name: None,
190200
base: uri::Origin::root().clone(),
191201
handler: Box::new(handler),
202+
error_type: None,
192203
rank: rank(uri::Origin::root().path()),
193204
code,
194205
location: None,
@@ -201,13 +212,15 @@ impl Catcher {
201212
///
202213
/// ```rust
203214
/// use rocket::request::Request;
204-
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
215+
/// use rocket::catcher::{Catcher, BoxFuture, ErasedError};
205216
/// use rocket::response::Responder;
206217
/// use rocket::http::Status;
207218
///
208-
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
219+
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
220+
/// -> BoxFuture<'r>
221+
/// {
209222
/// let res = (status, format!("404: {}", req.uri()));
210-
/// Box::pin(async move { res.respond_to(req) })
223+
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
211224
/// }
212225
///
213226
/// let catcher = Catcher::new(404, handle_404);
@@ -227,14 +240,16 @@ impl Catcher {
227240
///
228241
/// ```rust
229242
/// use rocket::request::Request;
230-
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
243+
/// use rocket::catcher::{Catcher, BoxFuture, ErasedError};
231244
/// use rocket::response::Responder;
232245
/// use rocket::http::Status;
233246
/// # use rocket::uri;
234247
///
235-
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
248+
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
249+
/// -> BoxFuture<'r>
250+
/// {
236251
/// let res = (status, format!("404: {}", req.uri()));
237-
/// Box::pin(async move { res.respond_to(req) })
252+
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
238253
/// }
239254
///
240255
/// let catcher = Catcher::new(404, handle_404);
@@ -281,13 +296,15 @@ impl Catcher {
281296
///
282297
/// ```rust
283298
/// use rocket::request::Request;
284-
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
299+
/// use rocket::catcher::{Catcher, BoxFuture, ErasedError};
285300
/// use rocket::response::Responder;
286301
/// use rocket::http::Status;
287302
///
288-
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
303+
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
304+
/// -> BoxFuture<'r>
305+
/// {
289306
/// let res = (status, format!("404: {}", req.uri()));
290-
/// Box::pin(async move { res.respond_to(req) })
307+
/// Box::pin(async move { res.respond_to(req).map_err(|s| (s, _e)) })
291308
/// }
292309
///
293310
/// let catcher = Catcher::new(404, handle_404);
@@ -332,6 +349,8 @@ pub struct StaticInfo {
332349
pub name: &'static str,
333350
/// The catcher's status code.
334351
pub code: Option<u16>,
352+
/// The catcher's error type.
353+
pub error_type: Option<(TypeId, &'static str)>,
335354
/// The catcher's handler, i.e, the annotated function.
336355
pub handler: for<'r> fn(Status, &'r Request<'_>, ErasedError<'r>) -> BoxFuture<'r>,
337356
/// The file, line, and column where the catcher was defined.
@@ -343,7 +362,13 @@ impl From<StaticInfo> for Catcher {
343362
#[inline]
344363
fn from(info: StaticInfo) -> Catcher {
345364
let mut catcher = Catcher::new(info.code, info.handler);
365+
if info.error_type.is_some() {
366+
// Lower rank if the error_type is defined, to ensure typed catchers
367+
// are always tried first
368+
catcher.rank -= 1;
369+
}
346370
catcher.name = Some(info.name.into());
371+
catcher.error_type = info.error_type;
347372
catcher.location = Some(info.location);
348373
catcher
349374
}
@@ -354,6 +379,7 @@ impl fmt::Debug for Catcher {
354379
f.debug_struct("Catcher")
355380
.field("name", &self.name)
356381
.field("base", &self.base)
382+
.field("error_type", &self.error_type.as_ref().map(|(_, n)| n))
357383
.field("code", &self.code)
358384
.field("rank", &self.rank)
359385
.finish()

core/lib/src/catcher/handler.rs

+9-6
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>;
3131
/// and used as follows:
3232
///
3333
/// ```rust,no_run
34-
/// use rocket::{Request, Catcher, catcher::{self, ErasedErrorRef}};
34+
/// use rocket::{Request, Catcher, catcher::{self, ErasedError}};
3535
/// use rocket::response::{Response, Responder};
3636
/// use rocket::http::Status;
3737
///
@@ -47,11 +47,13 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>;
4747
///
4848
/// #[rocket::async_trait]
4949
/// impl catcher::Handler for CustomHandler {
50-
/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> catcher::Result<'r> {
50+
/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, _e: ErasedError<'r>)
51+
/// -> catcher::Result<'r>
52+
/// {
5153
/// let inner = match self.0 {
52-
/// Kind::Simple => "simple".respond_to(req)?,
53-
/// Kind::Intermediate => "intermediate".respond_to(req)?,
54-
/// Kind::Complex => "complex".respond_to(req)?,
54+
/// Kind::Simple => "simple".respond_to(req).map_err(|e| (e, _e))?,
55+
/// Kind::Intermediate => "intermediate".respond_to(req).map_err(|e| (e, _e))?,
56+
/// Kind::Complex => "complex".respond_to(req).map_err(|e| (e, _e))?,
5557
/// };
5658
///
5759
/// Response::build_from(inner).status(status).ok()
@@ -99,7 +101,8 @@ pub trait Handler: Cloneable + Send + Sync + 'static {
99101
/// Nevertheless, failure is allowed, both for convenience and necessity. If
100102
/// an error handler fails, Rocket's default `500` catcher is invoked. If it
101103
/// succeeds, the returned `Response` is used to respond to the client.
102-
async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, error: ErasedError<'r>) -> Result<'r>;
104+
async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, error: ErasedError<'r>)
105+
-> Result<'r>;
103106
}
104107

105108
// We write this manually to avoid double-boxing.

core/lib/src/catcher/types.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use transient::{Any, CanRecoverFrom, Co, Downcast};
22
#[doc(inline)]
3-
pub use transient::{Static, Transient};
3+
pub use transient::{Static, Transient, TypeId};
44

55
pub type ErasedError<'r> = Box<dyn Any<Co<'r>> + Send + Sync + 'r>;
66
pub type ErasedErrorRef<'r> = dyn Any<Co<'r>> + Send + Sync + 'r;

core/lib/src/error.rs

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::error::Error as StdError;
55
use std::sync::Arc;
66

77
use figment::Profile;
8+
use transient::Static;
89

910
use crate::listener::Endpoint;
1011
use crate::{Catcher, Ignite, Orbit, Phase, Rocket, Route};
@@ -85,10 +86,14 @@ pub enum ErrorKind {
8586
Shutdown(Arc<Rocket<Orbit>>),
8687
}
8788

89+
impl Static for ErrorKind {}
90+
8891
/// An error that occurs when a value was unexpectedly empty.
8992
#[derive(Clone, Copy, Default, PartialEq, Eq, Hash, PartialOrd, Ord)]
9093
pub struct Empty;
9194

95+
impl Static for Empty {}
96+
9297
impl Error {
9398
#[inline(always)]
9499
pub(crate) fn new(kind: ErrorKind) -> Error {

0 commit comments

Comments
 (0)