Skip to content

Commit 6d06ac7

Browse files
committed
Update code to work properly with borrowed errors
1 parent 6427db2 commit 6d06ac7

File tree

8 files changed

+105
-69
lines changed

8 files changed

+105
-69
lines changed

core/lib/src/catcher/types.rs

-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ impl<'r, L, R> TypedError<'r> for Either<L, R>
9393
fn name(&self) -> &'static str { std::any::type_name::<Self>() }
9494

9595
fn source(&'r self) -> Option<&'r (dyn TypedError<'r> + 'r)> {
96-
println!("Downcasting either");
9796
match self {
9897
Self::Left(v) => Some(v),
9998
Self::Right(v) => Some(v),

core/lib/src/erased.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use futures::future::BoxFuture;
88
use http::request::Parts;
99
use tokio::io::{AsyncRead, ReadBuf};
1010

11+
use crate::catcher::TypedError;
1112
use crate::data::{Data, IoHandler, RawStream};
1213
use crate::{Request, Response, Rocket, Orbit};
1314

@@ -34,10 +35,12 @@ impl Drop for ErasedRequest {
3435
fn drop(&mut self) { }
3536
}
3637

37-
#[derive(Debug)]
38+
// TODO: #[derive(Debug)]
3839
pub struct ErasedResponse {
3940
// XXX: SAFETY: This (dependent) field must come first due to drop order!
4041
response: Response<'static>,
42+
// XXX: SAFETY: This (dependent) field must come first due to drop order!
43+
error: Option<Box<dyn TypedError<'static> + 'static>>,
4144
_request: Arc<ErasedRequest>,
4245
}
4346

@@ -94,7 +97,8 @@ impl ErasedRequest {
9497
T,
9598
&'r Rocket<Orbit>,
9699
&'r Request<'r>,
97-
Data<'r>
100+
Data<'r>,
101+
&'r mut Option<Box<dyn TypedError<'r> + 'r>>,
98102
) -> BoxFuture<'r, Response<'r>>,
99103
) -> ErasedResponse
100104
where T: Send + Sync + 'static,
@@ -111,15 +115,19 @@ impl ErasedRequest {
111115
};
112116

113117
let parent = parent;
118+
let mut error_ptr: Option<Box<dyn TypedError<'static> + 'static>> = None;
114119
let response: Response<'_> = {
115120
let parent: &ErasedRequest = &parent;
116121
let parent: &'static ErasedRequest = unsafe { transmute(parent) };
117122
let rocket: &Rocket<Orbit> = &parent._rocket;
118123
let request: &Request<'_> = &parent.request;
119-
dispatch(token, rocket, request, data).await
124+
// SAFETY: error_ptr is transmuted into the same type, with the same lifetime as the request.
125+
// It is kept alive by the erased response, so that the response type can borrow from it
126+
dispatch(token, rocket, request, data, unsafe { transmute(&mut error_ptr)}).await
120127
};
121128

122129
ErasedResponse {
130+
error: error_ptr,
123131
_request: parent,
124132
response,
125133
}

core/lib/src/lifecycle.rs

+66-41
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ impl Rocket<Orbit> {
9494
_token: RequestToken,
9595
request: &'r Request<'s>,
9696
data: Data<'r>,
97+
error_ptr: &'r mut Option<Box<dyn TypedError<'r> + 'r>>,
9798
// io_stream: impl Future<Output = io::Result<IoStream>> + Send,
9899
) -> Response<'r> {
99100
// Remember if the request is `HEAD` for later body stripping.
@@ -109,16 +110,24 @@ impl Rocket<Orbit> {
109110
request._set_method(Method::Get);
110111
match self.route(request, data).await {
111112
Outcome::Success(response) => response,
112-
Outcome::Error((status, error))
113-
=> self.dispatch_error(status, request, error).await,
114-
Outcome::Forward((_, status, error))
115-
=> self.dispatch_error(status, request, error).await,
113+
Outcome::Error((status, error)) => {
114+
*error_ptr = error;
115+
self.dispatch_error(status, request, error_ptr.as_ref().map(|b| b.as_ref())).await
116+
},
117+
Outcome::Forward((_, status, error)) => {
118+
*error_ptr = error;
119+
self.dispatch_error(status, request, error_ptr.as_ref().map(|b| b.as_ref())).await
120+
},
116121
}
117122
}
118-
Outcome::Forward((_, status, error))
119-
=> self.dispatch_error(status, request, error).await,
120-
Outcome::Error((status, error))
121-
=> self.dispatch_error(status, request, error).await,
123+
Outcome::Forward((_, status, error)) => {
124+
*error_ptr = error;
125+
self.dispatch_error(status, request, error_ptr.as_ref().map(|b| b.as_ref())).await
126+
},
127+
Outcome::Error((status, error)) => {
128+
*error_ptr = error;
129+
self.dispatch_error(status, request, error_ptr.as_ref().map(|b| b.as_ref())).await
130+
},
122131
};
123132

124133
// Set the cookies. Note that error responses will only include cookies
@@ -236,7 +245,7 @@ impl Rocket<Orbit> {
236245
&'s self,
237246
mut status: Status,
238247
req: &'r Request<'s>,
239-
mut error: Option<Box<dyn TypedError<'r> + 'r>>,
248+
mut error: Option<&'r dyn TypedError<'r>>,
240249
) -> Response<'r> {
241250
// We may wish to relax this in the future.
242251
req.cookies().reset_delta();
@@ -273,48 +282,64 @@ impl Rocket<Orbit> {
273282
async fn invoke_catcher<'s, 'r: 's>(
274283
&'s self,
275284
status: Status,
276-
error: Option<Box<dyn TypedError<'r> + 'r>>,
285+
error: Option<&'r dyn TypedError<'r>>,
277286
req: &'r Request<'s>
278287
) -> Result<Response<'r>, Option<Status>> {
279-
let error_ty = error.as_ref().map(|e| e.as_any().type_id());
280-
println!("Catching {:?}", error.as_ref().map(|e| e.name()));
281-
if let Some(catcher) = self.router.catch(status, req, error_ty) {
282-
self.invoke_specific_catcher(catcher, status, error.as_ref().map(|e| e.as_ref()), req).await
283-
} else if let Some(source) = error.as_ref().and_then(|e| e.source()) {
284-
println!("Catching {:?}", source.name());
285-
let error_ty = source.as_any().type_id();
286-
if let Some(catcher) = self.router.catch(status, req, Some(error_ty)) {
287-
self.invoke_specific_catcher(catcher, status, error.as_ref().and_then(|e| e.source()), req).await
288-
} else {
289-
info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = status.code,
290-
"no registered catcher: using Rocket default");
291-
Ok(catcher::default_handler(status, req))
288+
let mut error_copy = error;
289+
let mut counter = 0;
290+
// Matches error [.source ...] type
291+
while error_copy.is_some() && counter < 5 {
292+
if let Some(catcher) = self.router.catch(status, req, error_copy.map(|e| e.trait_obj_typeid())) {
293+
return self.invoke_specific_catcher(catcher, status, error_copy, req).await;
292294
}
293-
} else {
294-
info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = status.code,
295-
"no registered catcher: using Rocket default");
296-
Ok(catcher::default_handler(status, req))
295+
error_copy = error_copy.and_then(|e| e.source());
296+
counter += 1;
297297
}
298-
// if let Some(catcher) = self.router.catch(status, req, error.as_ref().map(|t| t.as_any().type_id())) {
299-
// catcher.trace_info();
300-
// catch_handle(
301-
// catcher.name.as_deref(),
302-
// || catcher.handler.handle(status, req, error)
303-
// ).await
304-
// .map(|result| result.map_err(|(s, e)| (Some(s), e)))
305-
// .unwrap_or_else(|| Err((None, None)))
306-
// } else {
307-
// info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = status.code,
308-
// "no registered catcher: using Rocket default");
309-
// Ok(catcher::default_handler(status, req))
310-
// }
298+
// Matches None type
299+
if let Some(catcher) = self.router.catch(status, req, None) {
300+
return self.invoke_specific_catcher(catcher, status, None, req).await;
301+
}
302+
let mut error_copy = error;
303+
let mut counter = 0;
304+
// Matches error [.source ...] type, and any status
305+
while error_copy.is_some() && counter < 5 {
306+
if let Some(catcher) = self.router.catch_any(status, req, error_copy.map(|e| e.trait_obj_typeid())) {
307+
return self.invoke_specific_catcher(catcher, status, error_copy, req).await;
308+
}
309+
error_copy = error_copy.and_then(|e| e.source());
310+
counter += 1;
311+
}
312+
// Matches None type, and any status
313+
if let Some(catcher) = self.router.catch_any(status, req, None) {
314+
return self.invoke_specific_catcher(catcher, status, None, req).await;
315+
}
316+
if let Some(error) = error {
317+
if let Ok(res) = error.respond_to(req) {
318+
return Ok(res);
319+
// TODO: this ignores the returned status.
320+
}
321+
}
322+
// Rocket default catcher
323+
info!(name: "catcher", name = "rocket::default", "uri.base" = "/", code = status.code,
324+
"no registered catcher: using Rocket default");
325+
Ok(catcher::default_handler(status, req))
326+
// TODO: document:
327+
// Set of matching catchers, tried in order:
328+
// - Matches error type
329+
// - Matches error.source type
330+
// - Matches error.source.source type
331+
// - ... etc
332+
// - Matches None type
333+
// - Registered default handler
334+
// - Rocket default handler
335+
// At each step, the catcher with the longest path is selected
311336
}
312337

313338
async fn invoke_specific_catcher<'s, 'r: 's>(
314339
&'s self,
315340
catcher: &Catcher,
316341
status: Status,
317-
error: Option<&'r (dyn TypedError<'r> + 'r)>,
342+
error: Option<&'r dyn TypedError<'r>>,
318343
req: &'r Request<'s>
319344
) -> Result<Response<'r>, Option<Status>> {
320345
catcher.trace_info();

core/lib/src/local/asynchronous/request.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,19 @@ impl<'c> LocalRequest<'c> {
8686
// _shouldn't_ error. Check that now and error only if not.
8787
if self.inner().uri() == invalid {
8888
error!("invalid request URI: {:?}", invalid.path());
89-
return LocalResponse::new(self.request, move |req| {
89+
return LocalResponse::new(self.request, move |req, error_ptr| {
9090
// TODO: Ideally the RequestErrors should contain actual information.
91-
rocket.dispatch_error(Status::BadRequest, req, Some(Box::new(RequestErrors::new(&[]))))
91+
*error_ptr = Some(Box::new(RequestErrors::new(&[])));
92+
rocket.dispatch_error(Status::BadRequest, req, error_ptr.as_ref().map(|b| b.as_ref()))
9293
}).await
9394
}
9495
}
9596

9697
// Actually dispatch the request.
9798
let mut data = Data::local(self.data);
9899
let token = rocket.preprocess(&mut self.request, &mut data).await;
99-
let response = LocalResponse::new(self.request, move |req| {
100-
rocket.dispatch(token, req, data)
100+
let response = LocalResponse::new(self.request, move |req, error_ptr| {
101+
rocket.dispatch(token, req, data, error_ptr)
101102
}).await;
102103

103104
// If the client is tracking cookies, updates the internal cookie jar

core/lib/src/local/asynchronous/response.rs

+10-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::{pin::Pin, task::{Context, Poll}};
44

55
use tokio::io::{AsyncRead, ReadBuf};
66

7+
use crate::catcher::TypedError;
78
use crate::http::CookieJar;
89
use crate::{Request, Response};
910

@@ -55,6 +56,7 @@ use crate::{Request, Response};
5556
pub struct LocalResponse<'c> {
5657
// XXX: SAFETY: This (dependent) field must come first due to drop order!
5758
response: Response<'c>,
59+
_error: Option<Box<dyn TypedError<'c> + 'c>>,
5860
cookies: CookieJar<'c>,
5961
_request: Box<Request<'c>>,
6062
}
@@ -65,8 +67,8 @@ impl Drop for LocalResponse<'_> {
6567

6668
impl<'c> LocalResponse<'c> {
6769
pub(crate) fn new<F, O>(req: Request<'c>, f: F) -> impl Future<Output = LocalResponse<'c>>
68-
where F: FnOnce(&'c Request<'c>) -> O + Send,
69-
O: Future<Output = Response<'c>> + Send
70+
where F: FnOnce(&'c Request<'c>, &'c mut Option<Box<dyn TypedError<'c> + 'c>>) -> O + Send,
71+
O: Future<Output = Response<'c>> + Send + 'c
7072
{
7173
// `LocalResponse` is a self-referential structure. In particular,
7274
// `response` and `cookies` can refer to `_request` and its contents. As
@@ -93,17 +95,21 @@ impl<'c> LocalResponse<'c> {
9395
let request: &'c Request<'c> = unsafe { &*(&*boxed_req as *const _) };
9496

9597
async move {
98+
use std::mem::transmute;
99+
let mut error: Option<Box<dyn TypedError<'c> + 'c>> = None;
96100
// NOTE: The cookie jar `secure` state will not reflect the last
97101
// known value in `request.cookies()`. This is okay: new cookies
98102
// should never be added to the resulting jar which is the only time
99103
// the value is used to set cookie defaults.
100-
let response: Response<'c> = f(request).await;
104+
// SAFETY: Much like request above, error can borrow from request, and
105+
// response can borrow from request. TODO
106+
let response: Response<'c> = f(request, unsafe { transmute(&mut error) }).await;
101107
let mut cookies = CookieJar::new(None, request.rocket());
102108
for cookie in response.cookies() {
103109
cookies.add_original(cookie.into_owned());
104110
}
105111

106-
LocalResponse { _request: boxed_req, cookies, response, }
112+
LocalResponse { _request: boxed_req, _error: error, cookies, response, }
107113
}
108114
}
109115
}

core/lib/src/router/matcher.rs

+2-9
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,12 @@ impl Catcher {
137137
/// let b_count = b.base().segments().filter(|s| !s.is_empty()).count();
138138
/// assert!(b_count > a_count);
139139
/// ```
140+
// TODO: document error matching
140141
pub fn matches(&self, status: Status, request: &Request<'_>, error: Option<TypeId>) -> bool {
141142
self.code.map_or(true, |code| code == status.code)
142-
&& self.error_matches(error)
143+
&& error == self.error_type.map(|(ty, _)| ty)
143144
&& self.base().segments().prefix_of(request.uri().path().segments())
144145
}
145-
146-
fn error_matches(&self, error: Option<TypeId>) -> bool {
147-
if let Some((ty, _)) = self.error_type {
148-
error.map_or(false, |t| t == ty)
149-
} else {
150-
true
151-
}
152-
}
153146
}
154147

155148

core/lib/src/router/router.rs

+7-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::collections::HashMap;
22

33
use transient::TypeId;
44

5+
use crate::catcher::TypedError;
56
use crate::request::Request;
67
use crate::http::{Method, Status};
78

@@ -54,16 +55,18 @@ impl Router {
5455
}
5556

5657
// For many catchers, using aho-corasick or similar should be much faster.
57-
pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>, error_ty: Option<TypeId>) -> Option<&Catcher> {
58+
// TODO: document difference between catch, and catch_any
59+
pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>, error: Option<TypeId>) -> Option<&Catcher> {
5860
// Note that catchers are presorted by descending base length.
5961
self.catchers.get(&Some(status.code))
60-
.and_then(|c| c.iter().find(|c| c.matches(status, req, error_ty)))
62+
.and_then(|c| c.iter().find(|c| c.matches(status, req, error)))
6163
}
6264

63-
pub fn catch_any<'r>(&self, status: Status, req: &'r Request<'r>, error_ty: Option<TypeId>) -> Option<&Catcher> {
65+
// For many catchers, using aho-corasick or similar should be much faster.
66+
pub fn catch_any<'r>(&self, status: Status, req: &'r Request<'r>, error: Option<TypeId>) -> Option<&Catcher> {
6467
// Note that catchers are presorted by descending base length.
6568
self.catchers.get(&None)
66-
.and_then(|c| c.iter().find(|c| c.matches(status, req, error_ty)))
69+
.and_then(|c| c.iter().find(|c| c.matches(status, req, error)))
6770
}
6871

6972
fn collisions<'a, I, T>(&self, items: I) -> impl Iterator<Item = (T, T)> + 'a

core/lib/src/server.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,17 @@ impl Rocket<Orbit> {
4343
let mut response = request.into_response(
4444
stream,
4545
|rocket, request, data| Box::pin(rocket.preprocess(request, data)),
46-
|token, rocket, request, data| Box::pin(async move {
46+
|token, rocket, request, data, error_ptr| Box::pin(async move {
4747
if !request.errors.is_empty() {
48+
*error_ptr = Some(Box::new(RequestErrors::new(&request.errors)));
4849
return rocket.dispatch_error(
4950
Status::BadRequest,
5051
request,
51-
Some(Box::new(RequestErrors::new(&request.errors)))
52+
error_ptr.as_ref().map(|b| b.as_ref())
5253
).await;
5354
}
5455

55-
rocket.dispatch(token, request, data).await
56+
rocket.dispatch(token, request, data, error_ptr).await
5657
})
5758
).await;
5859

0 commit comments

Comments
 (0)