Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Api change #2888

Merged
merged 34 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4ddb74e
token authorization update
Dec 5, 2023
1304ca3
update token
Dec 20, 2023
fde5f4d
token authorization plugin test
Dec 20, 2023
8ec3c40
fix format
Dec 20, 2023
d50beae
add key file generation at default
udaij12 Jan 3, 2024
d5c25e6
fix format
udaij12 Jan 3, 2024
1b35372
Merge branch 'master' into api_change # Please enter a commit message…
Jan 4, 2024
0010c86
Merge branch 'master' into api_change
udaij12 Jan 17, 2024
62cadc8
updated token plugin
udaij12 Jan 17, 2024
51f5889
Merge branch 'api_change' of https://github.com/pytorch/serve into ap…
udaij12 Jan 17, 2024
129b652
fixed file delete
udaij12 Jan 17, 2024
cecc086
fixed imports
udaij12 Jan 17, 2024
6876781
added custom expection
udaij12 Jan 17, 2024
4848114
fix format
udaij12 Jan 18, 2024
fa32c28
token handler
udaij12 Jan 30, 2024
c5c0f77
fix doc
udaij12 Jan 30, 2024
84343c4
merging master
udaij12 Jan 31, 2024
3b7fe29
fixed handler
udaij12 Jan 31, 2024
0d7afe3
added integration tests
udaij12 Feb 1, 2024
13caa27
Added integration tests
udaij12 Feb 2, 2024
a49acf5
updating token auth
udaij12 Feb 14, 2024
6b329b2
small changes to token auth
udaij12 Feb 15, 2024
0690e01
fixing changes
udaij12 Feb 15, 2024
eb37eff
changed keyfile to dictionary and updated readme and tests
udaij12 Feb 16, 2024
71bc7f5
remove comments
udaij12 Feb 16, 2024
3e18230
changes to tests
udaij12 Feb 16, 2024
f69c632
added config file
udaij12 Feb 16, 2024
599c635
Merge branch 'master' into api_change
mreso Feb 16, 2024
55cedd5
reduce time for expiration test
udaij12 Feb 17, 2024
4c7080d
Merge branch 'api_change' of https://github.com/pytorch/serve into ap…
udaij12 Feb 17, 2024
54841d9
change test to mnist
udaij12 Feb 17, 2024
4e1c1ea
removing install from src
udaij12 Feb 17, 2024
698b95a
final test change
udaij12 Feb 17, 2024
68b3b04
Fix spellcheck
mreso Feb 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.pytorch.serve.archive.model;

public class InvalidKeyException extends ModelException {

private static final long serialVersionUID = 1L;

/**
* Constructs an {@code InvalidKeyException} with the specified detail message.
*
* @param message The detail message (which is saved for later retrieval by the {@link
* #getMessage()} method)
*/
public InvalidKeyException(String message) {
super(message);
}

/**
* Constructs an {@code InvalidKeyException} with the specified detail message and cause.
*
* <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
* incorporated into this exception's detail message.
*
* @param message The detail message (which is saved for later retrieval by the {@link
* #getMessage()} method)
* @param cause The cause (which is saved for later retrieval by the {@link #getCause()}
* method). (A null value is permitted, and indicates that the cause is nonexistent or
* unknown.)
*/
public InvalidKeyException(String message, Throwable cause) {
super(message, cause);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.pytorch.serve.archive.model;

public class KeyTimeOutException extends ModelException {

private static final long serialVersionUID = 1L;

/**
* Constructs an {@code KeyTimeOutException} with the specified detail message.
*
* @param message The detail message (which is saved for later retrieval by the {@link
* #getMessage()} method)
*/
public KeyTimeOutException(String message) {
super(message);
}

/**
* Constructs an {@code KeyTimeOutException} with the specified detail message and cause.
*
* <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically
* incorporated into this exception's detail message.
*
* @param message The detail message (which is saved for later retrieval by the {@link
* #getMessage()} method)
* @param cause The cause (which is saved for later retrieval by the {@link #getCause()}
* method). (A null value is permitted, and indicates that the cause is nonexistent or
* unknown.)
*/
public KeyTimeOutException(String message, Throwable cause) {
super(message, cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ public static void main(String[] args) {
ConfigManager.init(arguments);
ConfigManager configManager = ConfigManager.getInstance();
PluginsManager.getInstance().initialize();
Map<String, ModelServerEndpoint> plugins =
PluginsManager.getInstance().getInferenceEndpoints();
if (plugins.containsKey("token")) {
configManager.generateKeyFile();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should a new key be generated as long as model server is started?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when model server is started the key file is generated with the three keys.

}
MetricCache.init();
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE);
ModelServer modelServer = new ModelServer(configManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ public void handleRequest(
String[] segments)
throws ModelException, DownloadArchiveException, WorkflowException,
WorkerInitializationException {

if (isApiDescription(segments)) {
String path = decoder.path();
if (("/".equals(path) && HttpMethod.OPTIONS.equals(req.method()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ public void handleRequest(
String[] segments)
throws ModelException, DownloadArchiveException, WorkflowException,
WorkerInitializationException {
ConfigManager configManager = ConfigManager.getInstance();
if (isInferenceReq(segments)) {
if (endpointMap.getOrDefault(segments[1], null) != null) {
configManager.checkTokenAuthorization(req, false);
handleCustomEndpoint(ctx, req, segments, decoder);
} else {
configManager.checkTokenAuthorization(req, true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the implementation token authentication should not be spreaded in each endpoint. It should be handled by InvalidRequestHandler.

switch (segments[1]) {
case "ping":
Runnable r =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.pytorch.serve.openapi.OpenApiUtils;
import org.pytorch.serve.servingsdk.ModelServerEndpoint;
import org.pytorch.serve.util.ApiUtils;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.util.messages.RequestInput;
Expand Down Expand Up @@ -61,6 +62,8 @@ public void handleRequest(
String[] segments)
throws ModelException, DownloadArchiveException, WorkflowException,
WorkerInitializationException {
ConfigManager configManager = ConfigManager.getInstance();
configManager.checkTokenAuthorization(req, false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if (isManagementReq(segments)) {
if (endpointMap.getOrDefault(segments[1], null) != null) {
handleCustomEndpoint(ctx, req, segments, decoder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.archive.workflow.WorkflowException;
import org.pytorch.serve.http.HttpRequestHandlerChain;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.slf4j.Logger;
Expand All @@ -47,6 +48,8 @@ public void handleRequest(
String[] segments)
throws ModelException, DownloadArchiveException, WorkflowException,
WorkerInitializationException {
ConfigManager configManager = ConfigManager.getInstance();
configManager.checkTokenAuthorization(req, true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if (segments.length >= 2 && "metrics".equals(segments[1])) {
ByteBuf resBuf = Unpooled.directBuffer();
List<String> params =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.SelfSignedCertificate;
Expand All @@ -20,11 +22,14 @@
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.SecureRandom;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.time.DateTimeException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
Expand All @@ -36,13 +41,17 @@
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.io.IOUtils;
import org.pytorch.serve.archive.model.InvalidKeyException;
import org.pytorch.serve.archive.model.KeyTimeOutException;
import org.pytorch.serve.archive.model.ModelException;
import org.pytorch.serve.metrics.MetricBuilder;
import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer;
import org.pytorch.serve.snapshot.SnapshotSerializerFactory;
Expand Down Expand Up @@ -146,6 +155,10 @@ public final class ConfigManager {
private boolean telemetryEnabled;
private Logger logger = LoggerFactory.getLogger(ConfigManager.class);

private SecureRandom secureRandom = new SecureRandom();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two seem to be unused

private Base64.Encoder baseEncoder = Base64.getUrlEncoder();
public String keyFileLocation;

private ConfigManager(Arguments args) throws IOException {
prop = new Properties();

Expand Down Expand Up @@ -837,6 +850,104 @@ public boolean isSnapshotDisabled() {
return snapshotDisabled;
}

public boolean isTokenExpired(Instant expirationTime) {
return !(Instant.now().isBefore(expirationTime));
}

public String generateToken() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that generateToken is as same as genereateKey in plugin. I don't understand why the token implementation need twice.

byte[] randomBytes = new byte[6];
secureRandom.nextBytes(randomBytes);
return baseEncoder.encodeToString(randomBytes);
}

public boolean generateKeyFile() throws IOException {
String fileData = " ";
String absoluteFilePath = getCanonicalPath(".") + "/key_file.txt";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the toke file path is different from the path implemented in plugin. Again, I don't understand why the token implementation need twice.

keyFileLocation = absoluteFilePath;
File file = new File(absoluteFilePath);
if (!file.createNewFile() && !file.exists()) {
return false;
}
Integer timeToExpiration = 30; // in minutes
fileData =
"Management Key: "
+ generateToken()
+ "\n"
+ "Inference Key: "
+ generateToken()
+ " --- Expiration time: "
+ Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(timeToExpiration))
+ "\n";
Files.write(Paths.get("key_file.txt"), fileData.getBytes());
return true;
}

public List<String> parseFile(File tsTokenFile) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parseFile is invoked by checkTokenAuthorization which in turn will be called during all inference and management API calls. Reading from file during each of these calls is going to be expensive. Could we potentially maintain the required data, i.e the tokens and expiration time in memory so that we can readily check against them. Also, we can flush this data to the file using say updateKeyFile when it is initially generated and when any of the data changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to maintain the data in memory and no longer need to parse the key file.

List<String> parsedTokens = new ArrayList<>();
try {
InputStream stream = Files.newInputStream(tsTokenFile.toPath());
byte[] array = new byte[100];
stream.read(array);
String data = new String(array);
String[] arrOfData = data.split("\n", 2);
String[] managementArr = arrOfData[0].split(" ", 3);
String[] inferenceArr = arrOfData[1].split(" ", 7);
parsedTokens.add(managementArr[2]);
parsedTokens.add(inferenceArr[2]);
String[] expirationArr = inferenceArr[6].split("\n", 2);
parsedTokens.add(expirationArr[0]);
} catch (IOException | ArrayIndexOutOfBoundsException e) {
System.out.println("Unable to read key file or key file has been modified");
return null;
}
return parsedTokens;
}

public void checkTokenAuthorization(FullHttpRequest req, boolean inferenceRequest)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is too heavy for real time inference. is ConfigManager the right place to implement the entire token authentication?

throws ModelException {
HttpMethod method = req.method();
String filePath = keyFileLocation;
if (filePath != null) {
File tsTokenFile = new File(filePath);
if (tsTokenFile.exists()) {
List<String> parsedTokens = parseFile(tsTokenFile);
String managementToken = parsedTokens.get(0);
String inferenceToken = parsedTokens.get(1);
Instant expirationTime = Instant.now();
try {
expirationTime = Instant.parse(parsedTokens.get(2));
} catch (DateTimeException e) {
e.printStackTrace();
System.out.println("{\n\t\"Error\": Key File has been modified \n}\n");
}
String tokenBearer = req.headers().get("Authorization");
if (tokenBearer == null) {
throw new InvalidKeyException("NO TOKEN PROVIDED");
}
String[] arrOfStr = tokenBearer.split(" ", 2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make parsing the token from the header more robust using a regex pattern.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer need to parse the key file.

if (arrOfStr.length == 1) {
throw new InvalidKeyException("NO TOKEN PROVIDED");
}
String token = arrOfStr[1];
String key = managementToken;
if (inferenceRequest) {
key = inferenceToken;
}

if (token.equals(key)) {
if (isTokenExpired(expirationTime) && inferenceRequest) {
throw new KeyTimeOutException("THE CURRENT TOKEN IS EXPIRED");
}
System.out.println("TOKEN AUTHORIZATION WORKED");
} else {
throw new InvalidKeyException("TOKEN IS INCORRECT ");
}
} else {
System.out.println("TOKEN AUTHORIZATION IS NOT ENABLED");
}
}
}

public boolean isSSLEnabled(ConnectorType connectorType) {
String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080");
switch (connectorType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ public void handleRequest(
String[] segments)
throws ModelException, DownloadArchiveException, WorkflowException,
WorkerInitializationException {

ConfigManager configManager = ConfigManager.getInstance();
configManager.checkTokenAuthorization(req, true);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if ("wfpredict".equalsIgnoreCase(segments[1])) {
if (segments.length < 3) {
throw new ResourceNotFoundException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.pytorch.serve.http.MethodNotAllowedException;
import org.pytorch.serve.http.ResourceNotFoundException;
import org.pytorch.serve.http.StatusResponse;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.JsonUtils;
import org.pytorch.serve.util.NettyUtils;
import org.pytorch.serve.wlm.WorkerInitializationException;
Expand Down Expand Up @@ -63,6 +64,9 @@ public void handleRequest(
String[] segments)
throws ModelException, DownloadArchiveException, WorkflowException,
WorkerInitializationException {

ConfigManager configManager = ConfigManager.getInstance();
configManager.checkTokenAuthorization(req, false);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if (isManagementReq(segments)) {
if (!"workflows".equals(segments[1])) {
throw new ResourceNotFoundException();
Expand Down
Loading