Skip to content

Commit 55cedd5

Browse files
committed
reduce time for expiration test
1 parent f69c632 commit 55cedd5

File tree

7 files changed

+22
-18
lines changed

7 files changed

+22
-18
lines changed

docs/token_authorization_api.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# TorchServe token authorization API
22

33
## Configuration
4-
1. Enable token authorization by adding the provided plugin at start using the `--plugin-path` command.
4+
1. Enable token authorization by adding the provided plugin at start using the `--plugins-path` command.
55
2. Torchserve will enable token authorization if the plugin is provided. In the current working directory a file `key_file.json` will be generated.
66
1. Example key file:
77

frontend/server/src/main/java/org/pytorch/serve/http/TokenAuthorizationHandler.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public class TokenAuthorizationHandler extends HttpRequestHandlerChain {
2626
private static Boolean tokenEnabled = false;
2727
private static Class<?> tokenClass;
2828
private static Object tokenObject;
29-
private static Integer timeToExpirationMinutes = 60;
29+
private static Double timeToExpirationMinutes = 60.0;
3030

3131
/** Creates a new {@code InferenceRequestHandler} instance. */
3232
public TokenAuthorizationHandler(TokenType type) {
@@ -59,9 +59,9 @@ public static void setupTokenClass() {
5959
try {
6060
tokenClass = Class.forName("org.pytorch.serve.plugins.endpoint.Token");
6161
tokenObject = tokenClass.getDeclaredConstructor().newInstance();
62-
Method method = tokenClass.getMethod("setTime", Integer.class);
63-
Integer time = ConfigManager.getInstance().getTimeToExpiration();
64-
if (time != 0) {
62+
Method method = tokenClass.getMethod("setTime", Double.class);
63+
Double time = ConfigManager.getInstance().getTimeToExpiration();
64+
if (time != 0.0) {
6565
timeToExpirationMinutes = time;
6666
}
6767
method.invoke(tokenObject, timeToExpirationMinutes);

frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -860,15 +860,15 @@ public boolean isSnapshotDisabled() {
860860
return snapshotDisabled;
861861
}
862862

863-
public Integer getTimeToExpiration() {
863+
public Double getTimeToExpiration() {
864864
if (prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN) != null) {
865865
try {
866-
return Integer.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN));
866+
return Double.valueOf(prop.getProperty(TS_TOKEN_EXPIRATION_TIME_MIN));
867867
} catch (NumberFormatException e) {
868868
logger.error("Token expiration not a valid integer");
869869
}
870870
}
871-
return 0;
871+
return 0.0;
872872
}
873873

874874
public boolean isSSLEnabled(ConnectorType connectorType) {

plugins/endpoints/src/main/java/org/pytorch/serve/plugins/endpoint/Token.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import java.util.List;
2020
import java.util.Map;
2121
import java.util.Set;
22-
import java.util.concurrent.TimeUnit;
2322
import org.pytorch.serve.servingsdk.Context;
2423
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
2524
import org.pytorch.serve.servingsdk.annotations.Endpoint;
@@ -39,7 +38,7 @@ public class Token extends ModelServerEndpoint {
3938
private static String inferenceKey;
4039
private static Instant managementExpirationTimeMinutes;
4140
private static Instant inferenceExpirationTimeMinutes;
42-
private static Integer timeToExpirationMinutes;
41+
private static Double timeToExpirationMinutes;
4342
private SecureRandom secureRandom = new SecureRandom();
4443
private Base64.Encoder baseEncoder = Base64.getUrlEncoder();
4544
private String fileName = "key_file.json";
@@ -80,7 +79,8 @@ public String generateKey() {
8079
}
8180

8281
public Instant generateTokenExpiration() {
83-
return Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(timeToExpirationMinutes));
82+
long secondsToAdd = (long) (timeToExpirationMinutes * 60);
83+
return Instant.now().plusSeconds(secondsToAdd);
8484
}
8585

8686
// generates a key file with new keys depending on the parameter provided
@@ -217,7 +217,7 @@ public Instant getManagementExpirationTime() {
217217
return managementExpirationTimeMinutes;
218218
}
219219

220-
public void setTime(Integer time) {
220+
public void setTime(Double time) {
221221
timeToExpirationMinutes = time;
222222
}
223223
}

test/pytest/test_token_authorization.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def setup_torchserve():
6666
MODEL_STORE = os.path.join(ROOT_DIR, "model_store/")
6767
PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path")
6868

69+
Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True)
70+
6971
test_utils.start_torchserve(no_config_snapshots=True, plugin_folder=PLUGIN_STORE)
7072

7173
key = read_key_file("management")
@@ -94,10 +96,10 @@ def setup_torchserve_expiration():
9496
MODEL_STORE = os.path.join(ROOT_DIR, "model_store/")
9597
PLUGIN_STORE = os.path.join(ROOT_DIR, "plugins-path")
9698

99+
Path(test_utils.MODEL_STORE).mkdir(parents=True, exist_ok=True)
100+
97101
test_utils.start_torchserve(
98-
snapshot_file=config_file,
99-
no_config_snapshots=True,
100-
plugin_folder=PLUGIN_STORE,
102+
snapshot_file=config_file, no_config_snapshots=True, plugin_folder=PLUGIN_STORE
101103
)
102104

103105
key = read_key_file("management")
@@ -214,13 +216,14 @@ def test_token_management_api(setup_torchserve):
214216

215217

216218
# Test expiration time
219+
@pytest.mark.module2
217220
def test_token_expiration_time(setup_torchserve_expiration):
218221
key = read_key_file("management")
219222
header = {"Authorization": f"Bearer {key}"}
220223
response = requests.get("http://localhost:8081/models/resnet18", headers=header)
221224
assert response.status_code == 200, "Token check failed"
222225

223-
time.sleep(60)
226+
time.sleep(15)
224227

225228
response = requests.get("http://localhost:8081/models/resnet18", headers=header)
226229
assert response.status_code == 400, "Token check failed"
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
token_expiration_min=1
1+
token_expiration_min=0.25

ts/model_server.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import os
6+
import pathlib
67
import platform
78
import re
89
import subprocess
@@ -48,7 +49,7 @@ def start() -> None:
4849
try:
4950
parent = psutil.Process(pid)
5051
parent.terminate()
51-
os.remove(os.getcwd() + "/key_file.json")
52+
pathlib.Path("key_file.json").unlink(missing_ok=True)
5253
if args.foreground:
5354
try:
5455
parent.wait(timeout=60)

0 commit comments

Comments
 (0)