Skip to content

Commit 0b60689

Browse files
Xrekibaoqiwen
authored andcommitted
Implement FastInterleavedAndBiasedNumericArrayConverter for wint2.
Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca
1 parent 487d643 commit 0b60689

File tree

6 files changed

+409
-808
lines changed

6 files changed

+409
-808
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "cutlass_extensions/arch/mma.h"
2121
#include "cutlass_extensions/interleaved_numeric_conversion.h"
22+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
2223
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
2324

2425
namespace cutlass {
@@ -156,13 +157,16 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
156157
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
157158
AccessTypeB>;
158159

160+
using TransformBAfterLDS = FastInterleavedAndBiasedNumericArrayConverter<
161+
ElementA, ElementB, MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
162+
159163
// Define the threadblock-scoped multistage matrix multiply
160164
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
161165
typename MmaCore::Shape,
162166
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
163167
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
164168
ElementAccumulator, layout::RowMajor,
165-
typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
169+
typename MmaCore::MmaPolicy, kStages, TransformBAfterLDS, SharedMemoryClear>;
166170
};
167171

168172
} // namespace threadblock

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ class Wint2xMmaBase {
9393
static int const kWarpGemmIterations =
9494
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
9595

96+
/// Number of warp-level GEMM oeprations per load for B
97+
static constexpr int kWarpGemmIterationsPerLoadForB =
98+
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
99+
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
100+
101+
static constexpr int kWarpLoadIterationsForB =
102+
kWarpGemmIterations / kWarpGemmIterationsPerLoadForB;
103+
104+
96105
/// Number of stages
97106
static int const kStages = Stages;
98107

@@ -131,16 +140,16 @@ class Wint2xMmaBase {
131140
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
132141
Shape::kN + Policy::SmemPaddingB::kColumn>;
133142

134-
// w uint8; local_scale uint8;
135-
constexpr static int kZippedRowsPerStages = Shape::kK / 4 + (Shape::kK + 127) / 128;
143+
// local_scale uint4
144+
constexpr static int kGroupWiseParamRows = Shape::kK / 64;
145+
146+
using GroupWiseParamShapeB = MatrixShape<kGroupWiseParamRows * kStages, Shape::kN>;
136147

137148
// code_scale float; code_zp float; super_scale ElementB
138-
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
149+
constexpr static int kColumnWiseParamRows = 2 * sizeof(float) +
139150
sizeof_bits<typename Operator::ElementB>::value / 8;
140151

141-
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
142-
143-
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
152+
using ColumnWiseParamShapeB = MatrixShape<kColumnWiseParamRows, Shape::kN>;
144153

145154
public:
146155
//
@@ -153,12 +162,11 @@ class Wint2xMmaBase {
153162
/// Buffer for B operand
154163
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
155164

156-
/// Buffer for quanted B operand
157-
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
165+
/// Buffer for local_scale of B operand
166+
AlignedBuffer<uint4b_t, GroupWiseParamShapeB::kCount> operand_local_scale_B;
158167

159-
/// Buffer for unzip B operand
160-
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
161-
operand_unzip_B;
168+
/// Buffer for column-wise params of B operand
169+
AlignedBuffer<uint8_t, ColumnWiseParamShapeB::kCount> operand_column_wise_B;
162170

163171
public:
164172
//
@@ -188,14 +196,6 @@ class Wint2xMmaBase {
188196
TensorRefB operand_B_ref() {
189197
return TensorRefB{operand_B.data(), LayoutB()};
190198
}
191-
192-
CUTLASS_HOST_DEVICE
193-
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
194-
195-
CUTLASS_HOST_DEVICE
196-
typename Operator::ElementB *operand_unzip_B_ptr() {
197-
return operand_unzip_B.data();
198-
}
199199
};
200200

201201
protected:

0 commit comments

Comments
 (0)