@@ -22,24 +22,30 @@ using ArchTag = cutlass::arch::Sm90;
22
22
using OperatorClass = cutlass::arch::OpClassTensorOp;
23
23
24
24
using LayoutA = cutlass::layout::RowMajor;
25
+ using LayoutA_Transpose =
26
+ typename cutlass::layout::LayoutTranspose<LayoutA>::type;
25
27
using LayoutB = cutlass::layout::ColumnMajor;
26
- using LayoutC = cutlass::layout::RowMajor;
28
+ using LayoutB_Transpose =
29
+ typename cutlass::layout::LayoutTranspose<LayoutB>::type;
30
+ using LayoutD = cutlass::layout::RowMajor;
31
+ using LayoutD_Transpose =
32
+ typename cutlass::layout::LayoutTranspose<LayoutD>::type;
33
+ using LayoutC = LayoutD;
34
+ using LayoutC_Transpose = LayoutD_Transpose;
27
35
28
36
template <typename ElementAB_, typename ElementC_,
29
37
template <typename , typename , typename > typename Epilogue_,
30
38
typename TileShape, typename ClusterShape, typename KernelSchedule,
31
- typename EpilogueSchedule>
39
+ typename EpilogueSchedule, bool swap_ab_ = false >
32
40
struct cutlass_3x_group_gemm {
41
+ static constexpr bool swap_ab = swap_ab_;
33
42
using ElementAB = ElementAB_;
34
43
using ElementC = void ;
35
44
using ElementD = ElementC_;
36
45
using ElementAccumulator = float ;
37
46
38
47
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
39
48
40
- using StrideC =
41
- cute::remove_pointer_t <cute::Stride<int64_t , cute::Int<1 >, cute::Int<0 >>>;
42
-
43
49
static constexpr int AlignmentAB =
44
50
128 / cutlass::sizeof_bits<ElementAB>::value;
45
51
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
@@ -50,19 +56,26 @@ struct cutlass_3x_group_gemm {
50
56
typename cutlass::epilogue::collective::CollectiveBuilder<
51
57
ArchTag, OperatorClass, TileShape, ClusterShape,
52
58
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
53
- ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
54
- LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
59
+ ElementAccumulator, ElementC,
60
+ conditional_t <swap_ab, LayoutC_Transpose*, LayoutC*>, AlignmentC,
61
+ ElementD, conditional_t <swap_ab, LayoutD_Transpose*, LayoutD*>,
62
+ AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
55
63
56
64
static constexpr size_t CEStorageSize =
57
65
sizeof (typename CollectiveEpilogue::SharedStorage);
58
66
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
59
67
static_cast <int >(CEStorageSize)>;
60
68
61
- using CollectiveMainloop =
69
+ using CollectiveMainloop = conditional_t <
70
+ swap_ab,
71
+ typename cutlass::gemm::collective::CollectiveBuilder<
72
+ ArchTag, OperatorClass, ElementAB, LayoutB_Transpose*, AlignmentAB,
73
+ ElementAB, LayoutA_Transpose*, AlignmentAB, ElementAccumulator,
74
+ TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp,
62
75
typename cutlass::gemm::collective::CollectiveBuilder<
63
76
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
64
77
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
65
- Stages, KernelSchedule>::CollectiveOp;
78
+ Stages, KernelSchedule>::CollectiveOp> ;
66
79
67
80
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
68
81
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
@@ -78,12 +91,12 @@ void cutlass_group_gemm_caller(
78
91
torch::Tensor const & problem_sizes, torch::Tensor const & a_strides,
79
92
torch::Tensor const & b_strides, torch::Tensor const & c_strides,
80
93
bool per_act_token, bool per_out_ch) {
94
+ static constexpr bool swap_ab = Gemm::swap_ab;
95
+
81
96
using ElementAB = typename Gemm::ElementAB;
82
97
using ElementD = typename Gemm::ElementD;
83
98
84
99
int num_experts = static_cast <int >(expert_offsets.size (0 ));
85
- int k_size = a_tensors.size (1 );
86
- int n_size = out_tensors.size (1 );
87
100
88
101
auto stream = at::cuda::getCurrentCUDAStream (a_tensors.device ().index ());
89
102
@@ -110,19 +123,35 @@ void cutlass_group_gemm_caller(
110
123
problem_sizes.data_ptr ());
111
124
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr };
112
125
113
- typename GemmKernel::MainloopArguments mainloop_args{
114
- static_cast <const ElementAB**>(a_ptrs.data_ptr ()),
115
- static_cast <StrideA*>(a_strides.data_ptr ()),
116
- static_cast <const ElementAB**>(b_ptrs.data_ptr ()),
117
- static_cast <StrideB*>(b_strides.data_ptr ())};
126
+ typename GemmKernel::MainloopArguments mainloop_args;
127
+ if constexpr (swap_ab) {
128
+ mainloop_args = typename GemmKernel::MainloopArguments{
129
+ static_cast <const ElementAB**>(b_ptrs.data_ptr ()),
130
+ static_cast <StrideB*>(b_strides.data_ptr ()),
131
+ static_cast <const ElementAB**>(a_ptrs.data_ptr ()),
132
+ static_cast <StrideA*>(a_strides.data_ptr ())};
133
+ } else {
134
+ mainloop_args = typename GemmKernel::MainloopArguments{
135
+ static_cast <const ElementAB**>(a_ptrs.data_ptr ()),
136
+ static_cast <StrideA*>(a_strides.data_ptr ()),
137
+ static_cast <const ElementAB**>(b_ptrs.data_ptr ()),
138
+ static_cast <StrideB*>(b_strides.data_ptr ())};
139
+ }
118
140
119
141
// Currently, we are only able to do broadcast on either all or none a_scales
120
142
// and on either all or none b_scales
121
143
typename GemmKernel::EpilogueArguments epilogue_args{
122
144
Gemm::Epilogue::prepare_args (
123
- static_cast <const ElementAccumulator**>(a_scales_ptrs.data_ptr ()),
124
- static_cast <const ElementAccumulator**>(b_scales_ptrs.data_ptr ()),
125
- per_act_token, per_out_ch),
145
+ swap_ab ? static_cast <const ElementAccumulator**>(
146
+ b_scales_ptrs.data_ptr ())
147
+ : static_cast <const ElementAccumulator**>(
148
+ a_scales_ptrs.data_ptr ()),
149
+ swap_ab ? static_cast <const ElementAccumulator**>(
150
+ a_scales_ptrs.data_ptr ())
151
+ : static_cast <const ElementAccumulator**>(
152
+ b_scales_ptrs.data_ptr ()),
153
+ swap_ab ? per_out_ch : per_act_token,
154
+ swap_ab ? per_act_token : per_out_ch),
126
155
nullptr , static_cast <StrideC*>(c_strides.data_ptr ()),
127
156
static_cast <ElementD**>(out_ptrs.data_ptr ()),
128
157
static_cast <StrideC*>(c_strides.data_ptr ())};
0 commit comments