Skip to content

Commit 3d184b9

Browse files
djmmossnv-dmoss
andauthored
[feat]: CUTLASS block scaled group gemm for SM100 (vllm-project#19757)
Signed-off-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Duncan Moss <dmoss@nvidia.com>
1 parent 2f35a02 commit 3d184b9

File tree

13 files changed

+726
-30
lines changed

13 files changed

+726
-30
lines changed

CMakeLists.txt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
259259
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
260260

261261
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
262-
set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
262+
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
263263

264264
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
265265
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -615,6 +615,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
615615
"in CUDA target architectures.")
616616
endif()
617617
endif()
618+
619+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
620+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
621+
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
622+
set_gencode_flags_for_srcs(
623+
SRCS "${SRCS}"
624+
CUDA_ARCHS "${SCALED_MM_ARCHS}")
625+
list(APPEND VLLM_EXT_SRC "${SRCS}")
626+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
627+
message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}")
628+
else()
629+
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
630+
message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is "
631+
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
632+
"if you intend on running FP8 quantized MoE models on Blackwell.")
633+
else()
634+
message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found "
635+
"in CUDA target architectures")
636+
endif()
637+
endif()
618638

619639
#
620640
# Machete kernels

csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
#include "cute/algorithm/functional.hpp"
4646
#include "cute/atom/mma_atom.hpp"
4747
#include "cute/algorithm/gemm.hpp"
48-
#include "cute/tensor_predicate.hpp"
4948
#include "cute/numeric/arithmetic_tuple.hpp"
5049

5150
#include "cutlass_extensions/gemm/dispatch_policy.hpp"

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ void cutlass_moe_mm(
239239
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
240240
bool per_act_token, bool per_out_ch);
241241

242+
void cutlass_blockwise_scaled_grouped_mm(
243+
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
244+
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
245+
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets);
246+
242247
void cutlass_fp4_group_mm(
243248
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
244249
const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales,

csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ struct cutlass_3x_gemm {
5151
// These are the minimum alignments needed for the kernels to compile
5252
static constexpr int AlignmentAB =
5353
128 / cutlass::sizeof_bits<ElementAB>::value;
54-
static constexpr int AlignmentCD = 4;
54+
static constexpr int AlignmentCD =
55+
128 / cutlass::sizeof_bits<ElementD>::value;
5556

5657
using CollectiveEpilogue =
5758
typename cutlass::epilogue::collective::CollectiveBuilder<

0 commit comments

Comments
 (0)