Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d7e72c8

Browse files
committedFeb 2, 2024
Enable per model useVenv configuration option
1 parent e1c0734 commit d7e72c8

File tree

3 files changed

+114
-51
lines changed

3 files changed

+114
-51
lines changed
 

‎frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java

+21-1
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,14 @@ public class ModelConfig {
6565
private int maxSequenceJobQueueSize = 1;
6666
/** the max number of sequences can be accepted. The default value is 1. */
6767
private int maxNumSequence = 1;
68-
6968
/** continuousBatching is a flag to enable continuous batching. */
7069
private boolean continuousBatching;
70+
/**
71+
* Create python virtual environment when using python backend to: 1) Install model dependencies
72+
* (if enabled globally using install_py_dep_per_model=true) 2) Run workers for model loading
73+
* and inference
74+
*/
75+
private boolean useVenv;
7176

7277
public static ModelConfig build(Map<String, Object> yamlMap) {
7378
ModelConfig modelConfig = new ModelConfig();
@@ -207,6 +212,13 @@ public static ModelConfig build(Map<String, Object> yamlMap) {
207212
v);
208213
}
209214
break;
215+
case "useVenv":
216+
if (v instanceof Boolean) {
217+
modelConfig.setUseVenv((boolean) v);
218+
} else {
219+
logger.warn("Invalid useVenv: {}, should be true or false", v);
220+
}
221+
break;
210222
default:
211223
break;
212224
}
@@ -379,6 +391,14 @@ public void setMaxNumSequence(int maxNumSequence) {
379391
this.maxNumSequence = Math.max(1, maxNumSequence);
380392
}
381393

394+
public boolean getUseVenv() {
395+
return useVenv;
396+
}
397+
398+
public void setUseVenv(boolean useVenv) {
399+
this.useVenv = useVenv;
400+
}
401+
382402
public enum ParallelType {
383403
NONE(""),
384404
PP("pp"),

‎frontend/server/src/main/java/org/pytorch/serve/util/messages/EnvironmentUtils.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.Map;
1111
import java.util.regex.Pattern;
1212
import org.pytorch.serve.archive.model.Manifest;
13+
import org.pytorch.serve.archive.model.ModelConfig;
1314
import org.pytorch.serve.util.ConfigManager;
1415
import org.pytorch.serve.wlm.Model;
1516
import org.slf4j.Logger;
@@ -79,8 +80,11 @@ public static String getPythonRunTime(Model model) {
7980
Manifest.RuntimeType runtime = model.getModelArchive().getManifest().getRuntime();
8081
if (runtime == Manifest.RuntimeType.PYTHON) {
8182
pythonRuntime = configManager.getPythonExecutable();
83+
ModelConfig modelConfig = model.getModelArchive().getModelConfig();
8284
Path pythonVenvRuntime = Paths.get(getPythonVenvPath(model), "bin", "python");
83-
if (Files.exists(pythonVenvRuntime)) {
85+
if (modelConfig != null
86+
&& modelConfig.getUseVenv() == true
87+
&& Files.exists(pythonVenvRuntime)) {
8488
pythonRuntime = pythonVenvRuntime.toString();
8589
}
8690
} else {

‎frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java

+88-49
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import java.io.IOException;
77
import java.io.InputStreamReader;
88
import java.net.HttpURLConnection;
9+
import java.nio.file.Files;
910
import java.nio.file.Path;
1011
import java.nio.file.Paths;
1112
import java.util.ArrayList;
@@ -99,6 +100,8 @@ public void registerAndUpdateModel(String modelName, JsonObject modelInfo)
99100

100101
createVersionedModel(tempModel, versionId);
101102

103+
setupModelVenv(tempModel);
104+
102105
setupModelDependencies(tempModel);
103106
if (defaultVersion) {
104107
modelManager.setDefaultVersion(modelName, versionId);
@@ -152,6 +155,8 @@ public ModelArchive registerModel(
152155
}
153156
}
154157

158+
setupModelVenv(tempModel);
159+
155160
setupModelDependencies(tempModel);
156161

157162
logger.info("Model {} loaded.", tempModel.getModelName());
@@ -206,9 +211,16 @@ private ModelArchive createModelArchive(
206211

207212
private void setupModelVenv(Model model)
208213
throws IOException, InterruptedException, ModelException {
214+
ModelConfig modelConfig = model.getModelArchive().getModelConfig();
215+
if (model.getModelArchive().getManifest().getRuntime() != Manifest.RuntimeType.PYTHON
216+
|| modelConfig == null
217+
|| modelConfig.getUseVenv() != true) {
218+
return;
219+
}
220+
209221
String venvPath = EnvironmentUtils.getPythonVenvPath(model);
210222
List<String> commandParts = new ArrayList<>();
211-
commandParts.add(EnvironmentUtils.getPythonRunTime(model));
223+
commandParts.add(configManager.getPythonExecutable());
212224
commandParts.add("-m");
213225
commandParts.add("venv");
214226
commandParts.add("--clear");
@@ -217,7 +229,7 @@ private void setupModelVenv(Model model)
217229

218230
ProcessBuilder processBuilder = new ProcessBuilder(commandParts);
219231

220-
if (isValidVenvPath(venvPath)) {
232+
if (isValidDependencyPath(venvPath)) {
221233
processBuilder.directory(Paths.get(venvPath).toFile().getParentFile());
222234
} else {
223235
throw new ModelException(
@@ -266,20 +278,29 @@ private void setupModelDependencies(Model model)
266278
String requirementsFile =
267279
model.getModelArchive().getManifest().getModel().getRequirementsFile();
268280

269-
if (configManager.getInstallPyDepPerModel() && requirementsFile != null) {
270-
setupModelVenv(model);
271-
String pythonRuntime = EnvironmentUtils.getPythonRunTime(model);
272-
if (!isValidVenvPath(pythonRuntime)) {
281+
if (!configManager.getInstallPyDepPerModel() || requirementsFile == null) {
282+
return;
283+
}
284+
285+
ModelConfig modelConfig = model.getModelArchive().getModelConfig();
286+
String pythonRuntime = EnvironmentUtils.getPythonRunTime(model);
287+
Path requirementsFilePath =
288+
Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile);
289+
List<String> commandParts = new ArrayList<>();
290+
ProcessBuilder processBuilder = new ProcessBuilder();
291+
292+
if (modelConfig != null && modelConfig.getUseVenv() == true) {
293+
if (!isValidDependencyPath(pythonRuntime)) {
273294
throw new ModelException(
274295
"Invalid python venv runtime path for model "
275296
+ model.getModelName()
276297
+ ": "
277298
+ pythonRuntime);
278299
}
279300

280-
Path requirementsFilePath =
281-
Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile);
282-
List<String> commandParts = new ArrayList<>();
301+
processBuilder.directory(
302+
Paths.get(EnvironmentUtils.getPythonVenvPath(model)).toFile().getParentFile());
303+
283304
commandParts.add(pythonRuntime);
284305
commandParts.add("-m");
285306
commandParts.add("pip");
@@ -289,54 +310,72 @@ private void setupModelDependencies(Model model)
289310
commandParts.add("only-if-needed");
290311
commandParts.add("-r");
291312
commandParts.add(requirementsFilePath.toString());
292-
293-
ProcessBuilder processBuilder = new ProcessBuilder(commandParts);
294-
295-
String[] envp =
296-
EnvironmentUtils.getEnvString(
297-
configManager.getModelServerHome(),
298-
model.getModelDir().getAbsolutePath(),
299-
null);
300-
Map<String, String> environment = processBuilder.environment();
301-
for (String envVar : envp) {
302-
String[] parts = envVar.split("=", 2);
303-
if (parts.length == 2) {
304-
environment.put(parts[0], parts[1]);
305-
}
313+
} else {
314+
File dependencyPath = model.getModelDir();
315+
if (Files.isSymbolicLink(dependencyPath.toPath())) {
316+
dependencyPath = dependencyPath.getParentFile();
317+
}
318+
if (!isValidDependencyPath(dependencyPath.getPath())) {
319+
throw new ModelException(
320+
"Invalid 3rd party package installation path "
321+
+ dependencyPath.getCanonicalPath());
306322
}
307-
processBuilder.directory(
308-
Paths.get(EnvironmentUtils.getPythonVenvPath(model)).toFile().getParentFile());
309-
processBuilder.redirectErrorStream(true);
310323

311-
Process process = processBuilder.start();
324+
processBuilder.directory(dependencyPath);
312325

313-
int exitCode = process.waitFor();
314-
String line;
315-
StringBuilder outputString = new StringBuilder();
316-
BufferedReader brdr =
317-
new BufferedReader(new InputStreamReader(process.getInputStream()));
318-
while ((line = brdr.readLine()) != null) {
319-
outputString.append(line + "\n");
320-
}
326+
commandParts.add(pythonRuntime);
327+
commandParts.add("-m");
328+
commandParts.add("pip");
329+
commandParts.add("install");
330+
commandParts.add("-U");
331+
commandParts.add("-t");
332+
commandParts.add(dependencyPath.getAbsolutePath());
333+
commandParts.add("-r");
334+
commandParts.add(requirementsFilePath.toString());
335+
}
321336

322-
if (exitCode == 0) {
323-
logger.info(
324-
"Installed custom pip packages for model {}:\n{}",
325-
model.getModelName(),
326-
outputString.toString());
327-
} else {
328-
logger.error(
329-
"Custom pip package installation failed for model {}:\n{}",
330-
model.getModelName(),
331-
outputString.toString());
332-
throw new ModelException(
333-
"Custom pip package installation failed for model " + model.getModelName());
337+
processBuilder.command(commandParts);
338+
String[] envp =
339+
EnvironmentUtils.getEnvString(
340+
configManager.getModelServerHome(),
341+
model.getModelDir().getAbsolutePath(),
342+
null);
343+
Map<String, String> environment = processBuilder.environment();
344+
for (String envVar : envp) {
345+
String[] parts = envVar.split("=", 2);
346+
if (parts.length == 2) {
347+
environment.put(parts[0], parts[1]);
334348
}
335349
}
350+
processBuilder.redirectErrorStream(true);
351+
352+
Process process = processBuilder.start();
353+
354+
int exitCode = process.waitFor();
355+
String line;
356+
StringBuilder outputString = new StringBuilder();
357+
BufferedReader brdr = new BufferedReader(new InputStreamReader(process.getInputStream()));
358+
while ((line = brdr.readLine()) != null) {
359+
outputString.append(line + "\n");
360+
}
361+
362+
if (exitCode == 0) {
363+
logger.info(
364+
"Installed custom pip packages for model {}:\n{}",
365+
model.getModelName(),
366+
outputString.toString());
367+
} else {
368+
logger.error(
369+
"Custom pip package installation failed for model {}:\n{}",
370+
model.getModelName(),
371+
outputString.toString());
372+
throw new ModelException(
373+
"Custom pip package installation failed for model " + model.getModelName());
374+
}
336375
}
337376

338-
private boolean isValidVenvPath(String venvPath) {
339-
if (Paths.get(venvPath)
377+
private boolean isValidDependencyPath(String dependencyPath) {
378+
if (Paths.get(dependencyPath)
340379
.normalize()
341380
.startsWith(FileUtils.getTempDirectory().toPath().normalize())) {
342381
return true;

0 commit comments

Comments
 (0)
Please sign in to comment.