Skip to content

Commit fa32c28

Browse files
committed
token handler
1 parent 4848114 commit fa32c28

File tree

13 files changed

+199
-135
lines changed

13 files changed

+199
-135
lines changed

docs/token_authorization_api.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# TorchServe token authorization API
22

3-
## Customer Use
3+
## Configuration
44
1. Enable token authorization by adding the provided plugin at start using the `--plugin-path` command.
5-
2. Torchserve will enable token authorization if the plugin is provided. In the model server home folder a file `key_file.txt` will be generated.
5+
2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.txt` will be generated.
66
1. Example key file:
77

88
`Management Key: aadJv_R6 --- Expiration time: 2024-01-16T22:23:32.952499Z`
@@ -11,19 +11,19 @@
1111

1212
`API Key: xryL_Vzs`
1313
3. There are 3 keys and each have a different use.
14-
1. Management key: Used for management apis. Example:
14+
1. Management key: Used for management APIs. Example:
1515
`curl http://localhost:8081/models/densenet161 -H "Authorization: Bearer aadJv_R6"`
16-
2. Inference key: Used for inference apis. Example:
16+
2. Inference key: Used for inference APIs. Example:
1717
`curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer poZXAlqe"`
18-
3. API key: Used for the token authorization api. Check section 4 for api use.
18+
3. API key: Used for the token authorization API. Check section 4 for API use.
1919
4. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models.
20-
4. The plugin also includes an api in order to generate a new key to replace either the management or inference key.
20+
4. The plugin also includes an API in order to generate a new key to replace either the management or inference key.
2121
1. Management Example:
2222
`curl localhost:8081/token?type=management -H "Authorization: Bearer xryL_Vzs"` will replace the current management key in the key_file with a new one and will update the expiration time.
2323
2. Inference example:
2424
`curl localhost:8081/token?type=inference -H "Authorization: Bearer xryL_Vzs"`
2525

26-
Users will have to use either one of the apis above.
26+
Users will have to use either one of the APIs above.
2727

2828
5. When users shut down the server the key_file will be deleted.
2929

frontend/server/src/main/java/org/pytorch/serve/ModelServer.java

-5
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,6 @@ public static void main(String[] args) {
8383
ConfigManager.init(arguments);
8484
ConfigManager configManager = ConfigManager.getInstance();
8585
PluginsManager.getInstance().initialize();
86-
Map<String, ModelServerEndpoint> plugins =
87-
PluginsManager.getInstance().getManagementEndpoints();
88-
if (plugins.containsKey("token")) {
89-
configManager.setupTokenClass();
90-
}
9186
MetricCache.init();
9287
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
9388
ModelServer modelServer = new ModelServer(configManager);

frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java

+8
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
import org.pytorch.serve.http.HttpRequestHandler;
1111
import org.pytorch.serve.http.HttpRequestHandlerChain;
1212
import org.pytorch.serve.http.InvalidRequestHandler;
13+
import org.pytorch.serve.http.TokenAuthorizationHandler;
1314
import org.pytorch.serve.http.api.rest.ApiDescriptionRequestHandler;
1415
import org.pytorch.serve.http.api.rest.InferenceRequestHandler;
1516
import org.pytorch.serve.http.api.rest.ManagementRequestHandler;
1617
import org.pytorch.serve.http.api.rest.PrometheusMetricsRequestHandler;
1718
import org.pytorch.serve.servingsdk.impl.PluginsManager;
1819
import org.pytorch.serve.util.ConfigManager;
1920
import org.pytorch.serve.util.ConnectorType;
21+
import org.pytorch.serve.util.TokenType;
2022
import org.pytorch.serve.workflow.api.http.WorkflowInferenceRequestHandler;
2123
import org.pytorch.serve.workflow.api.http.WorkflowMgmtRequestHandler;
2224

@@ -59,6 +61,9 @@ public void initChannel(Channel ch) {
5961
HttpRequestHandlerChain httpRequestHandlerChain = apiDescriptionRequestHandler;
6062
if (ConnectorType.ALL.equals(connectorType)
6163
|| ConnectorType.INFERENCE_CONNECTOR.equals(connectorType)) {
64+
httpRequestHandlerChain =
65+
httpRequestHandlerChain.setNextHandler(
66+
new TokenAuthorizationHandler(TokenType.INFERENCE));
6267
httpRequestHandlerChain =
6368
httpRequestHandlerChain.setNextHandler(
6469
new InferenceRequestHandler(
@@ -68,6 +73,9 @@ public void initChannel(Channel ch) {
6873
}
6974
if (ConnectorType.ALL.equals(connectorType)
7075
|| ConnectorType.MANAGEMENT_CONNECTOR.equals(connectorType)) {
76+
httpRequestHandlerChain =
77+
httpRequestHandlerChain.setNextHandler(
78+
new TokenAuthorizationHandler(TokenType.MANAGEMENT));
7179
httpRequestHandlerChain =
7280
httpRequestHandlerChain.setNextHandler(
7381
new ManagementRequestHandler(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package org.pytorch.serve.http;
2+
3+
import io.netty.channel.ChannelHandlerContext;
4+
import io.netty.handler.codec.http.FullHttpRequest;
5+
import io.netty.handler.codec.http.QueryStringDecoder;
6+
import java.lang.reflect.*;
7+
import org.pytorch.serve.archive.DownloadArchiveException;
8+
import org.pytorch.serve.archive.model.InvalidKeyException;
9+
import org.pytorch.serve.archive.model.ModelException;
10+
import org.pytorch.serve.archive.workflow.WorkflowException;
11+
import org.pytorch.serve.util.ConfigManager;
12+
import org.pytorch.serve.util.TokenType;
13+
import org.pytorch.serve.wlm.WorkerInitializationException;
14+
import org.slf4j.Logger;
15+
import org.slf4j.LoggerFactory;
16+
17+
/**
18+
* A class handling inbound HTTP requests to the inference API.
19+
*
20+
* <p>This class //
21+
*/
22+
public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
23+
24+
private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class);
25+
private static TokenType tokenType;
26+
private static Boolean tokenEnabled;
27+
private static Class<?> tokenClass;
28+
private static Object tokenObject;
29+
private static Integer timeToExpirationMinutes = 60;
30+
31+
/** Creates a new {@code InferenceRequestHandler} instance. */
32+
public TokenAuthorizationHandler(TokenType type) {
33+
tokenType = type;
34+
}
35+
36+
@Override
37+
public void handleRequest(
38+
ChannelHandlerContext ctx,
39+
FullHttpRequest req,
40+
QueryStringDecoder decoder,
41+
String[] segments)
42+
throws ModelException, DownloadArchiveException, WorkflowException,
43+
WorkerInitializationException {
44+
ConfigManager configManager = ConfigManager.getInstance();
45+
if (tokenType == TokenType.MANAGEMENT) {
46+
if (req.toString().contains("/token")) {
47+
checkTokenAuthorization(req, 0);
48+
} else {
49+
checkTokenAuthorization(req, 1);
50+
}
51+
} else if (tokenType == TokenType.INFERENCE) {
52+
checkTokenAuthorization(req, 2);
53+
}
54+
chain.handleRequest(ctx, req, decoder, segments);
55+
}
56+
57+
public static void setupTokenClass() {
58+
try {
59+
tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token");
60+
tokenObject = tokenClass.getDeclaredConstructor().newInstance();
61+
Method method = tokenClass.getMethod("setTime", Integer.class);
62+
Integer time = ConfigManager.getInstance().getTimeToExpiration();
63+
if (time == 0) {
64+
timeToExpirationMinutes = time;
65+
}
66+
method.invoke(tokenObject, timeToExpirationMinutes);
67+
method = tokenClass.getMethod("generateKeyFile", Integer.class);
68+
if ((boolean) method.invoke(tokenObject, 0)) {
69+
logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY");
70+
}
71+
} catch (ClassNotFoundException e) {
72+
logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
73+
e.printStackTrace();
74+
return;
75+
} catch (NoSuchMethodException
76+
| IllegalAccessException
77+
| InstantiationException
78+
| InvocationTargetException e) {
79+
e.printStackTrace();
80+
logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
81+
return;
82+
}
83+
tokenEnabled = true;
84+
}
85+
86+
private void checkTokenAuthorization(FullHttpRequest req, Integer type) throws ModelException {
87+
88+
if (tokenEnabled) {
89+
try {
90+
Method method =
91+
tokenClass.getMethod(
92+
"checkTokenAuthorization",
93+
io.netty.handler.codec.http.FullHttpRequest.class,
94+
Integer.class);
95+
boolean result = (boolean) (method.invoke(tokenObject, req, type));
96+
if (!result) {
97+
throw new InvalidKeyException(
98+
"Token Authenticaation failed. Token either incorrect, expired, or not provided correctly");
99+
}
100+
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
101+
e.printStackTrace();
102+
throw new InvalidKeyException(
103+
"Token Authenticaation failed. Token either incorrect, expired, or not provided correctly");
104+
}
105+
}
106+
}
107+
}

frontend/server/src/main/java/org/pytorch/serve/http/api/rest/InferenceRequestHandler.java

-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,6 @@ public void handleRequest(
5959
String[] segments)
6060
throws ModelException, DownloadArchiveException, WorkflowException,
6161
WorkerInitializationException {
62-
ConfigManager configManager = ConfigManager.getInstance();
63-
configManager.checkTokenAuthorization(req, 2);
6462
if (isInferenceReq(segments)) {
6563
if (endpointMap.getOrDefault(segments[1], null) != null) {
6664
handleCustomEndpoint(ctx, req, segments, decoder);

frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ManagementRequestHandler.java

-6
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import org.pytorch.serve.openapi.OpenApiUtils;
3333
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
3434
import org.pytorch.serve.util.ApiUtils;
35-
import org.pytorch.serve.util.ConfigManager;
3635
import org.pytorch.serve.util.JsonUtils;
3736
import org.pytorch.serve.util.NettyUtils;
3837
import org.pytorch.serve.util.messages.RequestInput;
@@ -62,15 +61,10 @@ public void handleRequest(
6261
String[] segments)
6362
throws ModelException, DownloadArchiveException, WorkflowException,
6463
WorkerInitializationException {
65-
ConfigManager configManager = ConfigManager.getInstance();
6664
if (isManagementReq(segments)) {
6765
if (endpointMap.getOrDefault(segments[1], null) != null) {
68-
if (req.toString().contains("/token")) {
69-
configManager.checkTokenAuthorization(req, 0);
70-
}
7166
handleCustomEndpoint(ctx, req, segments, decoder);
7267
} else {
73-
configManager.checkTokenAuthorization(req, 1);
7468
if (!"models".equals(segments[1])) {
7569
throw new ResourceNotFoundException();
7670
}

frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import java.util.Map;
66
import java.util.ServiceLoader;
77
import org.pytorch.serve.http.InvalidPluginException;
8+
import org.pytorch.serve.http.TokenAuthorizationHandler;
89
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
910
import org.pytorch.serve.servingsdk.annotations.Endpoint;
1011
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
@@ -30,6 +31,9 @@ public void initialize() {
3031
logger.info("Initializing plugins manager...");
3132
inferenceEndpoints = initInferenceEndpoints();
3233
managementEndpoints = initManagementEndpoints();
34+
if (managementEndpoints.containsKey("token")) {
35+
TokenAuthorizationHandler.setupTokenClass();
36+
}
3337
}
3438

3539
private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) {

frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java

+4-87
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import com.google.gson.JsonObject;
44
import com.google.gson.reflect.TypeToken;
5-
import io.netty.handler.codec.http.FullHttpRequest;
65
import io.netty.handler.ssl.SslContext;
76
import io.netty.handler.ssl.SslContextBuilder;
87
import io.netty.handler.ssl.util.SelfSignedCertificate;
@@ -46,8 +45,6 @@
4645
import org.apache.commons.cli.Option;
4746
import org.apache.commons.cli.Options;
4847
import org.apache.commons.io.IOUtils;
49-
import org.pytorch.serve.archive.model.InvalidKeyException;
50-
import org.pytorch.serve.archive.model.ModelException;
5148
import org.pytorch.serve.metrics.MetricBuilder;
5249
import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer;
5350
import org.pytorch.serve.snapshot.SnapshotSerializerFactory;
@@ -850,95 +847,15 @@ public boolean isSnapshotDisabled() {
850847
return snapshotDisabled;
851848
}
852849

853-
// Imports the token class and sets the expiration time either default or custom
854-
// calls generate key file in token api to create 3 keys and logs the result
855-
public void setupTokenClass() {
856-
try {
857-
tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token");
858-
tokenObject = tokenClass.getDeclaredConstructor().newInstance();
859-
Method method = tokenClass.getMethod("setTime", Integer.class);
860-
if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null) {
861-
timeToExpiration = Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME));
862-
}
863-
method.invoke(tokenObject, timeToExpiration);
864-
method = tokenClass.getMethod("generateKeyFile", Integer.class);
865-
if ((boolean) method.invoke(tokenObject, 0)) {
866-
System.out.println("TOKEN CLASS IMPORTED SUCCESSFULLY");
867-
dumpKeyLogs();
868-
}
869-
} catch (ClassNotFoundException e) {
870-
e.printStackTrace();
871-
} catch (NoSuchMethodException
872-
| IllegalAccessException
873-
| InstantiationException
874-
| InvocationTargetException e) {
875-
e.printStackTrace();
876-
}
877-
tokenAuthorizationEnabled = true;
878-
}
879-
880-
public void dumpKeyLogs() {
881-
String managementKey = "";
882-
String inferenceKey = "";
883-
String apiKey = "";
884-
try {
885-
Method method = tokenClass.getMethod("getManagementKey");
886-
managementKey = (String) method.invoke(tokenObject);
887-
method = tokenClass.getMethod("getInferenceKey");
888-
inferenceKey = (String) method.invoke(tokenObject);
889-
method = tokenClass.getMethod("getKey");
890-
apiKey = (String) method.invoke(tokenObject);
891-
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
892-
e.printStackTrace();
893-
}
894-
895-
logger.info("KEY FILE PATH: " + System.getProperty("user.dir") + "/key_file.txt");
896-
logger.info("MANAGEMENT KEY: " + managementKey);
897-
logger.info("INFERNCE KEY: " + inferenceKey);
898-
logger.info("API KEY: " + apiKey);
899-
logger.info(
900-
"MANAGEMENT API Example: curl http://localhost:8081/models/<model> -H \"Authorization: Bearer "
901-
+ managementKey
902-
+ "\"");
903-
logger.info(
904-
"INFERNCE API Example: curl http://127.0.0.1:8080/predictions/<model> -T <examples/image_classifier/kitten.jpg> -H \"Authorization: Bearer "
905-
+ inferenceKey
906-
+ "\"");
907-
logger.info(
908-
"API API Example: curl localhost:8081/token?type=management -H \"Authorization: Bearer "
909-
+ apiKey
910-
+ "\"");
911-
}
912-
913850
public boolean isTokenEnabled() {
914851
return tokenAuthorizationEnabled;
915852
}
916853

917-
// Calls the checkTokenAuthorization function in the token plugin
918-
// expects two inputs: the fullhttpRequest and an integer which is associated with the type
919-
// 0: token api
920-
// 1: management api
921-
// 2: inference api
922-
public void checkTokenAuthorization(FullHttpRequest req, Integer requestType)
923-
throws ModelException {
924-
925-
if (tokenAuthorizationEnabled) {
926-
try {
927-
Method method =
928-
tokenClass.getMethod(
929-
"checkTokenAuthorization",
930-
io.netty.handler.codec.http.FullHttpRequest.class,
931-
Integer.class);
932-
boolean result = (boolean) (method.invoke(tokenObject, req, requestType));
933-
if (!result) {
934-
throw new InvalidKeyException(
935-
"Token Authenticaation failed. Token either incorrect, expired, or not provided correctly");
936-
}
937-
System.out.println("TOKEN AUTHORIZATION WORKED");
938-
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
939-
e.printStackTrace();
940-
}
854+
public Integer getTimeToExpiration() {
855+
if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME) != null) {
856+
return Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME));
941857
}
858+
return 0;
942859
}
943860

944861
public boolean isSSLEnabled(ConnectorType connectorType) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package org.pytorch.serve.util;
2+
3+
public enum TokenType {
4+
INFERENCE,
5+
MANAGEMENT,
6+
TOKEN_API
7+
}

plugins/endpoints/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
dependencies {
22
implementation "com.google.code.gson:gson:${gson_version}"
33
implementation "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}"
4+
implementation "io.netty:netty-all:4.1.53.Final"
45
}
56

67
project.ext{
@@ -16,4 +17,3 @@ jar {
1617
exclude "META-INF//LICENSE*"
1718
exclude "META-INF//NOTICE*"
1819
}
19-

0 commit comments

Comments
 (0)