Skip to content

Commit 5780121

Browse files
shixiancShixian Cui
andauthored
[Perf] Add swap_ab to SM90 FP8 non-block CUTLASS moe grouped gemm (#20911)
Signed-off-by: Shixian Cui <shixian@amazon.com> Co-authored-by: Shixian Cui <shixian@amazon.com>
1 parent c7d8724 commit 5780121

File tree

4 files changed

+135
-50
lines changed

4 files changed

+135
-50
lines changed

csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,36 @@ struct sm90_fp8_config_default {
2929

3030
template <typename InType, typename OutType,
3131
template <typename, typename, typename> typename Epilogue>
32-
struct sm90_fp8_config_M16 {
33-
// M in [1, 16]
32+
struct sm90_fp8_config_M4 {
33+
// M in [1, 4]
3434
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
3535
using KernelSchedule =
3636
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
3737
using EpilogueSchedule =
3838
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
39-
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
40-
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
39+
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
40+
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
4141

4242
using Cutlass3xGemm =
4343
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
44-
KernelSchedule, EpilogueSchedule>;
44+
KernelSchedule, EpilogueSchedule, true>;
45+
};
46+
47+
template <typename InType, typename OutType,
48+
template <typename, typename, typename> typename Epilogue>
49+
struct sm90_fp8_config_M64 {
50+
// M in (4, 64]
51+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
52+
using KernelSchedule =
53+
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
54+
using EpilogueSchedule =
55+
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
56+
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
57+
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
58+
59+
using Cutlass3xGemm =
60+
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
61+
KernelSchedule, EpilogueSchedule, true>;
4562
};
4663

4764
template <typename InType, typename OutType,
@@ -102,7 +119,9 @@ void run_cutlass_moe_mm_sm90(
102119
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
103120
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
104121
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
105-
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16<
122+
using Cutlass3xGemmM4 = typename sm90_fp8_config_M4<
123+
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
124+
using Cutlass3xGemmM64 = typename sm90_fp8_config_M64<
106125
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
107126
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
108127
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
@@ -111,18 +130,24 @@ void run_cutlass_moe_mm_sm90(
111130
uint32_t const n = out_tensors.size(1);
112131
uint32_t const k = a_tensors.size(1);
113132

114-
if (n >= 8192) {
115-
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
133+
// Use swap_ab for M <= 64 by default to reduce padding
134+
if (m <= 4) {
135+
cutlass_group_gemm_caller<Cutlass3xGemmM4>(
116136
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
117137
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
118138
per_out_ch);
119-
} else if (k >= 8192) {
120-
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
139+
} else if (m <= 64) {
140+
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
121141
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
122142
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
123143
per_out_ch);
124-
} else if (m <= 16) {
125-
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
144+
} else if (n >= 8192) {
145+
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
146+
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
147+
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
148+
per_out_ch);
149+
} else if (k >= 8192) {
150+
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
126151
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
127152
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
128153
per_out_ch);

csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,30 @@ using ArchTag = cutlass::arch::Sm90;
2222
using OperatorClass = cutlass::arch::OpClassTensorOp;
2323

2424
using LayoutA = cutlass::layout::RowMajor;
25+
using LayoutA_Transpose =
26+
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
2527
using LayoutB = cutlass::layout::ColumnMajor;
26-
using LayoutC = cutlass::layout::RowMajor;
28+
using LayoutB_Transpose =
29+
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
30+
using LayoutD = cutlass::layout::RowMajor;
31+
using LayoutD_Transpose =
32+
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
33+
using LayoutC = LayoutD;
34+
using LayoutC_Transpose = LayoutD_Transpose;
2735

2836
template <typename ElementAB_, typename ElementC_,
2937
template <typename, typename, typename> typename Epilogue_,
3038
typename TileShape, typename ClusterShape, typename KernelSchedule,
31-
typename EpilogueSchedule>
39+
typename EpilogueSchedule, bool swap_ab_ = false>
3240
struct cutlass_3x_group_gemm {
41+
static constexpr bool swap_ab = swap_ab_;
3342
using ElementAB = ElementAB_;
3443
using ElementC = void;
3544
using ElementD = ElementC_;
3645
using ElementAccumulator = float;
3746

3847
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
3948

40-
using StrideC =
41-
cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
42-
4349
static constexpr int AlignmentAB =
4450
128 / cutlass::sizeof_bits<ElementAB>::value;
4551
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
@@ -50,19 +56,26 @@ struct cutlass_3x_group_gemm {
5056
typename cutlass::epilogue::collective::CollectiveBuilder<
5157
ArchTag, OperatorClass, TileShape, ClusterShape,
5258
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
53-
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
54-
LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
59+
ElementAccumulator, ElementC,
60+
conditional_t<swap_ab, LayoutC_Transpose*, LayoutC*>, AlignmentC,
61+
ElementD, conditional_t<swap_ab, LayoutD_Transpose*, LayoutD*>,
62+
AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
5563

5664
static constexpr size_t CEStorageSize =
5765
sizeof(typename CollectiveEpilogue::SharedStorage);
5866
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
5967
static_cast<int>(CEStorageSize)>;
6068

61-
using CollectiveMainloop =
69+
using CollectiveMainloop = conditional_t<
70+
swap_ab,
71+
typename cutlass::gemm::collective::CollectiveBuilder<
72+
ArchTag, OperatorClass, ElementAB, LayoutB_Transpose*, AlignmentAB,
73+
ElementAB, LayoutA_Transpose*, AlignmentAB, ElementAccumulator,
74+
TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp,
6275
typename cutlass::gemm::collective::CollectiveBuilder<
6376
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
6477
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
65-
Stages, KernelSchedule>::CollectiveOp;
78+
Stages, KernelSchedule>::CollectiveOp>;
6679

6780
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
6881
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
@@ -78,12 +91,12 @@ void cutlass_group_gemm_caller(
7891
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
7992
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
8093
bool per_act_token, bool per_out_ch) {
94+
static constexpr bool swap_ab = Gemm::swap_ab;
95+
8196
using ElementAB = typename Gemm::ElementAB;
8297
using ElementD = typename Gemm::ElementD;
8398

8499
int num_experts = static_cast<int>(expert_offsets.size(0));
85-
int k_size = a_tensors.size(1);
86-
int n_size = out_tensors.size(1);
87100

88101
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
89102

@@ -110,19 +123,35 @@ void cutlass_group_gemm_caller(
110123
problem_sizes.data_ptr());
111124
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
112125

113-
typename GemmKernel::MainloopArguments mainloop_args{
114-
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
115-
static_cast<StrideA*>(a_strides.data_ptr()),
116-
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
117-
static_cast<StrideB*>(b_strides.data_ptr())};
126+
typename GemmKernel::MainloopArguments mainloop_args;
127+
if constexpr (swap_ab) {
128+
mainloop_args = typename GemmKernel::MainloopArguments{
129+
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
130+
static_cast<StrideB*>(b_strides.data_ptr()),
131+
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
132+
static_cast<StrideA*>(a_strides.data_ptr())};
133+
} else {
134+
mainloop_args = typename GemmKernel::MainloopArguments{
135+
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
136+
static_cast<StrideA*>(a_strides.data_ptr()),
137+
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
138+
static_cast<StrideB*>(b_strides.data_ptr())};
139+
}
118140

119141
// Currently, we are only able to do broadcast on either all or none a_scales
120142
// and on either all or none b_scales
121143
typename GemmKernel::EpilogueArguments epilogue_args{
122144
Gemm::Epilogue::prepare_args(
123-
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
124-
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
125-
per_act_token, per_out_ch),
145+
swap_ab ? static_cast<const ElementAccumulator**>(
146+
b_scales_ptrs.data_ptr())
147+
: static_cast<const ElementAccumulator**>(
148+
a_scales_ptrs.data_ptr()),
149+
swap_ab ? static_cast<const ElementAccumulator**>(
150+
a_scales_ptrs.data_ptr())
151+
: static_cast<const ElementAccumulator**>(
152+
b_scales_ptrs.data_ptr()),
153+
swap_ab ? per_out_ch : per_act_token,
154+
swap_ab ? per_act_token : per_out_ch),
126155
nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
127156
static_cast<ElementD**>(out_ptrs.data_ptr()),
128157
static_cast<StrideC*>(c_strides.data_ptr())};

csrc/quantization/cutlass_w8a8/moe/moe_data.cu

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
#include <iostream>
77

88
constexpr uint64_t THREADS_PER_EXPERT = 512;
9+
// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90()
10+
constexpr int SWAP_AB_THRESHOLD = 64;
911

12+
template <bool SWAP_AB>
1013
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
1114
int32_t* problem_sizes1,
1215
int32_t* problem_sizes2,
@@ -24,40 +27,53 @@ __global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
2427

2528
if (threadIdx.x == 0) {
2629
int final_occurrences = atomic_buffer[expert_id];
27-
problem_sizes1[expert_id * 3] = final_occurrences;
28-
problem_sizes1[expert_id * 3 + 1] = 2 * n;
29-
problem_sizes1[expert_id * 3 + 2] = k;
30-
problem_sizes2[expert_id * 3] = final_occurrences;
31-
problem_sizes2[expert_id * 3 + 1] = k;
32-
problem_sizes2[expert_id * 3 + 2] = n;
30+
if constexpr (!SWAP_AB) {
31+
problem_sizes1[expert_id * 3] = final_occurrences;
32+
problem_sizes1[expert_id * 3 + 1] = 2 * n;
33+
problem_sizes1[expert_id * 3 + 2] = k;
34+
problem_sizes2[expert_id * 3] = final_occurrences;
35+
problem_sizes2[expert_id * 3 + 1] = k;
36+
problem_sizes2[expert_id * 3 + 2] = n;
37+
} else {
38+
problem_sizes1[expert_id * 3] = 2 * n;
39+
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
40+
problem_sizes1[expert_id * 3 + 2] = k;
41+
problem_sizes2[expert_id * 3] = k;
42+
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
43+
problem_sizes2[expert_id * 3 + 2] = n;
44+
}
3345
}
3446
}
3547

3648
__global__ void compute_expert_offsets(
3749
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
38-
int32_t* atomic_buffer, const int num_experts) {
50+
int32_t* atomic_buffer, const int num_experts, const int topk_length) {
3951
int32_t tot_offset = 0;
4052
expert_offsets[0] = 0;
4153
for (int i = 0; i < num_experts; ++i) {
4254
atomic_buffer[i] = tot_offset;
43-
tot_offset += problem_sizes1[i * 3];
55+
tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3]
56+
: problem_sizes1[i * 3 + 1];
4457
expert_offsets[i + 1] = tot_offset;
4558
}
4659
}
4760

4861
__global__ void compute_expert_blockscale_offsets(
4962
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
50-
int32_t* blockscale_offsets, int32_t* atomic_buffer,
51-
const int num_experts) {
63+
int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
64+
const int topk_length) {
5265
int32_t tot_offset = 0;
5366
int32_t tot_offset_round = 0;
5467
expert_offsets[0] = 0;
5568
blockscale_offsets[0] = 0;
5669
for (int i = 0; i < num_experts; ++i) {
70+
int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
71+
? problem_sizes1[i * 3]
72+
: problem_sizes1[i * 3 + 1];
5773
atomic_buffer[i] = tot_offset;
58-
tot_offset += problem_sizes1[i * 3];
74+
tot_offset += cur_offset;
5975
expert_offsets[i + 1] = tot_offset;
60-
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128;
76+
tot_offset_round += (cur_offset + (128 - 1)) / 128 * 128;
6177
blockscale_offsets[i + 1] = tot_offset_round;
6278
}
6379
}
@@ -102,22 +118,36 @@ void get_cutlass_moe_mm_data_caller(
102118
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
103119

104120
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
105-
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
106-
static_cast<const int32_t*>(topk_ids.data_ptr()),
107-
static_cast<int32_t*>(problem_sizes1.data_ptr()),
108-
static_cast<int32_t*>(problem_sizes2.data_ptr()),
109-
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
121+
122+
if (topk_ids.numel() > SWAP_AB_THRESHOLD) {
123+
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
124+
static_cast<const int32_t*>(topk_ids.data_ptr()),
125+
static_cast<int32_t*>(problem_sizes1.data_ptr()),
126+
static_cast<int32_t*>(problem_sizes2.data_ptr()),
127+
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
128+
k);
129+
} else {
130+
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
131+
static_cast<const int32_t*>(topk_ids.data_ptr()),
132+
static_cast<int32_t*>(problem_sizes1.data_ptr()),
133+
static_cast<int32_t*>(problem_sizes2.data_ptr()),
134+
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
135+
k);
136+
}
137+
110138
if (blockscale_offsets.has_value()) {
111139
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
112140
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
113141
static_cast<int32_t*>(expert_offsets.data_ptr()),
114142
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
115-
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
143+
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
144+
topk_ids.numel());
116145
} else {
117146
compute_expert_offsets<<<1, 1, 0, stream>>>(
118147
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
119148
static_cast<int32_t*>(expert_offsets.data_ptr()),
120-
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
149+
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
150+
topk_ids.numel());
121151
}
122152
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
123153
static_cast<const int32_t*>(topk_ids.data_ptr()),

tests/kernels/moe/test_cutlass_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
(2, 1024, 1536),
2626
(2, 3072, 1024),
2727
(2, 3072, 1536),
28+
(7, 3072, 1536),
2829
(64, 1024, 1024),
2930
(64, 1024, 1536),
3031
(64, 3072, 1024),

0 commit comments

Comments
 (0)