Skip to content

Commit 9e5552a

Browse files
kaln27mgoin
andauthored
[NVIDIA] Support Cutlass w8a8 FP8 for Blackwell Geforce GPUs (sm120) (#17280)
Signed-off-by: kaln27 <liaojuncheng123@foxmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
1 parent 0c600b9 commit 9e5552a

File tree

7 files changed

+238
-1
lines changed

7 files changed

+238
-1
lines changed

CMakeLists.txt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,36 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
420420
endif()
421421
endif()
422422

423+
424+
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
425+
# CUDA 12.8 or later
426+
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
427+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
428+
set(SRCS
429+
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
430+
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
431+
)
432+
set_gencode_flags_for_srcs(
433+
SRCS "${SRCS}"
434+
CUDA_ARCHS "${SCALED_MM_ARCHS}")
435+
list(APPEND VLLM_EXT_SRC "${SRCS}")
436+
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
437+
# Let scaled_mm_c2x know it doesn't need to build these arches
438+
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
439+
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
440+
else()
441+
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
442+
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
443+
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
444+
"later if you intend on running FP8 quantized models on "
445+
"Blackwell.")
446+
else()
447+
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
448+
"in CUDA target architectures")
449+
endif()
450+
endif()
451+
452+
423453
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
424454
# require CUDA 12.8 or later
425455
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")

csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,65 @@ struct cutlass_3x_gemm_sm100 {
144144
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
145145
};
146146

147+
template <typename ElementAB_, typename ElementD_,
148+
template <typename, typename, typename> typename Epilogue_,
149+
typename TileShape, typename ClusterShape, typename KernelSchedule,
150+
typename EpilogueSchedule>
151+
struct cutlass_3x_gemm_sm120 {
152+
using ElementAB = ElementAB_;
153+
using LayoutA = cutlass::layout::RowMajor;
154+
static constexpr int AlignmentA =
155+
128 / cutlass::sizeof_bits<ElementAB>::value;
156+
157+
using LayoutB = cutlass::layout::ColumnMajor;
158+
static constexpr int AlignmentB =
159+
128 / cutlass::sizeof_bits<ElementAB>::value;
160+
161+
using ElementC = void;
162+
using LayoutC = cutlass::layout::RowMajor;
163+
static constexpr int AlignmentC =
164+
128 / cutlass::sizeof_bits<ElementD_>::value;
165+
166+
using ElementD = ElementD_;
167+
using LayoutD = cutlass::layout::RowMajor;
168+
static constexpr int AlignmentD = AlignmentC;
169+
170+
using ElementAcc =
171+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
172+
float>::type;
173+
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
174+
175+
// MMA type
176+
using ElementAccumulator = float;
177+
178+
// Epilogue types
179+
using ElementBias = cutlass::half_t;
180+
using ElementCompute = float;
181+
using ElementAux = ElementD;
182+
using LayoutAux = LayoutD;
183+
using ElementAmax = float;
184+
185+
using EVTCompute = typename Epilogue::EVTCompute;
186+
187+
using CollectiveEpilogue =
188+
typename cutlass::epilogue::collective::CollectiveBuilder<
189+
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
190+
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
191+
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
192+
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
193+
EVTCompute>::CollectiveOp;
194+
195+
using CollectiveMainloop =
196+
typename cutlass::gemm::collective::CollectiveBuilder<
197+
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
198+
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
199+
ElementAccumulator, TileShape, ClusterShape,
200+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
201+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
202+
KernelSchedule>::CollectiveOp;
203+
204+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
205+
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
206+
};
207+
147208
} // namespace vllm

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
3636
torch::Tensor const& b_scales,
3737
std::optional<torch::Tensor> const& bias);
3838

39+
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
40+
torch::Tensor const& b,
41+
torch::Tensor const& a_scales,
42+
torch::Tensor const& b_scales,
43+
std::optional<torch::Tensor> const& bias);
44+
3945
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
4046
torch::Tensor const& a,
4147
torch::Tensor const& b,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "scaled_mm_kernels.hpp"
2+
#include "scaled_mm_sm120_fp8_dispatch.cuh"
3+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
4+
5+
namespace vllm {
6+
7+
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
8+
torch::Tensor const& b,
9+
torch::Tensor const& a_scales,
10+
torch::Tensor const& b_scales,
11+
std::optional<torch::Tensor> const& bias) {
12+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
13+
if (bias) {
14+
TORCH_CHECK(bias->dtype() == out.dtype(),
15+
"currently bias dtype must match output dtype ", out.dtype());
16+
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
17+
out, a, b, a_scales, b_scales, *bias);
18+
} else {
19+
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
20+
out, a, b, a_scales, b_scales);
21+
}
22+
}
23+
24+
} // namespace vllm
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#pragma once
2+
3+
#include "scaled_mm.cuh"
4+
#include "cutlass_gemm_caller.cuh"
5+
6+
/**
7+
* This file defines Gemm kernel configurations for SM120 (fp8) based on the
8+
* Gemm shape.
9+
*/
10+
11+
namespace vllm {
12+
13+
using c3x::cutlass_gemm_caller;
14+
15+
template <typename InType, typename OutType,
16+
template <typename, typename, typename> typename Epilogue>
17+
struct sm120_fp8_config_default {
18+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
19+
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
20+
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
21+
using TileShape = Shape<_128, _128, _128>;
22+
using ClusterShape = Shape<_1, _1, _1>; // Only work with Shape<_1, _1, _1>
23+
using Cutlass3xGemm =
24+
cutlass_3x_gemm_sm120<InType, OutType, Epilogue, TileShape, ClusterShape,
25+
KernelSchedule, EpilogueSchedule>;
26+
};
27+
28+
template <typename InType, typename OutType,
29+
template <typename, typename, typename> typename Epilogue,
30+
typename... EpilogueArgs>
31+
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
32+
torch::Tensor const& a,
33+
torch::Tensor const& b,
34+
EpilogueArgs&&... args) {
35+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
36+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
37+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
38+
39+
using Cutlass3xGemmDefault =
40+
typename sm120_fp8_config_default<InType, OutType,
41+
Epilogue>::Cutlass3xGemm;
42+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
43+
out, a, b, std::forward<EpilogueArgs>(args)...);
44+
}
45+
46+
template <template <typename, typename, typename> typename Epilogue,
47+
typename... EpilogueArgs>
48+
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
49+
torch::Tensor const& a,
50+
torch::Tensor const& b,
51+
EpilogueArgs&&... epilogue_args) {
52+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
53+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
54+
55+
if (out.dtype() == torch::kBFloat16) {
56+
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
57+
cutlass::bfloat16_t, Epilogue>(
58+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
59+
} else {
60+
TORCH_CHECK(out.dtype() == torch::kFloat16);
61+
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
62+
cutlass::half_t, Epilogue>(
63+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
64+
}
65+
}
66+
67+
} // namespace vllm
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <cudaTypedefs.h>
2+
#include "c3x/scaled_mm_kernels.hpp"
3+
4+
#include "cuda_utils.h"
5+
6+
/*
7+
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
8+
NVIDIA GPUs with sm120 (Blackwell Geforce).
9+
*/
10+
11+
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
12+
13+
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
14+
torch::Tensor const& b,
15+
torch::Tensor const& a_scales,
16+
torch::Tensor const& b_scales,
17+
std::optional<torch::Tensor> const& bias) {
18+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
19+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
20+
21+
int M = a.size(0), N = b.size(1), K = a.size(1);
22+
TORCH_CHECK(
23+
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
24+
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
25+
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
26+
27+
// Standard per-tensor/per-token/per-channel scaling
28+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
29+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
30+
"Currently, only fp8 gemm is implemented for Blackwell");
31+
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
32+
}
33+
34+
#endif

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ void cutlass_moe_mm_sm90(
4141

4242
#endif
4343

44+
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
45+
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
46+
torch::Tensor const& b,
47+
torch::Tensor const& a_scales,
48+
torch::Tensor const& b_scales,
49+
std::optional<torch::Tensor> const& bias);
50+
#endif
51+
4452
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
4553
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
4654
torch::Tensor const& b,
@@ -168,8 +176,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
168176
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
169177
int32_t version_num = get_sm_version_num();
170178

179+
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
180+
if (version_num >= 120) {
181+
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
182+
return;
183+
}
184+
#endif
185+
171186
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
172-
if (version_num >= 100) {
187+
if (version_num >= 100 && version_num < 120) {
173188
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
174189
return;
175190
}

0 commit comments

Comments
 (0)