Skip to content

Commit 567c4bf

Browse files
committed
Merge branch 'opt_wint2' of https://github.com/baoqiwen/FastDeploy into opt_wint2
Change-Id: Id74c7b781f639b8dd063a5b5fa0a2b697cf503f0
2 parents 5715889 + 81c704b commit 567c4bf

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,16 @@ class Wint2xMmaMultistage :
740740
warp_k_compute_offset_B
741741
);
742742
#if 0
743+
CUTLASS_TRACE_DEVICE(" pipe_state.warp_frag_B_=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
744+
static_cast<float>(pipe_state.warp_frag_B_[0]), static_cast<float>(pipe_state.warp_frag_B_[1]),
745+
static_cast<float>(pipe_state.warp_frag_B_[2]), static_cast<float>(pipe_state.warp_frag_B_[3]),
746+
static_cast<float>(pipe_state.warp_frag_B_[4]), static_cast<float>(pipe_state.warp_frag_B_[5]),
747+
static_cast<float>(pipe_state.warp_frag_B_[6]), static_cast<float>(pipe_state.warp_frag_B_[7]),
748+
static_cast<float>(pipe_state.warp_frag_B_[8]), static_cast<float>(pipe_state.warp_frag_B_[9]),
749+
static_cast<float>(pipe_state.warp_frag_B_[10]), static_cast<float>(pipe_state.warp_frag_B_[11]),
750+
static_cast<float>(pipe_state.warp_frag_B_[12]), static_cast<float>(pipe_state.warp_frag_B_[13]),
751+
static_cast<float>(pipe_state.warp_frag_B_[14]), static_cast<float>(pipe_state.warp_frag_B_[15]));
752+
743753
if (FragmentC::kElements == 16) {
744754
CUTLASS_TRACE_DEVICE(" tile_C[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
745755
static_cast<float>(accum[0]), static_cast<float>(accum[1]),
@@ -751,7 +761,6 @@ class Wint2xMmaMultistage :
751761
static_cast<float>(accum[12]), static_cast<float>(accum[13]),
752762
static_cast<float>(accum[14]), static_cast<float>(accum[15]));
753763
}
754-
#endif
755764

756765
// CUTLASS_TRACE_DEVICE_TID(" now1 warp_loaded_frag_A_[0:7]=[%f, %f, %f, %f, %f, %f, %f, %f]",
757766
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][0]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][1]),
@@ -779,6 +788,7 @@ class Wint2xMmaMultistage :
779788
// static_cast<float>(accum[10]), static_cast<float>(accum[11]),
780789
// static_cast<float>(accum[12]), static_cast<float>(accum[13]),
781790
// static_cast<float>(accum[14]), static_cast<float>(accum[15]));
791+
#endif
782792
}
783793

784794
// Except for the last warp-tile, all warp-tiles issue their share of

custom_ops/gpu_ops/cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,16 @@ class MmaTensorOpWin2xDequantizer<
362362
// avoid numerous conversion instructions in GEMM main loop.
363363
arch::device_breakpoint();
364364
#endif
365+
366+
const int fixed_values[64] = {
367+
0, 1, 8, 9, 16, 17, 24, 25, 32, 33, 40, 41, 48, 49, 56, 57,
368+
2, 3, 10, 11, 18, 19, 26, 27, 34, 35, 42, 43, 50, 51, 58, 59,
369+
4, 5, 12, 13, 20, 21, 28, 29, 36, 37, 44, 45, 52, 53, 60, 61,
370+
6, 7, 14, 15, 22, 23, 30, 31, 38, 39, 46, 47, 54, 55, 62, 63
371+
};
372+
for (int i = 0; i < FragmentUnpack::kElements; ++i) {
373+
output_frag[i] = static_cast<typename FragmentUnpack::Element>(fixed_values[(i % 16) + (threadIdx.x % 4) * 16]);
374+
}
365375
}
366376

367377
/// Add an offset to pointer in units of elements.

0 commit comments

Comments
 (0)