Skip to content

Commit b40937d

Browse files
committed
Add AMD frontend support
1 parent 610e1a1 commit b40937d

28 files changed

+1859
-146
lines changed

.gitignore

+6
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ instances.yaml.backup
4646
# cpp
4747
cpp/_build
4848
cpp/third-party
49+
50+
# projects
51+
.tool-versions
52+
**/*/.classpath
53+
**/*/.settings
54+
**/*/.project

frontend/build.gradle

+2-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ def javaProjects() {
3737

3838
configure(javaProjects()) {
3939
apply plugin: 'java-library'
40-
sourceCompatibility = 1.8
41-
targetCompatibility = 1.8
40+
sourceCompatibility = JavaVersion.VERSION_17
41+
targetCompatibility = JavaVersion.VERSION_17
4242

4343
defaultTasks 'jar'
4444

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package org.pytorch.serve.device;
2+
3+
import java.text.MessageFormat;
4+
import org.pytorch.serve.device.interfaces.IAcceleratorUtility;
5+
6+
public class Accelerator {
7+
public final Integer id;
8+
public final AcceleratorVendor vendor;
9+
public final String model;
10+
public IAcceleratorUtility acceleratorUtility;
11+
public Float usagePercentage;
12+
public Float memoryUtilizationPercentage;
13+
public Integer memoryAvailableMegabytes;
14+
public Integer memoryUtilizationMegabytes;
15+
16+
public Accelerator(String acceleratorName, AcceleratorVendor vendor, Integer gpuId) {
17+
this.model = acceleratorName;
18+
this.vendor = vendor;
19+
this.id = gpuId;
20+
this.usagePercentage = (float) 0.0;
21+
this.memoryUtilizationPercentage = (float) 0.0;
22+
this.memoryAvailableMegabytes = 0;
23+
this.memoryUtilizationMegabytes = 0;
24+
}
25+
26+
// Getters
27+
public Integer getMemoryAvailableMegaBytes() {
28+
return memoryAvailableMegabytes;
29+
}
30+
31+
public AcceleratorVendor getVendor() {
32+
return vendor;
33+
}
34+
35+
public String getAcceleratorModel() {
36+
return model;
37+
}
38+
39+
public Integer getAcceleratorId() {
40+
return id;
41+
}
42+
43+
public Float getUsagePercentage() {
44+
return usagePercentage;
45+
}
46+
47+
public Float getMemoryUtilizationPercentage() {
48+
return memoryUtilizationPercentage;
49+
}
50+
51+
public Integer getMemoryUtilizationMegabytes() {
52+
return memoryUtilizationMegabytes;
53+
}
54+
55+
// Setters
56+
public void setMemoryAvailableMegaBytes(Integer memoryAvailable) {
57+
this.memoryAvailableMegabytes = memoryAvailable;
58+
}
59+
60+
public void setUsagePercentage(Float acceleratorUtilization) {
61+
this.usagePercentage = acceleratorUtilization;
62+
}
63+
64+
public void setMemoryUtilizationPercentage(Float memoryUtilizationPercentage) {
65+
this.memoryUtilizationPercentage = memoryUtilizationPercentage;
66+
}
67+
68+
public void setMemoryUtilizationMegabytes(Integer memoryUtilizationMegabytes) {
69+
this.memoryUtilizationMegabytes = memoryUtilizationMegabytes;
70+
}
71+
72+
// Other Methods
73+
public String utilizationToString() {
74+
final String message =
75+
MessageFormat.format(
76+
"gpuId::{0} utilization.gpu::{1} % utilization.memory::{2} % memory.used::{3} MiB",
77+
id,
78+
usagePercentage,
79+
memoryUtilizationPercentage,
80+
memoryUtilizationMegabytes);
81+
82+
return message;
83+
}
84+
85+
public void updateDynamicAttributes(Accelerator updated) {
86+
this.usagePercentage = updated.usagePercentage;
87+
this.memoryUtilizationPercentage = updated.memoryUtilizationPercentage;
88+
this.memoryUtilizationMegabytes = updated.memoryUtilizationMegabytes;
89+
}
90+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package org.pytorch.serve.device;
2+
3+
public enum AcceleratorVendor {
4+
AMD,
5+
NVIDIA,
6+
INTEL,
7+
APPLE,
8+
UNKNOWN
9+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package org.pytorch.serve.device;
2+
3+
import java.io.IOException;
4+
import java.util.ArrayList;
5+
import java.util.LinkedHashSet;
6+
import java.util.List;
7+
import java.util.Map;
8+
import java.util.Optional;
9+
import java.util.stream.Collectors;
10+
import org.pytorch.serve.device.interfaces.IAcceleratorUtility;
11+
import org.pytorch.serve.device.utils.*;
12+
import org.slf4j.Logger;
13+
import org.slf4j.LoggerFactory;
14+
15+
public class SystemInfo {
16+
static final Logger logger = LoggerFactory.getLogger(SystemInfo.class);
17+
//
18+
// Contains information about the system (physical or virtual machine)
19+
// we are running the workload on.
20+
// Specifically how many accelerators and info about them.
21+
//
22+
23+
public AcceleratorVendor acceleratorVendor;
24+
ArrayList<Accelerator> accelerators;
25+
private IAcceleratorUtility acceleratorUtil;
26+
27+
public SystemInfo() {
28+
// Detect and set the vendor of any accelerators in the system
29+
this.acceleratorVendor = detectVendorType();
30+
this.accelerators = new ArrayList<Accelerator>();
31+
32+
// If accelerators are present (vendor != UNKNOWN),
33+
// initialize accelerator utilities
34+
Optional.of(hasAccelerators())
35+
// Only proceed if hasAccelerators() returns true
36+
.filter(Boolean::booleanValue)
37+
// Execute this block if accelerators are present
38+
.ifPresent(
39+
hasAcc -> {
40+
// Create the appropriate utility class based on vendor
41+
this.acceleratorUtil = createAcceleratorUtility();
42+
// Populate the accelerators list based on environment
43+
// variables and available devices
44+
populateAccelerators();
45+
});
46+
47+
// Safely handle accelerator metrics update
48+
Optional.ofNullable(accelerators)
49+
// Only proceed if the accelerators list is not empty
50+
.filter(list -> !list.isEmpty())
51+
// Update metrics (utilization, memory, etc.) for all accelerators if list
52+
// exists and not empty
53+
.ifPresent(list -> updateAcceleratorMetrics());
54+
}
55+
56+
private IAcceleratorUtility createAcceleratorUtility() {
57+
switch (this.acceleratorVendor) {
58+
case AMD:
59+
return new ROCmUtil();
60+
case NVIDIA:
61+
return new CudaUtil();
62+
case INTEL:
63+
return new XpuUtil();
64+
case APPLE:
65+
return new AppleUtil();
66+
default:
67+
return null;
68+
}
69+
}
70+
71+
private void populateAccelerators() {
72+
if (this.acceleratorUtil != null) {
73+
String envVarName = this.acceleratorUtil.getGpuEnvVariableName();
74+
String requestedAcceleratorIds = System.getenv(envVarName);
75+
LinkedHashSet<Integer> availableAcceleratorIds =
76+
IAcceleratorUtility.parseVisibleDevicesEnv(requestedAcceleratorIds);
77+
this.accelerators =
78+
this.acceleratorUtil.getAvailableAccelerators(availableAcceleratorIds);
79+
} else {
80+
this.accelerators = new ArrayList<>();
81+
}
82+
}
83+
84+
boolean hasAccelerators() {
85+
return this.acceleratorVendor != AcceleratorVendor.UNKNOWN;
86+
}
87+
88+
public Integer getNumberOfAccelerators() {
89+
// since we instance create `accelerators` as an empty list
90+
// in the constructor, the null check should be redundant.
91+
// leaving it to be sure.
92+
return (accelerators != null) ? accelerators.size() : 0;
93+
}
94+
95+
public static AcceleratorVendor detectVendorType() {
96+
if (isCommandAvailable("rocm-smi")) {
97+
return AcceleratorVendor.AMD;
98+
} else if (isCommandAvailable("nvidia-smi")) {
99+
return AcceleratorVendor.NVIDIA;
100+
} else if (isCommandAvailable("xpu-smi")) {
101+
return AcceleratorVendor.INTEL;
102+
} else if (isCommandAvailable("system_profiler")) {
103+
return AcceleratorVendor.APPLE;
104+
} else {
105+
return AcceleratorVendor.UNKNOWN;
106+
}
107+
}
108+
109+
private static boolean isCommandAvailable(String command) {
110+
String operatingSystem = System.getProperty("os.name").toLowerCase();
111+
String commandCheck = operatingSystem.contains("win") ? "where" : "which";
112+
ProcessBuilder processBuilder = new ProcessBuilder(commandCheck, command);
113+
try {
114+
Process process = processBuilder.start();
115+
int exitCode = process.waitFor();
116+
return exitCode == 0;
117+
} catch (IOException | InterruptedException e) {
118+
return false;
119+
}
120+
}
121+
122+
public ArrayList<Accelerator> getAccelerators() {
123+
return this.accelerators;
124+
}
125+
126+
private void updateAccelerators(List<Accelerator> updatedAccelerators) {
127+
// Create a map of existing accelerators with ID as key
128+
Map<Integer, Accelerator> existingAcceleratorsMap =
129+
this.accelerators.stream().collect(Collectors.toMap(acc -> acc.id, acc -> acc));
130+
131+
// Update existing accelerators and add new ones
132+
this.accelerators =
133+
updatedAccelerators.stream()
134+
.map(
135+
updatedAcc -> {
136+
Accelerator existingAcc =
137+
existingAcceleratorsMap.get(updatedAcc.id);
138+
if (existingAcc != null) {
139+
existingAcc.updateDynamicAttributes(updatedAcc);
140+
return existingAcc;
141+
} else {
142+
return updatedAcc;
143+
}
144+
})
145+
.collect(Collectors.toCollection(ArrayList::new));
146+
}
147+
148+
public void updateAcceleratorMetrics() {
149+
if (this.acceleratorUtil != null) {
150+
List<Accelerator> updatedAccelerators =
151+
this.acceleratorUtil.getUpdatedAcceleratorsUtilization(this.accelerators);
152+
153+
updateAccelerators(updatedAccelerators);
154+
}
155+
}
156+
157+
public AcceleratorVendor getAcceleratorVendor() {
158+
return this.acceleratorVendor;
159+
}
160+
161+
public String getVisibleDevicesEnvName() {
162+
if (this.accelerators.isEmpty() || this.accelerators == null) {
163+
return null;
164+
}
165+
return this.accelerators.get(0).acceleratorUtility.getGpuEnvVariableName();
166+
}
167+
}

0 commit comments

Comments
 (0)