Skip to content

Commit 5715889

Browse files
committed
Merge branch 'opt_wint2' of https://github.com/baoqiwen/FastDeploy into opt_wint2
Change-Id: Iee3d64458bf5ab1c2775b437ae6993533cafd68b
2 parents 0fabdbc + 0b60689 commit 5715889

File tree

5 files changed

+170
-34
lines changed

5 files changed

+170
-34
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 131 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@
4444
#include "cutlass/numeric_types.h"
4545

4646
#include "cutlass_extensions/arch/memory_copy_sm80.h"
47-
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
4847
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
4948
#include "cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
49+
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
5050

5151
/////////////////////////////////////////////////////////////////////////////////////////////////
5252

@@ -292,32 +292,32 @@ class Wint2xMmaMultistage :
292292
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
293293
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
294294

295-
CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d",
296-
Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave);
297-
CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d",
298-
Policy::kPartitionsK, Base::kWarpGemmIterations,
299-
Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k);
295+
//CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d",
296+
// Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave);
297+
//CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d",
298+
// Policy::kPartitionsK, Base::kWarpGemmIterations,
299+
// Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k);
300300

301301
// Add per-warp offsets in units of warp-level tiles
302302
this->warp_tile_iterator_A_.add_tile_offset(
303303
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
304304
this->warp_tile_iterator_B_.add_tile_offset(
305305
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
306306

307-
CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}",
308-
Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn);
309-
CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageA=%d, kAccessesPerVectorA=%d",
310-
shared_storage.operand_A.data(),
311-
static_cast<int>(Base::SharedStorage::ShapeA::kRow), static_cast<int>(Base::SharedStorage::ShapeA::kColumn),
312-
static_cast<int>(sizeof(shared_storage.operand_A)),
313-
static_cast<int>(IteratorA::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorA::AccessType)),
314-
static_cast<int>(Detail::AsyncCopyIterationsPerStageA), static_cast<int>(IteratorA::kAccessesPerVector));
315-
CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVectorA=%d",
316-
shared_storage.operand_B.data(),
317-
static_cast<int>(Base::SharedStorage::ShapeB::kRow), static_cast<int>(Base::SharedStorage::ShapeB::kColumn),
318-
static_cast<int>(sizeof(shared_storage.operand_B)),
319-
static_cast<int>(IteratorB::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorB::AccessType)),
320-
static_cast<int>(Detail::AsyncCopyIterationsPerStageB), static_cast<int>(IteratorB::kAccessesPerVector));
307+
//CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}",
308+
// Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn);
309+
//CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageA=%d, kAccessesPerVectorA=%d",
310+
// shared_storage.operand_A.data(),
311+
// static_cast<int>(Base::SharedStorage::ShapeA::kRow), static_cast<int>(Base::SharedStorage::ShapeA::kColumn),
312+
// static_cast<int>(sizeof(shared_storage.operand_A)),
313+
// static_cast<int>(IteratorA::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorA::AccessType)),
314+
// static_cast<int>(Detail::AsyncCopyIterationsPerStageA), static_cast<int>(IteratorA::kAccessesPerVector));
315+
//CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVectorA=%d",
316+
// shared_storage.operand_B.data(),
317+
// static_cast<int>(Base::SharedStorage::ShapeB::kRow), static_cast<int>(Base::SharedStorage::ShapeB::kColumn),
318+
// static_cast<int>(sizeof(shared_storage.operand_B)),
319+
// static_cast<int>(IteratorB::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorB::AccessType)),
320+
// static_cast<int>(Detail::AsyncCopyIterationsPerStageB), static_cast<int>(IteratorB::kAccessesPerVector));
321321

322322
smem_ptr_A_ = reinterpret_cast<ElementA*>(shared_storage.operand_A.data());
323323
smem_ptr_B_ = reinterpret_cast<uint8_t*>(shared_storage.operand_B.data());
@@ -678,9 +678,11 @@ class Wint2xMmaMultistage :
678678
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
679679
int stage)
680680
{
681+
681682
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
682683
CUTLASS_PRAGMA_UNROLL
683684
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
685+
684686
// Load the next warp-tile's A fragment from shared memory
685687
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
686688
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
@@ -699,27 +701,41 @@ class Wint2xMmaMultistage :
699701
}
700702

701703
// Execute the current warp-tile of MMA operations
704+
705+
// CUTLASS_TRACE_DEVICE("ElementA %d", PipeState::WarpTransformedFragmentA::kElements);
706+
// CUTLASS_TRACE_DEVICE("ElementB %d", PipeState::WarpTransformedFragmentB::kElements);
707+
// CUTLASS_TRACE_DEVICE("kStagedAccumulation %d", Detail::kStagedAccumulation);
708+
709+
// uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_[warp_mma_k % 2].data());
710+
// CUTLASS_TRACE_DEVICE(" reg_uint8_ptr=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
711+
// static_cast<int>(reg_uint8_ptr[0]), static_cast<int>(reg_uint8_ptr[1]),
712+
// static_cast<int>(reg_uint8_ptr[2]), static_cast<int>(reg_uint8_ptr[3]),
713+
// static_cast<int>(reg_uint8_ptr[4]), static_cast<int>(reg_uint8_ptr[5]),
714+
// static_cast<int>(reg_uint8_ptr[6]), static_cast<int>(reg_uint8_ptr[7]),
715+
// static_cast<int>(reg_uint8_ptr[8]), static_cast<int>(reg_uint8_ptr[9]),
716+
// static_cast<int>(reg_uint8_ptr[10]), static_cast<int>(reg_uint8_ptr[11]),
717+
// static_cast<int>(reg_uint8_ptr[12]), static_cast<int>(reg_uint8_ptr[13]),
718+
// static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
719+
// sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
720+
702721
if (Detail::kStagedAccumulation) {
703722
//CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
704723
warp_mma_(
705724
pipe_state.tmp_accum_,
706725
pipe_state.warp_frag_A_[warp_mma_k % 2],
707726
pipe_state.warp_frag_B_,
727+
// unpacked_frag_B,
708728
pipe_state.tmp_accum_,
709729
warp_k_compute_offset_B
710730
);
711731

712-
if (warp_mma_k == 0) {
713-
plus<FragmentC> plus_accum;
714-
accum = plus_accum(accum, pipe_state.tmp_accum_);
715-
pipe_state.tmp_accum_.clear();
716-
}
717732
} else {
718733
//CUTLASS_TRACE_DEVICE(" [MMa][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
719734
warp_mma_(
720735
accum,
721736
pipe_state.warp_frag_A_[warp_mma_k % 2],
722737
pipe_state.warp_frag_B_,
738+
// unpacked_frag_B,
723739
accum,
724740
warp_k_compute_offset_B
725741
);
@@ -736,6 +752,33 @@ class Wint2xMmaMultistage :
736752
static_cast<float>(accum[14]), static_cast<float>(accum[15]));
737753
}
738754
#endif
755+
756+
// CUTLASS_TRACE_DEVICE_TID(" now1 warp_loaded_frag_A_[0:7]=[%f, %f, %f, %f, %f, %f, %f, %f]",
757+
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][0]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][1]),
758+
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][2]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][3]),
759+
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][4]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][5]),
760+
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][6]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][7]));
761+
762+
// CUTLASS_TRACE_DEVICE_TID(" now1 unpacked_frag_B[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
763+
// static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
764+
// static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
765+
// static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
766+
// static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
767+
// static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
768+
// static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
769+
// static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
770+
// static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
771+
772+
// CUTLASS_TRACE_DEVICE_TID(" warp_k_compute_offset_B = %d, now1 tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
773+
// warp_k_compute_offset_B,
774+
// static_cast<float>(accum[0]), static_cast<float>(accum[1]),
775+
// static_cast<float>(accum[2]), static_cast<float>(accum[3]),
776+
// static_cast<float>(accum[4]), static_cast<float>(accum[5]),
777+
// static_cast<float>(accum[6]), static_cast<float>(accum[7]),
778+
// static_cast<float>(accum[8]), static_cast<float>(accum[9]),
779+
// static_cast<float>(accum[10]), static_cast<float>(accum[11]),
780+
// static_cast<float>(accum[12]), static_cast<float>(accum[13]),
781+
// static_cast<float>(accum[14]), static_cast<float>(accum[15]));
739782
}
740783

741784
// Except for the last warp-tile, all warp-tiles issue their share of
@@ -832,6 +875,7 @@ class Wint2xMmaMultistage :
832875
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
833876
++this->warp_tile_iterator_B_;
834877

878+
#if 0
835879
if (PipeState::WarpLoadedFragmentA::kElements == 8) {
836880
ElementA* warp_frag_A_ptr = reinterpret_cast<ElementA*>(pipe_state.warp_frag_A_[0].data());
837881
CUTLASS_TRACE_DEVICE(" warp_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes",
@@ -841,6 +885,7 @@ class Wint2xMmaMultistage :
841885
static_cast<float>(warp_frag_A_ptr[6]), static_cast<float>(warp_frag_A_ptr[7]),
842886
sizeof_bits<typename PipeState::WarpLoadedFragmentA>::value / 8);
843887
}
888+
#endif
844889
#if 0
845890
if (PipeState::WarpLoadedFragmentB::kElements == 64) {
846891
uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_.data());
@@ -881,8 +926,69 @@ class Wint2xMmaMultistage :
881926
pipe_state.warp_frag_B_,
882927
0);
883928

929+
#if 0
930+
if (TransformBAfterLDS::result_type::kElements == 64) {
931+
CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);
932+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
933+
static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
934+
static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
935+
static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
936+
static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
937+
static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
938+
static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
939+
static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
940+
static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
941+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
942+
static_cast<float>(unpacked_frag_B[16]), static_cast<float>(unpacked_frag_B[17]),
943+
static_cast<float>(unpacked_frag_B[18]), static_cast<float>(unpacked_frag_B[19]),
944+
static_cast<float>(unpacked_frag_B[20]), static_cast<float>(unpacked_frag_B[21]),
945+
static_cast<float>(unpacked_frag_B[22]), static_cast<float>(unpacked_frag_B[23]),
946+
static_cast<float>(unpacked_frag_B[24]), static_cast<float>(unpacked_frag_B[25]),
947+
static_cast<float>(unpacked_frag_B[26]), static_cast<float>(unpacked_frag_B[27]),
948+
static_cast<float>(unpacked_frag_B[28]), static_cast<float>(unpacked_frag_B[29]),
949+
static_cast<float>(unpacked_frag_B[30]), static_cast<float>(unpacked_frag_B[31]));
950+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
951+
static_cast<float>(unpacked_frag_B[32]), static_cast<float>(unpacked_frag_B[33]),
952+
static_cast<float>(unpacked_frag_B[34]), static_cast<float>(unpacked_frag_B[35]),
953+
static_cast<float>(unpacked_frag_B[36]), static_cast<float>(unpacked_frag_B[37]),
954+
static_cast<float>(unpacked_frag_B[38]), static_cast<float>(unpacked_frag_B[39]),
955+
static_cast<float>(unpacked_frag_B[40]), static_cast<float>(unpacked_frag_B[41]),
956+
static_cast<float>(unpacked_frag_B[42]), static_cast<float>(unpacked_frag_B[43]),
957+
static_cast<float>(unpacked_frag_B[44]), static_cast<float>(unpacked_frag_B[45]),
958+
static_cast<float>(unpacked_frag_B[46]), static_cast<float>(unpacked_frag_B[47]));
959+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
960+
static_cast<float>(unpacked_frag_B[48]), static_cast<float>(unpacked_frag_B[49]),
961+
static_cast<float>(unpacked_frag_B[50]), static_cast<float>(unpacked_frag_B[51]),
962+
static_cast<float>(unpacked_frag_B[52]), static_cast<float>(unpacked_frag_B[53]),
963+
static_cast<float>(unpacked_frag_B[54]), static_cast<float>(unpacked_frag_B[55]),
964+
static_cast<float>(unpacked_frag_B[56]), static_cast<float>(unpacked_frag_B[57]),
965+
static_cast<float>(unpacked_frag_B[58]), static_cast<float>(unpacked_frag_B[59]),
966+
static_cast<float>(unpacked_frag_B[60]), static_cast<float>(unpacked_frag_B[61]),
967+
static_cast<float>(unpacked_frag_B[62]), static_cast<float>(unpacked_frag_B[63]));
968+
}
969+
#endif
970+
884971
if (Detail::kStagedAccumulation) {
885972
pipe_state.tmp_accum_.clear();
973+
CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
974+
static_cast<float>(pipe_state.tmp_accum_[0]), static_cast<float>(pipe_state.tmp_accum_[1]),
975+
static_cast<float>(pipe_state.tmp_accum_[2]), static_cast<float>(pipe_state.tmp_accum_[3]),
976+
static_cast<float>(pipe_state.tmp_accum_[4]), static_cast<float>(pipe_state.tmp_accum_[5]),
977+
static_cast<float>(pipe_state.tmp_accum_[6]), static_cast<float>(pipe_state.tmp_accum_[7]),
978+
static_cast<float>(pipe_state.tmp_accum_[8]), static_cast<float>(pipe_state.tmp_accum_[9]),
979+
static_cast<float>(pipe_state.tmp_accum_[10]), static_cast<float>(pipe_state.tmp_accum_[11]),
980+
static_cast<float>(pipe_state.tmp_accum_[12]), static_cast<float>(pipe_state.tmp_accum_[13]),
981+
static_cast<float>(pipe_state.tmp_accum_[14]), static_cast<float>(pipe_state.tmp_accum_[15]));
982+
} else {
983+
CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
984+
static_cast<float>(accum[0]), static_cast<float>(accum[1]),
985+
static_cast<float>(accum[2]), static_cast<float>(accum[3]),
986+
static_cast<float>(accum[4]), static_cast<float>(accum[5]),
987+
static_cast<float>(accum[6]), static_cast<float>(accum[7]),
988+
static_cast<float>(accum[8]), static_cast<float>(accum[9]),
989+
static_cast<float>(accum[10]), static_cast<float>(accum[11]),
990+
static_cast<float>(accum[12]), static_cast<float>(accum[13]),
991+
static_cast<float>(accum[14]), static_cast<float>(accum[15]));
886992
}
887993

888994
int stage = Base::kStages - 1;

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ class Wint2ParamsAccessor {
187187
smem_write_stage_idx_(0),
188188
smem_read_stage_idx_(0)
189189
{
190-
CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, kSmemRows=%d, kSmemColumns=%d, kLocalScaleRows=%d, kStagesPerLocalScaleLoad=%d",
191-
Shape::kM, Shape::kN, Shape::kK, kSmemRows, kSmemColumns, kLocalScaleRows, kStagesPerLocalScaleLoad);
190+
//CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, kSmemRows=%d, kSmemColumns=%d, kLocalScaleRows=%d, kStagesPerLocalScaleLoad=%d",
191+
// Shape::kM, Shape::kN, Shape::kK, kSmemRows, kSmemColumns, kLocalScaleRows, kStagesPerLocalScaleLoad);
192192
//CUTLASS_TRACE_DEVICE(" IteratorSuperScale::Shape: {%d, %d}, kSuperScaleSmemOffset=%d, smem_ptr=%p",
193193
// IteratorSuperScale::Shape::kRow, IteratorSuperScale::Shape::kColumn, kSuperScaleSmemOffset, get_super_scale_smem_ptr());
194194
//CUTLASS_TRACE_DEVICE(" IteratorLocalScale::Shape: {%d, %d}, kLocalScaleSmemOffset=%d, smem_ptr=%p",

custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ class MmaTensorOpWin2xDequantizer<
287287
static_cast<ElementCompute>(shifted_local_scale) * static_cast<ElementCompute>(super_scale_frag[i]);
288288
}
289289

290-
#if 1
290+
#if 0
291291
if (FragmentCompute::kElements == 4) {
292292
CUTLASS_TRACE_DEVICE(" [stage=%d] tb_offset_k=%d, local_scale_shift=%d, scale_frag[0:3]=[%f, %f, %f, %f], sizeof(FragmentCompute)=%d bytes",
293293
stage, tb_offset_k, local_scale_shift,
@@ -312,7 +312,7 @@ class MmaTensorOpWin2xDequantizer<
312312
}
313313

314314
if (FragmentOutput::kElements == 64) {
315-
#if 1
315+
#if 0
316316
CUTLASS_TRACE_DEVICE(" [stage=%d] output_frag[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
317317
stage,
318318
static_cast<float>(output_frag[0]), static_cast<float>(output_frag[1]),

0 commit comments

Comments
 (0)