Skip to content

Commit 56e7fa6

Browse files
committed
Initial brush
1 parent 6857b82 commit 56e7fa6

File tree

15 files changed

+212
-46
lines changed

15 files changed

+212
-46
lines changed

core/codegen/src/attribute/catch/mod.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,25 @@ pub fn _catch(
3434
.map(|ty| ty.span())
3535
.unwrap_or_else(Span::call_site);
3636

37+
// TODO: how to handle request?
38+
// - Right now: (), (&Req), (Status, &Req) allowed
39+
// - New: (), (&E), (&Req, &E), (Status, &Req, &E)
3740
// Set the `req` and `status` spans to that of their respective function
3841
// arguments for a more correct `wrong type` error span. `rev` to be cute.
39-
let codegen_args = &[__req, __status];
42+
let codegen_args = &[__req, __status, __error];
4043
let inputs = catch.function.sig.inputs.iter().rev()
4144
.zip(codegen_args.iter())
4245
.map(|(fn_arg, codegen_arg)| match fn_arg {
4346
syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()),
4447
syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span())
4548
}).rev();
49+
let make_error = if let Some(arg) = catch.function.sig.inputs.iter().rev().next() {
50+
quote_spanned!(arg.span() =>
51+
// let
52+
)
53+
} else {
54+
quote! {}
55+
};
4656

4757
// We append `.await` to the function call if this is `async`.
4858
let dot_await = catch.function.sig.asyncness
@@ -68,9 +78,11 @@ pub fn _catch(
6878
fn into_info(self) -> #_catcher::StaticInfo {
6979
fn monomorphized_function<'__r>(
7080
#__status: #Status,
71-
#__req: &'__r #Request<'_>
81+
#__req: &'__r #Request<'_>,
82+
__error_init: &#ErasedErrorRef<'__r>,
7283
) -> #_catcher::BoxFuture<'__r> {
7384
#_Box::pin(async move {
85+
#make_error
7486
let __response = #catcher_response;
7587
#Response::build()
7688
.status(#__status)

core/codegen/src/attribute/route/mod.rs

+10-6
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ fn query_decls(route: &Route) -> Option<TokenStream> {
125125
fn request_guard_decl(guard: &Guard) -> TokenStream {
126126
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
127127
define_spanned_export!(ty.span() =>
128-
__req, __data, _request, display_hack, FromRequest, Outcome
128+
__req, __data, _request, display_hack, FromRequest, Outcome, ErrorResolver, ErrorDefault
129129
);
130130

131131
quote_spanned! { ty.span() =>
@@ -150,11 +150,13 @@ fn request_guard_decl(guard: &Guard) -> TokenStream {
150150
target: concat!("rocket::codegen::route::", module_path!()),
151151
parameter = stringify!(#ident),
152152
type_name = stringify!(#ty),
153-
reason = %#display_hack!(__e),
153+
reason = %#display_hack!(&__e),
154154
"request guard failed"
155155
);
156156

157-
return #Outcome::Error(__c);
157+
#[allow(unused)]
158+
use #ErrorDefault;
159+
return #Outcome::Error((__c, #ErrorResolver::new(__e).cast()));
158160
}
159161
};
160162
}
@@ -219,7 +221,7 @@ fn param_guard_decl(guard: &Guard) -> TokenStream {
219221

220222
fn data_guard_decl(guard: &Guard) -> TokenStream {
221223
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
222-
define_spanned_export!(ty.span() => __req, __data, display_hack, FromData, Outcome);
224+
define_spanned_export!(ty.span() => __req, __data, display_hack, FromData, Outcome, ErrorResolver, ErrorDefault);
223225

224226
quote_spanned! { ty.span() =>
225227
let #ident: #ty = match <#ty as #FromData>::from_data(#__req, #__data).await {
@@ -243,11 +245,13 @@ fn data_guard_decl(guard: &Guard) -> TokenStream {
243245
target: concat!("rocket::codegen::route::", module_path!()),
244246
parameter = stringify!(#ident),
245247
type_name = stringify!(#ty),
246-
reason = %#display_hack!(__e),
248+
reason = %#display_hack!(&__e),
247249
"data guard failed"
248250
);
249251

250-
return #Outcome::Error(__c);
252+
#[allow(unused)]
253+
use #ErrorDefault;
254+
return #Outcome::Error((__c, #ErrorResolver::new(__e).cast()));
251255
}
252256
};
253257
}

core/codegen/src/exports.rs

+3
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ define_exported_paths! {
102102
Route => ::rocket::Route,
103103
Catcher => ::rocket::Catcher,
104104
Status => ::rocket::http::Status,
105+
ErrorResolver => ::rocket::catcher::resolution::Resolve,
106+
ErrorDefault => ::rocket::catcher::resolution::DefaultTypeErase,
107+
ErasedErrorRef => ::rocket::catcher::ErasedErrorRef,
105108
}
106109

107110
macro_rules! define_spanned_export {

core/lib/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ tokio-stream = { version = "0.1.6", features = ["signal", "time"] }
7474
cookie = { version = "0.18", features = ["percent-encode"] }
7575
futures = { version = "0.3.30", default-features = false, features = ["std"] }
7676
state = "0.6"
77+
transient = { version = "0.2.0", path = "../../../transient" }
7778

7879
# tracing
7980
tracing = { version = "0.1.40", default-features = false, features = ["std", "attributes"] }

core/lib/src/catcher/catcher.rs

+15-13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use crate::request::Request;
88
use crate::http::{Status, ContentType, uri};
99
use crate::catcher::{Handler, BoxFuture};
1010

11+
use super::ErasedErrorRef;
12+
1113
/// An error catching route.
1214
///
1315
/// Catchers are routes that run when errors are produced by the application.
@@ -147,20 +149,20 @@ impl Catcher {
147149
///
148150
/// ```rust
149151
/// use rocket::request::Request;
150-
/// use rocket::catcher::{Catcher, BoxFuture};
152+
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
151153
/// use rocket::response::Responder;
152154
/// use rocket::http::Status;
153155
///
154-
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
156+
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
155157
/// let res = (status, format!("404: {}", req.uri()));
156158
/// Box::pin(async move { res.respond_to(req) })
157159
/// }
158160
///
159-
/// fn handle_500<'r>(_: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
161+
/// fn handle_500<'r>(_: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
160162
/// Box::pin(async move{ "Whoops, we messed up!".respond_to(req) })
161163
/// }
162164
///
163-
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
165+
/// fn handle_default<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
164166
/// let res = (status, format!("{}: {}", status, req.uri()));
165167
/// Box::pin(async move { res.respond_to(req) })
166168
/// }
@@ -199,11 +201,11 @@ impl Catcher {
199201
///
200202
/// ```rust
201203
/// use rocket::request::Request;
202-
/// use rocket::catcher::{Catcher, BoxFuture};
204+
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
203205
/// use rocket::response::Responder;
204206
/// use rocket::http::Status;
205207
///
206-
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
208+
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
207209
/// let res = (status, format!("404: {}", req.uri()));
208210
/// Box::pin(async move { res.respond_to(req) })
209211
/// }
@@ -225,12 +227,12 @@ impl Catcher {
225227
///
226228
/// ```rust
227229
/// use rocket::request::Request;
228-
/// use rocket::catcher::{Catcher, BoxFuture};
230+
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
229231
/// use rocket::response::Responder;
230232
/// use rocket::http::Status;
231233
/// # use rocket::uri;
232234
///
233-
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
235+
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
234236
/// let res = (status, format!("404: {}", req.uri()));
235237
/// Box::pin(async move { res.respond_to(req) })
236238
/// }
@@ -279,11 +281,11 @@ impl Catcher {
279281
///
280282
/// ```rust
281283
/// use rocket::request::Request;
282-
/// use rocket::catcher::{Catcher, BoxFuture};
284+
/// use rocket::catcher::{Catcher, BoxFuture, ErasedErrorRef};
283285
/// use rocket::response::Responder;
284286
/// use rocket::http::Status;
285287
///
286-
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
288+
/// fn handle_404<'r>(status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
287289
/// let res = (status, format!("404: {}", req.uri()));
288290
/// Box::pin(async move { res.respond_to(req) })
289291
/// }
@@ -313,7 +315,7 @@ impl Catcher {
313315

314316
impl Default for Catcher {
315317
fn default() -> Self {
316-
fn handler<'r>(s: Status, req: &'r Request<'_>) -> BoxFuture<'r> {
318+
fn handler<'r>(s: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
317319
Box::pin(async move { Ok(default_handler(s, req)) })
318320
}
319321

@@ -331,7 +333,7 @@ pub struct StaticInfo {
331333
/// The catcher's status code.
332334
pub code: Option<u16>,
333335
/// The catcher's handler, i.e, the annotated function.
334-
pub handler: for<'r> fn(Status, &'r Request<'_>) -> BoxFuture<'r>,
336+
pub handler: for<'r> fn(Status, &'r Request<'_>, &ErasedErrorRef<'r>) -> BoxFuture<'r>,
335337
/// The file, line, and column where the catcher was defined.
336338
pub location: (&'static str, u32, u32),
337339
}
@@ -418,7 +420,7 @@ macro_rules! default_handler_fn {
418420

419421
pub(crate) fn default_handler<'r>(
420422
status: Status,
421-
req: &'r Request<'_>
423+
req: &'r Request<'_>,
422424
) -> Response<'r> {
423425
let preferred = req.accept().map(|a| a.preferred());
424426
let (mime, text) = if preferred.map_or(false, |a| a.is_json()) {

core/lib/src/catcher/handler.rs

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use crate::{Request, Response};
22
use crate::http::Status;
33

4+
use super::ErasedErrorRef;
5+
46
/// Type alias for the return type of a [`Catcher`](crate::Catcher)'s
57
/// [`Handler::handle()`].
68
pub type Result<'r> = std::result::Result<Response<'r>, crate::http::Status>;
@@ -29,7 +31,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>;
2931
/// and used as follows:
3032
///
3133
/// ```rust,no_run
32-
/// use rocket::{Request, Catcher, catcher};
34+
/// use rocket::{Request, Catcher, catcher::{self, ErasedErrorRef}};
3335
/// use rocket::response::{Response, Responder};
3436
/// use rocket::http::Status;
3537
///
@@ -45,7 +47,7 @@ pub type BoxFuture<'r, T = Result<'r>> = futures::future::BoxFuture<'r, T>;
4547
///
4648
/// #[rocket::async_trait]
4749
/// impl catcher::Handler for CustomHandler {
48-
/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>) -> catcher::Result<'r> {
50+
/// async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, _e: &ErasedErrorRef<'r>) -> catcher::Result<'r> {
4951
/// let inner = match self.0 {
5052
/// Kind::Simple => "simple".respond_to(req)?,
5153
/// Kind::Intermediate => "intermediate".respond_to(req)?,
@@ -97,30 +99,32 @@ pub trait Handler: Cloneable + Send + Sync + 'static {
9799
/// Nevertheless, failure is allowed, both for convenience and necessity. If
98100
/// an error handler fails, Rocket's default `500` catcher is invoked. If it
99101
/// succeeds, the returned `Response` is used to respond to the client.
100-
async fn handle<'r>(&self, status: Status, req: &'r Request<'_>) -> Result<'r>;
102+
async fn handle<'r>(&self, status: Status, req: &'r Request<'_>, error: &ErasedErrorRef<'r>) -> Result<'r>;
101103
}
102104

103105
// We write this manually to avoid double-boxing.
104106
impl<F: Clone + Sync + Send + 'static> Handler for F
105-
where for<'x> F: Fn(Status, &'x Request<'_>) -> BoxFuture<'x>,
107+
where for<'x> F: Fn(Status, &'x Request<'_>, &ErasedErrorRef<'x>) -> BoxFuture<'x>,
106108
{
107-
fn handle<'r, 'life0, 'life1, 'async_trait>(
109+
fn handle<'r, 'life0, 'life1, 'life2, 'async_trait>(
108110
&'life0 self,
109111
status: Status,
110112
req: &'r Request<'life1>,
113+
error: &'life2 ErasedErrorRef<'r>,
111114
) -> BoxFuture<'r>
112115
where 'r: 'async_trait,
113116
'life0: 'async_trait,
114117
'life1: 'async_trait,
118+
'life2: 'async_trait,
115119
Self: 'async_trait,
116120
{
117-
self(status, req)
121+
self(status, req, error)
118122
}
119123
}
120124

121125
// Used in tests! Do not use, please.
122126
#[doc(hidden)]
123-
pub fn dummy_handler<'r>(_: Status, _: &'r Request<'_>) -> BoxFuture<'r> {
127+
pub fn dummy_handler<'r>(_: Status, _: &'r Request<'_>, _: &ErasedErrorRef<'r>) -> BoxFuture<'r> {
124128
Box::pin(async move { Ok(Response::new()) })
125129
}
126130

core/lib/src/catcher/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
33
mod catcher;
44
mod handler;
5+
mod types;
56

67
pub use catcher::*;
78
pub use handler::*;
9+
pub use types::*;

core/lib/src/catcher/types.rs

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
use transient::{Any, CanRecoverFrom, Co, Transient, Downcast};
2+
3+
pub type ErasedError<'r> = Box<dyn Any<Co<'r>> + Send + Sync + 'r>;
4+
pub type ErasedErrorRef<'r> = dyn Any<Co<'r>> + Send + Sync + 'r;
5+
6+
pub fn default_error_type<'r>() -> ErasedError<'r> {
7+
Box::new(())
8+
}
9+
10+
pub fn downcast<'a, 'r, T: Transient + 'r>(v: &'a ErasedErrorRef<'r>) -> Option<&'a T>
11+
where T::Transience: CanRecoverFrom<Co<'r>>
12+
{
13+
v.downcast_ref()
14+
}
15+
16+
// /// Chosen not to expose this macro, since it's pretty short and sweet
17+
// #[doc(hidden)]
18+
// #[macro_export]
19+
// macro_rules! resolve_typed_catcher {
20+
// ($T:expr) => ({
21+
// #[allow(unused_imports)]
22+
// use $crate::catcher::types::Resolve;
23+
//
24+
// Resolve::new($T).cast()
25+
// })
26+
// }
27+
28+
// pub use resolve_typed_catcher;
29+
30+
pub mod resolution {
31+
use std::marker::PhantomData;
32+
33+
use transient::{CanTranscendTo, Transient};
34+
35+
use super::*;
36+
37+
/// The *magic*.
38+
///
39+
/// `Resolve<T>::item` for `T: Transient` is `<T as Transient>::item`.
40+
/// `Resolve<T>::item` for `T: !Transient` is `DefaultTypeErase::item`.
41+
///
42+
/// This _must_ be used as `Resolve::<T>:item` for resolution to work. This
43+
/// is a fun, static dispatch hack for "specialization" that works because
44+
/// Rust prefers inherent methods over blanket trait impl methods.
45+
pub struct Resolve<'r, T: 'r>(T, PhantomData<&'r ()>);
46+
47+
impl<'r, T: 'r> Resolve<'r, T> {
48+
pub fn new(val: T) -> Self {
49+
Self(val, PhantomData)
50+
}
51+
}
52+
53+
/// Fallback trait "implementing" `Transient` for all types. This is what
54+
/// Rust will resolve `Resolve<T>::item` to when `T: !Transient`.
55+
pub trait DefaultTypeErase<'r>: Sized {
56+
const SPECIALIZED: bool = false;
57+
58+
fn cast(self) -> ErasedError<'r> { Box::new(()) }
59+
}
60+
61+
impl<'r, T: 'r> DefaultTypeErase<'r> for Resolve<'r, T> {}
62+
63+
/// "Specialized" "implementation" of `Transient` for `T: Transient`. This is
64+
/// what Rust will resolve `Resolve<T>::item` to when `T: Transient`.
65+
impl<'r, T: Transient + Send + Sync + 'r> Resolve<'r, T>
66+
where T::Transience: CanTranscendTo<Co<'r>>
67+
{
68+
pub const SPECIALIZED: bool = true;
69+
70+
pub fn cast(self) -> ErasedError<'r> { Box::new(self.0) }
71+
}
72+
}
73+
74+
#[cfg(test)]
75+
mod test {
76+
// use std::any::TypeId;
77+
78+
use transient::{Transient, TypeId};
79+
80+
use super::resolution::{Resolve, DefaultTypeErase};
81+
82+
struct NotAny;
83+
#[derive(Transient)]
84+
struct YesAny;
85+
86+
#[test]
87+
fn check_can_determine() {
88+
let not_any = Resolve::new(NotAny).cast();
89+
assert_eq!(not_any.type_id(), TypeId::of::<()>());
90+
91+
let yes_any = Resolve::new(YesAny).cast();
92+
assert_ne!(yes_any.type_id(), TypeId::of::<()>());
93+
}
94+
95+
// struct HasSentinel<T>(T);
96+
97+
// #[test]
98+
// fn parent_works() {
99+
// let child = resolve!(YesASentinel, HasSentinel<YesASentinel>);
100+
// assert!(child.type_name.ends_with("YesASentinel"));
101+
// assert_eq!(child.parent.unwrap(), TypeId::of::<HasSentinel<YesASentinel>>());
102+
// assert!(child.specialized);
103+
104+
// let not_a_direct_sentinel = resolve!(HasSentinel<YesASentinel>);
105+
// assert!(not_a_direct_sentinel.type_name.contains("HasSentinel"));
106+
// assert!(not_a_direct_sentinel.type_name.contains("YesASentinel"));
107+
// assert!(not_a_direct_sentinel.parent.is_none());
108+
// assert!(!not_a_direct_sentinel.specialized);
109+
// }
110+
}

0 commit comments

Comments
 (0)