@@ -29,40 +29,40 @@ struct sm100_fp8_config_default {
29
29
template <typename InType, typename OutType,
30
30
template <typename , typename , typename > typename Epilogue>
31
31
struct sm100_fp8_config_M256 {
32
- // M in (128 , 256]
32
+ // M in (64 , 256]
33
33
static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
34
34
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
35
35
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
36
36
using TileShape = Shape<_128, _128, _128>;
37
- using ClusterShape = Shape<_2, _2 , _1>;
37
+ using ClusterShape = Shape<_2, _1 , _1>;
38
38
using Cutlass3xGemm =
39
39
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
40
40
KernelSchedule, EpilogueSchedule>;
41
41
};
42
42
43
43
template <typename InType, typename OutType,
44
44
template <typename , typename , typename > typename Epilogue>
45
- struct sm100_fp8_config_M128 {
46
- // M in (64, 128 ]
45
+ struct sm100_fp8_config_M64 {
46
+ // M in (16, 64 ]
47
47
static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
48
48
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
49
49
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
50
- using TileShape = Shape<_128, _128, _256 >;
51
- using ClusterShape = Shape<_2, _4 , _1>;
50
+ using TileShape = Shape<_64, _64, _128 >;
51
+ using ClusterShape = Shape<_1, _1 , _1>;
52
52
using Cutlass3xGemm =
53
53
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
54
54
KernelSchedule, EpilogueSchedule>;
55
55
};
56
56
57
57
template <typename InType, typename OutType,
58
58
template <typename , typename , typename > typename Epilogue>
59
- struct sm100_fp8_config_M64 {
60
- // M in [1, 64 ]
59
+ struct sm100_fp8_config_M16 {
60
+ // M in [1, 16 ]
61
61
static_assert (std::is_same<InType, cutlass::float_e4m3_t >());
62
62
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
63
63
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
64
- using TileShape = Shape<_64, _64, _256 >;
65
- using ClusterShape = Shape<_1, _8 , _1>;
64
+ using TileShape = Shape<_64, _64, _128 >;
65
+ using ClusterShape = Shape<_1, _4 , _1>;
66
66
using Cutlass3xGemm =
67
67
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
68
68
KernelSchedule, EpilogueSchedule>;
@@ -82,27 +82,27 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
82
82
using Cutlass3xGemmDefault =
83
83
typename sm100_fp8_config_default<InType, OutType,
84
84
Epilogue>::Cutlass3xGemm;
85
+ using Cutlass3xGemmM16 =
86
+ typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
85
87
using Cutlass3xGemmM64 =
86
88
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
87
- using Cutlass3xGemmM128 =
88
- typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
89
89
using Cutlass3xGemmM256 =
90
90
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
91
91
92
92
uint32_t const m = a.size (0 );
93
93
uint32_t const mp2 =
94
- std::max (static_cast <uint32_t >(64 ), next_pow_2 (m)); // next power of 2
94
+ std::max (static_cast <uint32_t >(16 ), next_pow_2 (m)); // next power of 2
95
95
96
- if (mp2 <= 64 ) {
97
- // m in [1, 64 ]
98
- return cutlass_gemm_caller<Cutlass3xGemmM64 >(
96
+ if (mp2 <= 16 ) {
97
+ // m in [1, 16 ]
98
+ return cutlass_gemm_caller<Cutlass3xGemmM16 >(
99
99
out, a, b, std::forward<EpilogueArgs>(args)...);
100
- } else if (mp2 <= 128 ) {
101
- // m in (64, 128 ]
102
- return cutlass_gemm_caller<Cutlass3xGemmM128 >(
100
+ } else if (mp2 <= 64 ) {
101
+ // m in (16, 64 ]
102
+ return cutlass_gemm_caller<Cutlass3xGemmM64 >(
103
103
out, a, b, std::forward<EpilogueArgs>(args)...);
104
104
} else if (mp2 <= 256 ) {
105
- // m in (128 , 256]
105
+ // m in (64 , 256]
106
106
return cutlass_gemm_caller<Cutlass3xGemmM256>(
107
107
out, a, b, std::forward<EpilogueArgs>(args)...);
108
108
} else {
0 commit comments