Skip to content

Commit 750f2e0

Browse files
committed
Refactor implementaiton to check for useVenv in Model.java
1 parent 62a82cf commit 750f2e0

File tree

3 files changed

+36
-33
lines changed

3 files changed

+36
-33
lines changed

frontend/server/src/main/java/org/pytorch/serve/util/messages/EnvironmentUtils.java

+7-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import java.util.Map;
1111
import java.util.regex.Pattern;
1212
import org.pytorch.serve.archive.model.Manifest;
13-
import org.pytorch.serve.archive.model.ModelConfig;
1413
import org.pytorch.serve.util.ConfigManager;
1514
import org.pytorch.serve.wlm.Model;
1615
import org.slf4j.Logger;
@@ -77,14 +76,12 @@ public static String[] getEnvString(String cwd, String modelPath, String handler
7776

7877
public static String getPythonRunTime(Model model) {
7978
String pythonRuntime;
80-
Manifest.RuntimeType runtime = model.getModelArchive().getManifest().getRuntime();
79+
Manifest.RuntimeType runtime = model.getRuntimeType();
8180
if (runtime == Manifest.RuntimeType.PYTHON) {
8281
pythonRuntime = configManager.getPythonExecutable();
83-
ModelConfig modelConfig = model.getModelArchive().getModelConfig();
84-
Path pythonVenvRuntime = Paths.get(getPythonVenvPath(model), "bin", "python");
85-
if (modelConfig != null
86-
&& modelConfig.getUseVenv() == true
87-
&& Files.exists(pythonVenvRuntime)) {
82+
Path pythonVenvRuntime =
83+
Paths.get(getPythonVenvPath(model).toString(), "bin", "python");
84+
if (model.isUseVenv() && Files.exists(pythonVenvRuntime)) {
8885
pythonRuntime = pythonVenvRuntime.toString();
8986
}
9087
} else {
@@ -93,13 +90,13 @@ public static String getPythonRunTime(Model model) {
9390
return pythonRuntime;
9491
}
9592

96-
public static String getPythonVenvPath(Model model) {
93+
public static File getPythonVenvPath(Model model) {
9794
File modelDir = model.getModelDir();
9895
if (Files.isSymbolicLink(modelDir.toPath())) {
9996
modelDir = modelDir.getParentFile();
10097
}
101-
Path venvPath = Paths.get(modelDir.getAbsolutePath(), "venv");
102-
return venvPath.toString();
98+
Path venvPath = Paths.get(modelDir.getAbsolutePath(), "venv").toAbsolutePath();
99+
return venvPath.toFile();
103100
}
104101

105102
public static String[] getCppEnvString(String libPath) {

frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java

+6
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ public class Model {
8282
private boolean useJobTicket;
8383
private AtomicInteger numJobTickets;
8484
private boolean continuousBatching;
85+
private boolean useVenv;
8586

8687
public Model(ModelArchive modelArchive, int queueSize) {
8788
this.modelArchive = modelArchive;
8889
if (modelArchive != null && modelArchive.getModelConfig() != null) {
8990
continuousBatching = modelArchive.getModelConfig().isContinuousBatching();
91+
useVenv = modelArchive.getModelConfig().getUseVenv();
9092
if (modelArchive.getModelConfig().getParallelLevel() > 0
9193
&& modelArchive.getModelConfig().getParallelType()
9294
!= ModelConfig.ParallelType.NONE) {
@@ -630,6 +632,10 @@ public boolean isContinuousBatching() {
630632
return continuousBatching;
631633
}
632634

635+
public boolean isUseVenv() {
636+
return useVenv;
637+
}
638+
633639
public boolean hasTensorParallel() {
634640
switch (this.parallelType) {
635641
case PP:

frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java

+23-23
Original file line numberDiff line numberDiff line change
@@ -211,29 +211,29 @@ private ModelArchive createModelArchive(
211211

212212
private void setupModelVenv(Model model)
213213
throws IOException, InterruptedException, ModelException {
214-
ModelConfig modelConfig = model.getModelArchive().getModelConfig();
215-
if (model.getModelArchive().getManifest().getRuntime() != Manifest.RuntimeType.PYTHON
216-
|| modelConfig == null
217-
|| modelConfig.getUseVenv() != true) {
214+
if (model.getRuntimeType() != Manifest.RuntimeType.PYTHON || !model.isUseVenv()) {
218215
return;
219216
}
220217

221-
String venvPath = EnvironmentUtils.getPythonVenvPath(model);
218+
File venvPath = EnvironmentUtils.getPythonVenvPath(model);
222219
List<String> commandParts = new ArrayList<>();
223220
commandParts.add(configManager.getPythonExecutable());
224221
commandParts.add("-m");
225222
commandParts.add("venv");
226223
commandParts.add("--clear");
227224
commandParts.add("--system-site-packages");
228-
commandParts.add(venvPath);
225+
commandParts.add(venvPath.toString());
229226

230227
ProcessBuilder processBuilder = new ProcessBuilder(commandParts);
231228

232229
if (isValidDependencyPath(venvPath)) {
233-
processBuilder.directory(Paths.get(venvPath).toFile().getParentFile());
230+
processBuilder.directory(venvPath.getParentFile());
234231
} else {
235232
throw new ModelException(
236-
"Invalid python venv path for model " + model.getModelName() + ": " + venvPath);
233+
"Invalid python venv path for model "
234+
+ model.getModelName()
235+
+ ": "
236+
+ venvPath.toString());
237237
}
238238
Map<String, String> environment = processBuilder.environment();
239239
String[] envp =
@@ -261,12 +261,14 @@ private void setupModelVenv(Model model)
261261

262262
if (exitCode == 0) {
263263
logger.info(
264-
"Created virtual environment for model {}: {}", model.getModelName(), venvPath);
264+
"Created virtual environment for model {}: {}",
265+
model.getModelName(),
266+
venvPath.toString());
265267
} else {
266268
logger.error(
267269
"Virtual environment creation for model {} at {} failed:\n{}",
268270
model.getModelName(),
269-
venvPath,
271+
venvPath.toString(),
270272
outputString.toString());
271273
throw new ModelException(
272274
"Virtual environment creation failed for model " + model.getModelName());
@@ -282,24 +284,22 @@ private void setupModelDependencies(Model model)
282284
return;
283285
}
284286

285-
ModelConfig modelConfig = model.getModelArchive().getModelConfig();
286287
String pythonRuntime = EnvironmentUtils.getPythonRunTime(model);
287288
Path requirementsFilePath =
288-
Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile);
289+
Paths.get(model.getModelDir().getAbsolutePath(), requirementsFile).toAbsolutePath();
289290
List<String> commandParts = new ArrayList<>();
290291
ProcessBuilder processBuilder = new ProcessBuilder();
291292

292-
if (modelConfig != null && modelConfig.getUseVenv() == true) {
293-
if (!isValidDependencyPath(pythonRuntime)) {
293+
if (model.isUseVenv()) {
294+
if (!isValidDependencyPath(Paths.get(pythonRuntime).toFile())) {
294295
throw new ModelException(
295296
"Invalid python venv runtime path for model "
296297
+ model.getModelName()
297298
+ ": "
298299
+ pythonRuntime);
299300
}
300301

301-
processBuilder.directory(
302-
Paths.get(EnvironmentUtils.getPythonVenvPath(model)).toFile().getParentFile());
302+
processBuilder.directory(EnvironmentUtils.getPythonVenvPath(model).getParentFile());
303303

304304
commandParts.add(pythonRuntime);
305305
commandParts.add("-m");
@@ -311,14 +311,13 @@ private void setupModelDependencies(Model model)
311311
commandParts.add("-r");
312312
commandParts.add(requirementsFilePath.toString());
313313
} else {
314-
File dependencyPath = model.getModelDir();
314+
File dependencyPath = model.getModelDir().getAbsolutePath();
315315
if (Files.isSymbolicLink(dependencyPath.toPath())) {
316316
dependencyPath = dependencyPath.getParentFile();
317317
}
318-
if (!isValidDependencyPath(dependencyPath.getPath())) {
318+
if (!isValidDependencyPath(dependencyPath)) {
319319
throw new ModelException(
320-
"Invalid 3rd party package installation path "
321-
+ dependencyPath.getCanonicalPath());
320+
"Invalid 3rd party package installation path " + dependencyPath.toString());
322321
}
323322

324323
processBuilder.directory(dependencyPath);
@@ -329,7 +328,7 @@ private void setupModelDependencies(Model model)
329328
commandParts.add("install");
330329
commandParts.add("-U");
331330
commandParts.add("-t");
332-
commandParts.add(dependencyPath.getAbsolutePath());
331+
commandParts.add(dependencyPath.toString());
333332
commandParts.add("-r");
334333
commandParts.add(requirementsFilePath.toString());
335334
}
@@ -374,8 +373,9 @@ private void setupModelDependencies(Model model)
374373
}
375374
}
376375

377-
private boolean isValidDependencyPath(String dependencyPath) {
378-
if (Paths.get(dependencyPath)
376+
private boolean isValidDependencyPath(File dependencyPath) {
377+
if (dependencyPath
378+
.toPath()
379379
.normalize()
380380
.startsWith(FileUtils.getTempDirectory().toPath().normalize())) {
381381
return true;

0 commit comments

Comments
 (0)