|
8 | 8 | import vllm.envs as envs
|
9 | 9 | from vllm import _custom_ops as ops
|
10 | 10 | from vllm.platforms import current_platform
|
| 11 | +from vllm.utils import direct_register_custom_op |
11 | 12 |
|
12 | 13 | from .cutlass import CutlassScaledMMLinearKernel
|
13 | 14 | from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
14 | 15 |
|
15 | 16 |
|
| 17 | +def rocm_aiter_gemm_w8a8_impl( |
| 18 | + A: torch.Tensor, |
| 19 | + B: torch.Tensor, |
| 20 | + As: torch.Tensor, |
| 21 | + Bs: torch.Tensor, |
| 22 | + bias: Optional[torch.Tensor] = None, |
| 23 | + output_dtype: torch.dtype = torch.float16, |
| 24 | +) -> torch.Tensor: |
| 25 | + |
| 26 | + from aiter import gemm_a8w8_CK |
| 27 | + |
| 28 | + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects |
| 29 | + # a to be [M, K] |
| 30 | + # b to be [N, K] |
| 31 | + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format |
| 32 | + return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) |
| 33 | + |
| 34 | + |
| 35 | +def rocm_aiter_gemm_w8a8_fake( |
| 36 | + A: torch.Tensor, |
| 37 | + B: torch.Tensor, |
| 38 | + As: torch.Tensor, |
| 39 | + Bs: torch.Tensor, |
| 40 | + bias: Optional[torch.Tensor] = None, |
| 41 | + output_dtype: torch.dtype = torch.float16, |
| 42 | +) -> torch.Tensor: |
| 43 | + |
| 44 | + m = A.shape[0] |
| 45 | + n = B.shape[0] |
| 46 | + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) |
| 47 | + return Y |
| 48 | + |
| 49 | + |
| 50 | +if current_platform.is_rocm(): |
| 51 | + direct_register_custom_op( |
| 52 | + op_name="rocm_aiter_gemm_w8a8", |
| 53 | + op_func=rocm_aiter_gemm_w8a8_impl, |
| 54 | + mutates_args=[], |
| 55 | + fake_impl=rocm_aiter_gemm_w8a8_fake, |
| 56 | + dispatch_key=current_platform.dispatch_key, |
| 57 | + ) |
| 58 | + |
| 59 | + |
16 | 60 | class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
17 | 61 |
|
18 | 62 | @classmethod
|
@@ -111,10 +155,9 @@ def apply_weights(self,
|
111 | 155 | " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
|
112 | 156 | "does not support AITER block scaled GEMM.")
|
113 | 157 |
|
114 |
| - from aiter import gemm_a8w8_CK |
115 |
| - |
116 | 158 | # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
117 | 159 | # a to be [M, K]
|
118 | 160 | # b to be [N, K]
|
119 | 161 | # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
120 |
| - return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) |
| 162 | + return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s, |
| 163 | + bias, out_dtype) |
0 commit comments