@@ -241,7 +241,7 @@ void get_cutlass_moe_mm_data(
241
241
// mm to run it for.
242
242
int32_t version_num = get_sm_version_num ();
243
243
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
244
- (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90 )
244
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 )
245
245
get_cutlass_moe_mm_data_caller (topk_ids, expert_offsets, problem_sizes1,
246
246
problem_sizes2, input_permutation,
247
247
output_permutation, num_experts, n, k,
@@ -252,7 +252,7 @@ void get_cutlass_moe_mm_data(
252
252
false ,
253
253
" No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
254
254
" CUDA device capability: " ,
255
- version_num, " . Required capability: 90" );
255
+ version_num, " . Required capability: 90 or 100 " );
256
256
}
257
257
258
258
void get_cutlass_pplx_moe_mm_data (torch::Tensor& expert_offsets,
@@ -265,7 +265,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
265
265
// This function currently gets compiled only if we have a valid cutlass moe
266
266
// mm to run it for.
267
267
int32_t version_num = get_sm_version_num ();
268
- #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
268
+ #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
269
+ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
269
270
get_cutlass_pplx_moe_mm_data_caller (expert_offsets, problem_sizes1,
270
271
problem_sizes2, expert_num_tokens,
271
272
num_local_experts, padded_m, n, k);
@@ -275,7 +276,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
275
276
false ,
276
277
" No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
277
278
" for CUDA device capability: " ,
278
- version_num, " . Required capability: 90" );
279
+ version_num, " . Required capability: 90 or 100 " );
279
280
}
280
281
281
282
void cutlass_scaled_mm_azp (torch::Tensor& c, torch::Tensor const & a,
0 commit comments