Skip to content

Commit 8cd3cde

Browse files
udaij12Ubuntumreso
authored andcommitted
Api change (pytorch#2888)
* token authorization update * update token * token authorization plugin test * fix format * add key file generation at default * fix format * updated token plugin * fixed file delete * fixed imports * added custom expection * fix format * token handler * fix doc * fixed handler * added integration tests * Added integration tests * updating token auth * small changes to token auth * fixing changes * changed keyfile to dictionary and updated readme and tests * remove comments * changes to tests * added config file * reduce time for expiration test * change test to mnist * removing install from src * final test change * Fix spellcheck --------- Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Matthias Reso <[email protected]>
1 parent 44da600 commit 8cd3cde

File tree

21 files changed

+681
-6
lines changed

21 files changed

+681
-6
lines changed

docs/token_authorization_api.md

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# TorchServe token authorization API
2+
3+
## Configuration
4+
1. Enable token authorization by adding the provided plugin at start using the `--plugins-path` command.
5+
2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.json` will be generated.
6+
1. Example key file:
7+
8+
```python
9+
{
10+
"management": {
11+
"key": "B-E5KSRM",
12+
"expiration time": "2024-02-16T21:12:24.801167Z"
13+
},
14+
"inference": {
15+
"key": "gNRuA7dS",
16+
"expiration time": "2024-02-16T21:12:24.801148Z"
17+
},
18+
"API": {
19+
"key": "yv9uQajP"
20+
}
21+
}
22+
```
23+
24+
3. There are 3 keys and each have a different use.
25+
1. Management key: Used for management APIs. Example:
26+
`curl http://localhost:8081/models/densenet161 -H "Authorization: Bearer I_J_ItMb"`
27+
2. Inference key: Used for inference APIs. Example:
28+
`curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer FINhR1fj"`
29+
3. API key: Used for the token authorization API. Check section 4 for API use.
30+
4. The plugin also includes an API in order to generate a new key to replace either the management or inference key.
31+
1. Management Example:
32+
`curl localhost:8081/token?type=management -H "Authorization: Bearer m4M-5IBY"` will replace the current management key in the key_file with a new one and will update the expiration time.
33+
2. Inference example:
34+
`curl localhost:8081/token?type=inference -H "Authorization: Bearer m4M-5IBY"`
35+
36+
Users will have to use either one of the APIs above.
37+
38+
5. When users shut down the server the key_file will be deleted.
39+
40+
41+
## Customization
42+
Torchserve offers various ways to customize the token authorization to allow owners to reach the desired result.
43+
1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30`
44+
2. The token authorization code is consolidated in the plugin and thus can be changed without impacting the frontend or end result. The only thing the user cannot change is:
45+
1. The urlPattern for the plugin must be 'token' and the class name must not change
46+
2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signature must not change. However, the code in the functions can be modified depending on user necessity.
47+
48+
## Notes
49+
1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file.
50+
2. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package org.pytorch.serve.archive.model;
2+
3+
public class InvalidKeyException extends ModelException {
4+
5+
private static final long serialVersionUID = 1L;
6+
7+
/**
8+
* Constructs an {@code InvalidKeyException} with the specified detail message.
9+
*
10+
* @param message The detail message (which is saved for later retrieval by the {@link
11+
* #getMessage()} method)
12+
*/
13+
public InvalidKeyException(String message) {
14+
super(message);
15+
}
16+
17+
/**
18+
* Constructs an {@code InvalidKeyException} with the specified detail message and cause.
19+
*
20+
* <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
21+
* incorporated into this exception's detail message.
22+
*
23+
* @param message The detail message (which is saved for later retrieval by the {@link
24+
* #getMessage()} method)
25+
* @param cause The cause (which is saved for later retrieval by the {@link #getCause()}
26+
* method). (A null value is permitted, and indicates that the cause is nonexistent or
27+
* unknown.)
28+
*/
29+
public InvalidKeyException(String message, Throwable cause) {
30+
super(message, cause);
31+
}
32+
}

frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.pytorch.serve.http.HttpRequestHandler;
1111
import org.pytorch.serve.http.HttpRequestHandlerChain;
1212
import org.pytorch.serve.http.InvalidRequestHandler;
13+
import org.pytorch.serve.http.TokenAuthorizationHandler;
1314
import org.pytorch.serve.http.api.rest.ApiDescriptionRequestHandler;
1415
import org.pytorch.serve.http.api.rest.InferenceRequestHandler;
1516
import org.pytorch.serve.http.api.rest.ManagementRequestHandler;
@@ -18,6 +19,7 @@
1819
import org.pytorch.serve.servingsdk.impl.PluginsManager;
1920
import org.pytorch.serve.util.ConfigManager;
2021
import org.pytorch.serve.util.ConnectorType;
22+
import org.pytorch.serve.util.TokenType;
2123
import org.pytorch.serve.workflow.api.http.WorkflowInferenceRequestHandler;
2224
import org.pytorch.serve.workflow.api.http.WorkflowMgmtRequestHandler;
2325
import org.slf4j.Logger;
@@ -63,6 +65,9 @@ public void initChannel(Channel ch) {
6365
HttpRequestHandlerChain httpRequestHandlerChain = apiDescriptionRequestHandler;
6466
if (ConnectorType.ALL.equals(connectorType)
6567
|| ConnectorType.INFERENCE_CONNECTOR.equals(connectorType)) {
68+
httpRequestHandlerChain =
69+
httpRequestHandlerChain.setNextHandler(
70+
new TokenAuthorizationHandler(TokenType.INFERENCE));
6671
httpRequestHandlerChain =
6772
httpRequestHandlerChain.setNextHandler(
6873
new InferenceRequestHandler(
@@ -80,6 +85,9 @@ public void initChannel(Channel ch) {
8085
}
8186
if (ConnectorType.ALL.equals(connectorType)
8287
|| ConnectorType.MANAGEMENT_CONNECTOR.equals(connectorType)) {
88+
httpRequestHandlerChain =
89+
httpRequestHandlerChain.setNextHandler(
90+
new TokenAuthorizationHandler(TokenType.MANAGEMENT));
8391
httpRequestHandlerChain =
8492
httpRequestHandlerChain.setNextHandler(
8593
new ManagementRequestHandler(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package org.pytorch.serve.http;
2+
3+
import io.netty.channel.ChannelHandlerContext;
4+
import io.netty.handler.codec.http.FullHttpRequest;
5+
import io.netty.handler.codec.http.QueryStringDecoder;
6+
import java.lang.reflect.*;
7+
import org.pytorch.serve.archive.DownloadArchiveException;
8+
import org.pytorch.serve.archive.model.InvalidKeyException;
9+
import org.pytorch.serve.archive.model.ModelException;
10+
import org.pytorch.serve.archive.workflow.WorkflowException;
11+
import org.pytorch.serve.util.ConfigManager;
12+
import org.pytorch.serve.util.TokenType;
13+
import org.pytorch.serve.wlm.WorkerInitializationException;
14+
import org.slf4j.Logger;
15+
import org.slf4j.LoggerFactory;
16+
17+
/**
18+
* A class handling token check for all inbound HTTP requests
19+
*
20+
* <p>This class //
21+
*/
22+
public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
23+
24+
private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class);
25+
private static TokenType tokenType;
26+
private static Boolean tokenEnabled = false;
27+
private static Class<?> tokenClass;
28+
private static Object tokenObject;
29+
private static Double timeToExpirationMinutes = 60.0;
30+
31+
/** Creates a new {@code InferenceRequestHandler} instance. */
32+
public TokenAuthorizationHandler(TokenType type) {
33+
tokenType = type;
34+
}
35+
36+
@Override
37+
public void handleRequest(
38+
ChannelHandlerContext ctx,
39+
FullHttpRequest req,
40+
QueryStringDecoder decoder,
41+
String[] segments)
42+
throws ModelException, DownloadArchiveException, WorkflowException,
43+
WorkerInitializationException {
44+
if (tokenEnabled) {
45+
if (tokenType == TokenType.MANAGEMENT) {
46+
if (req.toString().contains("/token")) {
47+
checkTokenAuthorization(req, "token");
48+
} else {
49+
checkTokenAuthorization(req, "management");
50+
}
51+
} else if (tokenType == TokenType.INFERENCE) {
52+
checkTokenAuthorization(req, "inference");
53+
}
54+
}
55+
chain.handleRequest(ctx, req, decoder, segments);
56+
}
57+
58+
public static void setupTokenClass() {
59+
try {
60+
tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token");
61+
tokenObject = tokenClass.getDeclaredConstructor().newInstance();
62+
Method method = tokenClass.getMethod("setTime", Double.class);
63+
Double time = ConfigManager.getInstance().getTimeToExpiration();
64+
if (time != 0.0) {
65+
timeToExpirationMinutes = time;
66+
}
67+
method.invoke(tokenObject, timeToExpirationMinutes);
68+
method = tokenClass.getMethod("generateKeyFile", String.class);
69+
if ((boolean) method.invoke(tokenObject, "token")) {
70+
logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY");
71+
}
72+
} catch (NoSuchMethodException
73+
| IllegalAccessException
74+
| InstantiationException
75+
| InvocationTargetException
76+
| ClassNotFoundException e) {
77+
e.printStackTrace();
78+
logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
79+
throw new IllegalStateException("Unable to import token class", e);
80+
}
81+
tokenEnabled = true;
82+
}
83+
84+
private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException {
85+
86+
try {
87+
Method method =
88+
tokenClass.getMethod(
89+
"checkTokenAuthorization",
90+
io.netty.handler.codec.http.FullHttpRequest.class,
91+
String.class);
92+
boolean result = (boolean) (method.invoke(tokenObject, req, type));
93+
if (!result) {
94+
throw new InvalidKeyException(
95+
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
96+
}
97+
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
98+
e.printStackTrace();
99+
throw new InvalidKeyException(
100+
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
101+
}
102+
}
103+
}

frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java

-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ public void handleRequest(
3030
String[] segments)
3131
throws ModelException, DownloadArchiveException, WorkflowException,
3232
WorkerInitializationException {
33-
3433
if (isApiDescription(segments)) {
3534
String path = decoder.path();
3635
if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method()))

frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import java.util.Map;
66
import java.util.ServiceLoader;
77
import org.pytorch.serve.http.InvalidPluginException;
8+
import org.pytorch.serve.http.TokenAuthorizationHandler;
89
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
910
import org.pytorch.serve.servingsdk.annotations.Endpoint;
1011
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
@@ -30,6 +31,9 @@ public void initialize() {
3031
logger.info("Initializing plugins manager...");
3132
inferenceEndpoints = initInferenceEndpoints();
3233
managementEndpoints = initManagementEndpoints();
34+
if (managementEndpoints.containsKey("token")) {
35+
TokenAuthorizationHandler.setupTokenClass();
36+
}
3337
}
3438

3539
private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) {

frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java

+12
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ public final class ConfigManager {
107107
private static final String TS_WORKFLOW_STORE = "workflow_store";
108108
private static final String TS_CPP_LOG_CONFIG = "cpp_log_config";
109109
private static final String TS_OPEN_INFERENCE_PROTOCOL = "ts_open_inference_protocol";
110+
private static final String TS_TOKEN_EXPIRATION_TIME_MIN = "token_expiration_min";
110111

111112
// Configuration which are not documented or enabled through environment variables
112113
private static final String USE_NATIVE_IO = "use_native_io";
@@ -859,6 +860,17 @@ public boolean isSnapshotDisabled() {
859860
return snapshotDisabled;
860861
}
861862

863+
public Double getTimeToExpiration() {
864+
if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN) != null) {
865+
try {
866+
return Double.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN));
867+
} catch (NumberFormatException e) {
868+
logger.error("Token expiration not a valid integer");
869+
}
870+
}
871+
return 0.0;
872+
}
873+
862874
public boolean isSSLEnabled(ConnectorType connectorType) {
863875
String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080");
864876
switch (connectorType) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package org.pytorch.serve.util;
2+
3+
public enum TokenType {
4+
INFERENCE,
5+
MANAGEMENT,
6+
TOKEN_API
7+
}

frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ public void handleRequest(
8080
String[] segments)
8181
throws ModelException, DownloadArchiveException, WorkflowException,
8282
WorkerInitializationException {
83+
8384
if ("wfpredict".equalsIgnoreCase(segments[1])) {
8485
if (segments.length < 3) {
8586
throw new ResourceNotFoundException();

frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public void handleRequest(
6363
String[] segments)
6464
throws ModelException, DownloadArchiveException, WorkflowException,
6565
WorkerInitializationException {
66+
6667
if (isManagementReq(segments)) {
6768
if (!"workflows".equals(segments[1])) {
6869
throw new ResourceNotFoundException();

plugins/endpoints/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
dependencies {
22
implementation "com.google.code.gson:gson:${gson_version}"
33
implementation "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}"
4+
implementation "io.netty:netty-all:4.1.53.Final"
45
}
56

67
project.ext{
@@ -16,4 +17,3 @@ jar {
1617
exclude "META-INF//LICENSE*"
1718
exclude "META-INF//NOTICE*"
1819
}
19-

0 commit comments

Comments
 (0)