Skip to content

Commit 5d040ac

Browse files
committed
Implement Wint2ParamsAccessor to load extra quant params from global memory.
Change-Id: I042bac28c5df64673f259933cc9764f7b4aced15
1 parent 2bda35d commit 5d040ac

File tree

9 files changed

+888
-743
lines changed

9 files changed

+888
-743
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,65 @@
1818
#pragma once
1919

2020
#include "cutlass_extensions/arch/mma.h"
21-
#include "cutlass_extensions/interleaved_numeric_conversion.h"
2221
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
2322
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
23+
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
2424

2525
namespace cutlass {
2626
namespace gemm {
2727
namespace threadblock {
2828

2929
////////////////////////////////////////////////////////////////////////////////
3030

31+
template <typename ThreadblockShape, typename ElementT, int GroupSize>
32+
struct DefaultQuantParamsIterators {
33+
private:
34+
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
35+
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
36+
37+
static constexpr int kRows =
38+
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
39+
static constexpr int kColumns = ThreadblockShape::kN;
40+
41+
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
42+
layout::PitchLinearShape<kColumns, kRows>,
43+
kColumns / kAlignment, kAlignment>;
44+
45+
public:
46+
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
47+
MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor, 0,
48+
IteratorThreadMap, kAlignment>;
49+
using SmemIterator = Iterator;
50+
51+
//using AccessType = cutlass::Array<ElementT, kAlignment>;
52+
//using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
53+
// MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor,
54+
// 0, IteratorThreadMap, AccessType>;
55+
};
56+
57+
template <typename ThreadblockShape, int GroupSize>
58+
struct DefaultQuantParamsIterators<ThreadblockShape, uint4b_t, GroupSize> {
59+
private:
60+
static constexpr int kAlignment = 128 / sizeof_bits<uint4b_t>::value;
61+
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
62+
63+
static constexpr int kRows =
64+
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
65+
static constexpr int kColumns =
66+
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
67+
68+
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
69+
layout::PitchLinearShape<kColumns, kRows>,
70+
kColumns / kAlignment, kAlignment>;
71+
72+
public:
73+
using Iterator =
74+
cutlass::transform::threadblock::PredicatedTileIterator<
75+
cutlass::MatrixShape<kRows, kColumns>, uint4b_t,
76+
layout::RowMajor, 0, IteratorThreadMap, kAlignment>;
77+
using SmemIterator = Iterator;
78+
};
79+
3180
template <
3281
/// Element type for A matrix operand
3382
typename ElementA_,
@@ -100,7 +149,7 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
100149
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
101150
kStages, Operator, SharedMemoryClear>
102151
{
103-
152+
public:
104153
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
105154
"Element A must be fp16 or bf16");
106155

@@ -110,6 +159,12 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
110159
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
111160
"Mma multistage must dequantize after ldsm");
112161

162+
using ElementSuperScale = ElementA;
163+
using ElementLocalScale = uint4b_t;
164+
using ElementCodeScaleZp = float;
165+
166+
static constexpr int kGroupSize = 64;
167+
113168
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
114169
? cutlass::arch::CacheOperation::Global
115170
: cutlass::arch::CacheOperation::Always;
@@ -157,16 +212,36 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
157212
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
158213
AccessTypeB>;
159214

160-
using TransformBAfterLDS = FastInterleavedAndBiasedNumericArrayConverter<
161-
ElementA, ElementB, MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
215+
private:
216+
// Define iterators over tiles from extra quant params for B operand
217+
using IteratorSuperScale = typename DefaultQuantParamsIterators<
218+
ThreadblockShape, ElementSuperScale, -1>::Iterator;
219+
using SmemIteratorSuperScale = typename DefaultQuantParamsIterators<
220+
ThreadblockShape, ElementSuperScale, -1>::SmemIterator;
221+
222+
using IteratorLocalScale = typename DefaultQuantParamsIterators<
223+
ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator;
224+
using SmemIteratorLocalScale = typename DefaultQuantParamsIterators<
225+
ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator;
226+
227+
using IteratorCodeScaleZp = typename DefaultQuantParamsIterators<
228+
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
229+
using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators<
230+
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
231+
232+
public:
233+
using QuantParamsAccessor = Wint2ParamsAccessor<
234+
ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale,
235+
IteratorLocalScale, SmemIteratorLocalScale,
236+
IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>;
162237

163238
// Define the threadblock-scoped multistage matrix multiply
164239
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
165240
typename MmaCore::Shape,
166241
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
167242
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
168-
ElementAccumulator, layout::RowMajor,
169-
typename MmaCore::MmaPolicy, kStages, TransformBAfterLDS, SharedMemoryClear>;
243+
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy,
244+
kStages, QuantParamsAccessor, SharedMemoryClear>;
170245
};
171246

172247
} // namespace threadblock

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ template <
6363
typename Policy_,
6464
/// Number of stages,
6565
int Stages,
66-
/// Used for partial specialization
67-
typename Enable = bool>
66+
/// Size of extra quantized params
67+
typename QuantParamsShape>
6868
class Wint2xMmaBase {
6969
public:
7070
///< Size of the Gemm problem - concept: gemm::GemmShape<>
@@ -101,7 +101,6 @@ class Wint2xMmaBase {
101101
static constexpr int kWarpLoadIterationsForB =
102102
kWarpGemmIterations / kWarpGemmIterationsPerLoadForB;
103103

104-
105104
/// Number of stages
106105
static int const kStages = Stages;
107106

@@ -140,16 +139,8 @@ class Wint2xMmaBase {
140139
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
141140
Shape::kN + Policy::SmemPaddingB::kColumn>;
142141

143-
// local_scale uint4
144-
constexpr static int kGroupWiseParamRows = Shape::kK / 64;
145-
146-
using GroupWiseParamShapeB = MatrixShape<kGroupWiseParamRows * kStages, Shape::kN>;
147-
148-
// code_scale float; code_zp float; super_scale ElementB
149-
constexpr static int kColumnWiseParamRows = 2 * sizeof(float) +
150-
sizeof_bits<typename Operator::ElementB>::value / 8;
151-
152-
using ColumnWiseParamShapeB = MatrixShape<kColumnWiseParamRows, Shape::kN>;
142+
/// Shape of all quant params in shared memory
143+
using QuantParamsShapeB = QuantParamsShape;
153144

154145
public:
155146
//
@@ -162,11 +153,8 @@ class Wint2xMmaBase {
162153
/// Buffer for B operand
163154
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
164155

165-
/// Buffer for local_scale of B operand
166-
AlignedBuffer<uint4b_t, GroupWiseParamShapeB::kCount> operand_local_scale_B;
167-
168-
/// Buffer for column-wise params of B operand
169-
AlignedBuffer<uint8_t, ColumnWiseParamShapeB::kCount> operand_column_wise_B;
156+
/// Buffer for extra quant params of B operand
157+
AlignedBuffer<uint8_t, QuantParamsShapeB::kCount> operand_quant_params_B;
170158

171159
public:
172160
//

0 commit comments

Comments
 (0)