Skip to content

Commit 97abeb1

Browse files
authored
[feat] enable SM100 CUTLASS block scaled group gemm for smaller batch sizes (#20640)
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
1 parent 34dad19 commit 97abeb1

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -522,16 +522,14 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
522522
return out.to(dtype=out_dtype)
523523

524524

525-
def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor,
526-
w1: torch.Tensor,
525+
def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
527526
w2: torch.Tensor) -> bool:
528527

529-
def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int):
530-
return M >= 128 and N % 128 == 0 and K % 128 == 0
528+
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
529+
return N % 128 == 0 and K % 128 == 0
531530

532-
m = hidden_states.size(0)
533531
_, K, N = w2.size()
534-
if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K):
532+
if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K):
535533
logger.debug(
536534
"CutlassBlockScaledGroupedGemm disabled: unalinged problem size.")
537535
return False

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1180,7 +1180,7 @@ def fused_experts(
11801180
apply_router_weight_on_input=apply_router_weight_on_input,
11811181
)
11821182
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
1183-
and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)):
1183+
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)):
11841184
assert apply_router_weight_on_input is False
11851185
return run_cutlass_block_scaled_fused_experts(
11861186
a=hidden_states,

0 commit comments

Comments
 (0)