Skip to content

Commit eaa92d4

Browse files
authored
[ROCm] [Feature] [Doc] [Dockerfile] [BugFix] Support Per-Token-Activation Per-Channel-Weight FP8 Quantization Inferencing (vllm-project#12501)
1 parent 0630d45 commit eaa92d4

File tree

8 files changed

+295
-32
lines changed

8 files changed

+295
-32
lines changed

Dockerfile.rocm_base

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ ARG RCCL_BRANCH="648a58d"
66
ARG RCCL_REPO="https://github.com/ROCm/rccl"
77
ARG TRITON_BRANCH="e5be006"
88
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
9-
ARG PYTORCH_BRANCH="8d4926e"
9+
ARG PYTORCH_BRANCH="3a585126"
1010
ARG PYTORCH_VISION_BRANCH="v0.19.1"
1111
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"

docs/source/getting_started/installation/gpu/rocm.inc.md

+45-19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Installation
22

3-
vLLM supports AMD GPUs with ROCm 6.2.
3+
vLLM supports AMD GPUs with ROCm 6.3.
44

55
:::{attention}
66
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
@@ -9,7 +9,7 @@ There are no pre-built wheels for this device, so you must either use the pre-bu
99
## Requirements
1010

1111
- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
12-
- ROCm 6.2
12+
- ROCm 6.3
1313

1414
## Set up using Python
1515

@@ -24,9 +24,15 @@ Currently, there are no pre-built ROCm wheels.
2424
- [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html)
2525
- [PyTorch](https://pytorch.org/)
2626

27-
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`.
27+
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.3_ubuntu24.04_py3.12_pytorch_release_2.4.0`, `rocm/pytorch-nightly`. If you are using docker image, you can skip to Step 3.
2828

29-
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/)
29+
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/). Example:
30+
31+
```console
32+
# Install PyTorch
33+
$ pip uninstall torch -y
34+
$ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.3
35+
```
3036

3137
1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton)
3238

@@ -37,7 +43,7 @@ Currently, there are no pre-built ROCm wheels.
3743
pip uninstall -y triton
3844
git clone https://github.com/OpenAI/triton.git
3945
cd triton
40-
git checkout e192dba
46+
git checkout e5be006
4147
cd python
4248
pip3 install .
4349
cd ../..
@@ -49,15 +55,15 @@ Currently, there are no pre-built ROCm wheels.
4955

5056
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile)
5157

52-
Install ROCm's flash attention (v2.5.9.post1) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
58+
Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
5359
Alternatively, wheels intended for vLLM use can be accessed under the releases.
5460

55-
For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
61+
For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
5662

5763
```console
5864
git clone https://github.com/ROCm/flash-attention.git
5965
cd flash-attention
60-
git checkout 3cea2fb
66+
git checkout b7d29fb
6167
git submodule update --init
6268
GPU_ARCHS="gfx90a" python3 setup.py install
6369
cd ..
@@ -67,20 +73,16 @@ Currently, there are no pre-built ROCm wheels.
6773
You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
6874
:::
6975

70-
3. Build vLLM. For example, vLLM on ROCM 6.2 can be built with the following steps:
76+
3. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps:
7177

7278
```bash
7379
$ pip install --upgrade pip
7480

75-
# Install PyTorch
76-
$ pip uninstall torch -y
77-
$ pip install --no-cache-dir --pre torch --index-url https://download.pytorch.org/whl/rocm6.2
78-
7981
# Build & install AMD SMI
8082
$ pip install /opt/rocm/share/amd_smi
8183

8284
# Install dependencies
83-
$ pip install --upgrade numba scipy huggingface-hub[cli]
85+
$ pip install --upgrade numba scipy huggingface-hub[cli,hf_transfer] setuptools_scm
8486
$ pip install "numpy<2"
8587
$ pip install -r requirements-rocm.txt
8688

@@ -104,7 +106,7 @@ Currently, there are no pre-built ROCm wheels.
104106
For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization).
105107
:::
106108

107-
## Set up using Docker
109+
## Set up using Docker (Recommended)
108110

109111
### Pre-built images
110112

@@ -120,7 +122,12 @@ for instructions on how to use this prebuilt docker image.
120122

121123
Building the Docker image from source is the recommended way to use vLLM with ROCm.
122124

123-
First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
125+
#### (Optional) Build an image with ROCm software stack
126+
127+
Build a docker image from <gh-file:Dockerfile.rocm_base> which setup ROCm software stack needed by the vLLM.
128+
**This step is optional as this rocm_base image is usually prebuilt and store at [Docker Hub](https://hub.docker.com/r/rocm/vllm-dev) under tag `rocm/vllm-dev:base` to speed up user experience.**
129+
If you choose to build this rocm_base image yourself, the steps are as follows.
130+
124131
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
125132

126133
```console
@@ -131,7 +138,26 @@ It is important that the user kicks off the docker build using buildkit. Either
131138
}
132139
```
133140

134-
<gh-file:Dockerfile.rocm> uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
141+
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:
142+
143+
```console
144+
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm_base -t rocm/vllm-dev:base .
145+
```
146+
147+
#### Build an image with vLLM
148+
149+
First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
150+
It is important that the user kicks off the docker build using buildkit. Either the user put `DOCKER_BUILDKIT=1` as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
151+
152+
```console
153+
{
154+
"features": {
155+
"buildkit": true
156+
}
157+
}
158+
```
159+
160+
<gh-file:Dockerfile.rocm> uses ROCm 6.3 by default, but also supports ROCm 5.7, 6.0, 6.1, and 6.2, in older vLLM branches.
135161
It provides flexibility to customize the build of docker image using the following arguments:
136162

137163
- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using <gh-file:Dockerfile.rocm_base>
@@ -141,13 +167,13 @@ It provides flexibility to customize the build of docker image using the followi
141167

142168
Their values can be passed in when running `docker build` with `--build-arg` options.
143169

144-
To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
170+
To build vllm on ROCm 6.3 for MI200 and MI300 series, you can use the default:
145171

146172
```console
147173
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
148174
```
149175

150-
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:
176+
To build vllm on ROCm 6.3 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:
151177

152178
```console
153179
DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm .

tests/quantization/test_fp8.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,21 @@ def check_model(model):
5555

5656
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
5757

58-
# NOTE: it is valid for scales to be 1.0 (default value), but
59-
# we know these checkpoints have scales < 1.0
60-
assert 0.0 < attn._k_scale < 1.0
61-
assert 0.0 < attn._v_scale < 1.0
58+
if not current_platform.is_rocm():
59+
# NOTE: This code path requires validation on Non-CUDA platform
60+
# NOTE: it is valid for scales to be 1.0 (default value), but
61+
# we know these checkpoints have scales < 1.0
62+
assert 0.0 < attn._k_scale < 1.0
63+
assert 0.0 < attn._v_scale < 1.0
64+
else:
65+
# NOTE: This code path is for ROCm platform
66+
# NOTE: it is valid for scales to be 1.0 (default value), but
67+
# we know these checkpoints have scales < 1.0
68+
# However on ROCm platform, the _k_scale and _v_scale will be
69+
# scaled by a factor of 2 as described in
70+
# vllm/model_executor/layers/quantization/kv_cache.py
71+
assert 0.0 < attn._k_scale < (1.0 * 2.0)
72+
assert 0.0 < attn._v_scale < (1.0 * 2.0)
6273

6374
llm.apply_model(check_model)
6475

@@ -91,13 +102,29 @@ def check_model(model):
91102
assert attn._k_scale == 1.0
92103
assert attn._v_scale == 1.0
93104

94-
if current_platform.has_device_capability(89) and not force_marlin:
95-
# For GPUs with hardware support, we keep weights in fp8
96-
assert fc1.weight.dtype == torch.float8_e4m3fn
97-
else:
98-
# For GPUs without hardware support, we pack the fp8 weights
99-
# for weight-only quantization using Marlin kernels
100-
assert fc1.weight.dtype == torch.int32
105+
if current_platform.is_cuda():
106+
if current_platform.has_device_capability(
107+
89) and not force_marlin:
108+
# For GPUs with hardware support, we keep weights in fp8
109+
assert fc1.weight.dtype == torch.float8_e4m3fn
110+
else:
111+
# For GPUs without hardware support, we pack the fp8 weights
112+
# for weight-only quantization using Marlin kernels
113+
assert fc1.weight.dtype == torch.int32
114+
elif current_platform.is_rocm():
115+
# Only MI300 and above support quantization='fp8'
116+
if current_platform.has_device_capability(
117+
94) and not force_marlin:
118+
# For GPUs with hardware support, we keep weights in fp8
119+
assert fc1.weight.dtype == torch.float8_e4m3fnuz
120+
else: # unsupported ROCm platform
121+
pytest.skip(
122+
"Skip `test_load_fp16_model`. "
123+
"It only runs on ROCm platform with FP8 compute."
124+
" e.g. MI300X and above.")
125+
else: # unsupported platform
126+
pytest.skip("Skip `test_load_fp16_model`. "
127+
"It only runs on CUDA and ROCm platform.")
101128

102129
llm.apply_model(check_model)
103130

tests/quantization/test_ptpc_fp8.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests whether PTPC w8a8 FP8 computation is enabled correctly.
3+
4+
Run `pytest tests/quantization/test_ptpc_fp8.py --forked`.
5+
"""
6+
import pytest
7+
import torch
8+
9+
from tests.quantization.utils import is_quant_method_supported
10+
from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod
11+
from vllm.model_executor.layers.quantization.ptpc_fp8 import (
12+
PTPCFp8LinearMethod)
13+
from vllm.platforms import current_platform
14+
15+
16+
@pytest.mark.skipif(not is_quant_method_supported("ptpc_fp8"),
17+
reason="PTPC FP8 is not supported on this GPU type.")
18+
@pytest.mark.skipif(not current_platform.is_rocm(),
19+
reason="This test is for ROCm GPU.")
20+
@pytest.mark.parametrize("dtype", ["auto", "bfloat16", "float16"])
21+
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
22+
def test_ptpc_fp8_rocm(vllm_runner, dtype: str, kv_cache_dtype: str) -> None:
23+
24+
try:
25+
with vllm_runner("facebook/opt-125m",
26+
dtype=dtype,
27+
quantization="ptpc_fp8",
28+
kv_cache_dtype=kv_cache_dtype) as llm:
29+
30+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
31+
fc1 = model.model.decoder.layers[0].fc1
32+
assert isinstance(fc1.quant_method, PTPCFp8LinearMethod)
33+
if kv_cache_dtype == "ptpc_fp8":
34+
attn = model.model.decoder.layers[0].self_attn.attn
35+
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
36+
assert attn._k_scale == 1.0
37+
assert attn._v_scale == 1.0
38+
39+
if current_platform.has_device_capability(94):
40+
# For GPUs with hardware support, we keep weights in fp8
41+
assert fc1.weight.dtype == torch.float8_e4m3fnuz
42+
else:
43+
pytest.skip()
44+
45+
output = llm.generate_greedy("Hello my name is", max_tokens=20)
46+
assert output
47+
except AssertionError as e:
48+
if str(
49+
e
50+
) == "Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16. torch.float16 is specified.": # noqa: E501
51+
# If the error message matches, the test passes
52+
pass
53+
else:
54+
# If the error message does not match, re-raise the exception
55+
raise

vllm/model_executor/layers/quantization/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"deepspeedfp",
1212
"tpu_int8",
1313
"fp8",
14+
"ptpc_fp8",
1415
"fbgemm_fp8",
1516
"modelopt",
1617
# The order of gptq methods is important for config.py iteration over
@@ -99,6 +100,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
99100
from .modelopt import ModelOptFp8Config
100101
from .moe_wna16 import MoeWNA16Config
101102
from .neuron_quant import NeuronQuantConfig
103+
from .ptpc_fp8 import PTPCFp8Config
102104
from .qqq import QQQConfig
103105
from .tpu_int8 import Int8TpuConfig
104106

@@ -120,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
120122
"gptq": GPTQConfig,
121123
"compressed-tensors": CompressedTensorsConfig,
122124
"bitsandbytes": BitsAndBytesConfig,
125+
"ptpc_fp8": PTPCFp8Config,
123126
"qqq": QQQConfig,
124127
"hqq": HQQMarlinConfig,
125128
"experts_int8": ExpertsInt8Config,

0 commit comments

Comments
 (0)