1
1
mod parse;
2
2
3
- use devise:: ext:: SpanDiagnosticExt ;
4
- use devise:: { Diagnostic , Level , Result , Spanned } ;
3
+ use devise:: { Result , Spanned } ;
5
4
use proc_macro2:: { TokenStream , Span } ;
6
5
7
6
use crate :: http_codegen:: Optional ;
8
- use crate :: syn_ext:: ReturnTypeExt ;
7
+ use crate :: syn_ext:: { IdentExt , ReturnTypeExt } ;
9
8
use crate :: exports:: * ;
10
9
11
- fn error_arg_ty ( arg : & syn:: FnArg ) -> Result < & syn:: Type > {
12
- match arg {
13
- syn:: FnArg :: Receiver ( _) => Err ( Diagnostic :: spanned (
14
- arg. span ( ) ,
15
- Level :: Error ,
16
- "Catcher cannot have self as a parameter" ,
17
- ) ) ,
18
- syn:: FnArg :: Typed ( syn:: PatType { ty, .. } ) => match ty. as_ref ( ) {
19
- syn:: Type :: Reference ( syn:: TypeReference { elem, .. } ) => Ok ( elem. as_ref ( ) ) ,
20
- _ => Err ( Diagnostic :: spanned (
21
- ty. span ( ) ,
22
- Level :: Error ,
23
- "Error type must be a reference" ,
24
- ) ) ,
25
- } ,
10
+ use self :: parse:: ErrorGuard ;
11
+
12
+ use super :: param:: Guard ;
13
+
14
+ fn error_type ( guard : & ErrorGuard ) -> TokenStream {
15
+ let ty = & guard. ty ;
16
+ quote ! {
17
+ ( #_catcher:: TypeId :: of:: <#ty>( ) , :: std:: any:: type_name:: <#ty>( ) )
18
+ }
19
+ }
20
+
21
+ fn error_guard_decl ( guard : & ErrorGuard ) -> TokenStream {
22
+ let ( ident, ty) = ( guard. ident . rocketized ( ) , & guard. ty ) ;
23
+ quote_spanned ! { ty. span( ) =>
24
+ let #ident: & #ty = match #_catcher:: downcast( __error_init. as_ref( ) ) {
25
+ Some ( v) => v,
26
+ None => return #_Result:: Err ( ( #__status, __error_init) ) ,
27
+ } ;
28
+ }
29
+ }
30
+
31
+ fn request_guard_decl ( guard : & Guard ) -> TokenStream {
32
+ let ( ident, ty) = ( guard. fn_ident . rocketized ( ) , & guard. ty ) ;
33
+ quote_spanned ! { ty. span( ) =>
34
+ let #ident: #ty = match <#ty as #FromRequest >:: from_request( #__req) . await {
35
+ #Outcome :: Success ( __v) => __v,
36
+ #Outcome :: Forward ( __e) => {
37
+ :: rocket:: trace:: info!(
38
+ name: "forward" ,
39
+ target: concat!( "rocket::codegen::catch::" , module_path!( ) ) ,
40
+ parameter = stringify!( #ident) ,
41
+ type_name = stringify!( #ty) ,
42
+ status = __e. code,
43
+ "request guard forwarding; trying next catcher"
44
+ ) ;
45
+
46
+ return #_Err( ( #__status, __error_init) ) ;
47
+ } ,
48
+ #[ allow( unreachable_code) ]
49
+ #Outcome :: Error ( ( __c, __e) ) => {
50
+ :: rocket:: trace:: info!(
51
+ name: "failure" ,
52
+ target: concat!( "rocket::codegen::catch::" , module_path!( ) ) ,
53
+ parameter = stringify!( #ident) ,
54
+ type_name = stringify!( #ty) ,
55
+ reason = %#display_hack!( & __e) ,
56
+ "request guard failed; forwarding to 500 handler"
57
+ ) ;
58
+
59
+ return #_Err( ( #Status :: InternalServerError , __error_init) ) ;
60
+ }
61
+ } ;
26
62
}
27
63
}
28
64
@@ -31,7 +67,7 @@ pub fn _catch(
31
67
input : proc_macro:: TokenStream
32
68
) -> Result < TokenStream > {
33
69
// Parse and validate all of the user's input.
34
- let catch = parse:: Attribute :: parse ( args. into ( ) , input) ?;
70
+ let catch = parse:: Attribute :: parse ( args. into ( ) , input. into ( ) ) ?;
35
71
36
72
// Gather everything we'll need to generate the catcher.
37
73
let user_catcher_fn = & catch. function ;
@@ -40,48 +76,27 @@ pub fn _catch(
40
76
let status_code = Optional ( catch. status . map ( |s| s. code ) ) ;
41
77
let deprecated = catch. function . attrs . iter ( ) . find ( |a| a. path ( ) . is_ident ( "deprecated" ) ) ;
42
78
43
- // Determine the number of parameters that will be passed in.
44
- if catch. function . sig . inputs . len ( ) > 3 {
45
- return Err ( catch. function . sig . paren_token . span . join ( )
46
- . error ( "invalid number of arguments: must be zero, one, or two" )
47
- . help ( "catchers optionally take `&Request` or `Status, &Request`" ) ) ;
48
- }
49
-
50
79
// This ensures that "Responder not implemented" points to the return type.
51
80
let return_type_span = catch. function . sig . output . ty ( )
52
81
. map ( |ty| ty. span ( ) )
53
82
. unwrap_or_else ( Span :: call_site) ;
54
83
55
- // TODO: how to handle request?
56
- // - Right now: (), (&Req), (Status, &Req) allowed
57
- // Set the `req` and `status` spans to that of their respective function
58
- // arguments for a more correct `wrong type` error span. `rev` to be cute.
59
- let codegen_args = & [ __req, __status, __error] ;
60
- let inputs = catch. function . sig . inputs . iter ( ) . rev ( )
61
- . zip ( codegen_args. iter ( ) )
62
- . map ( |( fn_arg, codegen_arg) | match fn_arg {
63
- syn:: FnArg :: Receiver ( _) => codegen_arg. respanned ( fn_arg. span ( ) ) ,
64
- syn:: FnArg :: Typed ( a) => codegen_arg. respanned ( a. ty . span ( ) )
65
- } ) . rev ( ) ;
66
- let ( make_error, error_type) = if catch. function . sig . inputs . len ( ) >= 3 {
67
- let arg = catch. function . sig . inputs . first ( ) . unwrap ( ) ;
68
- let ty = error_arg_ty ( arg) ?;
69
- ( quote_spanned ! ( arg. span( ) =>
70
- let #__error: & #ty = match :: rocket:: catcher:: downcast( __error_init. as_ref( ) ) {
71
- Some ( v) => v,
72
- None => return #_Result:: Err ( ( #__status, __error_init) ) ,
73
- } ;
74
- ) , quote ! { Some ( ( #_catcher:: TypeId :: of:: <#ty>( ) , :: std:: any:: type_name:: <#ty>( ) ) ) } )
75
- } else {
76
- ( quote ! { } , quote ! { None } )
77
- } ;
84
+ let status_guard = catch. status_guard . as_ref ( ) . map ( |( _, s) | {
85
+ let ident = s. rocketized ( ) ;
86
+ quote ! { let #ident = #__status; }
87
+ } ) ;
88
+ let error_guard = catch. error_guard . as_ref ( ) . map ( error_guard_decl) ;
89
+ let error_type = Optional ( catch. error_guard . as_ref ( ) . map ( error_type) ) ;
90
+ let request_guards = catch. request_guards . iter ( ) . map ( request_guard_decl) ;
91
+ let parameter_names = catch. arguments . map . values ( )
92
+ . map ( |( ident, _) | ident. rocketized ( ) ) ;
78
93
79
94
// We append `.await` to the function call if this is `async`.
80
95
let dot_await = catch. function . sig . asyncness
81
96
. map ( |a| quote_spanned ! ( a. span( ) => . await ) ) ;
82
97
83
98
let catcher_response = quote_spanned ! ( return_type_span => {
84
- let ___responder = #user_catcher_fn_name( #( #inputs ) , * ) #dot_await;
99
+ let ___responder = #user_catcher_fn_name( #( #parameter_names ) , * ) #dot_await;
85
100
#_response:: Responder :: respond_to( ___responder, #__req) . map_err( |s| ( s, __error_init) ) ?
86
101
} ) ;
87
102
@@ -104,7 +119,9 @@ pub fn _catch(
104
119
__error_init: #ErasedError <' __r>,
105
120
) -> #_catcher:: BoxFuture <' __r> {
106
121
#_Box:: pin( async move {
107
- #make_error
122
+ #error_guard
123
+ #status_guard
124
+ #( #request_guards) *
108
125
let __response = #catcher_response;
109
126
#_Result:: Ok (
110
127
#Response :: build( )
0 commit comments