Skip to content

Commit 5923ab9

Browse files
authored
[fix]: disable cutlass block scaled group gemm for EP (#20781)
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
1 parent 0cf893c commit 5923ab9

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm(
201201
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
202202
layout_sfb.data_ptr())};
203203

204-
cutlass::KernelHardwareInfo hw_info;
205-
hw_info.device_id = a_ptrs.get_device();
206-
hw_info.sm_count =
207-
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
208-
hw_info.device_id);
204+
int device_id = a_ptrs.device().index();
205+
static const cutlass::KernelHardwareInfo hw_info{
206+
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
207+
device_id)};
209208

210209
// Epilogue Arguments
211210
typename GemmKernel::EpilogueArguments epilogue_args{

vllm/model_executor/layers/fused_moe/cutlass_moe.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,10 @@ def cutlass_moe_fp4(a: torch.Tensor,
553553
return out.to(dtype=out_dtype)
554554

555555

556-
def _valid_cutlass_block_scaled_grouped_gemm(w1: torch.Tensor,
557-
w2: torch.Tensor) -> bool:
556+
def _valid_cutlass_block_scaled_grouped_gemm(
557+
w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str,
558+
apply_router_weight_on_input: bool,
559+
expert_map: Optional[torch.Tensor]) -> bool:
558560

559561
def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
560562
return N % 128 == 0 and K % 128 == 0
@@ -570,6 +572,29 @@ def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int):
570572
"CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).")
571573
return False
572574

575+
if expert_map is not None:
576+
logger.debug(
577+
"CutlassBlockScaledGroupedGemm disabled: expert_parallel is"
578+
" not supported.")
579+
return False
580+
581+
if activation != "silu":
582+
logger.debug(
583+
"CutlassBlockScaledGroupedGemm disabled: only activation silu is"
584+
" supported.")
585+
return False
586+
587+
if apply_router_weight_on_input:
588+
logger.debug("CutlassBlockScaledGroupedGemm disabled:"
589+
" apply_router_weight_on_input is not supported.")
590+
return False
591+
592+
if inplace:
593+
logger.debug(
594+
"CutlassBlockScaledGroupedGemm disabled: inplace is not supported."
595+
)
596+
return False
597+
573598
return True
574599

575600

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,8 +1192,9 @@ def fused_experts(
11921192
apply_router_weight_on_input=apply_router_weight_on_input,
11931193
)
11941194
elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8
1195-
and _valid_cutlass_block_scaled_grouped_gemm(w1, w2)):
1196-
assert apply_router_weight_on_input is False
1195+
and _valid_cutlass_block_scaled_grouped_gemm(
1196+
w1, w2, inplace, activation, apply_router_weight_on_input,
1197+
expert_map)):
11971198
return run_cutlass_block_scaled_fused_experts(
11981199
a=hidden_states,
11991200
w1=w1,

0 commit comments

Comments
 (0)