|
47 | 47 | #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
|
48 | 48 | #include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
|
49 | 49 |
|
| 50 | +#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h" |
| 51 | + |
50 | 52 | /////////////////////////////////////////////////////////////////////////////////////////////////
|
51 | 53 |
|
52 | 54 | namespace cutlass {
|
@@ -129,6 +131,17 @@ class Wint2xMmaMultistage :
|
129 | 131 | /// Minimum architecture is Sm80 to support cp.async
|
130 | 132 | using ArchTag = arch::Sm80;
|
131 | 133 |
|
| 134 | + using LayoutScale = cutlass::layout::ColumnMajor; |
| 135 | + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; |
| 136 | + using ElementB = typename WarpTransformedFragmentB::Element; |
| 137 | + using Dequantizer = warp::MmaTensorOpWin2xDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementB, |
| 138 | + cutlass::layout::ColumnMajor, 32, WeightOnlyQuantOp::UNDEFINED>; |
| 139 | + |
| 140 | + static_assert( |
| 141 | + sizeof(Dequantizer) > 0, |
| 142 | + "Dequantizer template instantiation failed" |
| 143 | + ); |
| 144 | + |
132 | 145 | /// Complex transform on A operand
|
133 | 146 | static ComplexTransform const kTransformA = Operator::kTransformA;
|
134 | 147 |
|
@@ -196,6 +209,9 @@ class Wint2xMmaMultistage :
|
196 | 209 | /// Warp-level MMA operator
|
197 | 210 | Operator warp_mma_;
|
198 | 211 |
|
| 212 | + // Wint2 unzip operator |
| 213 | + Dequantizer warp_dequantizer_; |
| 214 | + |
199 | 215 | /// Iterator to write threadblock-scoped tile of A operand to shared memory
|
200 | 216 | SmemIteratorA smem_iterator_A_;
|
201 | 217 |
|
@@ -679,12 +695,41 @@ class Wint2xMmaMultistage :
|
679 | 695 | this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
|
680 | 696 | ++this->warp_tile_iterator_B_;
|
681 | 697 |
|
682 |
| - // Transform, if necessary, the first warp-tile's shared memory fragments |
683 |
| - warp_mma_.transform( |
684 |
| - pipe_state.warp_transformed_frag_A_[0], |
685 |
| - pipe_state.warp_transformed_frag_B_[0], |
686 |
| - pipe_state.warp_loaded_frag_A_[0], |
687 |
| - pipe_state.warp_loaded_frag_B_[0]); |
| 698 | + // // Transform, if necessary, the first warp-tile's shared memory fragments |
| 699 | + // warp_mma_.transform( |
| 700 | + // pipe_state.warp_transformed_frag_A_[0], |
| 701 | + // pipe_state.warp_transformed_frag_B_[0], |
| 702 | + // pipe_state.warp_loaded_frag_A_[0], |
| 703 | + // pipe_state.warp_loaded_frag_B_[0]); |
| 704 | + |
| 705 | + __syncthreads(); // 确保所有线程执行到此处 |
| 706 | + if (threadIdx.x == 0) { // 仅让一个线程打印,避免重复输出 |
| 707 | + // printf("DEBUG: warp_loaded_frag_A_[0] values:\n"); |
| 708 | + for (int i = 0; i < pipe_state.warp_loaded_frag_A_[0].size(); ++i) { |
| 709 | + // 读取 fragment 中的元素 |
| 710 | + auto val = pipe_state.warp_loaded_frag_A_[0][i]; |
| 711 | + |
| 712 | + // 以 16-bit 形式 reinterpret 为 uint16_t 查看原始位模式 |
| 713 | + uint16_t bits = reinterpret_cast<const uint16_t*>(&val)[0]; |
| 714 | + |
| 715 | + CUTLASS_TRACE_DEVICE(" warp_loaded_frag_A_[%d] = 0x%04x", i, bits); |
| 716 | + } |
| 717 | + } |
| 718 | + __syncthreads(); |
| 719 | + |
| 720 | + typename Dequantizer::FragmentLocalScale warp_frag_local_scale; |
| 721 | + typename Dequantizer::FragmentCodeScale warp_frag_code_scale; |
| 722 | + typename Dequantizer::FragmentCodeZp warp_frag_code_zp; |
| 723 | + typename Dequantizer::FragmentSuperScale warp_frag_super_scale; |
| 724 | + |
| 725 | + typename Dequantizer::FragmentOutOperand warp_frag_out; |
| 726 | + |
| 727 | + CUTLASS_TRACE_DEVICE(" warp_dequantizer_ - start load"); |
| 728 | + warp_dequantizer_.load(warp_frag_local_scale, warp_frag_code_scale, warp_frag_code_zp, warp_frag_super_scale); |
| 729 | + __syncthreads(); |
| 730 | + |
| 731 | + CUTLASS_TRACE_DEVICE("warp_dequantizer_ - start dequant"); |
| 732 | + warp_dequantizer_.dequantize(warp_frag_out, pipe_state.warp_loaded_frag_B_[0], warp_frag_local_scale, warp_frag_code_scale, warp_frag_code_zp, warp_frag_super_scale); |
688 | 733 |
|
689 | 734 | if (Detail::kStagedAccumulation) {
|
690 | 735 | pipe_state.tmp_accum_.clear();
|
@@ -770,18 +815,17 @@ class Wint2xMmaMultistage :
|
770 | 815 | TileDequanterB tile_dequanter_B,
|
771 | 816 | ///< initial value of accumulator
|
772 | 817 | FragmentC const &src_accum) {
|
773 |
| - |
774 | 818 | // Prologue (start fetching iterations of global fragments into shared memory)
|
775 | 819 | prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
|
776 | 820 |
|
777 | 821 | // Wait until we have at least one completed global fetch stage
|
778 | 822 | gmem_wait();
|
779 |
| - |
| 823 | + |
780 | 824 | // Initialize destination accumulators with source accumulators
|
781 | 825 | accum = src_accum;
|
782 | 826 |
|
783 | 827 | // Perform the MAC-iterations
|
784 |
| - //gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); |
| 828 | + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B); |
785 | 829 | }
|
786 | 830 | };
|
787 | 831 |
|
|
0 commit comments