@@ -93,6 +93,15 @@ class Wint2xMmaBase {
93
93
static int const kWarpGemmIterations =
94
94
(WarpGemm::kK / Operator::Policy::MmaShape::kK );
95
95
96
+ // / Number of warp-level GEMM oeprations per load for B
97
+ static constexpr int kWarpGemmIterationsPerLoadForB =
98
+ Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK ;
99
+ static_assert (!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB ), " " );
100
+
101
+ static constexpr int kWarpLoadIterationsForB =
102
+ kWarpGemmIterations / kWarpGemmIterationsPerLoadForB ;
103
+
104
+
96
105
// / Number of stages
97
106
static int const kStages = Stages;
98
107
@@ -131,16 +140,16 @@ class Wint2xMmaBase {
131
140
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow ,
132
141
Shape::kN + Policy::SmemPaddingB::kColumn >;
133
142
134
- // w uint8; local_scale uint8;
135
- constexpr static int kZippedRowsPerStages = Shape::kK / 4 + (Shape::kK + 127 ) / 128 ;
143
+ // local_scale uint4
144
+ constexpr static int kGroupWiseParamRows = Shape::kK / 64 ;
145
+
146
+ using GroupWiseParamShapeB = MatrixShape<kGroupWiseParamRows * kStages , Shape::kN >;
136
147
137
148
// code_scale float; code_zp float; super_scale ElementB
138
- constexpr static int kColumnWiseParamsRows = 2 * sizeof (float ) +
149
+ constexpr static int kColumnWiseParamRows = 2 * sizeof (float ) +
139
150
sizeof_bits<typename Operator::ElementB>::value / 8 ;
140
151
141
- using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages , Shape::kN >;
142
-
143
- using NopaddingShapeB = MatrixShape<Shape::kK , Shape::kN >;
152
+ using ColumnWiseParamShapeB = MatrixShape<kColumnWiseParamRows , Shape::kN >;
144
153
145
154
public:
146
155
//
@@ -153,12 +162,11 @@ class Wint2xMmaBase {
153
162
// / Buffer for B operand
154
163
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount > operand_B;
155
164
156
- // / Buffer for quanted B operand
157
- AlignedBuffer<uint8_t , ZippedShapeB ::kCount > operand_zipped_B ;
165
+ // / Buffer for local_scale of B operand
166
+ AlignedBuffer<uint4b_t , GroupWiseParamShapeB ::kCount > operand_local_scale_B ;
158
167
159
- // / Buffer for unzip B operand
160
- AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount >
161
- operand_unzip_B;
168
+ // / Buffer for column-wise params of B operand
169
+ AlignedBuffer<uint8_t , ColumnWiseParamShapeB::kCount > operand_column_wise_B;
162
170
163
171
public:
164
172
//
@@ -188,14 +196,6 @@ class Wint2xMmaBase {
188
196
TensorRefB operand_B_ref () {
189
197
return TensorRefB{operand_B.data (), LayoutB ()};
190
198
}
191
-
192
- CUTLASS_HOST_DEVICE
193
- uint8_t *operand_zipped_B_ptr () { return operand_zipped_B.data (); }
194
-
195
- CUTLASS_HOST_DEVICE
196
- typename Operator::ElementB *operand_unzip_B_ptr () {
197
- return operand_unzip_B.data ();
198
- }
199
199
};
200
200
201
201
protected:
0 commit comments