Skip to content

Commit 2bda35d

Browse files
committed
Implement FastInterleavedAndBiasedNumericArrayConverter for wint2.
Change-Id: I438f2b18ab964a04ae1cdb09d9e7d9f7b95eafca
1 parent 487d643 commit 2bda35d

File tree

5 files changed

+285
-45
lines changed

5 files changed

+285
-45
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "cutlass_extensions/arch/mma.h"
2121
#include "cutlass_extensions/interleaved_numeric_conversion.h"
22+
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
2223
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h"
2324

2425
namespace cutlass {
@@ -156,13 +157,16 @@ struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlig
156157
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
157158
AccessTypeB>;
158159

160+
using TransformBAfterLDS = FastInterleavedAndBiasedNumericArrayConverter<
161+
ElementA, ElementB, MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
162+
159163
// Define the threadblock-scoped multistage matrix multiply
160164
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
161165
typename MmaCore::Shape,
162166
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
163167
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
164168
ElementAccumulator, layout::RowMajor,
165-
typename MmaCore::MmaPolicy, kStages, SharedMemoryClear>;
169+
typename MmaCore::MmaPolicy, kStages, TransformBAfterLDS, SharedMemoryClear>;
166170
};
167171

168172
} // namespace threadblock

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_base.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ class Wint2xMmaBase {
9393
static int const kWarpGemmIterations =
9494
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
9595

96+
/// Number of warp-level GEMM oeprations per load for B
97+
static constexpr int kWarpGemmIterationsPerLoadForB =
98+
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
99+
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
100+
101+
static constexpr int kWarpLoadIterationsForB =
102+
kWarpGemmIterations / kWarpGemmIterationsPerLoadForB;
103+
104+
96105
/// Number of stages
97106
static int const kStages = Stages;
98107

@@ -131,16 +140,16 @@ class Wint2xMmaBase {
131140
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
132141
Shape::kN + Policy::SmemPaddingB::kColumn>;
133142

134-
// w uint8; local_scale uint8;
135-
constexpr static int kZippedRowsPerStages = Shape::kK / 4 + (Shape::kK + 127) / 128;
143+
// local_scale uint4
144+
constexpr static int kGroupWiseParamRows = Shape::kK / 64;
145+
146+
using GroupWiseParamShapeB = MatrixShape<kGroupWiseParamRows * kStages, Shape::kN>;
136147

137148
// code_scale float; code_zp float; super_scale ElementB
138-
constexpr static int kColumnWiseParamsRows = 2 * sizeof(float) +
149+
constexpr static int kColumnWiseParamRows = 2 * sizeof(float) +
139150
sizeof_bits<typename Operator::ElementB>::value / 8;
140151

141-
using ZippedShapeB = MatrixShape<kColumnWiseParamsRows + kZippedRowsPerStages * kStages, Shape::kN>;
142-
143-
using NopaddingShapeB = MatrixShape<Shape::kK, Shape::kN>;
152+
using ColumnWiseParamShapeB = MatrixShape<kColumnWiseParamRows, Shape::kN>;
144153

145154
public:
146155
//
@@ -153,12 +162,11 @@ class Wint2xMmaBase {
153162
/// Buffer for B operand
154163
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
155164

156-
/// Buffer for quanted B operand
157-
AlignedBuffer<uint8_t, ZippedShapeB::kCount> operand_zipped_B;
165+
/// Buffer for local_scale of B operand
166+
AlignedBuffer<uint4b_t, GroupWiseParamShapeB::kCount> operand_local_scale_B;
158167

159-
/// Buffer for unzip B operand
160-
AlignedBuffer<typename Operator::ElementB, NopaddingShapeB::kCount>
161-
operand_unzip_B;
168+
/// Buffer for column-wise params of B operand
169+
AlignedBuffer<uint8_t, ColumnWiseParamShapeB::kCount> operand_column_wise_B;
162170

163171
public:
164172
//
@@ -188,14 +196,6 @@ class Wint2xMmaBase {
188196
TensorRefB operand_B_ref() {
189197
return TensorRefB{operand_B.data(), LayoutB()};
190198
}
191-
192-
CUTLASS_HOST_DEVICE
193-
uint8_t *operand_zipped_B_ptr() { return operand_zipped_B.data(); }
194-
195-
CUTLASS_HOST_DEVICE
196-
typename Operator::ElementB *operand_unzip_B_ptr() {
197-
return operand_unzip_B.data();
198-
}
199199
};
200200

201201
protected:

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 152 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ template <
8686
typename Policy_,
8787
/// Number of stages,
8888
int Stages,
89+
/// Transform for input B applied in register after the LDS
90+
typename TransformBAfterLDS_,
8991
/// Use zfill or predicate for out-of-bound cp.async
90-
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
91-
/// Used for partial specialization
92-
typename Enable = bool>
92+
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
9393
class Wint2xMmaMultistage :
9494
public Wint2xMmaBase<Shape_, Policy_, Stages> {
9595
public:
@@ -107,8 +107,10 @@ class Wint2xMmaMultistage :
107107
using LayoutC = LayoutC_;
108108
///< Policy describing tuning details
109109
using Policy = Policy_;
110+
/// Transform for input B applied in register after the LDS
111+
using TransformBAfterLDS = TransformBAfterLDS_;
110112

111-
using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB;
113+
static constexpr int kInterleave = IteratorB::Shape::kRow / Shape::kK;
112114

113115
using SmemIteratorA = SmemIteratorA_;
114116
using SmemIteratorB = SmemIteratorB_;
@@ -131,12 +133,11 @@ class Wint2xMmaMultistage :
131133

132134
using LayoutScale = cutlass::layout::ColumnMajor;
133135
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
134-
using ElementB = typename WarpTransformedFragmentB::Element;
135136
using Dequantizer =
136137
warp::MmaTensorOpWin2xDequantizer<Operator,
137138
typename Base::WarpGemm,
138139
Operand::kB,
139-
ElementB,
140+
typename WarpTransformedFragmentB::Element,
140141
cutlass::layout::ColumnMajor,
141142
32,
142143
WeightOnlyQuantOp::UNDEFINED>;
@@ -199,6 +200,14 @@ class Wint2xMmaMultistage :
199200
WarpTransformedFragmentB warp_transformed_frag_B_[2];
200201
};
201202

203+
using ElementA = typename IteratorA::Element;
204+
using ElementB = typename IteratorB::Element;
205+
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
206+
207+
static constexpr bool IsTileInterleaveLayout =
208+
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
209+
static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
210+
"Layout K must match threadblockK");
202211

203212
private:
204213

@@ -224,10 +233,11 @@ class Wint2xMmaMultistage :
224233
/// Shared memory read stage index
225234
int smem_read_stage_idx_;
226235

227-
uint8_t* column_wise_smem_ptr_B_;
236+
/// Transform for B in register
237+
TransformBAfterLDS transform_B_;
228238

229-
uint8_t* smem_zipped_ptr_B_;
230-
int smem_zipped_bytes_per_stage_B_;
239+
uint8_t* smem_ptr_B_;
240+
uint8_t* ptr_B_;
231241

232242
public:
233243

@@ -261,16 +271,31 @@ class Wint2xMmaMultistage :
261271
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
262272
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
263273

274+
CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d",
275+
Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave);
276+
CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d",
277+
Policy::kPartitionsK, Base::kWarpGemmIterations,
278+
Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k);
279+
264280
// Add per-warp offsets in units of warp-level tiles
265281
this->warp_tile_iterator_A_.add_tile_offset(
266282
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
267283
this->warp_tile_iterator_B_.add_tile_offset(
268284
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
269285

270-
column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr();
271-
272-
smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn;
273-
smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn;
286+
CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}",
287+
Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn);
288+
CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d",
289+
shared_storage.operand_A.data(), static_cast<int>(Base::SharedStorage::ShapeA::kRow),
290+
static_cast<int>(Base::SharedStorage::ShapeA::kColumn));
291+
CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVector=%d",
292+
shared_storage.operand_B.data(),
293+
static_cast<int>(Base::SharedStorage::ShapeB::kRow), static_cast<int>(Base::SharedStorage::ShapeB::kColumn),
294+
static_cast<int>(sizeof(shared_storage.operand_B)),
295+
static_cast<int>(IteratorB::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorB::AccessType)),
296+
static_cast<int>(Detail::AsyncCopyIterationsPerStageB), static_cast<int>(IteratorB::kAccessesPerVector));
297+
298+
smem_ptr_B_ = reinterpret_cast<uint8_t*>(shared_storage.operand_B.data());
274299
}
275300

276301
/// Advance shared memory read-iterators to the next stage
@@ -371,6 +396,13 @@ class Wint2xMmaMultistage :
371396
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
372397
auto gmem_ptr = iterator_B.get();
373398

399+
if (group_start_B == 0 && j == 0 && v == 0) {
400+
CUTLASS_TRACE_DEVICE(" dst_ptr=%p, iterator_B.get()=%p, kAccessesPerGroupB=%d, kAccessesPerVector=%d, sizeof(AccessType)=%d",
401+
reinterpret_cast<void*>(dst_ptr), reinterpret_cast<void*>(gmem_ptr),
402+
static_cast<int>(Detail::kAccessesPerGroupB), static_cast<int>(IteratorB::kAccessesPerVector),
403+
static_cast<int>(sizeof(typename IteratorB::Element)));
404+
}
405+
374406
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
375407
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
376408
dst_ptr + v, gmem_ptr, iterator_B.valid());
@@ -423,7 +455,7 @@ class Wint2xMmaMultistage :
423455

424456
template <bool GlobalToSharedB, bool InitStage>
425457
CUTLASS_DEVICE
426-
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B) {
458+
void copy_tiles_and_advance_per_stage_B(IteratorB &iterator_B, int stage) {
427459
iterator_B.set_iteration_index(0);
428460
this->smem_iterator_B_.set_iteration_index(0);
429461

@@ -443,6 +475,31 @@ class Wint2xMmaMultistage :
443475
IteratorB::ThreadMap::kElementsPerAccess /
444476
IteratorB::kAccessesPerVector / 8;
445477

478+
if (v == 0) {
479+
int gmem_offset = reinterpret_cast<int>(gmem_ptr) - reinterpret_cast<int>(ptr_B_);
480+
int gmem_k = 8192 * kInterleave / 4;
481+
int gmem_n = 1792 / kInterleave;
482+
int gmem_row = gmem_offset / gmem_k;
483+
int gmem_col = gmem_offset % gmem_k;
484+
485+
int smem_offset = reinterpret_cast<int>(dst_ptr) - reinterpret_cast<int>(smem_ptr_B_);
486+
int smem_k = Shape::kK * kInterleave / 4;
487+
int smem_n = Shape::kN / kInterleave;
488+
int smem_row = smem_offset / smem_k;
489+
int smem_col = smem_offset % smem_k;
490+
491+
uint8_t* gmem_uint8_ptr = reinterpret_cast<uint8_t*>(gmem_ptr);
492+
493+
CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};",
494+
stage, reinterpret_cast<void*>(gmem_ptr), reinterpret_cast<void*>(dst_ptr), kSrcBytes,
495+
gmem_n, gmem_k, gmem_row, gmem_col,
496+
static_cast<int>(gmem_uint8_ptr[0]), static_cast<int>(gmem_uint8_ptr[1]),
497+
static_cast<int>(gmem_uint8_ptr[2]), static_cast<int>(gmem_uint8_ptr[3]),
498+
static_cast<int>(gmem_uint8_ptr[4]), static_cast<int>(gmem_uint8_ptr[5]),
499+
static_cast<int>(gmem_uint8_ptr[6]), static_cast<int>(gmem_uint8_ptr[7]),
500+
smem_row, smem_col);
501+
}
502+
446503
if (InitStage) {
447504
cutlass::arch::copy_zfill<kSrcBytes, kCacheOpB, GlobalToSharedB>(
448505
dst_ptr + v, iterator_B.get(), iterator_B.valid());
@@ -484,7 +541,7 @@ class Wint2xMmaMultistage :
484541
copy_tiles_and_advance_per_stage_A(iterator_A);
485542

486543
// Async copy zipped B to shared memory.
487-
copy_tiles_and_advance_per_stage_B<true, true>(iterator_B);
544+
copy_tiles_and_advance_per_stage_B<true, true>(iterator_B, stage);
488545

489546
// TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
490547

@@ -666,6 +723,18 @@ class Wint2xMmaMultistage :
666723
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
667724
IteratorB &iterator_B)
668725
{
726+
#if 0
727+
int smem_k = Shape::kK * kInterleave / 4;
728+
int smem_n = Shape::kN / kInterleave;
729+
for (int i = 0; i < 3 * smem_n; ++i) {
730+
for (int j = 0; j < smem_k; ++j) {
731+
if (i % 3 == 0) {
732+
CUTLASS_TRACE_DEVICE(" [i=%d, j=%d, %dx%d] %d", i, j, smem_n, smem_k, static_cast<int>(smem_ptr_B_[i * smem_k + j]));
733+
}
734+
}
735+
}
736+
#endif
737+
669738
PipeState pipe_state;
670739

671740
// Disable global fetching if done with global fetch iterations
@@ -682,6 +751,70 @@ class Wint2xMmaMultistage :
682751
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
683752
++this->warp_tile_iterator_B_;
684753

754+
if (PipeState::WarpLoadedFragmentA::kElements == 8) {
755+
ElementA* warp_frag_A_ptr = reinterpret_cast<ElementA*>(pipe_state.warp_loaded_frag_A_[0].data());
756+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes",
757+
static_cast<float>(warp_frag_A_ptr[0]), static_cast<float>(warp_frag_A_ptr[1]),
758+
static_cast<float>(warp_frag_A_ptr[2]), static_cast<float>(warp_frag_A_ptr[3]),
759+
static_cast<float>(warp_frag_A_ptr[4]), static_cast<float>(warp_frag_A_ptr[5]),
760+
static_cast<float>(warp_frag_A_ptr[6]), static_cast<float>(warp_frag_A_ptr[7]),
761+
sizeof_bits<typename PipeState::WarpLoadedFragmentA>::value / 8);
762+
}
763+
if (PipeState::WarpLoadedFragmentB::kElements == 64) {
764+
uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_[0].data());
765+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
766+
static_cast<int>(reg_uint8_ptr[0]), static_cast<int>(reg_uint8_ptr[1]),
767+
static_cast<int>(reg_uint8_ptr[2]), static_cast<int>(reg_uint8_ptr[3]),
768+
static_cast<int>(reg_uint8_ptr[4]), static_cast<int>(reg_uint8_ptr[5]),
769+
static_cast<int>(reg_uint8_ptr[6]), static_cast<int>(reg_uint8_ptr[7]),
770+
static_cast<int>(reg_uint8_ptr[8]), static_cast<int>(reg_uint8_ptr[9]),
771+
static_cast<int>(reg_uint8_ptr[10]), static_cast<int>(reg_uint8_ptr[11]),
772+
static_cast<int>(reg_uint8_ptr[12]), static_cast<int>(reg_uint8_ptr[13]),
773+
static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
774+
sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
775+
}
776+
777+
typename TransformBAfterLDS::result_type unpacked_frag_B = transform_B_(pipe_state.warp_loaded_frag_B_[0]);
778+
if (TransformBAfterLDS::result_type::kElements == 64) {
779+
CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);
780+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
781+
static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
782+
static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
783+
static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
784+
static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
785+
static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
786+
static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
787+
static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
788+
static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
789+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
790+
static_cast<float>(unpacked_frag_B[16]), static_cast<float>(unpacked_frag_B[17]),
791+
static_cast<float>(unpacked_frag_B[18]), static_cast<float>(unpacked_frag_B[19]),
792+
static_cast<float>(unpacked_frag_B[20]), static_cast<float>(unpacked_frag_B[21]),
793+
static_cast<float>(unpacked_frag_B[22]), static_cast<float>(unpacked_frag_B[23]),
794+
static_cast<float>(unpacked_frag_B[24]), static_cast<float>(unpacked_frag_B[25]),
795+
static_cast<float>(unpacked_frag_B[26]), static_cast<float>(unpacked_frag_B[27]),
796+
static_cast<float>(unpacked_frag_B[28]), static_cast<float>(unpacked_frag_B[29]),
797+
static_cast<float>(unpacked_frag_B[30]), static_cast<float>(unpacked_frag_B[31]));
798+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
799+
static_cast<float>(unpacked_frag_B[32]), static_cast<float>(unpacked_frag_B[33]),
800+
static_cast<float>(unpacked_frag_B[34]), static_cast<float>(unpacked_frag_B[35]),
801+
static_cast<float>(unpacked_frag_B[36]), static_cast<float>(unpacked_frag_B[37]),
802+
static_cast<float>(unpacked_frag_B[38]), static_cast<float>(unpacked_frag_B[39]),
803+
static_cast<float>(unpacked_frag_B[40]), static_cast<float>(unpacked_frag_B[41]),
804+
static_cast<float>(unpacked_frag_B[42]), static_cast<float>(unpacked_frag_B[43]),
805+
static_cast<float>(unpacked_frag_B[44]), static_cast<float>(unpacked_frag_B[45]),
806+
static_cast<float>(unpacked_frag_B[46]), static_cast<float>(unpacked_frag_B[47]));
807+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
808+
static_cast<float>(unpacked_frag_B[48]), static_cast<float>(unpacked_frag_B[49]),
809+
static_cast<float>(unpacked_frag_B[50]), static_cast<float>(unpacked_frag_B[51]),
810+
static_cast<float>(unpacked_frag_B[52]), static_cast<float>(unpacked_frag_B[53]),
811+
static_cast<float>(unpacked_frag_B[54]), static_cast<float>(unpacked_frag_B[55]),
812+
static_cast<float>(unpacked_frag_B[56]), static_cast<float>(unpacked_frag_B[57]),
813+
static_cast<float>(unpacked_frag_B[58]), static_cast<float>(unpacked_frag_B[59]),
814+
static_cast<float>(unpacked_frag_B[60]), static_cast<float>(unpacked_frag_B[61]),
815+
static_cast<float>(unpacked_frag_B[62]), static_cast<float>(unpacked_frag_B[63]));
816+
}
817+
685818
typename Dequantizer::FragmentLocalScale warp_frag_local_scale;
686819
typename Dequantizer::FragmentCodeScale warp_frag_code_scale;
687820
typename Dequantizer::FragmentCodeZp warp_frag_code_zp;
@@ -702,6 +835,7 @@ class Wint2xMmaMultistage :
702835
warp_frag_code_zp,
703836
warp_frag_super_scale);
704837

838+
#if 0
705839
// Transform, if necessary, the first warp-tile's shared memory fragments
706840
warp_mma_.transform(
707841
pipe_state.warp_transformed_frag_A_[0],
@@ -713,7 +847,6 @@ class Wint2xMmaMultistage :
713847
pipe_state.tmp_accum_.clear();
714848
}
715849

716-
#if 0
717850
int stage = Base::kStages - 1;
718851

719852
// Mainloop
@@ -790,6 +923,8 @@ class Wint2xMmaMultistage :
790923
///< initial value of accumulator
791924
FragmentC const &src_accum) {
792925

926+
ptr_B_ = reinterpret_cast<uint8_t*>(iterator_B.get_origin_pointer());
927+
793928
// Prologue (start fetching iterations of global fragments into shared memory)
794929
prologue(iterator_A, iterator_B, gemm_k_iterations);
795930

@@ -800,7 +935,7 @@ class Wint2xMmaMultistage :
800935
accum = src_accum;
801936

802937
// Perform the MAC-iterations
803-
//gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B);
938+
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B);
804939
}
805940
};
806941

0 commit comments

Comments
 (0)