Skip to content

Commit a41862c

Browse files
committed
Add wint2x Dequantizer
1 parent da648e8 commit a41862c

File tree

2 files changed

+648
-9
lines changed

2 files changed

+648
-9
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
4848
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
4949

50+
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
51+
5052
/////////////////////////////////////////////////////////////////////////////////////////////////
5153

5254
namespace cutlass {
@@ -129,6 +131,17 @@ class Wint2xMmaMultistage :
129131
/// Minimum architecture is Sm80 to support cp.async
130132
using ArchTag = arch::Sm80;
131133

134+
using LayoutScale = cutlass::layout::ColumnMajor;
135+
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
136+
using ElementB = typename WarpTransformedFragmentB::Element;
137+
using Dequantizer = warp::MmaTensorOpWin2xDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementB,
138+
cutlass::layout::ColumnMajor, 32, WeightOnlyQuantOp::UNDEFINED>;
139+
140+
static_assert(
141+
sizeof(Dequantizer) > 0,
142+
"Dequantizer template instantiation failed"
143+
);
144+
132145
/// Complex transform on A operand
133146
static ComplexTransform const kTransformA = Operator::kTransformA;
134147

@@ -196,6 +209,9 @@ class Wint2xMmaMultistage :
196209
/// Warp-level MMA operator
197210
Operator warp_mma_;
198211

212+
// Wint2 unzip operator
213+
Dequantizer warp_dequantizer_;
214+
199215
/// Iterator to write threadblock-scoped tile of A operand to shared memory
200216
SmemIteratorA smem_iterator_A_;
201217

@@ -679,12 +695,41 @@ class Wint2xMmaMultistage :
679695
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
680696
++this->warp_tile_iterator_B_;
681697

682-
// Transform, if necessary, the first warp-tile's shared memory fragments
683-
warp_mma_.transform(
684-
pipe_state.warp_transformed_frag_A_[0],
685-
pipe_state.warp_transformed_frag_B_[0],
686-
pipe_state.warp_loaded_frag_A_[0],
687-
pipe_state.warp_loaded_frag_B_[0]);
698+
// // Transform, if necessary, the first warp-tile's shared memory fragments
699+
// warp_mma_.transform(
700+
// pipe_state.warp_transformed_frag_A_[0],
701+
// pipe_state.warp_transformed_frag_B_[0],
702+
// pipe_state.warp_loaded_frag_A_[0],
703+
// pipe_state.warp_loaded_frag_B_[0]);
704+
705+
__syncthreads(); // 确保所有线程执行到此处
706+
if (threadIdx.x == 0) { // 仅让一个线程打印,避免重复输出
707+
// printf("DEBUG: warp_loaded_frag_A_[0] values:\n");
708+
for (int i = 0; i < pipe_state.warp_loaded_frag_A_[0].size(); ++i) {
709+
// 读取 fragment 中的元素
710+
auto val = pipe_state.warp_loaded_frag_A_[0][i];
711+
712+
// 以 16-bit 形式 reinterpret 为 uint16_t 查看原始位模式
713+
uint16_t bits = reinterpret_cast<const uint16_t*>(&val)[0];
714+
715+
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_A_[%d] = 0x%04x", i, bits);
716+
}
717+
}
718+
__syncthreads();
719+
720+
typename Dequantizer::FragmentLocalScale warp_frag_local_scale;
721+
typename Dequantizer::FragmentCodeScale warp_frag_code_scale;
722+
typename Dequantizer::FragmentCodeZp warp_frag_code_zp;
723+
typename Dequantizer::FragmentSuperScale warp_frag_super_scale;
724+
725+
typename Dequantizer::FragmentOutOperand warp_frag_out;
726+
727+
CUTLASS_TRACE_DEVICE(" warp_dequantizer_ - start load");
728+
warp_dequantizer_.load(warp_frag_local_scale, warp_frag_code_scale, warp_frag_code_zp, warp_frag_super_scale);
729+
__syncthreads();
730+
731+
CUTLASS_TRACE_DEVICE("warp_dequantizer_ - start dequant");
732+
warp_dequantizer_.dequantize(warp_frag_out, pipe_state.warp_loaded_frag_B_[0], warp_frag_local_scale, warp_frag_code_scale, warp_frag_code_zp, warp_frag_super_scale);
688733

689734
if (Detail::kStagedAccumulation) {
690735
pipe_state.tmp_accum_.clear();
@@ -770,18 +815,17 @@ class Wint2xMmaMultistage :
770815
TileDequanterB tile_dequanter_B,
771816
///< initial value of accumulator
772817
FragmentC const &src_accum) {
773-
774818
// Prologue (start fetching iterations of global fragments into shared memory)
775819
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
776820

777821
// Wait until we have at least one completed global fetch stage
778822
gmem_wait();
779-
823+
780824
// Initialize destination accumulators with source accumulators
781825
accum = src_accum;
782826

783827
// Perform the MAC-iterations
784-
//gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
828+
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
785829
}
786830
};
787831

0 commit comments

Comments
 (0)