Skip to content

Commit 44d2e6a

Browse files
authored
[Bugfix] Build moe_data for both sm100 and sm90 (#20086)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 2d7779f commit 44d2e6a

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

CMakeLists.txt

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
513513
CUDA_ARCHS "${FP4_ARCHS}")
514514
list(APPEND VLLM_EXT_SRC "${SRCS}")
515515
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1")
516+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
516517
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
517518
else()
518519
message(STATUS "Not building NVFP4 as no compatible archs were found.")
@@ -547,8 +548,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
547548
# if it's possible to compile MoE kernels that use its output.
548549
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
549550
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
550-
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
551-
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
551+
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
552552
set_gencode_flags_for_srcs(
553553
SRCS "${SRCS}"
554554
CUDA_ARCHS "${SCALED_MM_ARCHS}")
@@ -566,6 +566,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
566566
endif()
567567
endif()
568568

569+
# moe_data.cu is used by all CUTLASS MoE kernels.
570+
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
571+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
572+
set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
573+
set_gencode_flags_for_srcs(
574+
SRCS "${SRCS}"
575+
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
576+
list(APPEND VLLM_EXT_SRC "${SRCS}")
577+
endif()
578+
569579
#
570580
# Machete kernels
571581

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ void get_cutlass_moe_mm_data(
241241
// mm to run it for.
242242
int32_t version_num = get_sm_version_num();
243243
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
244-
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
244+
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
245245
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
246246
problem_sizes2, input_permutation,
247247
output_permutation, num_experts, n, k,
@@ -252,7 +252,7 @@ void get_cutlass_moe_mm_data(
252252
false,
253253
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
254254
"CUDA device capability: ",
255-
version_num, ". Required capability: 90");
255+
version_num, ". Required capability: 90 or 100");
256256
}
257257

258258
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
@@ -265,7 +265,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
265265
// This function currently gets compiled only if we have a valid cutlass moe
266266
// mm to run it for.
267267
int32_t version_num = get_sm_version_num();
268-
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
268+
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
269+
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
269270
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
270271
problem_sizes2, expert_num_tokens,
271272
num_local_experts, padded_m, n, k);
@@ -275,7 +276,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
275276
false,
276277
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
277278
"for CUDA device capability: ",
278-
version_num, ". Required capability: 90");
279+
version_num, ". Required capability: 90 or 100");
279280
}
280281

281282
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,

0 commit comments

Comments
 (0)