Skip to content

Commit 487d643

Browse files
committed
Merge branch 'opt_wint2' of https://github.com/baoqiwen/FastDeploy into opt_wint2
Change-Id: I365bc20cf33e8a73273a8a2b02fc20d10db85ccf
2 parents 5ce3424 + 8c6fa14 commit 487d643

File tree

2 files changed

+733
-0
lines changed

2 files changed

+733
-0
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "cutlass/numeric_types.h"
4545

4646
#include "cutlass_extensions/arch/memory_copy_sm80.h"
47+
#include "cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
4748
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
4849

4950
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -128,6 +129,19 @@ class Wint2xMmaMultistage :
128129
/// Minimum architecture is Sm80 to support cp.async
129130
using ArchTag = arch::Sm80;
130131

132+
using LayoutScale = cutlass::layout::ColumnMajor;
133+
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
134+
using ElementB = typename WarpTransformedFragmentB::Element;
135+
using Dequantizer =
136+
warp::MmaTensorOpWin2xDequantizer<Operator,
137+
typename Base::WarpGemm,
138+
Operand::kB,
139+
ElementB,
140+
cutlass::layout::ColumnMajor,
141+
32,
142+
WeightOnlyQuantOp::UNDEFINED>;
143+
static_assert(sizeof(Dequantizer) > 0, "Dequantizer template instantiation failed");
144+
131145
/// Complex transform on A operand
132146
static ComplexTransform const kTransformA = Operator::kTransformA;
133147

@@ -195,6 +209,9 @@ class Wint2xMmaMultistage :
195209
/// Warp-level MMA operator
196210
Operator warp_mma_;
197211

212+
// Wint2 unzip operator
213+
Dequantizer warp_dequantizer_;
214+
198215
/// Iterator to write threadblock-scoped tile of A operand to shared memory
199216
SmemIteratorA smem_iterator_A_;
200217

@@ -665,6 +682,26 @@ class Wint2xMmaMultistage :
665682
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
666683
++this->warp_tile_iterator_B_;
667684

685+
typename Dequantizer::FragmentLocalScale warp_frag_local_scale;
686+
typename Dequantizer::FragmentCodeScale warp_frag_code_scale;
687+
typename Dequantizer::FragmentCodeZp warp_frag_code_zp;
688+
typename Dequantizer::FragmentSuperScale warp_frag_super_scale;
689+
typename Dequantizer::FragmentOutOperand warp_frag_out;
690+
691+
CUTLASS_TRACE_DEVICE(" warp_dequantizer_ - start load");
692+
warp_dequantizer_.load(warp_frag_local_scale,
693+
warp_frag_code_scale,
694+
warp_frag_code_zp,
695+
warp_frag_super_scale);
696+
697+
CUTLASS_TRACE_DEVICE("warp_dequantizer_ - start dequant");
698+
warp_dequantizer_.dequantize(warp_frag_out,
699+
pipe_state.warp_loaded_frag_B_[0],
700+
warp_frag_local_scale,
701+
warp_frag_code_scale,
702+
warp_frag_code_zp,
703+
warp_frag_super_scale);
704+
668705
// Transform, if necessary, the first warp-tile's shared memory fragments
669706
warp_mma_.transform(
670707
pipe_state.warp_transformed_frag_A_[0],

0 commit comments

Comments
 (0)