Skip to content

Commit 81c704b

Browse files
Xrekibaoqiwen
authored andcommitted
Merge branch 'opt_wint2' of https://github.com/baoqiwen/FastDeploy into opt_wint2
Change-Id: Iee3d64458bf5ab1c2775b437ae6993533cafd68b
2 parents 0fabdbc + 0b60689 commit 81c704b

File tree

5 files changed

+169
-34
lines changed

5 files changed

+169
-34
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 120 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@
4444
#include "cutlass/numeric_types.h"
4545

4646
#include "cutlass_extensions/arch/memory_copy_sm80.h"
47-
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
4847
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
4948
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
49+
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
5050

5151
/////////////////////////////////////////////////////////////////////////////////////////////////
5252

@@ -292,32 +292,32 @@ class Wint2xMmaMultistage :
292292
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
293293
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
294294

295-
CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d",
296-
Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave);
297-
CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d",
298-
Policy::kPartitionsK, Base::kWarpGemmIterations,
299-
Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k);
295+
//CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d",
296+
// Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave);
297+
//CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d",
298+
// Policy::kPartitionsK, Base::kWarpGemmIterations,
299+
// Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k);
300300

301301
// Add per-warp offsets in units of warp-level tiles
302302
this->warp_tile_iterator_A_.add_tile_offset(
303303
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
304304
this->warp_tile_iterator_B_.add_tile_offset(
305305
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
306306

307-
CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}",
308-
Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn);
309-
CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageA=%d, kAccessesPerVectorA=%d",
310-
shared_storage.operand_A.data(),
311-
static_cast<int>(Base::SharedStorage::ShapeA::kRow), static_cast<int>(Base::SharedStorage::ShapeA::kColumn),
312-
static_cast<int>(sizeof(shared_storage.operand_A)),
313-
static_cast<int>(IteratorA::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorA::AccessType)),
314-
static_cast<int>(Detail::AsyncCopyIterationsPerStageA), static_cast<int>(IteratorA::kAccessesPerVector));
315-
CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVectorA=%d",
316-
shared_storage.operand_B.data(),
317-
static_cast<int>(Base::SharedStorage::ShapeB::kRow), static_cast<int>(Base::SharedStorage::ShapeB::kColumn),
318-
static_cast<int>(sizeof(shared_storage.operand_B)),
319-
static_cast<int>(IteratorB::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorB::AccessType)),
320-
static_cast<int>(Detail::AsyncCopyIterationsPerStageB), static_cast<int>(IteratorB::kAccessesPerVector));
307+
//CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}",
308+
// Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn);
309+
//CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageA=%d, kAccessesPerVectorA=%d",
310+
// shared_storage.operand_A.data(),
311+
// static_cast<int>(Base::SharedStorage::ShapeA::kRow), static_cast<int>(Base::SharedStorage::ShapeA::kColumn),
312+
// static_cast<int>(sizeof(shared_storage.operand_A)),
313+
// static_cast<int>(IteratorA::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorA::AccessType)),
314+
// static_cast<int>(Detail::AsyncCopyIterationsPerStageA), static_cast<int>(IteratorA::kAccessesPerVector));
315+
//CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVectorA=%d",
316+
// shared_storage.operand_B.data(),
317+
// static_cast<int>(Base::SharedStorage::ShapeB::kRow), static_cast<int>(Base::SharedStorage::ShapeB::kColumn),
318+
// static_cast<int>(sizeof(shared_storage.operand_B)),
319+
// static_cast<int>(IteratorB::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorB::AccessType)),
320+
// static_cast<int>(Detail::AsyncCopyIterationsPerStageB), static_cast<int>(IteratorB::kAccessesPerVector));
321321

322322
smem_ptr_A_ = reinterpret_cast<ElementA*>(shared_storage.operand_A.data());
323323
smem_ptr_B_ = reinterpret_cast<uint8_t*>(shared_storage.operand_B.data());
@@ -678,9 +678,11 @@ class Wint2xMmaMultistage :
678678
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
679679
int stage)
680680
{
681+
681682
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
682683
CUTLASS_PRAGMA_UNROLL
683684
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
685+
684686
// Load the next warp-tile's A fragment from shared memory
685687
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
686688
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
@@ -699,31 +701,55 @@ class Wint2xMmaMultistage :
699701
}
700702

701703
// Execute the current warp-tile of MMA operations
704+
705+
// CUTLASS_TRACE_DEVICE("ElementA %d", PipeState::WarpTransformedFragmentA::kElements);
706+
// CUTLASS_TRACE_DEVICE("ElementB %d", PipeState::WarpTransformedFragmentB::kElements);
707+
// CUTLASS_TRACE_DEVICE("kStagedAccumulation %d", Detail::kStagedAccumulation);
708+
709+
// uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_[warp_mma_k % 2].data());
710+
// CUTLASS_TRACE_DEVICE(" reg_uint8_ptr=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
711+
// static_cast<int>(reg_uint8_ptr[0]), static_cast<int>(reg_uint8_ptr[1]),
712+
// static_cast<int>(reg_uint8_ptr[2]), static_cast<int>(reg_uint8_ptr[3]),
713+
// static_cast<int>(reg_uint8_ptr[4]), static_cast<int>(reg_uint8_ptr[5]),
714+
// static_cast<int>(reg_uint8_ptr[6]), static_cast<int>(reg_uint8_ptr[7]),
715+
// static_cast<int>(reg_uint8_ptr[8]), static_cast<int>(reg_uint8_ptr[9]),
716+
// static_cast<int>(reg_uint8_ptr[10]), static_cast<int>(reg_uint8_ptr[11]),
717+
// static_cast<int>(reg_uint8_ptr[12]), static_cast<int>(reg_uint8_ptr[13]),
718+
// static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
719+
// sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
720+
702721
if (Detail::kStagedAccumulation) {
703722
//CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
704723
warp_mma_(
705724
pipe_state.tmp_accum_,
706725
pipe_state.warp_frag_A_[warp_mma_k % 2],
707726
pipe_state.warp_frag_B_,
727+
// unpacked_frag_B,
708728
pipe_state.tmp_accum_,
709729
warp_k_compute_offset_B
710730
);
711731

712-
if (warp_mma_k == 0) {
713-
plus<FragmentC> plus_accum;
714-
accum = plus_accum(accum, pipe_state.tmp_accum_);
715-
pipe_state.tmp_accum_.clear();
716-
}
717732
} else {
718733
//CUTLASS_TRACE_DEVICE(" [MMa][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
719734
warp_mma_(
720735
accum,
721736
pipe_state.warp_frag_A_[warp_mma_k % 2],
722737
pipe_state.warp_frag_B_,
738+
// unpacked_frag_B,
723739
accum,
724740
warp_k_compute_offset_B
725741
);
726742
#if 0
743+
CUTLASS_TRACE_DEVICE(" pipe_state.warp_frag_B_=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
744+
static_cast<float>(pipe_state.warp_frag_B_[0]), static_cast<float>(pipe_state.warp_frag_B_[1]),
745+
static_cast<float>(pipe_state.warp_frag_B_[2]), static_cast<float>(pipe_state.warp_frag_B_[3]),
746+
static_cast<float>(pipe_state.warp_frag_B_[4]), static_cast<float>(pipe_state.warp_frag_B_[5]),
747+
static_cast<float>(pipe_state.warp_frag_B_[6]), static_cast<float>(pipe_state.warp_frag_B_[7]),
748+
static_cast<float>(pipe_state.warp_frag_B_[8]), static_cast<float>(pipe_state.warp_frag_B_[9]),
749+
static_cast<float>(pipe_state.warp_frag_B_[10]), static_cast<float>(pipe_state.warp_frag_B_[11]),
750+
static_cast<float>(pipe_state.warp_frag_B_[12]), static_cast<float>(pipe_state.warp_frag_B_[13]),
751+
static_cast<float>(pipe_state.warp_frag_B_[14]), static_cast<float>(pipe_state.warp_frag_B_[15]));
752+
727753
if (FragmentC::kElements == 16) {
728754
CUTLASS_TRACE_DEVICE(" tile_C[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
729755
static_cast<float>(accum[0]), static_cast<float>(accum[1]),
@@ -735,6 +761,12 @@ class Wint2xMmaMultistage :
735761
static_cast<float>(accum[12]), static_cast<float>(accum[13]),
736762
static_cast<float>(accum[14]), static_cast<float>(accum[15]));
737763
}
764+
765+
// CUTLASS_TRACE_DEVICE_TID(" now1 warp_loaded_frag_A_[0:7]=[%f, %f, %f, %f, %f, %f, %f, %f]",
766+
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][0]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][1]),
767+
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][2]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][3]),
768+
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][4]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][5]),
769+
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][6]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][7]));
738770
#endif
739771
}
740772

@@ -832,6 +864,7 @@ class Wint2xMmaMultistage :
832864
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
833865
++this->warp_tile_iterator_B_;
834866

867+
#if 0
835868
if (PipeState::WarpLoadedFragmentA::kElements == 8) {
836869
ElementA* warp_frag_A_ptr = reinterpret_cast<ElementA*>(pipe_state.warp_frag_A_[0].data());
837870
CUTLASS_TRACE_DEVICE(" warp_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes",
@@ -841,6 +874,7 @@ class Wint2xMmaMultistage :
841874
static_cast<float>(warp_frag_A_ptr[6]), static_cast<float>(warp_frag_A_ptr[7]),
842875
sizeof_bits<typename PipeState::WarpLoadedFragmentA>::value / 8);
843876
}
877+
#endif
844878
#if 0
845879
if (PipeState::WarpLoadedFragmentB::kElements == 64) {
846880
uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_.data());
@@ -881,8 +915,69 @@ class Wint2xMmaMultistage :
881915
pipe_state.warp_frag_B_,
882916
0);
883917

918+
#if 0
919+
if (TransformBAfterLDS::result_type::kElements == 64) {
920+
CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);
921+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
922+
static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
923+
static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
924+
static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
925+
static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
926+
static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
927+
static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
928+
static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
929+
static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
930+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
931+
static_cast<float>(unpacked_frag_B[16]), static_cast<float>(unpacked_frag_B[17]),
932+
static_cast<float>(unpacked_frag_B[18]), static_cast<float>(unpacked_frag_B[19]),
933+
static_cast<float>(unpacked_frag_B[20]), static_cast<float>(unpacked_frag_B[21]),
934+
static_cast<float>(unpacked_frag_B[22]), static_cast<float>(unpacked_frag_B[23]),
935+
static_cast<float>(unpacked_frag_B[24]), static_cast<float>(unpacked_frag_B[25]),
936+
static_cast<float>(unpacked_frag_B[26]), static_cast<float>(unpacked_frag_B[27]),
937+
static_cast<float>(unpacked_frag_B[28]), static_cast<float>(unpacked_frag_B[29]),
938+
static_cast<float>(unpacked_frag_B[30]), static_cast<float>(unpacked_frag_B[31]));
939+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
940+
static_cast<float>(unpacked_frag_B[32]), static_cast<float>(unpacked_frag_B[33]),
941+
static_cast<float>(unpacked_frag_B[34]), static_cast<float>(unpacked_frag_B[35]),
942+
static_cast<float>(unpacked_frag_B[36]), static_cast<float>(unpacked_frag_B[37]),
943+
static_cast<float>(unpacked_frag_B[38]), static_cast<float>(unpacked_frag_B[39]),
944+
static_cast<float>(unpacked_frag_B[40]), static_cast<float>(unpacked_frag_B[41]),
945+
static_cast<float>(unpacked_frag_B[42]), static_cast<float>(unpacked_frag_B[43]),
946+
static_cast<float>(unpacked_frag_B[44]), static_cast<float>(unpacked_frag_B[45]),
947+
static_cast<float>(unpacked_frag_B[46]), static_cast<float>(unpacked_frag_B[47]));
948+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
949+
static_cast<float>(unpacked_frag_B[48]), static_cast<float>(unpacked_frag_B[49]),
950+
static_cast<float>(unpacked_frag_B[50]), static_cast<float>(unpacked_frag_B[51]),
951+
static_cast<float>(unpacked_frag_B[52]), static_cast<float>(unpacked_frag_B[53]),
952+
static_cast<float>(unpacked_frag_B[54]), static_cast<float>(unpacked_frag_B[55]),
953+
static_cast<float>(unpacked_frag_B[56]), static_cast<float>(unpacked_frag_B[57]),
954+
static_cast<float>(unpacked_frag_B[58]), static_cast<float>(unpacked_frag_B[59]),
955+
static_cast<float>(unpacked_frag_B[60]), static_cast<float>(unpacked_frag_B[61]),
956+
static_cast<float>(unpacked_frag_B[62]), static_cast<float>(unpacked_frag_B[63]));
957+
}
958+
#endif
959+
884960
if (Detail::kStagedAccumulation) {
885961
pipe_state.tmp_accum_.clear();
962+
CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
963+
static_cast<float>(pipe_state.tmp_accum_[0]), static_cast<float>(pipe_state.tmp_accum_[1]),
964+
static_cast<float>(pipe_state.tmp_accum_[2]), static_cast<float>(pipe_state.tmp_accum_[3]),
965+
static_cast<float>(pipe_state.tmp_accum_[4]), static_cast<float>(pipe_state.tmp_accum_[5]),
966+
static_cast<float>(pipe_state.tmp_accum_[6]), static_cast<float>(pipe_state.tmp_accum_[7]),
967+
static_cast<float>(pipe_state.tmp_accum_[8]), static_cast<float>(pipe_state.tmp_accum_[9]),
968+
static_cast<float>(pipe_state.tmp_accum_[10]), static_cast<float>(pipe_state.tmp_accum_[11]),
969+
static_cast<float>(pipe_state.tmp_accum_[12]), static_cast<float>(pipe_state.tmp_accum_[13]),
970+
static_cast<float>(pipe_state.tmp_accum_[14]), static_cast<float>(pipe_state.tmp_accum_[15]));
971+
} else {
972+
CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
973+
static_cast<float>(accum[0]), static_cast<float>(accum[1]),
974+
static_cast<float>(accum[2]), static_cast<float>(accum[3]),
975+
static_cast<float>(accum[4]), static_cast<float>(accum[5]),
976+
static_cast<float>(accum[6]), static_cast<float>(accum[7]),
977+
static_cast<float>(accum[8]), static_cast<float>(accum[9]),
978+
static_cast<float>(accum[10]), static_cast<float>(accum[11]),
979+
static_cast<float>(accum[12]), static_cast<float>(accum[13]),
980+
static_cast<float>(accum[14]), static_cast<float>(accum[15]));
886981
}
887982

888983
int stage = Base::kStages - 1;

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ class Wint2ParamsAccessor {
187187
smem_write_stage_idx_(0),
188188
smem_read_stage_idx_(0)
189189
{
190-
CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, kSmemRows=%d, kSmemColumns=%d, kLocalScaleRows=%d, kStagesPerLocalScaleLoad=%d",
191-
Shape::kM, Shape::kN, Shape::kK, kSmemRows, kSmemColumns, kLocalScaleRows, kStagesPerLocalScaleLoad);
190+
//CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, kSmemRows=%d, kSmemColumns=%d, kLocalScaleRows=%d, kStagesPerLocalScaleLoad=%d",
191+
// Shape::kM, Shape::kN, Shape::kK, kSmemRows, kSmemColumns, kLocalScaleRows, kStagesPerLocalScaleLoad);
192192
//CUTLASS_TRACE_DEVICE(" IteratorSuperScale::Shape: {%d, %d}, kSuperScaleSmemOffset=%d, smem_ptr=%p",
193193
// IteratorSuperScale::Shape::kRow, IteratorSuperScale::Shape::kColumn, kSuperScaleSmemOffset, get_super_scale_smem_ptr());
194194
//CUTLASS_TRACE_DEVICE(" IteratorLocalScale::Shape: {%d, %d}, kLocalScaleSmemOffset=%d, smem_ptr=%p",

custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ class MmaTensorOpWin2xDequantizer<
287287
static_cast<ElementCompute>(shifted_local_scale) * static_cast<ElementCompute>(super_scale_frag[i]);
288288
}
289289

290-
#if 1
290+
#if 0
291291
if (FragmentCompute::kElements == 4) {
292292
CUTLASS_TRACE_DEVICE(" [stage=%d] tb_offset_k=%d, local_scale_shift=%d, scale_frag[0:3]=[%f, %f, %f, %f], sizeof(FragmentCompute)=%d bytes",
293293
stage, tb_offset_k, local_scale_shift,
@@ -312,7 +312,7 @@ class MmaTensorOpWin2xDequantizer<
312312
}
313313

314314
if (FragmentOutput::kElements == 64) {
315-
#if 1
315+
#if 0
316316
CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
317317
stage,
318318
static_cast<float>(output_frag[0]), static_cast<float>(output_frag[1]),
@@ -362,6 +362,16 @@ class MmaTensorOpWin2xDequantizer<
362362
// avoid numerous conversion instructions in GEMM main loop.
363363
arch::device_breakpoint();
364364
#endif
365+
366+
const int fixed_values[64] = {
367+
0, 1, 8, 9, 16, 17, 24, 25, 32, 33, 40, 41, 48, 49, 56, 57,
368+
2, 3, 10, 11, 18, 19, 26, 27, 34, 35, 42, 43, 50, 51, 58, 59,
369+
4, 5, 12, 13, 20, 21, 28, 29, 36, 37, 44, 45, 52, 53, 60, 61,
370+
6, 7, 14, 15, 22, 23, 30, 31, 38, 39, 46, 47, 54, 55, 62, 63
371+
};
372+
for (int i = 0; i < FragmentUnpack::kElements; ++i) {
373+
output_frag[i] = static_cast<typename FragmentUnpack::Element>(fixed_values[(i % 16) + (threadIdx.x % 4) * 16]);
374+
}
365375
}
366376

367377
/// Add an offset to pointer in units of elements.

0 commit comments

Comments
 (0)