18
18
#pragma once
19
19
20
20
#include " cutlass_extensions/arch/mma.h"
21
- #include " cutlass_extensions/interleaved_numeric_conversion.h"
22
21
#include " cutlass_extensions/gemm/threadblock/default_dq_mma.h"
23
22
#include " cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
23
+ #include " cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
24
24
25
25
namespace cutlass {
26
26
namespace gemm {
27
27
namespace threadblock {
28
28
29
29
// //////////////////////////////////////////////////////////////////////////////
30
30
31
+ template <typename ThreadblockShape, typename ElementT, int GroupSize>
32
+ struct DefaultQuantParamsIterators {
33
+ private:
34
+ static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
35
+ static_assert ((ThreadblockShape::kN % kAlignment ) == 0 , " " );
36
+
37
+ static constexpr int kRows =
38
+ (GroupSize == -1 ) ? 1 : (ThreadblockShape::kK + GroupSize - 1 ) / GroupSize;
39
+ static constexpr int kColumns = ThreadblockShape::kN ;
40
+
41
+ using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
42
+ layout::PitchLinearShape<kColumns , kRows >,
43
+ kColumns / kAlignment , kAlignment >;
44
+
45
+ public:
46
+ using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
47
+ MatrixShape<kRows , kColumns >, ElementT, layout::RowMajor, 0 ,
48
+ IteratorThreadMap, kAlignment >;
49
+ using SmemIterator = Iterator;
50
+
51
+ // using AccessType = cutlass::Array<ElementT, kAlignment>;
52
+ // using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
53
+ // MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor,
54
+ // 0, IteratorThreadMap, AccessType>;
55
+ };
56
+
57
+ template <typename ThreadblockShape, int GroupSize>
58
+ struct DefaultQuantParamsIterators <ThreadblockShape, uint4b_t , GroupSize> {
59
+ private:
60
+ static constexpr int kAlignment = 128 / sizeof_bits<uint4b_t >::value;
61
+ static_assert ((ThreadblockShape::kN % kAlignment ) == 0 , " " );
62
+
63
+ static constexpr int kRows =
64
+ (GroupSize == -1 ) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1 ) / (2 * GroupSize);
65
+ static constexpr int kColumns =
66
+ (GroupSize == -1 ) ? ThreadblockShape::kN : ThreadblockShape::kN * 2 ;
67
+
68
+ using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
69
+ layout::PitchLinearShape<kColumns , kRows >,
70
+ kColumns / kAlignment , kAlignment >;
71
+
72
+ public:
73
+ using Iterator =
74
+ cutlass::transform::threadblock::PredicatedTileIterator<
75
+ cutlass::MatrixShape<kRows , kColumns >, uint4b_t ,
76
+ layout::RowMajor, 0 , IteratorThreadMap, kAlignment >;
77
+ using SmemIterator = Iterator;
78
+ };
79
+
31
80
template <
32
81
// / Element type for A matrix operand
33
82
typename ElementA_,
@@ -100,7 +149,7 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
100
149
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
101
150
kStages , Operator, SharedMemoryClear>
102
151
{
103
-
152
+ public:
104
153
static_assert (platform::is_same<ElementA, half_t >::value || platform::is_same<ElementA, bfloat16_t >::value,
105
154
" Element A must be fp16 or bf16" );
106
155
@@ -110,6 +159,12 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
110
159
static_assert (platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
111
160
" Mma multistage must dequantize after ldsm" );
112
161
162
+ using ElementSuperScale = ElementA;
163
+ using ElementLocalScale = uint4b_t ;
164
+ using ElementCodeScaleZp = float ;
165
+
166
+ static constexpr int kGroupSize = 64 ;
167
+
113
168
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA ) == 128 )
114
169
? cutlass::arch::CacheOperation::Global
115
170
: cutlass::arch::CacheOperation::Always;
@@ -157,16 +212,36 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
157
212
IteratorShapeB, ElementB, layout::ColumnMajor, 0 , InterleavedThreadMapB,
158
213
AccessTypeB>;
159
214
160
- using TransformBAfterLDS = FastInterleavedAndBiasedNumericArrayConverter<
161
- ElementA, ElementB, MmaCore::MmaPolicy::Operator::FragmentB::kElements >;
215
+ private:
216
+ // Define iterators over tiles from extra quant params for B operand
217
+ using IteratorSuperScale = typename DefaultQuantParamsIterators<
218
+ ThreadblockShape, ElementSuperScale, -1 >::Iterator;
219
+ using SmemIteratorSuperScale = typename DefaultQuantParamsIterators<
220
+ ThreadblockShape, ElementSuperScale, -1 >::SmemIterator;
221
+
222
+ using IteratorLocalScale = typename DefaultQuantParamsIterators<
223
+ ThreadblockShape, ElementLocalScale, kGroupSize >::Iterator;
224
+ using SmemIteratorLocalScale = typename DefaultQuantParamsIterators<
225
+ ThreadblockShape, ElementLocalScale, kGroupSize >::SmemIterator;
226
+
227
+ using IteratorCodeScaleZp = typename DefaultQuantParamsIterators<
228
+ ThreadblockShape, ElementCodeScaleZp, -1 >::Iterator;
229
+ using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators<
230
+ ThreadblockShape, ElementCodeScaleZp, -1 >::Iterator;
231
+
232
+ public:
233
+ using QuantParamsAccessor = Wint2ParamsAccessor<
234
+ ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale,
235
+ IteratorLocalScale, SmemIteratorLocalScale,
236
+ IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages , kGroupSize >;
162
237
163
238
// Define the threadblock-scoped multistage matrix multiply
164
239
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
165
240
typename MmaCore::Shape,
166
241
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA ,
167
242
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB ,
168
- ElementAccumulator, layout::RowMajor,
169
- typename MmaCore::MmaPolicy, kStages , TransformBAfterLDS , SharedMemoryClear>;
243
+ ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy,
244
+ kStages , QuantParamsAccessor , SharedMemoryClear>;
170
245
};
171
246
172
247
} // namespace threadblock
0 commit comments