Skip to content

Commit 8f9d57e

Browse files
ilmarkovilmarkov
andauthored
[Perf] SM100 FP8 GEMM Optimizations after cutlass_profiler (vllm-project#20071)
Signed-off-by: ilmarkov <imarkov@redhat.com> Co-authored-by: ilmarkov <imarkov@redhat.com>
1 parent cc6609d commit 8f9d57e

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,40 +29,40 @@ struct sm100_fp8_config_default {
2929
template <typename InType, typename OutType,
3030
template <typename, typename, typename> typename Epilogue>
3131
struct sm100_fp8_config_M256 {
32-
// M in (128, 256]
32+
// M in (64, 256]
3333
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
3434
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
3535
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
3636
using TileShape = Shape<_128, _128, _128>;
37-
using ClusterShape = Shape<_2, _2, _1>;
37+
using ClusterShape = Shape<_2, _1, _1>;
3838
using Cutlass3xGemm =
3939
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
4040
KernelSchedule, EpilogueSchedule>;
4141
};
4242

4343
template <typename InType, typename OutType,
4444
template <typename, typename, typename> typename Epilogue>
45-
struct sm100_fp8_config_M128 {
46-
// M in (64, 128]
45+
struct sm100_fp8_config_M64 {
46+
// M in (16, 64]
4747
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
4848
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
4949
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
50-
using TileShape = Shape<_128, _128, _256>;
51-
using ClusterShape = Shape<_2, _4, _1>;
50+
using TileShape = Shape<_64, _64, _128>;
51+
using ClusterShape = Shape<_1, _1, _1>;
5252
using Cutlass3xGemm =
5353
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
5454
KernelSchedule, EpilogueSchedule>;
5555
};
5656

5757
template <typename InType, typename OutType,
5858
template <typename, typename, typename> typename Epilogue>
59-
struct sm100_fp8_config_M64 {
60-
// M in [1, 64]
59+
struct sm100_fp8_config_M16 {
60+
// M in [1, 16]
6161
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
6262
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
6363
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
64-
using TileShape = Shape<_64, _64, _256>;
65-
using ClusterShape = Shape<_1, _8, _1>;
64+
using TileShape = Shape<_64, _64, _128>;
65+
using ClusterShape = Shape<_1, _4, _1>;
6666
using Cutlass3xGemm =
6767
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
6868
KernelSchedule, EpilogueSchedule>;
@@ -82,27 +82,27 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
8282
using Cutlass3xGemmDefault =
8383
typename sm100_fp8_config_default<InType, OutType,
8484
Epilogue>::Cutlass3xGemm;
85+
using Cutlass3xGemmM16 =
86+
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
8587
using Cutlass3xGemmM64 =
8688
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
87-
using Cutlass3xGemmM128 =
88-
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
8989
using Cutlass3xGemmM256 =
9090
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
9191

9292
uint32_t const m = a.size(0);
9393
uint32_t const mp2 =
94-
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
94+
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
9595

96-
if (mp2 <= 64) {
97-
// m in [1, 64]
98-
return cutlass_gemm_caller<Cutlass3xGemmM64>(
96+
if (mp2 <= 16) {
97+
// m in [1, 16]
98+
return cutlass_gemm_caller<Cutlass3xGemmM16>(
9999
out, a, b, std::forward<EpilogueArgs>(args)...);
100-
} else if (mp2 <= 128) {
101-
// m in (64, 128]
102-
return cutlass_gemm_caller<Cutlass3xGemmM128>(
100+
} else if (mp2 <= 64) {
101+
// m in (16, 64]
102+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
103103
out, a, b, std::forward<EpilogueArgs>(args)...);
104104
} else if (mp2 <= 256) {
105-
// m in (128, 256]
105+
// m in (64, 256]
106106
return cutlass_gemm_caller<Cutlass3xGemmM256>(
107107
out, a, b, std::forward<EpilogueArgs>(args)...);
108108
} else {

0 commit comments

Comments
 (0)