Skip to content

Commit c77d0fb

Browse files
author
Edward J Kim
authoredJul 30, 2020
Add xgboost examples for inference in script mode (aws#1341)
* Add xgboost examples for inference in script mode * Address comments by Eric
1 parent a63dec5 commit c77d0fb

File tree

4 files changed

+518
-0
lines changed

4 files changed

+518
-0
lines changed
 
Loading
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# A copy of the License is located at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
12+
# permissions and limitations under the License.
13+
import json
14+
import os
15+
import pickle as pkl
16+
17+
import numpy as np
18+
19+
import sagemaker_xgboost_container.encoder as xgb_encoders
20+
21+
22+
def model_fn(model_dir):
23+
"""
24+
Deserialize and return fitted model.
25+
"""
26+
model_file = "xgboost-model"
27+
booster = pkl.load(open(os.path.join(model_dir, model_file), "rb"))
28+
return booster
29+
30+
31+
def input_fn(request_body, request_content_type):
32+
"""
33+
The SageMaker XGBoost model server receives the request data body and the content type,
34+
and invokes the `input_fn`.
35+
36+
Return a DMatrix (an object that can be passed to predict_fn).
37+
"""
38+
if request_content_type == "text/libsvm":
39+
return xgb_encoders.libsvm_to_dmatrix(request_body)
40+
else:
41+
raise ValueError(
42+
"Content type {} is not supported.".format(request_content_type)
43+
)
44+
45+
46+
def predict_fn(input_data, model):
47+
"""
48+
SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`.
49+
50+
Return a two-dimensional NumPy array where the first columns are predictions
51+
and the remaining columns are the feature contributions (SHAP values) for that prediction.
52+
"""
53+
prediction = model.predict(input_data)
54+
feature_contribs = model.predict(input_data, pred_contribs=True)
55+
output = np.hstack((prediction[:, np.newaxis], feature_contribs))
56+
return output
57+
58+
59+
def output_fn(prediction, content_type):
60+
"""
61+
After invoking predict_fn, the model server invokes `output_fn`.
62+
"""
63+
if content_type == "application/json":
64+
return json.dumps(prediction.tolist())
65+
else:
66+
raise ValueError("Content type {} is not supported.".format(content_type))

0 commit comments

Comments
 (0)
Please sign in to comment.