Skip to content

Commit 80d38b8

Browse files
authored
[V1] [ROCm] [AITER] Upgrade AITER to commit 916bf3c and bugfix APIs (#20880)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
1 parent 211b6a6 commit 80d38b8

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

docker/Dockerfile.rocm_base

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1313
ARG FA_BRANCH="1a7f4dfa"
1414
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15-
ARG AITER_BRANCH="6487649"
15+
ARG AITER_BRANCH="916bf3c"
1616
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1717

1818
FROM ${BASE_IMAGE} AS base

vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,55 @@
88
import vllm.envs as envs
99
from vllm import _custom_ops as ops
1010
from vllm.platforms import current_platform
11+
from vllm.utils import direct_register_custom_op
1112

1213
from .cutlass import CutlassScaledMMLinearKernel
1314
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
1415

1516

17+
def rocm_aiter_gemm_w8a8_impl(
18+
A: torch.Tensor,
19+
B: torch.Tensor,
20+
As: torch.Tensor,
21+
Bs: torch.Tensor,
22+
bias: Optional[torch.Tensor] = None,
23+
output_dtype: torch.dtype = torch.float16,
24+
) -> torch.Tensor:
25+
26+
from aiter import gemm_a8w8_CK
27+
28+
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
29+
# a to be [M, K]
30+
# b to be [N, K]
31+
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
32+
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
33+
34+
35+
def rocm_aiter_gemm_w8a8_fake(
36+
A: torch.Tensor,
37+
B: torch.Tensor,
38+
As: torch.Tensor,
39+
Bs: torch.Tensor,
40+
bias: Optional[torch.Tensor] = None,
41+
output_dtype: torch.dtype = torch.float16,
42+
) -> torch.Tensor:
43+
44+
m = A.shape[0]
45+
n = B.shape[0]
46+
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
47+
return Y
48+
49+
50+
if current_platform.is_rocm():
51+
direct_register_custom_op(
52+
op_name="rocm_aiter_gemm_w8a8",
53+
op_func=rocm_aiter_gemm_w8a8_impl,
54+
mutates_args=[],
55+
fake_impl=rocm_aiter_gemm_w8a8_fake,
56+
dispatch_key=current_platform.dispatch_key,
57+
)
58+
59+
1660
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
1761

1862
@classmethod
@@ -111,10 +155,9 @@ def apply_weights(self,
111155
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
112156
"does not support AITER block scaled GEMM.")
113157

114-
from aiter import gemm_a8w8_CK
115-
116158
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
117159
# a to be [M, K]
118160
# b to be [N, K]
119161
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
120-
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)
162+
return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s,
163+
bias, out_dtype)

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def rocm_aiter_gemm_w8a8_blockscale_impl(
5656
) -> torch.Tensor:
5757
import aiter as rocm_aiter
5858

59-
return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype)
59+
return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
6060

6161

6262
def rocm_aiter_gemm_w8a8_blockscale_fake(

0 commit comments

Comments
 (0)