-
Notifications
You must be signed in to change notification settings - Fork 878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Api change #2888
Api change #2888
Changes from 7 commits
4ddb74e
1304ca3
fde5f4d
8ec3c40
d50beae
d5c25e6
1b35372
0010c86
62cadc8
51f5889
129b652
cecc086
6876781
4848114
fa32c28
c5c0f77
84343c4
3b7fe29
0d7afe3
13caa27
a49acf5
6b329b2
0690e01
eb37eff
71bc7f5
3e18230
f69c632
599c635
55cedd5
4c7080d
54841d9
4e1c1ea
698b95a
68b3b04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package org.pytorch.serve.archive.model; | ||
|
||
public class InvalidKeyException extends ModelException { | ||
|
||
private static final long serialVersionUID = 1L; | ||
|
||
/** | ||
* Constructs an {@code InvalidKeyException} with the specified detail message. | ||
* | ||
* @param message The detail message (which is saved for later retrieval by the {@link | ||
* #getMessage()} method) | ||
*/ | ||
public InvalidKeyException(String message) { | ||
super(message); | ||
} | ||
|
||
/** | ||
* Constructs an {@code InvalidKeyException} with the specified detail message and cause. | ||
* | ||
* <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically | ||
* incorporated into this exception's detail message. | ||
* | ||
* @param message The detail message (which is saved for later retrieval by the {@link | ||
* #getMessage()} method) | ||
* @param cause The cause (which is saved for later retrieval by the {@link #getCause()} | ||
* method). (A null value is permitted, and indicates that the cause is nonexistent or | ||
* unknown.) | ||
*/ | ||
public InvalidKeyException(String message, Throwable cause) { | ||
super(message, cause); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package org.pytorch.serve.archive.model; | ||
|
||
public class KeyTimeOutException extends ModelException { | ||
|
||
private static final long serialVersionUID = 1L; | ||
|
||
/** | ||
* Constructs an {@code KeyTimeOutException} with the specified detail message. | ||
* | ||
* @param message The detail message (which is saved for later retrieval by the {@link | ||
* #getMessage()} method) | ||
*/ | ||
public KeyTimeOutException(String message) { | ||
super(message); | ||
} | ||
|
||
/** | ||
* Constructs an {@code KeyTimeOutException} with the specified detail message and cause. | ||
* | ||
* <p>Note that the detail message associated with {@code cause} is <i>not</i> automatically | ||
* incorporated into this exception's detail message. | ||
* | ||
* @param message The detail message (which is saved for later retrieval by the {@link | ||
* #getMessage()} method) | ||
* @param cause The cause (which is saved for later retrieval by the {@link #getCause()} | ||
* method). (A null value is permitted, and indicates that the cause is nonexistent or | ||
* unknown.) | ||
*/ | ||
public KeyTimeOutException(String message, Throwable cause) { | ||
super(message, cause); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,10 +59,13 @@ public void handleRequest( | |
String[] segments) | ||
throws ModelException, DownloadArchiveException, WorkflowException, | ||
WorkerInitializationException { | ||
ConfigManager configManager = ConfigManager.getInstance(); | ||
if (isInferenceReq(segments)) { | ||
if (endpointMap.getOrDefault(segments[1], null) != null) { | ||
configManager.checkTokenAuthorization(req, false); | ||
handleCustomEndpoint(ctx, req, segments, decoder); | ||
} else { | ||
configManager.checkTokenAuthorization(req, true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the implementation token authentication should not be spreaded in each endpoint. It should be handled by InvalidRequestHandler. |
||
switch (segments[1]) { | ||
case "ping": | ||
Runnable r = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
import org.pytorch.serve.openapi.OpenApiUtils; | ||
import org.pytorch.serve.servingsdk.ModelServerEndpoint; | ||
import org.pytorch.serve.util.ApiUtils; | ||
import org.pytorch.serve.util.ConfigManager; | ||
import org.pytorch.serve.util.JsonUtils; | ||
import org.pytorch.serve.util.NettyUtils; | ||
import org.pytorch.serve.util.messages.RequestInput; | ||
|
@@ -61,6 +62,8 @@ public void handleRequest( | |
String[] segments) | ||
throws ModelException, DownloadArchiveException, WorkflowException, | ||
WorkerInitializationException { | ||
ConfigManager configManager = ConfigManager.getInstance(); | ||
configManager.checkTokenAuthorization(req, false); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
if (isManagementReq(segments)) { | ||
if (endpointMap.getOrDefault(segments[1], null) != null) { | ||
handleCustomEndpoint(ctx, req, segments, decoder); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
import org.pytorch.serve.archive.model.ModelException; | ||
import org.pytorch.serve.archive.workflow.WorkflowException; | ||
import org.pytorch.serve.http.HttpRequestHandlerChain; | ||
import org.pytorch.serve.util.ConfigManager; | ||
import org.pytorch.serve.util.NettyUtils; | ||
import org.pytorch.serve.wlm.WorkerInitializationException; | ||
import org.slf4j.Logger; | ||
|
@@ -47,6 +48,8 @@ public void handleRequest( | |
String[] segments) | ||
throws ModelException, DownloadArchiveException, WorkflowException, | ||
WorkerInitializationException { | ||
ConfigManager configManager = ConfigManager.getInstance(); | ||
configManager.checkTokenAuthorization(req, true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
if (segments.length >= 2 && "metrics".equals(segments[1])) { | ||
ByteBuf resBuf = Unpooled.directBuffer(); | ||
List<String> params = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
|
||
import com.google.gson.JsonObject; | ||
import com.google.gson.reflect.TypeToken; | ||
import io.netty.handler.codec.http.FullHttpRequest; | ||
import io.netty.handler.codec.http.HttpMethod; | ||
import io.netty.handler.ssl.SslContext; | ||
import io.netty.handler.ssl.SslContextBuilder; | ||
import io.netty.handler.ssl.util.SelfSignedCertificate; | ||
|
@@ -20,11 +22,14 @@ | |
import java.security.KeyFactory; | ||
import java.security.KeyStore; | ||
import java.security.PrivateKey; | ||
import java.security.SecureRandom; | ||
import java.security.cert.Certificate; | ||
import java.security.cert.CertificateFactory; | ||
import java.security.cert.X509Certificate; | ||
import java.security.spec.InvalidKeySpecException; | ||
import java.security.spec.PKCS8EncodedKeySpec; | ||
import java.time.DateTimeException; | ||
import java.time.Instant; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.Base64; | ||
|
@@ -36,13 +41,17 @@ | |
import java.util.Map; | ||
import java.util.Properties; | ||
import java.util.Set; | ||
import java.util.concurrent.TimeUnit; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
import java.util.regex.PatternSyntaxException; | ||
import org.apache.commons.cli.CommandLine; | ||
import org.apache.commons.cli.Option; | ||
import org.apache.commons.cli.Options; | ||
import org.apache.commons.io.IOUtils; | ||
import org.pytorch.serve.archive.model.InvalidKeyException; | ||
import org.pytorch.serve.archive.model.KeyTimeOutException; | ||
import org.pytorch.serve.archive.model.ModelException; | ||
import org.pytorch.serve.metrics.MetricBuilder; | ||
import org.pytorch.serve.servingsdk.snapshot.SnapshotSerializer; | ||
import org.pytorch.serve.snapshot.SnapshotSerializerFactory; | ||
|
@@ -146,6 +155,10 @@ public final class ConfigManager { | |
private boolean telemetryEnabled; | ||
private Logger logger = LoggerFactory.getLogger(ConfigManager.class); | ||
|
||
private SecureRandom secureRandom = new SecureRandom(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two seem to be unused |
||
private Base64.Encoder baseEncoder = Base64.getUrlEncoder(); | ||
public String keyFileLocation; | ||
|
||
private ConfigManager(Arguments args) throws IOException { | ||
prop = new Properties(); | ||
|
||
|
@@ -837,6 +850,104 @@ public boolean isSnapshotDisabled() { | |
return snapshotDisabled; | ||
} | ||
|
||
public boolean isTokenExpired(Instant expirationTime) { | ||
return !(Instant.now().isBefore(expirationTime)); | ||
} | ||
|
||
public String generateToken() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that generateToken is as same as genereateKey in plugin. I don't understand why the token implementation need twice. |
||
byte[] randomBytes = new byte[6]; | ||
secureRandom.nextBytes(randomBytes); | ||
return baseEncoder.encodeToString(randomBytes); | ||
} | ||
|
||
public boolean generateKeyFile() throws IOException { | ||
String fileData = " "; | ||
String absoluteFilePath = getCanonicalPath(".") + "/key_file.txt"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that the toke file path is different from the path implemented in plugin. Again, I don't understand why the token implementation need twice. |
||
keyFileLocation = absoluteFilePath; | ||
File file = new File(absoluteFilePath); | ||
if (!file.createNewFile() && !file.exists()) { | ||
return false; | ||
} | ||
Integer timeToExpiration = 30; // in minutes | ||
fileData = | ||
"Management Key: " | ||
+ generateToken() | ||
+ "\n" | ||
+ "Inference Key: " | ||
+ generateToken() | ||
+ " --- Expiration time: " | ||
+ Instant.now().plusSeconds(TimeUnit.MINUTES.toSeconds(timeToExpiration)) | ||
+ "\n"; | ||
Files.write(Paths.get("key_file.txt"), fileData.getBytes()); | ||
return true; | ||
} | ||
|
||
public List<String> parseFile(File tsTokenFile) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to maintain the data in memory and no longer need to parse the key file. |
||
List<String> parsedTokens = new ArrayList<>(); | ||
try { | ||
InputStream stream = Files.newInputStream(tsTokenFile.toPath()); | ||
byte[] array = new byte[100]; | ||
stream.read(array); | ||
String data = new String(array); | ||
String[] arrOfData = data.split("\n", 2); | ||
String[] managementArr = arrOfData[0].split(" ", 3); | ||
String[] inferenceArr = arrOfData[1].split(" ", 7); | ||
parsedTokens.add(managementArr[2]); | ||
parsedTokens.add(inferenceArr[2]); | ||
String[] expirationArr = inferenceArr[6].split("\n", 2); | ||
parsedTokens.add(expirationArr[0]); | ||
} catch (IOException | ArrayIndexOutOfBoundsException e) { | ||
System.out.println("Unable to read key file or key file has been modified"); | ||
return null; | ||
} | ||
return parsedTokens; | ||
} | ||
|
||
public void checkTokenAuthorization(FullHttpRequest req, boolean inferenceRequest) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation is too heavy for real time inference. is ConfigManager the right place to implement the entire token authentication? |
||
throws ModelException { | ||
HttpMethod method = req.method(); | ||
String filePath = keyFileLocation; | ||
if (filePath != null) { | ||
File tsTokenFile = new File(filePath); | ||
if (tsTokenFile.exists()) { | ||
List<String> parsedTokens = parseFile(tsTokenFile); | ||
String managementToken = parsedTokens.get(0); | ||
String inferenceToken = parsedTokens.get(1); | ||
Instant expirationTime = Instant.now(); | ||
try { | ||
expirationTime = Instant.parse(parsedTokens.get(2)); | ||
} catch (DateTimeException e) { | ||
e.printStackTrace(); | ||
System.out.println("{\n\t\"Error\": Key File has been modified \n}\n"); | ||
} | ||
String tokenBearer = req.headers().get("Authorization"); | ||
if (tokenBearer == null) { | ||
throw new InvalidKeyException("NO TOKEN PROVIDED"); | ||
} | ||
String[] arrOfStr = tokenBearer.split(" ", 2); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could make parsing the token from the header more robust using a regex pattern. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No longer need to parse the key file. |
||
if (arrOfStr.length == 1) { | ||
throw new InvalidKeyException("NO TOKEN PROVIDED"); | ||
} | ||
String token = arrOfStr[1]; | ||
String key = managementToken; | ||
if (inferenceRequest) { | ||
key = inferenceToken; | ||
} | ||
|
||
if (token.equals(key)) { | ||
if (isTokenExpired(expirationTime) && inferenceRequest) { | ||
throw new KeyTimeOutException("THE CURRENT TOKEN IS EXPIRED"); | ||
} | ||
System.out.println("TOKEN AUTHORIZATION WORKED"); | ||
} else { | ||
throw new InvalidKeyException("TOKEN IS INCORRECT "); | ||
} | ||
} else { | ||
System.out.println("TOKEN AUTHORIZATION IS NOT ENABLED"); | ||
} | ||
} | ||
} | ||
|
||
public boolean isSSLEnabled(ConnectorType connectorType) { | ||
String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080"); | ||
switch (connectorType) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,9 @@ public void handleRequest( | |
String[] segments) | ||
throws ModelException, DownloadArchiveException, WorkflowException, | ||
WorkerInitializationException { | ||
|
||
ConfigManager configManager = ConfigManager.getInstance(); | ||
configManager.checkTokenAuthorization(req, true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
if ("wfpredict".equalsIgnoreCase(segments[1])) { | ||
if (segments.length < 3) { | ||
throw new ResourceNotFoundException(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ | |
import org.pytorch.serve.http.MethodNotAllowedException; | ||
import org.pytorch.serve.http.ResourceNotFoundException; | ||
import org.pytorch.serve.http.StatusResponse; | ||
import org.pytorch.serve.util.ConfigManager; | ||
import org.pytorch.serve.util.JsonUtils; | ||
import org.pytorch.serve.util.NettyUtils; | ||
import org.pytorch.serve.wlm.WorkerInitializationException; | ||
|
@@ -63,6 +64,9 @@ public void handleRequest( | |
String[] segments) | ||
throws ModelException, DownloadArchiveException, WorkflowException, | ||
WorkerInitializationException { | ||
|
||
ConfigManager configManager = ConfigManager.getInstance(); | ||
configManager.checkTokenAuthorization(req, false); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
if (isManagementReq(segments)) { | ||
if (!"workflows".equals(segments[1])) { | ||
throw new ResourceNotFoundException(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should a new key be generated as long as model server is started?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, when model server is started the key file is generated with the three keys.