Skip to content

Commit 18057f3

Browse files
udaij12Ubuntumreso
authoredFeb 20, 2024··
Api change (#2888)
* token authorization update * update token * token authorization plugin test * fix format * add key file generation at default * fix format * updated token plugin * fixed file delete * fixed imports * added custom expection * fix format * token handler * fix doc * fixed handler * added integration tests * Added integration tests * updating token auth * small changes to token auth * fixing changes * changed keyfile to dictionary and updated readme and tests * remove comments * changes to tests * added config file * reduce time for expiration test * change test to mnist * removing install from src * final test change * Fix spellcheck --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-11-32.us-west-2.compute.internal> Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
1 parent cd52683 commit 18057f3

File tree

21 files changed

+681
-6
lines changed

21 files changed

+681
-6
lines changed
 

‎docs/token_authorization_api.md

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# TorchServe token authorization API
2+
3+
## Configuration
4+
1. Enable token authorization by adding the provided plugin at start using the `--plugins-path` command.
5+
2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.json` will be generated.
6+
1. Example key file:
7+
8+
```python
9+
{
10+
"management": {
11+
"key": "B-E5KSRM",
12+
"expiration time": "2024-02-16T21:12:24.801167Z"
13+
},
14+
"inference": {
15+
"key": "gNRuA7dS",
16+
"expiration time": "2024-02-16T21:12:24.801148Z"
17+
},
18+
"API": {
19+
"key": "yv9uQajP"
20+
}
21+
}
22+
```
23+
24+
3. There are 3 keys and each have a different use.
25+
1. Management key: Used for management APIs. Example:
26+
`curl http://localhost:8081/models/densenet161 -H "Authorization: Bearer I_J_ItMb"`
27+
2. Inference key: Used for inference APIs. Example:
28+
`curl http://127.0.0.1:8080/predictions/densenet161 -T examples/image_classifier/kitten.jpg -H "Authorization: Bearer FINhR1fj"`
29+
3. API key: Used for the token authorization API. Check section 4 for API use.
30+
4. The plugin also includes an API in order to generate a new key to replace either the management or inference key.
31+
1. Management Example:
32+
`curl localhost:8081/token?type=management -H "Authorization: Bearer m4M-5IBY"` will replace the current management key in the key_file with a new one and will update the expiration time.
33+
2. Inference example:
34+
`curl localhost:8081/token?type=inference -H "Authorization: Bearer m4M-5IBY"`
35+
36+
Users will have to use either one of the APIs above.
37+
38+
5. When users shut down the server the key_file will be deleted.
39+
40+
41+
## Customization
42+
Torchserve offers various ways to customize the token authorization to allow owners to reach the desired result.
43+
1. Time to expiration is set to default at 60 minutes but can be changed in the config.properties by adding `token_expiration_min`. Ex:`token_expiration_min=30`
44+
2. The token authorization code is consolidated in the plugin and thus can be changed without impacting the frontend or end result. The only thing the user cannot change is:
45+
1. The urlPattern for the plugin must be 'token' and the class name must not change
46+
2. The `generateKeyFile`, `checkTokenAuthorization`, and `setTime` functions return type and signature must not change. However, the code in the functions can be modified depending on user necessity.
47+
48+
## Notes
49+
1. DO NOT MODIFY THE KEY FILE. Modifying the key file might impact reading and writing to the file thus preventing new keys from properly being displayed in the file.
50+
2. 3 tokens allow the owner with the most flexibility in use and enables them to adapt the tokens to their use. Owners of the server can provide users with the inference token if users should not mess with models. The owner can also provide owners with the management key if owners want users to add and remove models.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package org.pytorch.serve.archive.model;
2+
3+
public class InvalidKeyException extends ModelException {
4+
5+
private static final long serialVersionUID = 1L;
6+
7+
/**
8+
* Constructs an {@code InvalidKeyException} with the specified detail message.
9+
*
10+
* @param message The detail message (which is saved for later retrieval by the {@link
11+
* #getMessage()} method)
12+
*/
13+
public InvalidKeyException(String message) {
14+
super(message);
15+
}
16+
17+
/**
18+
* Constructs an {@code InvalidKeyException} with the specified detail message and cause.
19+
*
20+
* <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
21+
* incorporated into this exception's detail message.
22+
*
23+
* @param message The detail message (which is saved for later retrieval by the {@link
24+
* #getMessage()} method)
25+
* @param cause The cause (which is saved for later retrieval by the {@link #getCause()}
26+
* method). (A null value is permitted, and indicates that the cause is nonexistent or
27+
* unknown.)
28+
*/
29+
public InvalidKeyException(String message, Throwable cause) {
30+
super(message, cause);
31+
}
32+
}

‎frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.pytorch.serve.http.HttpRequestHandler;
1111
import org.pytorch.serve.http.HttpRequestHandlerChain;
1212
import org.pytorch.serve.http.InvalidRequestHandler;
13+
import org.pytorch.serve.http.TokenAuthorizationHandler;
1314
import org.pytorch.serve.http.api.rest.ApiDescriptionRequestHandler;
1415
import org.pytorch.serve.http.api.rest.InferenceRequestHandler;
1516
import org.pytorch.serve.http.api.rest.ManagementRequestHandler;
@@ -18,6 +19,7 @@
1819
import org.pytorch.serve.servingsdk.impl.PluginsManager;
1920
import org.pytorch.serve.util.ConfigManager;
2021
import org.pytorch.serve.util.ConnectorType;
22+
import org.pytorch.serve.util.TokenType;
2123
import org.pytorch.serve.workflow.api.http.WorkflowInferenceRequestHandler;
2224
import org.pytorch.serve.workflow.api.http.WorkflowMgmtRequestHandler;
2325
import org.slf4j.Logger;
@@ -63,6 +65,9 @@ public void initChannel(Channel ch) {
6365
HttpRequestHandlerChain httpRequestHandlerChain = apiDescriptionRequestHandler;
6466
if (ConnectorType.ALL.equals(connectorType)
6567
|| ConnectorType.INFERENCE_CONNECTOR.equals(connectorType)) {
68+
httpRequestHandlerChain =
69+
httpRequestHandlerChain.setNextHandler(
70+
new TokenAuthorizationHandler(TokenType.INFERENCE));
6671
httpRequestHandlerChain =
6772
httpRequestHandlerChain.setNextHandler(
6873
new InferenceRequestHandler(
@@ -80,6 +85,9 @@ public void initChannel(Channel ch) {
8085
}
8186
if (ConnectorType.ALL.equals(connectorType)
8287
|| ConnectorType.MANAGEMENT_CONNECTOR.equals(connectorType)) {
88+
httpRequestHandlerChain =
89+
httpRequestHandlerChain.setNextHandler(
90+
new TokenAuthorizationHandler(TokenType.MANAGEMENT));
8391
httpRequestHandlerChain =
8492
httpRequestHandlerChain.setNextHandler(
8593
new ManagementRequestHandler(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package org.pytorch.serve.http;
2+
3+
import io.netty.channel.ChannelHandlerContext;
4+
import io.netty.handler.codec.http.FullHttpRequest;
5+
import io.netty.handler.codec.http.QueryStringDecoder;
6+
import java.lang.reflect.*;
7+
import org.pytorch.serve.archive.DownloadArchiveException;
8+
import org.pytorch.serve.archive.model.InvalidKeyException;
9+
import org.pytorch.serve.archive.model.ModelException;
10+
import org.pytorch.serve.archive.workflow.WorkflowException;
11+
import org.pytorch.serve.util.ConfigManager;
12+
import org.pytorch.serve.util.TokenType;
13+
import org.pytorch.serve.wlm.WorkerInitializationException;
14+
import org.slf4j.Logger;
15+
import org.slf4j.LoggerFactory;
16+
17+
/**
18+
* A class handling token check for all inbound HTTP requests
19+
*
20+
* <p>This class //
21+
*/
22+
public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
23+
24+
private static final Logger logger = LoggerFactory.getLogger(TokenAuthorizationHandler.class);
25+
private static TokenType tokenType;
26+
private static Boolean tokenEnabled = false;
27+
private static Class<?> tokenClass;
28+
private static Object tokenObject;
29+
private static Double timeToExpirationMinutes = 60.0;
30+
31+
/** Creates a new {@code InferenceRequestHandler} instance. */
32+
public TokenAuthorizationHandler(TokenType type) {
33+
tokenType = type;
34+
}
35+
36+
@Override
37+
public void handleRequest(
38+
ChannelHandlerContext ctx,
39+
FullHttpRequest req,
40+
QueryStringDecoder decoder,
41+
String[] segments)
42+
throws ModelException, DownloadArchiveException, WorkflowException,
43+
WorkerInitializationException {
44+
if (tokenEnabled) {
45+
if (tokenType == TokenType.MANAGEMENT) {
46+
if (req.toString().contains("/token")) {
47+
checkTokenAuthorization(req, "token");
48+
} else {
49+
checkTokenAuthorization(req, "management");
50+
}
51+
} else if (tokenType == TokenType.INFERENCE) {
52+
checkTokenAuthorization(req, "inference");
53+
}
54+
}
55+
chain.handleRequest(ctx, req, decoder, segments);
56+
}
57+
58+
public static void setupTokenClass() {
59+
try {
60+
tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token");
61+
tokenObject = tokenClass.getDeclaredConstructor().newInstance();
62+
Method method = tokenClass.getMethod("setTime", Double.class);
63+
Double time = ConfigManager.getInstance().getTimeToExpiration();
64+
if (time != 0.0) {
65+
timeToExpirationMinutes = time;
66+
}
67+
method.invoke(tokenObject, timeToExpirationMinutes);
68+
method = tokenClass.getMethod("generateKeyFile", String.class);
69+
if ((boolean) method.invoke(tokenObject, "token")) {
70+
logger.info("TOKEN CLASS IMPORTED SUCCESSFULLY");
71+
}
72+
} catch (NoSuchMethodException
73+
| IllegalAccessException
74+
| InstantiationException
75+
| InvocationTargetException
76+
| ClassNotFoundException e) {
77+
e.printStackTrace();
78+
logger.error("TOKEN CLASS IMPORTED UNSUCCESSFULLY");
79+
throw new IllegalStateException("Unable to import token class", e);
80+
}
81+
tokenEnabled = true;
82+
}
83+
84+
private void checkTokenAuthorization(FullHttpRequest req, String type) throws ModelException {
85+
86+
try {
87+
Method method =
88+
tokenClass.getMethod(
89+
"checkTokenAuthorization",
90+
io.netty.handler.codec.http.FullHttpRequest.class,
91+
String.class);
92+
boolean result = (boolean) (method.invoke(tokenObject, req, type));
93+
if (!result) {
94+
throw new InvalidKeyException(
95+
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
96+
}
97+
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
98+
e.printStackTrace();
99+
throw new InvalidKeyException(
100+
"Token Authentication failed. Token either incorrect, expired, or not provided correctly");
101+
}
102+
}
103+
}

‎frontend/server/src/main/java/org/pytorch/serve/http/api/rest/ApiDescriptionRequestHandler.java

-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ public void handleRequest(
3030
String[] segments)
3131
throws ModelException, DownloadArchiveException, WorkflowException,
3232
WorkerInitializationException {
33-
3433
if (isApiDescription(segments)) {
3534
String path = decoder.path();
3635
if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method()))

‎frontend/server/src/main/java/org/pytorch/serve/servingsdk/impl/PluginsManager.java

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import java.util.Map;
66
import java.util.ServiceLoader;
77
import org.pytorch.serve.http.InvalidPluginException;
8+
import org.pytorch.serve.http.TokenAuthorizationHandler;
89
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
910
import org.pytorch.serve.servingsdk.annotations.Endpoint;
1011
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
@@ -30,6 +31,9 @@ public void initialize() {
3031
logger.info("Initializing plugins manager...");
3132
inferenceEndpoints = initInferenceEndpoints();
3233
managementEndpoints = initManagementEndpoints();
34+
if (managementEndpoints.containsKey("token")) {
35+
TokenAuthorizationHandler.setupTokenClass();
36+
}
3337
}
3438

3539
private boolean validateEndpointPlugin(Annotation a, EndpointTypes type) {

‎frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java

+12
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ public final class ConfigManager {
107107
private static final String TS_WORKFLOW_STORE = "workflow_store";
108108
private static final String TS_CPP_LOG_CONFIG = "cpp_log_config";
109109
private static final String TS_OPEN_INFERENCE_PROTOCOL = "ts_open_inference_protocol";
110+
private static final String TS_TOKEN_EXPIRATION_TIME_MIN = "token_expiration_min";
110111

111112
// Configuration which are not documented or enabled through environment variables
112113
private static final String USE_NATIVE_IO = "use_native_io";
@@ -859,6 +860,17 @@ public boolean isSnapshotDisabled() {
859860
return snapshotDisabled;
860861
}
861862

863+
public Double getTimeToExpiration() {
864+
if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN) != null) {
865+
try {
866+
return Double.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN));
867+
} catch (NumberFormatException e) {
868+
logger.error("Token expiration not a valid integer");
869+
}
870+
}
871+
return 0.0;
872+
}
873+
862874
public boolean isSSLEnabled(ConnectorType connectorType) {
863875
String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080");
864876
switch (connectorType) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package org.pytorch.serve.util;
2+
3+
public enum TokenType {
4+
INFERENCE,
5+
MANAGEMENT,
6+
TOKEN_API
7+
}

‎frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowInferenceRequestHandler.java

+1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ public void handleRequest(
8080
String[] segments)
8181
throws ModelException, DownloadArchiveException, WorkflowException,
8282
WorkerInitializationException {
83+
8384
if ("wfpredict".equalsIgnoreCase(segments[1])) {
8485
if (segments.length < 3) {
8586
throw new ResourceNotFoundException();

‎frontend/server/src/main/java/org/pytorch/serve/workflow/api/http/WorkflowMgmtRequestHandler.java

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public void handleRequest(
6363
String[] segments)
6464
throws ModelException, DownloadArchiveException, WorkflowException,
6565
WorkerInitializationException {
66+
6667
if (isManagementReq(segments)) {
6768
if (!"workflows".equals(segments[1])) {
6869
throw new ResourceNotFoundException();

‎plugins/endpoints/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
dependencies {
22
implementation "com.google.code.gson:gson:${gson_version}"
33
implementation "org.pytorch:torchserve-plugins-sdk:${torchserve_sdk_version}"
4+
implementation "io.netty:netty-all:4.1.53.Final"
45
}
56

67
project.ext{
@@ -16,4 +17,3 @@ jar {
1617
exclude "META-INF//LICENSE*"
1718
exclude "META-INF//NOTICE*"
1819
}
19-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
package org.pytorch.serve.plugins.endpoint;
2+
3+
// import java.util.Properties;
4+
import com.google.gson.GsonBuilder;
5+
import com.google.gson.JsonObject;
6+
import io.netty.handler.codec.http.FullHttpRequest;
7+
import io.netty.handler.codec.http.QueryStringDecoder;
8+
import java.io.File;
9+
import java.io.IOException;
10+
import java.nio.charset.StandardCharsets;
11+
import java.nio.file.Files;
12+
import java.nio.file.Path;
13+
import java.nio.file.Paths;
14+
import java.nio.file.attribute.PosixFilePermission;
15+
import java.nio.file.attribute.PosixFilePermissions;
16+
import java.security.SecureRandom;
17+
import java.time.Instant;
18+
import java.util.Base64;
19+
import java.util.List;
20+
import java.util.Map;
21+
import java.util.Set;
22+
import org.pytorch.serve.servingsdk.Context;
23+
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
24+
import org.pytorch.serve.servingsdk.annotations.Endpoint;
25+
import org.pytorch.serve.servingsdk.annotations.helpers.EndpointTypes;
26+
import org.pytorch.serve.servingsdk.http.Request;
27+
import org.pytorch.serve.servingsdk.http.Response;
28+
29+
// import org.pytorch.serve.util.TokenType;
30+
31+
@Endpoint(
32+
urlPattern = "token",
33+
endpointType = EndpointTypes.MANAGEMENT,
34+
description = "Token authentication endpoint")
35+
public class Token extends ModelServerEndpoint {
36+
private static String apiKey;
37+
private static String managementKey;
38+
private static String inferenceKey;
39+
private static Instant managementExpirationTimeMinutes;
40+
private static Instant inferenceExpirationTimeMinutes;
41+
private static Double timeToExpirationMinutes;
42+
private SecureRandom secureRandom = new SecureRandom();
43+
private Base64.Encoder baseEncoder = Base64.getUrlEncoder();
44+
private String fileName = "key_file.json";
45+
46+
@Override
47+
public void doGet(Request req, Response rsp, Context ctx) throws IOException {
48+
String queryResponse = parseQuery(req);
49+
String test = "";
50+
if ("management".equals(queryResponse)) {
51+
generateKeyFile("management");
52+
} else if ("inference".equals(queryResponse)) {
53+
generateKeyFile("inference");
54+
} else {
55+
test = "{\n\t\"Error\": " + queryResponse + "\n}\n";
56+
}
57+
rsp.getOutputStream().write(test.getBytes(StandardCharsets.UTF_8));
58+
}
59+
60+
// parses query and either returns management/inference or a wrong type error
61+
public String parseQuery(Request req) {
62+
QueryStringDecoder decoder = new QueryStringDecoder(req.getRequestURI());
63+
Map<String, List<String>> parameters = decoder.parameters();
64+
List<String> values = parameters.get("type");
65+
if (values != null && !values.isEmpty()) {
66+
if ("management".equals(values.get(0)) || "inference".equals(values.get(0))) {
67+
return values.get(0);
68+
} else {
69+
return "WRONG TYPE";
70+
}
71+
}
72+
return "NO TYPE PROVIDED";
73+
}
74+
75+
public String generateKey() {
76+
byte[] randomBytes = new byte[6];
77+
secureRandom.nextBytes(randomBytes);
78+
return baseEncoder.encodeToString(randomBytes);
79+
}
80+
81+
public Instant generateTokenExpiration() {
82+
long secondsToAdd = (long) (timeToExpirationMinutes * 60);
83+
return Instant.now().plusSeconds(secondsToAdd);
84+
}
85+
86+
// generates a key file with new keys depending on the parameter provided
87+
public boolean generateKeyFile(String type) throws IOException {
88+
String userDirectory = System.getProperty("user.dir") + "/" + fileName;
89+
File file = new File(userDirectory);
90+
if (!file.createNewFile() && !file.exists()) {
91+
return false;
92+
}
93+
if (apiKey == null) {
94+
apiKey = generateKey();
95+
}
96+
switch (type) {
97+
case "management":
98+
managementKey = generateKey();
99+
managementExpirationTimeMinutes = generateTokenExpiration();
100+
break;
101+
case "inference":
102+
inferenceKey = generateKey();
103+
inferenceExpirationTimeMinutes = generateTokenExpiration();
104+
break;
105+
default:
106+
managementKey = generateKey();
107+
inferenceKey = generateKey();
108+
inferenceExpirationTimeMinutes = generateTokenExpiration();
109+
managementExpirationTimeMinutes = generateTokenExpiration();
110+
}
111+
112+
JsonObject parentObject = new JsonObject();
113+
114+
JsonObject managementObject = new JsonObject();
115+
managementObject.addProperty("key", managementKey);
116+
managementObject.addProperty("expiration time", managementExpirationTimeMinutes.toString());
117+
parentObject.add("management", managementObject);
118+
119+
JsonObject inferenceObject = new JsonObject();
120+
inferenceObject.addProperty("key", inferenceKey);
121+
inferenceObject.addProperty("expiration time", inferenceExpirationTimeMinutes.toString());
122+
parentObject.add("inference", inferenceObject);
123+
124+
JsonObject apiObject = new JsonObject();
125+
apiObject.addProperty("key", apiKey);
126+
parentObject.add("API", apiObject);
127+
128+
Files.write(
129+
Paths.get(fileName),
130+
new GsonBuilder()
131+
.setPrettyPrinting()
132+
.create()
133+
.toJson(parentObject)
134+
.getBytes(StandardCharsets.UTF_8));
135+
136+
if (!setFilePermissions()) {
137+
try {
138+
Files.delete(Paths.get(fileName));
139+
} catch (IOException e) {
140+
return false;
141+
}
142+
return false;
143+
}
144+
return true;
145+
}
146+
147+
public boolean setFilePermissions() {
148+
Path path = Paths.get(fileName);
149+
try {
150+
Set<PosixFilePermission> permissions = PosixFilePermissions.fromString("rw-------");
151+
Files.setPosixFilePermissions(path, permissions);
152+
} catch (Exception e) {
153+
return false;
154+
}
155+
return true;
156+
}
157+
158+
// checks the token provided in the http with the saved keys depening on parameters
159+
public boolean checkTokenAuthorization(FullHttpRequest req, String type) {
160+
String key;
161+
Instant expiration;
162+
switch (type) {
163+
case "token":
164+
key = apiKey;
165+
expiration = null;
166+
break;
167+
case "management":
168+
key = managementKey;
169+
expiration = managementExpirationTimeMinutes;
170+
break;
171+
default:
172+
key = inferenceKey;
173+
expiration = inferenceExpirationTimeMinutes;
174+
}
175+
176+
String tokenBearer = req.headers().get("Authorization");
177+
if (tokenBearer == null) {
178+
return false;
179+
}
180+
String[] arrOfStr = tokenBearer.split(" ", 2);
181+
if (arrOfStr.length == 1) {
182+
return false;
183+
}
184+
String token = arrOfStr[1];
185+
186+
if (token.equals(key)) {
187+
if (expiration != null && isTokenExpired(expiration)) {
188+
return false;
189+
}
190+
} else {
191+
return false;
192+
}
193+
return true;
194+
}
195+
196+
public boolean isTokenExpired(Instant expirationTime) {
197+
return !(Instant.now().isBefore(expirationTime));
198+
}
199+
200+
public String getManagementKey() {
201+
return managementKey;
202+
}
203+
204+
public String getInferenceKey() {
205+
return inferenceKey;
206+
}
207+
208+
public String getKey() {
209+
return apiKey;
210+
}
211+
212+
public Instant getInferenceExpirationTime() {
213+
return inferenceExpirationTimeMinutes;
214+
}
215+
216+
public Instant getManagementExpirationTime() {
217+
return managementExpirationTimeMinutes;
218+
}
219+
220+
public void setTime(Double time) {
221+
timeToExpirationMinutes = time;
222+
}
223+
}
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
org.pytorch.serve.plugins.endpoint.ExecutionParameters
2+
org.pytorch.serve.plugins.endpoint.Token

‎plugins/gradle/wrapper/gradle-wrapper.properties

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ distributionBase=GRADLE_USER_HOME
33
distributionPath=wrapper/dists
44
zipStoreBase=GRADLE_USER_HOME
55
zipStorePath=wrapper/dists
6-
distributionUrl=https\://services.gradle.org/distributions/gradle-6.4-bin.zip
6+
distributionUrl=https\://services.gradle.org/distributions/gradle-7.3-all.zip

‎plugins/settings.gradle

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,4 @@
99

1010
rootProject.name = 'plugins'
1111
include 'endpoints'
12-
include 'DDBEndPoint'
13-
12+
// include 'DDBEndPoint'

‎test/pytest/test_data/0.png

272 Bytes
Loading
+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import json
2+
import os
3+
import shutil
4+
import subprocess
5+
import tempfile
6+
import time
7+
from pathlib import Path
8+
9+
import pytest
10+
import requests
11+
import test_utils
12+
13+
ROOT_DIR = os.path.join(tempfile.gettempdir(), "workspace")
14+
REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")
15+
data_file_zero = os.path.join(REPO_ROOT, "test/pytest/test_data/0.png")
16+
config_file = os.path.join(REPO_ROOT, "test/resources/config_token.properties")
17+
18+
19+
# Set up token plugin
20+
def get_plugin_jar():
21+
new_folder_path = os.path.join(ROOT_DIR, "plugins-path")
22+
plugin_folder = os.path.join(REPO_ROOT, "plugins")
23+
os.makedirs(new_folder_path, exist_ok=True)
24+
os.chdir(plugin_folder)
25+
subprocess.run(["./gradlew", "formatJava"])
26+
result = subprocess.run(["./gradlew", "build"])
27+
jar_path = os.path.join(plugin_folder, "endpoints/build/libs")
28+
jar_file = [file for file in os.listdir(jar_path) if file.endswith(".jar")]
29+
if jar_file:
30+
shutil.move(
31+
os.path.join(jar_path, jar_file[0]),
32+
os.path.join(new_folder_path, jar_file[0]),
33+
)
34+
os.chdir(REPO_ROOT)
35+
36+
37+
# Parse json file and return key
38+
def read_key_file(type):
39+
json_file_path = os.path.join(REPO_ROOT, "key_file.json")
40+
with open(json_file_path) as json_file:
41+
json_data = json.load(json_file)
42+
43+
options = {
44+
"management": json_data.get("management", {}).get("key", "NOT_PRESENT"),
45+
"inference": json_data.get("inference", {}).get("key", "NOT_PRESENT"),
46+
"token": json_data.get("API", {}).get("key", "NOT_PRESENT"),
47+
}
48+
key = options.get(type, "Invalid data type")
49+
return key
50+
51+
52+
@pytest.fixture(scope="module")
53+
def setup_torchserve():
54+
get_plugin_jar()
55+
MODEL_STORE = os.path.join(ROOT_DIR, "model_store/")
56+
PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path")
57+
58+
Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True)
59+
60+
test_utils.start_torchserve(no_config_snapshots=True, plugin_folder=PLUGIN_STORE)
61+
62+
key = read_key_file("management")
63+
header = {"Authorization": f"Bearer {key}"}
64+
65+
params = (
66+
("model_name", "mnist"),
67+
("url", "mnist.mar"),
68+
("initial_workers", "1"),
69+
("synchronous", "true"),
70+
)
71+
response = requests.post(
72+
"http://localhost:8081/models", params=params, headers=header
73+
)
74+
file_content = Path(f"{REPO_ROOT}/key_file.json").read_text()
75+
print(file_content)
76+
77+
yield "test"
78+
79+
test_utils.stop_torchserve()
80+
81+
82+
@pytest.fixture(scope="module")
83+
def setup_torchserve_expiration():
84+
get_plugin_jar()
85+
MODEL_STORE = os.path.join(ROOT_DIR, "model_store/")
86+
PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path")
87+
88+
Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True)
89+
90+
test_utils.start_torchserve(
91+
snapshot_file=config_file, no_config_snapshots=True, plugin_folder=PLUGIN_STORE
92+
)
93+
94+
key = read_key_file("management")
95+
header = {"Authorization": f"Bearer {key}"}
96+
97+
params = (
98+
("model_name", "mnist"),
99+
("url", "mnist.mar"),
100+
("initial_workers", "1"),
101+
("synchronous", "true"),
102+
)
103+
response = requests.post(
104+
"http://localhost:8081/models", params=params, headers=header
105+
)
106+
file_content = Path(f"{REPO_ROOT}/key_file.json").read_text()
107+
print(file_content)
108+
109+
yield "test"
110+
111+
test_utils.stop_torchserve()
112+
113+
114+
# Test describe model API with token enabled
115+
def test_managament_api_with_token(setup_torchserve):
116+
key = read_key_file("management")
117+
header = {"Authorization": f"Bearer {key}"}
118+
response = requests.get("http://localhost:8081/models/mnist", headers=header)
119+
120+
assert response.status_code == 200, "Token check failed"
121+
122+
123+
# Test describe model API with incorrect token and no token
124+
def test_managament_api_with_incorrect_token(setup_torchserve):
125+
# Using random key
126+
header = {"Authorization": "Bearer abcd1234"}
127+
response = requests.get(f"http://localhost:8081/models/mnist", headers=header)
128+
129+
assert response.status_code == 400, "Token check failed"
130+
131+
132+
# Test inference API with token enabled
133+
def test_inference_api_with_token(setup_torchserve):
134+
key = read_key_file("inference")
135+
header = {"Authorization": f"Bearer {key}"}
136+
137+
response = requests.post(
138+
url="http://localhost:8080/predictions/mnist",
139+
files={"data": open(data_file_zero, "rb")},
140+
headers=header,
141+
)
142+
143+
assert response.status_code == 200, "Token check failed"
144+
145+
146+
# Test inference API with incorrect token
147+
def test_inference_api_with_incorrect_token(setup_torchserve):
148+
# Using random key
149+
header = {"Authorization": "Bearer abcd1234"}
150+
151+
response = requests.post(
152+
url="http://localhost:8080/predictions/mnist",
153+
files={"data": open(data_file_zero, "rb")},
154+
headers=header,
155+
)
156+
157+
assert response.status_code == 400, "Token check failed"
158+
159+
160+
# Test Token API for regenerating new inference key
161+
def test_token_inference_api(setup_torchserve):
162+
token_key = read_key_file("token")
163+
inference_key = read_key_file("inference")
164+
header_inference = {"Authorization": f"Bearer {inference_key}"}
165+
header_token = {"Authorization": f"Bearer {token_key}"}
166+
params = {"type": "inference"}
167+
168+
# check inference works with current token
169+
response = requests.post(
170+
url="http://localhost:8080/predictions/mnist",
171+
files={"data": open(data_file_zero, "rb")},
172+
headers=header_inference,
173+
)
174+
assert response.status_code == 200, "Token check failed"
175+
176+
# generate new inference token and check it is different
177+
response = requests.get(
178+
url="http://localhost:8081/token", params=params, headers=header_token
179+
)
180+
assert response.status_code == 200, "Token check failed"
181+
assert inference_key != read_key_file("inference"), "Key file not updated"
182+
183+
# check inference does not works with original token
184+
response = requests.post(
185+
url="http://localhost:8080/predictions/mnist",
186+
files={"data": open(data_file_zero, "rb")},
187+
headers=header_inference,
188+
)
189+
assert response.status_code == 400, "Token check failed"
190+
191+
192+
# Test Token API for regenerating new management key
193+
def test_token_management_api(setup_torchserve):
194+
token_key = read_key_file("token")
195+
management_key = read_key_file("management")
196+
header = {"Authorization": f"Bearer {token_key}"}
197+
params = {"type": "management"}
198+
199+
response = requests.get(
200+
url="http://localhost:8081/token", params=params, headers=header
201+
)
202+
203+
assert management_key != read_key_file("management"), "Key file not updated"
204+
assert response.status_code == 200, "Token check failed"
205+
206+
207+
# Test expiration time
208+
@pytest.mark.module2
209+
def test_token_expiration_time(setup_torchserve_expiration):
210+
key = read_key_file("management")
211+
header = {"Authorization": f"Bearer {key}"}
212+
response = requests.get("http://localhost:8081/models/mnist", headers=header)
213+
assert response.status_code == 200, "Token check failed"
214+
215+
time.sleep(15)
216+
217+
response = requests.get("http://localhost:8081/models/mnist", headers=header)
218+
assert response.status_code == 400, "Token check failed"

‎test/pytest/test_utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@ def run(self):
5454

5555

5656
def start_torchserve(
57-
model_store=None, snapshot_file=None, no_config_snapshots=False, gen_mar=True
57+
model_store=None,
58+
snapshot_file=None,
59+
no_config_snapshots=False,
60+
gen_mar=True,
61+
plugin_folder=None,
5862
):
5963
stop_torchserve()
6064
crate_mar_file_table()
@@ -63,6 +67,8 @@ def start_torchserve(
6367
if gen_mar:
6468
mg.gen_mar(model_store)
6569
cmd.extend(["--model-store", model_store])
70+
if plugin_folder:
71+
cmd.extend(["--plugins-path", plugin_folder])
6672
if snapshot_file:
6773
cmd.extend(["--ts-config", snapshot_file])
6874
if no_config_snapshots:
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
token_expiration_min=0.25

‎ts/model_server.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import os
6+
import pathlib
67
import platform
78
import re
89
import subprocess
@@ -48,6 +49,7 @@ def start() -> None:
4849
try:
4950
parent = psutil.Process(pid)
5051
parent.terminate()
52+
pathlib.Path("key_file.json").unlink(missing_ok=True)
5153
if args.foreground:
5254
try:
5355
parent.wait(timeout=60)

‎ts_scripts/spellcheck_conf/wordlist.txt

+8
Original file line numberDiff line numberDiff line change
@@ -1187,3 +1187,11 @@ FxGraphCache
11871187
TorchInductor
11881188
fx
11891189
locustapache
1190+
FINhR
1191+
IBY
1192+
ItMb
1193+
checkTokenAuthorization
1194+
fj
1195+
generateKeyFile
1196+
setTime
1197+
urlPattern

0 commit comments

Comments
 (0)
Please sign in to comment.