Skip to content

Commit f918cd1

Browse files
udaij12Ubuntu
and
Ubuntu
authored
Token Authorization fixes (#3192)
* test * token authorization integration * env false regression cpu ci * testing ci * testing newman * fix newman tests * testing pytest * testing cmd arg * pytest fixes * fixing tests * doc update * spell check * fixing priority between config file and cmd * test fixes * removing unneeded files * Delete unneeded files * review fixes * removing comments * adding doc clarification and new test * changes to docs * adding new tests * fixing merge conflict * format fix * format fixes * addressing comments * fixing merge conflict * fixing merge conflict * fixing merge conflict * fix merge conflict * doc update * fixing format * fix to benchmarks --------- Co-authored-by: Ubuntu <[email protected]>
1 parent 9336ad2 commit f918cd1

File tree

9 files changed

+79
-4397
lines changed

9 files changed

+79
-4397
lines changed

benchmarks/utils/system_under_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def start(self):
116116
click.secho("*Starting local Torchserve instance...", fg="green")
117117

118118
ts_cmd = (
119-
f"torchserve --start --model-store {self.execution_params['tmp_dir']}/model_store --model-api-enabled --disable-token"
119+
f"torchserve --start --model-store {self.execution_params['tmp_dir']}/model_store --model-api-enabled --disable-token "
120120
f"--workflow-store {self.execution_params['tmp_dir']}/wf_store "
121121
f"--ts-config {self.execution_params['tmp_dir']}/benchmark/conf/{self.execution_params['config_properties_name']} "
122122
f" > {self.execution_params['tmp_dir']}/benchmark/logs/model_metrics.log"
@@ -195,7 +195,7 @@ def start(self):
195195
f"docker run {self.execution_params['docker_runtime']} {backend_profiling} --name ts --user root -p "
196196
f"127.0.0.1:{inference_port}:{inference_port} -p 127.0.0.1:{management_port}:{management_port} "
197197
f"-v {self.execution_params['tmp_dir']}:/tmp {enable_gpu} -itd {docker_image} "
198-
f'"torchserve --start --model-store /home/model-server/model-store --model-api-enabled --disable-token'
198+
f'"torchserve --start --model-store /home/model-server/model-store --model-api-enabled --disable-token '
199199
f"\--workflow-store /home/model-server/wf-store "
200200
f"--ts-config /tmp/benchmark/conf/{self.execution_params['config_properties_name']} > "
201201
f'/tmp/benchmark/logs/model_metrics.log"'

docs/token_authorization_api.md

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# TorchServe token authorization API
22

3-
Torchserve now supports token authorization by default.
3+
TorchServe now enforces token authorization by default
4+
45

56
## How to set and disable Token Authorization
67
* Global environment variable: use `TS_DISABLE_TOKEN_AUTHORIZATION` and set to `true` to disable and `false` to enable token authorization. Note that `enable_envvars_config=true` must be set in config.properties for global environment variables to be used
78
* Command line: Command line can only be used to disable token authorization by adding the `--disable-token` flag.
89
* Config properties file: use `disable_token_authorization` and set to `true` to disable and `false` to enable token authorization.
910

10-
Priority between env variables, cmd, and config file follows the following [TorchServer standard](https://github.com/pytorch/serve/blob/c74a29e8144bc12b84196775076b0e8cf3c5a6fc/docs/configuration.md#advanced-configuration)
11+
Priority between env variables, cmd, and config file follows the following [TorchServer standard](https://github.com/pytorch/serve/blob/master/docs/configuration.md)
12+
1113
* Example 1:
1214
* Config file: `disable_token_authorization=false`
1315

@@ -48,7 +50,7 @@ Priority between env variables, cmd, and config file follows the following [Torc
4850
2. Inference key: Used for inference APIs. Example:
4951
`curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer FINhR1fj"`
5052
3. API key: Used for the token authorization API. Check section 4 for API use.
51-
4. The plugin also includes an API in order to generate a new key to replace either the management or inference key.
53+
4. API in order to generate a new key to replace either the management or inference key.
5254
1. Management Example:
5355
`curl localhost:8081/token?type=management -H "Authorization: Bearer m4M-5IBY"` will replace the current management key in the key_file with a new one and will update the expiration time.
5456
2. Inference example:
@@ -61,4 +63,4 @@ Priority between env variables, cmd, and config file follows the following [Torc
6163
## Notes
6264
1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file.
6365
2. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30`
64-
3. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should only be able to run inferences against models that have already been loaded. The owner can also provide owners with the management key if owners want users to add and remove models.
66+
3. Three tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should only be able to run inferences against models that have already been loaded. The owner can also provide owners with the management key if owners want users to add and remove models.

frontend/server/src/main/java/org/pytorch/serve/ModelServer.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.pytorch.serve.archive.model.ModelNotFoundException;
4141
import org.pytorch.serve.grpcimpl.GRPCInterceptor;
4242
import org.pytorch.serve.grpcimpl.GRPCServiceFactory;
43+
import org.pytorch.serve.http.TokenAuthorizationHandler;
4344
import org.pytorch.serve.http.messages.RegisterModelRequest;
4445
import org.pytorch.serve.metrics.MetricCache;
4546
import org.pytorch.serve.metrics.MetricManager;
@@ -86,7 +87,7 @@ public static void main(String[] args) {
8687
ConfigManager.Arguments arguments = new ConfigManager.Arguments(cmd);
8788
ConfigManager.init(arguments);
8889
ConfigManager configManager = ConfigManager.getInstance();
89-
configManager.setupToken();
90+
TokenAuthorizationHandler.setupToken();
9091
PluginsManager.getInstance().initialize();
9192
MetricCache.init();
9293
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);

frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java

+48-91
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
4040
private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class);
4141
private static TokenType tokenType;
4242
private static Boolean tokenEnabled = false;
43-
private static Token tokenClass;
43+
private static Token token;
4444
private static Object tokenObject;
45-
private static Double timeToExpirationMinutes = 60.0;
4645

4746
/** Creates a new {@code InferenceRequestHandler} instance. */
4847
public TokenAuthorizationHandler(TokenType type) {
@@ -62,11 +61,12 @@ public void handleRequest(
6261
if (req.toString().contains("/token")) {
6362
try {
6463
checkTokenAuthorization(req, "token");
65-
String resp = tokenClass.updateKeyFile(req);
64+
String queryResponse = parseQuery(req);
65+
String resp = token.updateKeyFile(queryResponse);
6666
NettyUtils.sendJsonResponse(ctx, resp);
6767
return;
6868
} catch (Exception e) {
69-
logger.error("TOKEN CLASS UPDATED UNSUCCESSFULLY");
69+
logger.error("Key file updated unsuccessfully");
7070
throw new InvalidKeyException(
7171
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
7272
}
@@ -76,48 +76,60 @@ public void handleRequest(
7676
} else if (tokenType == TokenType.INFERENCE) {
7777
checkTokenAuthorization(req, "inference");
7878
}
79-
} else {
80-
if (tokenType == TokenType.MANAGEMENT && req.toString().contains("/token")) {
81-
throw new ResourceNotFoundException();
82-
}
8379
}
8480
chain.handleRequest(ctx, req, decoder, segments);
8581
}
8682

87-
public static void setupTokenClass() {
88-
try {
89-
tokenClass = new Token();
90-
Double time = ConfigManager.getInstance().getTimeToExpiration();
91-
String home = ConfigManager.getInstance().getModelServerHome();
92-
tokenClass.setFilePath(home);
93-
if (time != 0.0) {
94-
timeToExpirationMinutes = time;
95-
}
96-
tokenClass.setTime(timeToExpirationMinutes);
97-
if (tokenClass.generateKeyFile("token")) {
98-
logger.info("Token Authorization Enabled");
83+
public static void setupToken() {
84+
if (!ConfigManager.getInstance().getDisableTokenAuthorization()) {
85+
try {
86+
token = new Token();
87+
if (token.generateKeyFile("token")) {
88+
logger.info("Token Authorization Enabled");
89+
}
90+
} catch (IOException e) {
91+
e.printStackTrace();
92+
logger.error("Token Authorization setup unsuccessfully");
93+
throw new IllegalStateException("Token Authorization setup unsuccessfully", e);
9994
}
100-
} catch (Exception e) {
101-
e.printStackTrace();
102-
logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
103-
throw new IllegalStateException("Unable to import token class", e);
95+
tokenEnabled = true;
10496
}
105-
tokenEnabled = true;
10697
}
10798

10899
private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException {
100+
String tokenBearer = req.headers().get("Authorization");
101+
if (tokenBearer == null) {
102+
throw new InvalidKeyException(
103+
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
104+
}
105+
String[] arrOfStr = tokenBearer.split(" ", 2);
106+
if (arrOfStr.length == 1) {
107+
throw new InvalidKeyException(
108+
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
109+
}
110+
String currToken = arrOfStr[1];
109111

110-
try {
111-
boolean result = tokenClass.checkTokenAuthorization(req, type);
112-
if (!result) {
113-
throw new InvalidKeyException(
114-
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
115-
}
116-
} catch (Exception e) {
112+
boolean result = token.checkTokenAuthorization(currToken, type);
113+
if (!result) {
117114
throw new InvalidKeyException(
118115
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
119116
}
120117
}
118+
119+
// parses query and either returns management/inference or a wrong type error
120+
private String parseQuery(FullHttpRequest req) {
121+
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
122+
Map<String, List<String>> parameters = decoder.parameters();
123+
List<String> values = parameters.get("type");
124+
if (values != null && !values.isEmpty()) {
125+
if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) {
126+
return values.get(0);
127+
} else {
128+
return "WRONG TYPE";
129+
}
130+
}
131+
return "NO TYPE PROVIDED";
132+
}
121133
}
122134

123135
class Token {
@@ -126,14 +138,12 @@ class Token {
126138
private static String inferenceKey;
127139
private static Instant managementExpirationTimeMinutes;
128140
private static Instant inferenceExpirationTimeMinutes;
129-
private static Double timeToExpirationMinutes;
130141
private SecureRandom secureRandom = new SecureRandom();
131142
private Base64.Encoder baseEncoder = Base64.getUrlEncoder();
132143
private String fileName = "key_file.json";
133-
private String filePath = "";
144+
private String filePath = ConfigManager.getInstance().getModelServerHome();
134145

135-
public String updateKeyFile(FullHttpRequest req) throws IOException {
136-
String queryResponse = parseQuery(req);
146+
public String updateKeyFile(String queryResponse) throws IOException {
137147
String test = "";
138148
if ("management".equals(queryResponse)) {
139149
generateKeyFile("management");
@@ -145,36 +155,17 @@ public String updateKeyFile(FullHttpRequest req) throws IOException {
145155
return test;
146156
}
147157

148-
// parses query and either returns management/inference or a wrong type error
149-
public String parseQuery(FullHttpRequest req) {
150-
QueryStringDecoder decoder = new QueryStringDecoder(req.uri());
151-
Map<String, List<String>> parameters = decoder.parameters();
152-
List<String> values = parameters.get("type");
153-
if (values != null && !values.isEmpty()) {
154-
if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) {
155-
return values.get(0);
156-
} else {
157-
return "WRONG TYPE";
158-
}
159-
}
160-
return "NO TYPE PROVIDED";
161-
}
162-
163158
public String generateKey() {
164159
byte[] randomBytes = new byte[6];
165160
secureRandom.nextBytes(randomBytes);
166161
return baseEncoder.encodeToString(randomBytes);
167162
}
168163

169164
public Instant generateTokenExpiration() {
170-
long secondsToAdd = (long) (timeToExpirationMinutes * 60);
165+
long secondsToAdd = (long) (ConfigManager.getInstance().getTimeToExpiration() * 60);
171166
return Instant.now().plusSeconds(secondsToAdd);
172167
}
173168

174-
public void setFilePath(String path) {
175-
filePath = path;
176-
}
177-
178169
// generates a key file with new keys depending on the parameter provided
179170
public boolean generateKeyFile(String type) throws IOException {
180171
String userDirectory = filePath + "/" + fileName;
@@ -248,7 +239,7 @@ public boolean setFilePermissions() {
248239
}
249240

250241
// checks the token provided in the http with the saved keys depening on parameters
251-
public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
242+
public boolean checkTokenAuthorization(String token, String type) {
252243
String key;
253244
Instant expiration;
254245
switch (type) {
@@ -265,16 +256,6 @@ public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
265256
expiration = inferenceExpirationTimeMinutes;
266257
}
267258

268-
String tokenBearer = req.headers().get("Authorization");
269-
if (tokenBearer == null) {
270-
return false;
271-
}
272-
String[] arrOfStr = tokenBearer.split(" ", 2);
273-
if (arrOfStr.length == 1) {
274-
return false;
275-
}
276-
String token = arrOfStr[1];
277-
278259
if (token.equals(key)) {
279260
if (expiration != null && isTokenExpired(expiration)) {
280261
return false;
@@ -288,28 +269,4 @@ public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
288269
public boolean isTokenExpired(Instant expirationTime) {
289270
return !(Instant.now().isBefore(expirationTime));
290271
}
291-
292-
public String getManagementKey() {
293-
return managementKey;
294-
}
295-
296-
public String getInferenceKey() {
297-
return inferenceKey;
298-
}
299-
300-
public String getKey() {
301-
return apiKey;
302-
}
303-
304-
public Instant getInferenceExpirationTime() {
305-
return inferenceExpirationTimeMinutes;
306-
}
307-
308-
public Instant getManagementExpirationTime() {
309-
return managementExpirationTimeMinutes;
310-
}
311-
312-
public void setTime(Double time) {
313-
timeToExpirationMinutes = time;
314-
}
315272
}

frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java

+1-10
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
import org.apache.commons.cli.Options;
4747
import org.apache.commons.io.IOUtils;
4848
import org.pytorch.serve.archive.model.Manifest;
49-
import org.pytorch.serve.http.TokenAuthorizationHandler;
5049
import org.pytorch.serve.metrics.MetricBuilder;
5150
import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer;
5251
import org.pytorch.serve.snapshot.SnapshotSerializerFactory;
@@ -450,14 +449,6 @@ public boolean isOpenInferenceProtocol() {
450449
return Boolean.parseBoolean(prop.getProperty(TS_OPEN_INFERENCE_PROTOCOL, "false"));
451450
}
452451

453-
public boolean setupToken() {
454-
boolean disable_token_authorization = getDisableTokenAuthorization();
455-
if (!disable_token_authorization) {
456-
TokenAuthorizationHandler.setupTokenClass();
457-
}
458-
return true;
459-
}
460-
461452
public boolean isGRPCSSLEnabled() {
462453
return Boolean.parseBoolean(getProperty(TS_ENABLE_GRPC_SSL, "false"));
463454
}
@@ -1001,7 +992,7 @@ public Double getTimeToExpiration() {
1001992
logger.error("Token expiration not a valid integer");
1002993
}
1003994
}
1004-
return 0.0;
995+
return 60.0;
1005996
}
1006997

1007998
public String getTsHeaderKeySequenceId() {

0 commit comments

Comments
 (0)