Skip to content

Commit bcdfb2a

Browse files
authored
[Bugfix] Fix incorrect dispatch for CutlassBlockScaledGroupedGemm and DeepGEMM (#20933)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent ba8c300 commit bcdfb2a

File tree

1 file changed

+10
-5
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+10
-5
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -488,22 +488,27 @@ def __init__(self, quant_config: Fp8Config):
488488
logger.warning_once("Failed to import DeepGemm kernels.")
489489
elif not self.block_quant:
490490
logger.warning_once("Model is not block quantized. Not using "
491-
" DeepGemm kernels")
491+
"DeepGemm kernels")
492492
elif (current_platform.is_cuda()
493-
and current_platform.has_device_capability(90)):
493+
and current_platform.is_device_capability(90)):
494494
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
495495
self.allow_deep_gemm = True
496+
elif (current_platform.is_cuda()
497+
and is_blackwell_deep_gemm_used()):
498+
logger.info_once("Using DeepGemm SM100 kernels for "
499+
"Fp8MoEMethod.")
500+
self.allow_deep_gemm = True
496501
else:
497502
logger.warning_once(
498503
"DeepGemm not supported on the current platform.")
499504

500505
# Check for CutlassBlockScaledGroupedGemm support.
501506
self.allow_cutlass_block_scaled_grouped_gemm = False
502507
if not self.block_quant:
503-
logger.warning_once("Model is not block quantized. Not using "
504-
"CutlassBlockScaledGroupedGemm kernels")
508+
logger.debug_once("Model is not block quantized. Not using "
509+
"CutlassBlockScaledGroupedGemm kernels")
505510
elif (current_platform.is_cuda()
506-
and current_platform.has_device_capability(100)):
511+
and current_platform.is_device_capability(100)):
507512
logger.info_once(
508513
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
509514
)

0 commit comments

Comments
 (0)