|
44 | 44 | #include "cutlass/numeric_types.h"
|
45 | 45 |
|
46 | 46 | #include "cutlass_extensions/arch/memory_copy_sm80.h"
|
| 47 | +#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h" |
47 | 48 | #include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
|
48 | 49 |
|
49 | 50 | /////////////////////////////////////////////////////////////////////////////////////////////////
|
@@ -128,6 +129,19 @@ class Wint2xMmaMultistage :
|
128 | 129 | /// Minimum architecture is Sm80 to support cp.async
|
129 | 130 | using ArchTag = arch::Sm80;
|
130 | 131 |
|
| 132 | + using LayoutScale = cutlass::layout::ColumnMajor; |
| 133 | + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; |
| 134 | + using ElementB = typename WarpTransformedFragmentB::Element; |
| 135 | + using Dequantizer = |
| 136 | + warp::MmaTensorOpWin2xDequantizer<Operator, |
| 137 | + typename Base::WarpGemm, |
| 138 | + Operand::kB, |
| 139 | + ElementB, |
| 140 | + cutlass::layout::ColumnMajor, |
| 141 | + 32, |
| 142 | + WeightOnlyQuantOp::UNDEFINED>; |
| 143 | + static_assert(sizeof(Dequantizer) > 0, "Dequantizer template instantiation failed"); |
| 144 | + |
131 | 145 | /// Complex transform on A operand
|
132 | 146 | static ComplexTransform const kTransformA = Operator::kTransformA;
|
133 | 147 |
|
@@ -195,6 +209,9 @@ class Wint2xMmaMultistage :
|
195 | 209 | /// Warp-level MMA operator
|
196 | 210 | Operator warp_mma_;
|
197 | 211 |
|
| 212 | + // Wint2 unzip operator |
| 213 | + Dequantizer warp_dequantizer_; |
| 214 | + |
198 | 215 | /// Iterator to write threadblock-scoped tile of A operand to shared memory
|
199 | 216 | SmemIteratorA smem_iterator_A_;
|
200 | 217 |
|
@@ -665,6 +682,26 @@ class Wint2xMmaMultistage :
|
665 | 682 | this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
|
666 | 683 | ++this->warp_tile_iterator_B_;
|
667 | 684 |
|
| 685 | + typename Dequantizer::FragmentLocalScale warp_frag_local_scale; |
| 686 | + typename Dequantizer::FragmentCodeScale warp_frag_code_scale; |
| 687 | + typename Dequantizer::FragmentCodeZp warp_frag_code_zp; |
| 688 | + typename Dequantizer::FragmentSuperScale warp_frag_super_scale; |
| 689 | + typename Dequantizer::FragmentOutOperand warp_frag_out; |
| 690 | + |
| 691 | + CUTLASS_TRACE_DEVICE(" warp_dequantizer_ - start load"); |
| 692 | + warp_dequantizer_.load(warp_frag_local_scale, |
| 693 | + warp_frag_code_scale, |
| 694 | + warp_frag_code_zp, |
| 695 | + warp_frag_super_scale); |
| 696 | + |
| 697 | + CUTLASS_TRACE_DEVICE("warp_dequantizer_ - start dequant"); |
| 698 | + warp_dequantizer_.dequantize(warp_frag_out, |
| 699 | + pipe_state.warp_loaded_frag_B_[0], |
| 700 | + warp_frag_local_scale, |
| 701 | + warp_frag_code_scale, |
| 702 | + warp_frag_code_zp, |
| 703 | + warp_frag_super_scale); |
| 704 | + |
668 | 705 | // Transform, if necessary, the first warp-tile's shared memory fragments
|
669 | 706 | warp_mma_.transform(
|
670 | 707 | pipe_state.warp_transformed_frag_A_[0],
|
|
0 commit comments