Skip to content

Commit c791a85

Browse files
gshtrascharlifu
andauthored
Cherry pick skinny gemms (#544)
* enable skinny gemm for bs4 Signed-off-by: charlifu <charlifu@amd.com> * add cache to on_mi250_mi300 Signed-off-by: charlifu <charlifu@amd.com> --------- Signed-off-by: charlifu <charlifu@amd.com> Co-authored-by: charlifu <charlifu@amd.com>
1 parent b526478 commit c791a85

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

vllm/model_executor/layers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def rocm_unquantized_gemm(x: torch.Tensor,
8484
m = weight.shape[0]
8585
cu_count = current_platform.get_cu_count()
8686

87-
if m > 8 and 0 < n < 4:
87+
if m > 8 and 0 < n <= 4:
8888
out = ops.wvSplitK(weight, x_view, cu_count)
8989
return out.view(*x.shape[:-1], weight.shape[0])
9090
elif m % 4 == 0 and n == 1 and k <= 8192:

vllm/platforms/rocm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
104104
return device_id
105105

106106

107+
@cache
107108
def on_mi250_mi300() -> bool:
108109
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
109110
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"])

0 commit comments

Comments
 (0)