Skip to content

Commit f0f2342

Browse files
committed
Major fixes for matching and responder
- Update Responder derive to match new trait - Update lifecycle to deal with error types correctly - Update catcher rank to ignore type - now handled by lifecycle
1 parent 99bba53 commit f0f2342

File tree

8 files changed

+313
-121
lines changed

8 files changed

+313
-121
lines changed

core/codegen/src/derive/responder.rs

+163-33
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use quote::ToTokens;
22
use devise::{*, ext::{TypeExt, SpanDiagnosticExt}};
3-
use proc_macro2::TokenStream;
3+
use proc_macro2::{Span, TokenStream};
4+
use syn::{Ident, Lifetime, Type};
45

5-
use crate::exports::*;
6+
use crate::{exports::*, syn_ext::IdentExt};
67
use crate::syn_ext::{TypeExt as _, GenericsExt as _};
78
use crate::http_codegen::{ContentType, Status};
89

@@ -25,32 +26,7 @@ pub fn derive_responder(input: proc_macro::TokenStream) -> TokenStream {
2526
.type_bound_mapper(MapperBuild::new()
2627
.try_enum_map(|m, e| mapper::enum_null(m, e))
2728
.try_fields_map(|_, fields| {
28-
let generic_idents = fields.parent.input().generics().type_idents();
29-
let lifetime = |ty: &syn::Type| syn::Lifetime::new("'o", ty.span());
30-
let mut types = fields.iter()
31-
.map(|f| (f, &f.field.inner.ty))
32-
.map(|(f, ty)| (f, ty.with_replaced_lifetimes(lifetime(ty))));
33-
34-
let mut bounds = vec![];
35-
if let Some((_, ty)) = types.next() {
36-
if !ty.is_concrete(&generic_idents) {
37-
let span = ty.span();
38-
bounds.push(quote_spanned!(span => #ty: #_response::Responder<'r, 'o>));
39-
}
40-
}
41-
42-
for (f, ty) in types {
43-
let attr = FieldAttr::one_from_attrs("response", &f.attrs)?.unwrap_or_default();
44-
if ty.is_concrete(&generic_idents) || attr.ignore {
45-
continue;
46-
}
47-
48-
bounds.push(quote_spanned! { ty.span() =>
49-
#ty: ::std::convert::Into<#_http::Header<'o>>
50-
});
51-
}
52-
53-
Ok(quote!(#(#bounds,)*))
29+
bounds_from_fields(fields)
5430
})
5531
)
5632
.validator(ValidatorBuild::new()
@@ -75,16 +51,30 @@ pub fn derive_responder(input: proc_macro::TokenStream) -> TokenStream {
7551
fn set_header_tokens<T: ToTokens + Spanned>(item: T) -> TokenStream {
7652
quote_spanned!(item.span() => __res.set_header(#item);)
7753
}
54+
55+
let error_outcome = match fields.parent {
56+
FieldParent::Variant(p) => {
57+
// let name = p.parent.ident.append("Error");
58+
// let var_name = &p.ident;
59+
// quote! { #name::#var_name(e) }
60+
quote! { #_catcher::AnyError(#_Box::new(e)) }
61+
},
62+
_ => quote! { e },
63+
};
7864

7965
let attr = ItemAttr::one_from_attrs("response", fields.parent.attrs())?
8066
.unwrap_or_default();
8167

8268
let responder = fields.iter().next().map(|f| {
8369
let (accessor, ty) = (f.accessor(), f.ty.with_stripped_lifetimes());
8470
quote_spanned! { f.span() =>
85-
let mut __res = #try_outcome!(<#ty as #_response::Responder>::respond_to(
71+
let mut __res = match <#ty as #_response::Responder>::respond_to(
8672
#accessor, __req
87-
));
73+
) {
74+
#Outcome::Success(val) => val,
75+
#Outcome::Error(e) => return #Outcome::Error(#error_outcome),
76+
#Outcome::Forward(f) => return #Outcome::Forward(f),
77+
};
8878
}
8979
}).expect("have at least one field");
9080

@@ -118,14 +108,154 @@ pub fn derive_responder(input: proc_macro::TokenStream) -> TokenStream {
118108
type Error = #output;
119109
})
120110
.try_struct_map(|_, item| {
121-
let responder = item.fields.iter().next().map(|f| {
122-
&f.ty
111+
let (old, ty) = item.fields.iter().next().map(|f| {
112+
let ty = f.ty.with_replaced_lifetimes(Lifetime::new("'o", Span::call_site()));
113+
let old = f.ty.with_replaced_lifetimes(Lifetime::new("'a", Span::call_site()));
114+
(old, ty)
123115
}).expect("have at least one field");
116+
let type_params: Vec<_> = item.generics.type_params().map(|p| &p.ident).collect();
117+
let output_life = if old == ty && ty.is_concrete(&type_params) {
118+
quote! { 'static }
119+
} else {
120+
quote! { 'o }
121+
};
124122

125123
Ok(quote! {
126-
<#responder as #_response::Responder<'r, 'o>>::Error
124+
<#ty as #_response::Responder<'r, #output_life>>::Error
127125
})
128126
})
127+
.enum_map(|_, _item| {
128+
// let name = item.ident.append("Error");
129+
// let response_types: Vec<_> = item.variants()
130+
// .flat_map(|f| responder_types(f.fields()).into_iter()).collect();
131+
// // TODO: add where clauses, and filter for the type params I need
132+
// let type_params: Vec<_> = item.generics
133+
// .type_params()
134+
// .map(|p| &p.ident)
135+
// .filter(|p| generic_used(p, &response_types))
136+
// .collect();
137+
// quote!{ #name<'r, 'o, #(#type_params,)*> }
138+
quote!{ #_catcher::AnyError<'r> }
139+
})
129140
)
141+
// .outer_mapper(MapperBuild::new()
142+
// .enum_map(|_, item| {
143+
// let name = item.ident.append("Error");
144+
// let variants = item.variants().map(|d| {
145+
// let var_name = &d.ident;
146+
// let (old, ty) = d.fields().iter().next().map(|f| {
147+
// let ty = f.ty.with_replaced_lifetimes(Lifetime::new("'o", Span::call_site()));
148+
// (f.ty.clone(), ty)
149+
// }).expect("have at least one field");
150+
// let output_life = if old == ty {
151+
// quote! { 'static }
152+
// } else {
153+
// quote! { 'o }
154+
// };
155+
// quote!{
156+
// #var_name(<#ty as #_response::Responder<'r, #output_life>>::Error),
157+
// }
158+
// });
159+
// let source = item.variants().map(|d| {
160+
// let var_name = &d.ident;
161+
// quote!{
162+
// Self::#var_name(v) => #_Some(v),
163+
// }
164+
// });
165+
// let response_types: Vec<_> = item.variants()
166+
// .flat_map(|f| responder_types(f.fields()).into_iter()).collect();
167+
// // TODO: add where clauses, and filter for the type params I need
168+
// let type_params: Vec<_> = item.generics
169+
// .type_params()
170+
// .map(|p| &p.ident)
171+
// .filter(|p| generic_used(p, &response_types))
172+
// .collect();
173+
// // let bounds: Vec<_> = item.variants().map(|f| bounds_from_fields(f.fields()).expect("Bounds must be valid")).collect();
174+
// let bounds: Vec<_> = item.variants()
175+
// .flat_map(|f| responder_types(f.fields()).into_iter())
176+
// .map(|t| quote!{#t: #_response::Responder<'r, 'o>,})
177+
// .collect();
178+
// quote!{
179+
// pub enum #name<'r, 'o, #(#type_params: 'r,)*>
180+
// where #(#bounds)*
181+
// {
182+
// #(#variants)*
183+
// UnusedVariant(
184+
// // Make this variant impossible to construct
185+
// ::std::convert::Infallible,
186+
// ::std::marker::PhantomData<&'o ()>,
187+
// ),
188+
// }
189+
// // TODO: validate this impl - roughly each variant must be (at least) inv
190+
// // wrt a lifetime, since they impl CanTransendTo<Inv<'r>>
191+
// // TODO: also need to add requirements on the type parameters
192+
// unsafe impl<'r, 'o: 'r, #(#type_params: 'r,)*> ::rocket::catcher::Transient for #name<'r, 'o, #(#type_params,)*>
193+
// where #(#bounds)*
194+
// {
195+
// type Static = #name<'static, 'static>;
196+
// type Transience = ::rocket::catcher::Inv<'r>;
197+
// }
198+
// impl<'r, 'o: 'r, #(#type_params,)*> #TypedError<'r> for #name<'r, 'o, #(#type_params,)*>
199+
// where #(#bounds)*
200+
// {
201+
// fn source(&self) -> #_Option<&dyn #TypedError<'r>> {
202+
// match self {
203+
// #(#source)*
204+
// Self::UnusedVariant(f, ..) => match *f { }
205+
// }
206+
// }
207+
// }
208+
// }
209+
// })
210+
// )
130211
.to_tokens()
131212
}
213+
214+
fn generic_used(ident: &Ident, res_types: &[Type]) -> bool {
215+
res_types.iter().any(|t| !t.is_concrete(&[ident]))
216+
}
217+
218+
fn responder_types(fields: Fields<'_>) -> Vec<Type> {
219+
let generic_idents = fields.parent.input().generics().type_idents();
220+
let lifetime = |ty: &syn::Type| syn::Lifetime::new("'o", ty.span());
221+
let mut types = fields.iter()
222+
.map(|f| (f, &f.field.inner.ty))
223+
.map(|(f, ty)| (f, ty.with_replaced_lifetimes(lifetime(ty))));
224+
225+
let mut bounds = vec![];
226+
if let Some((_, ty)) = types.next() {
227+
if !ty.is_concrete(&generic_idents) {
228+
bounds.push(ty);
229+
}
230+
}
231+
bounds
232+
}
233+
234+
fn bounds_from_fields(fields: Fields<'_>) -> Result<TokenStream> {
235+
let generic_idents = fields.parent.input().generics().type_idents();
236+
let lifetime = |ty: &syn::Type| syn::Lifetime::new("'o", ty.span());
237+
let mut types = fields.iter()
238+
.map(|f| (f, &f.field.inner.ty))
239+
.map(|(f, ty)| (f, ty.with_replaced_lifetimes(lifetime(ty))));
240+
241+
let mut bounds = vec![];
242+
if let Some((_, ty)) = types.next() {
243+
if !ty.is_concrete(&generic_idents) {
244+
let span = ty.span();
245+
bounds.push(quote_spanned!(span => #ty: #_response::Responder<'r, 'o>));
246+
}
247+
}
248+
249+
for (f, ty) in types {
250+
let attr = FieldAttr::one_from_attrs("response", &f.attrs)?.unwrap_or_default();
251+
if ty.is_concrete(&generic_idents) || attr.ignore {
252+
continue;
253+
}
254+
255+
bounds.push(quote_spanned! { ty.span() =>
256+
#ty: ::std::convert::Into<#_http::Header<'o>>
257+
});
258+
}
259+
260+
Ok(quote!(#(#bounds,)*))
261+
}

core/lib/src/catcher/catcher.rs

+1-6
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ pub struct Catcher {
133133
// with more nonempty segments have lower ranks => higher precedence.
134134
// Doubled to provide space between for typed catchers.
135135
fn rank(base: Path<'_>) -> isize {
136-
-(base.segments().filter(|s| !s.is_empty()).count() as isize) * 2
136+
-(base.segments().filter(|s| !s.is_empty()).count() as isize)
137137
}
138138

139139
impl Catcher {
@@ -355,11 +355,6 @@ impl From<StaticInfo> for Catcher {
355355
#[inline]
356356
fn from(info: StaticInfo) -> Catcher {
357357
let mut catcher = Catcher::new(info.code, info.handler);
358-
if info.error_type.is_some() {
359-
// Lower rank if the error_type is defined, to ensure typed catchers
360-
// are always tried first
361-
catcher.rank -= 1;
362-
}
363358
catcher.name = Some(info.name.into());
364359
catcher.error_type = info.error_type;
365360
catcher.location = Some(info.location);

core/lib/src/catcher/types.rs

+17-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use either::Either;
2-
use transient::{Any, CanRecoverFrom, CanTranscendTo, Downcast, Inv, Transience};
2+
use transient::{Any, CanRecoverFrom, CanTranscendTo, Downcast, Transience};
33
use crate::{http::Status, Request, Response};
44
#[doc(inline)]
5-
pub use transient::{Static, Transient, TypeId};
5+
pub use transient::{Static, Transient, TypeId, Inv};
66

77
/// Polyfill for trait upcasting to [`Any`]
88
pub trait AsAny<Tr: Transience>: Any<Tr> + Sealed {
@@ -57,6 +57,8 @@ pub trait TypedError<'r>: AsAny<Inv<'r>> + Send + Sync + 'r {
5757

5858
impl<'r> TypedError<'r> for std::convert::Infallible { }
5959

60+
impl<'r> TypedError<'r> for () { }
61+
6062
impl<'r> TypedError<'r> for std::io::Error {
6163
fn status(&self) -> Status {
6264
match self.kind() {
@@ -107,6 +109,18 @@ impl<'r, L, R> TypedError<'r> for Either<L, R>
107109
}
108110
}
109111

112+
// TODO: This cannot be used as a bound on an untyped catcher to get any error type.
113+
// This is mostly an implementation detail (and issue with double boxing) for
114+
// the responder derive
115+
#[derive(Transient)]
116+
pub struct AnyError<'r>(pub Box<dyn TypedError<'r> + 'r>);
117+
118+
impl<'r> TypedError<'r> for AnyError<'r> {
119+
fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> {
120+
Some(self.0.as_ref())
121+
}
122+
}
123+
110124
pub fn downcast<'r, T: Transient + 'r>(v: Option<&'r dyn TypedError<'r>>) -> Option<&'r T>
111125
where T::Transience: CanRecoverFrom<Inv<'r>>
112126
{
@@ -177,7 +191,7 @@ pub mod resolution {
177191
{
178192
pub const SPECIALIZED: bool = true;
179193

180-
pub fn cast(self) -> Option<Box<dyn TypedError<'r>>> { Some(Box::new(self.0))}
194+
pub fn cast(self) -> Option<Box<dyn TypedError<'r>>> { Some(Box::new(self.0)) }
181195
}
182196

183197
/// Wrapper type to hold the return type of `resolve_typed_catcher`.

0 commit comments

Comments
 (0)