Skip to content

Commit f8c8bb8

Browse files
committed
Rework catch attribute
See catch attribute docs for the new syntax.
1 parent 1308c19 commit f8c8bb8

File tree

6 files changed

+260
-94
lines changed

6 files changed

+260
-94
lines changed

core/codegen/src/attribute/catch/mod.rs

+68-51
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,64 @@
11
mod parse;
22

3-
use devise::ext::SpanDiagnosticExt;
4-
use devise::{Diagnostic, Level, Result, Spanned};
3+
use devise::{Result, Spanned};
54
use proc_macro2::{TokenStream, Span};
65

76
use crate::http_codegen::Optional;
8-
use crate::syn_ext::ReturnTypeExt;
7+
use crate::syn_ext::{IdentExt, ReturnTypeExt};
98
use crate::exports::*;
109

11-
fn error_arg_ty(arg: &syn::FnArg) -> Result<&syn::Type> {
12-
match arg {
13-
syn::FnArg::Receiver(_) => Err(Diagnostic::spanned(
14-
arg.span(),
15-
Level::Error,
16-
"Catcher cannot have self as a parameter",
17-
)),
18-
syn::FnArg::Typed(syn::PatType { ty, .. }) => match ty.as_ref() {
19-
syn::Type::Reference(syn::TypeReference { elem, .. }) => Ok(elem.as_ref()),
20-
_ => Err(Diagnostic::spanned(
21-
ty.span(),
22-
Level::Error,
23-
"Error type must be a reference",
24-
)),
25-
},
10+
use self::parse::ErrorGuard;
11+
12+
use super::param::Guard;
13+
14+
fn error_type(guard: &ErrorGuard) -> TokenStream {
15+
let ty = &guard.ty;
16+
quote! {
17+
(#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>())
18+
}
19+
}
20+
21+
fn error_guard_decl(guard: &ErrorGuard) -> TokenStream {
22+
let (ident, ty) = (guard.ident.rocketized(), &guard.ty);
23+
quote_spanned! { ty.span() =>
24+
let #ident: &#ty = match #_catcher::downcast(__error_init.as_ref()) {
25+
Some(v) => v,
26+
None => return #_Result::Err((#__status, __error_init)),
27+
};
28+
}
29+
}
30+
31+
fn request_guard_decl(guard: &Guard) -> TokenStream {
32+
let (ident, ty) = (guard.fn_ident.rocketized(), &guard.ty);
33+
quote_spanned! { ty.span() =>
34+
let #ident: #ty = match <#ty as #FromRequest>::from_request(#__req).await {
35+
#Outcome::Success(__v) => __v,
36+
#Outcome::Forward(__e) => {
37+
::rocket::trace::info!(
38+
name: "forward",
39+
target: concat!("rocket::codegen::catch::", module_path!()),
40+
parameter = stringify!(#ident),
41+
type_name = stringify!(#ty),
42+
status = __e.code,
43+
"request guard forwarding; trying next catcher"
44+
);
45+
46+
return #_Err((#__status, __error_init));
47+
},
48+
#[allow(unreachable_code)]
49+
#Outcome::Error((__c, __e)) => {
50+
::rocket::trace::info!(
51+
name: "failure",
52+
target: concat!("rocket::codegen::catch::", module_path!()),
53+
parameter = stringify!(#ident),
54+
type_name = stringify!(#ty),
55+
reason = %#display_hack!(&__e),
56+
"request guard failed; forwarding to 500 handler"
57+
);
58+
59+
return #_Err((#Status::InternalServerError, __error_init));
60+
}
61+
};
2662
}
2763
}
2864

@@ -31,7 +67,7 @@ pub fn _catch(
3167
input: proc_macro::TokenStream
3268
) -> Result<TokenStream> {
3369
// Parse and validate all of the user's input.
34-
let catch = parse::Attribute::parse(args.into(), input)?;
70+
let catch = parse::Attribute::parse(args.into(), input.into())?;
3571

3672
// Gather everything we'll need to generate the catcher.
3773
let user_catcher_fn = &catch.function;
@@ -40,48 +76,27 @@ pub fn _catch(
4076
let status_code = Optional(catch.status.map(|s| s.code));
4177
let deprecated = catch.function.attrs.iter().find(|a| a.path().is_ident("deprecated"));
4278

43-
// Determine the number of parameters that will be passed in.
44-
if catch.function.sig.inputs.len() > 3 {
45-
return Err(catch.function.sig.paren_token.span.join()
46-
.error("invalid number of arguments: must be zero, one, or two")
47-
.help("catchers optionally take `&Request` or `Status, &Request`"));
48-
}
49-
5079
// This ensures that "Responder not implemented" points to the return type.
5180
let return_type_span = catch.function.sig.output.ty()
5281
.map(|ty| ty.span())
5382
.unwrap_or_else(Span::call_site);
5483

55-
// TODO: how to handle request?
56-
// - Right now: (), (&Req), (Status, &Req) allowed
57-
// Set the `req` and `status` spans to that of their respective function
58-
// arguments for a more correct `wrong type` error span. `rev` to be cute.
59-
let codegen_args = &[__req, __status, __error];
60-
let inputs = catch.function.sig.inputs.iter().rev()
61-
.zip(codegen_args.iter())
62-
.map(|(fn_arg, codegen_arg)| match fn_arg {
63-
syn::FnArg::Receiver(_) => codegen_arg.respanned(fn_arg.span()),
64-
syn::FnArg::Typed(a) => codegen_arg.respanned(a.ty.span())
65-
}).rev();
66-
let (make_error, error_type) = if catch.function.sig.inputs.len() >= 3 {
67-
let arg = catch.function.sig.inputs.first().unwrap();
68-
let ty = error_arg_ty(arg)?;
69-
(quote_spanned!(arg.span() =>
70-
let #__error: &#ty = match ::rocket::catcher::downcast(__error_init.as_ref()) {
71-
Some(v) => v,
72-
None => return #_Result::Err((#__status, __error_init)),
73-
};
74-
), quote! {Some((#_catcher::TypeId::of::<#ty>(), ::std::any::type_name::<#ty>()))})
75-
} else {
76-
(quote! {}, quote! {None})
77-
};
84+
let status_guard = catch.status_guard.as_ref().map(|(_, s)| {
85+
let ident = s.rocketized();
86+
quote! { let #ident = #__status; }
87+
});
88+
let error_guard = catch.error_guard.as_ref().map(error_guard_decl);
89+
let error_type = Optional(catch.error_guard.as_ref().map(error_type));
90+
let request_guards = catch.request_guards.iter().map(request_guard_decl);
91+
let parameter_names = catch.arguments.map.values()
92+
.map(|(ident, _)| ident.rocketized());
7893

7994
// We append `.await` to the function call if this is `async`.
8095
let dot_await = catch.function.sig.asyncness
8196
.map(|a| quote_spanned!(a.span() => .await));
8297

8398
let catcher_response = quote_spanned!(return_type_span => {
84-
let ___responder = #user_catcher_fn_name(#(#inputs),*) #dot_await;
99+
let ___responder = #user_catcher_fn_name(#(#parameter_names),*) #dot_await;
85100
#_response::Responder::respond_to(___responder, #__req).map_err(|s| (s, __error_init))?
86101
});
87102

@@ -104,7 +119,9 @@ pub fn _catch(
104119
__error_init: #ErasedError<'__r>,
105120
) -> #_catcher::BoxFuture<'__r> {
106121
#_Box::pin(async move {
107-
#make_error
122+
#error_guard
123+
#status_guard
124+
#(#request_guards)*
108125
let __response = #catcher_response;
109126
#_Result::Ok(
110127
#Response::build()

core/codegen/src/attribute/catch/parse.rs

+112-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
use devise::ext::SpanDiagnosticExt;
1+
use devise::ext::{SpanDiagnosticExt, TypeExt};
22
use devise::{Diagnostic, FromMeta, MetaItem, Result, SpanWrapped, Spanned};
3-
use proc_macro2::TokenStream;
3+
use proc_macro2::{Span, TokenStream, Ident};
4+
use quote::ToTokens;
45

5-
use crate::attribute::param::Dynamic;
6+
use crate::attribute::param::{Dynamic, Guard};
7+
use crate::name::{ArgumentMap, Arguments, Name};
8+
use crate::proc_macro_ext::Diagnostics;
9+
use crate::syn_ext::FnArgExt;
610
use crate::{http, http_codegen};
711

812
/// This structure represents the parsed `catch` attribute and associated items.
@@ -11,15 +15,66 @@ pub struct Attribute {
1115
pub status: Option<http::Status>,
1216
/// The function that was decorated with the `catch` attribute.
1317
pub function: syn::ItemFn,
14-
pub error: Option<SpanWrapped<Dynamic>>,
18+
pub arguments: Arguments,
19+
pub error_guard: Option<ErrorGuard>,
20+
pub status_guard: Option<(Name, syn::Ident)>,
21+
pub request_guards: Vec<Guard>,
22+
}
23+
24+
pub struct ErrorGuard {
25+
pub span: Span,
26+
pub name: Name,
27+
pub ident: syn::Ident,
28+
pub ty: syn::Type,
29+
}
30+
31+
impl ErrorGuard {
32+
fn new(param: SpanWrapped<Dynamic>, args: &Arguments) -> Result<Self> {
33+
if let Some((ident, ty)) = args.map.get(&param.name) {
34+
match ty {
35+
syn::Type::Reference(syn::TypeReference { elem, .. }) => Ok(Self {
36+
span: param.span(),
37+
name: param.name.clone(),
38+
ident: ident.clone(),
39+
ty: elem.as_ref().clone(),
40+
}),
41+
ty => {
42+
let msg = format!(
43+
"Error argument must be a reference, found `{}`",
44+
ty.to_token_stream()
45+
);
46+
let diag = param.span()
47+
.error("invalid type")
48+
.span_note(ty.span(), msg)
49+
.help(format!("Perhaps use `&{}` instead", ty.to_token_stream()));
50+
Err(diag)
51+
}
52+
}
53+
} else {
54+
let msg = format!("expected argument named `{}` here", param.name);
55+
let diag = param.span().error("unused parameter").span_note(args.span, msg);
56+
Err(diag)
57+
}
58+
}
59+
}
60+
61+
fn status_guard(param: SpanWrapped<Dynamic>, args: &Arguments) -> Result<(Name, Ident)> {
62+
if let Some((ident, _)) = args.map.get(&param.name) {
63+
Ok((param.name.clone(), ident.clone()))
64+
} else {
65+
let msg = format!("expected argument named `{}` here", param.name);
66+
let diag = param.span().error("unused parameter").span_note(args.span, msg);
67+
Err(diag)
68+
}
1569
}
1670

1771
/// We generate a full parser for the meta-item for great error messages.
1872
#[derive(FromMeta)]
1973
struct Meta {
2074
#[meta(naked)]
2175
code: Code,
22-
// error: Option<SpanWrapped<Dynamic>>,
76+
error: Option<SpanWrapped<Dynamic>>,
77+
status: Option<SpanWrapped<Dynamic>>,
2378
}
2479

2580
/// `Some` if there's a code, `None` if it's `default`.
@@ -46,16 +101,66 @@ impl FromMeta for Code {
46101

47102
impl Attribute {
48103
pub fn parse(args: TokenStream, input: proc_macro::TokenStream) -> Result<Self> {
104+
let mut diags = Diagnostics::new();
105+
49106
let function: syn::ItemFn = syn::parse(input)
50107
.map_err(Diagnostic::from)
51108
.map_err(|diag| diag.help("`#[catch]` can only be used on functions"))?;
52109

53110
let attr: MetaItem = syn::parse2(quote!(catch(#args)))?;
54-
let status = Meta::from_meta(&attr)
111+
let attr = Meta::from_meta(&attr)
55112
.map(|meta| meta)
56113
.map_err(|diag| diag.help("`#[catch]` expects a status code int or `default`: \
57114
`#[catch(404)]` or `#[catch(default)]`"))?;
58115

59-
Ok(Attribute { status: status.code.0, function, error: None })
116+
let span = function.sig.paren_token.span.join();
117+
let mut arguments = Arguments { map: ArgumentMap::new(), span };
118+
for arg in function.sig.inputs.iter() {
119+
if let Some((ident, ty)) = arg.typed() {
120+
let value = (ident.clone(), ty.with_stripped_lifetimes());
121+
arguments.map.insert(Name::from(ident), value);
122+
} else {
123+
let span = arg.span();
124+
let diag = if arg.wild().is_some() {
125+
span.error("handler arguments must be named")
126+
.help("to name an ignored handler argument, use `_name`")
127+
} else {
128+
span.error("handler arguments must be of the form `ident: Type`")
129+
};
130+
131+
diags.push(diag);
132+
}
133+
}
134+
// let mut error_guard = None;
135+
let error_guard = attr.error.clone()
136+
.map(|p| ErrorGuard::new(p, &arguments))
137+
.and_then(|p| p.map_err(|e| diags.push(e)).ok());
138+
let status_guard = attr.status.clone()
139+
.map(|n| status_guard(n, &arguments))
140+
.and_then(|p| p.map_err(|e| diags.push(e)).ok());
141+
let request_guards = arguments.map.iter()
142+
.filter(|(name, _)| {
143+
let mut all_other_guards = error_guard.iter()
144+
.map(|g| &g.name)
145+
.chain(status_guard.iter().map(|(n, _)| n));
146+
147+
all_other_guards.all(|n| n != *name)
148+
})
149+
.enumerate()
150+
.map(|(index, (name, (ident, ty)))| Guard {
151+
source: Dynamic { index, name: name.clone(), trailing: false },
152+
fn_ident: ident.clone(),
153+
ty: ty.clone(),
154+
})
155+
.collect();
156+
157+
diags.head_err_or(Attribute {
158+
status: attr.code.0,
159+
function,
160+
arguments,
161+
error_guard,
162+
status_guard,
163+
request_guards,
164+
})
60165
}
61166
}

core/codegen/src/attribute/route/parse.rs

+1-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::proc_macro_ext::Diagnostics;
88
use crate::http_codegen::{Method, MediaType};
99
use crate::attribute::param::{Parameter, Dynamic, Guard};
1010
use crate::syn_ext::FnArgExt;
11-
use crate::name::Name;
11+
use crate::name::{ArgumentMap, Arguments, Name};
1212
use crate::http::ext::IntoOwned;
1313
use crate::http::uri::{Origin, fmt};
1414

@@ -31,14 +31,6 @@ pub struct Route {
3131
pub arguments: Arguments,
3232
}
3333

34-
type ArgumentMap = IndexMap<Name, (syn::Ident, syn::Type)>;
35-
36-
#[derive(Debug)]
37-
pub struct Arguments {
38-
pub span: Span,
39-
pub map: ArgumentMap
40-
}
41-
4234
/// The parsed `#[route(..)]` attribute.
4335
#[derive(Debug, FromMeta)]
4436
pub struct Attribute {

0 commit comments

Comments
 (0)