Skip to content

Commit 5ce3424

Browse files
committed
Remove TileDequanterB related codes.
Change-Id: Id8e65703b72a8984d367f584ff41b7726017fbb8
1 parent da648e8 commit 5ce3424

File tree

3 files changed

+6
-179
lines changed

3 files changed

+6
-179
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545

4646
#include "cutlass_extensions/arch/memory_copy_sm80.h"
4747
#include "cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
48-
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
4948

5049
/////////////////////////////////////////////////////////////////////////////////////////////////
5150

@@ -272,17 +271,12 @@ class Wint2xMmaMultistage :
272271
}
273272

274273
/// Advance global memory read-iterators and shared memory write-iterators to the stage
275-
template <typename TileDequanterB>
276274
CUTLASS_DEVICE
277-
void advance_smem_write_stage(
278-
IteratorA &iterator_A,
279-
IteratorB &iterator_B,
280-
TileDequanterB &tile_dequanter_B)
275+
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B)
281276
{
282277
// Advance global iterators
283278
iterator_A.add_tile_offset({0, 1});
284279
iterator_B.add_tile_offset({1, 0});
285-
//tile_dequanter_B.AddTileOffset({1, 0});
286280

287281
// Advance shared iterators
288282
smem_iterator_A_.add_tile_offset({0, 1});
@@ -455,12 +449,10 @@ class Wint2xMmaMultistage :
455449

456450
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
457451
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
458-
template <typename TileDequanterB>
459452
CUTLASS_DEVICE
460453
void prologue(
461454
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
462455
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
463-
TileDequanterB &tile_dequanter_B,
464456
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
465457
{
466458
// Issue several complete stages
@@ -478,11 +470,9 @@ class Wint2xMmaMultistage :
478470
copy_tiles_and_advance_per_stage_B<true, true>(iterator_B);
479471

480472
// TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
481-
//tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
482-
// column_wise_smem_ptr_B_, stage);
483473

484474
// Move to the next write stage
485-
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
475+
advance_smem_write_stage(iterator_A, iterator_B);
486476

487477
// Defines the boundary of a stage of cp.async.
488478
cutlass::arch::cp_async_fence();
@@ -544,14 +534,12 @@ class Wint2xMmaMultistage :
544534
}
545535

546536
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
547-
template <typename TileDequanterB>
548537
CUTLASS_DEVICE
549538
void mac_loop_iter(
550539
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
551540
FragmentC &accum, ///< [in|out] destination accumulator tile
552541
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
553542
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
554-
TileDequanterB &tile_dequanter_B, ///< [in|out] tile dequantizer for B operand
555543
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
556544
int stage)
557545
{
@@ -630,7 +618,7 @@ class Wint2xMmaMultistage :
630618
gmem_wait();
631619

632620
// Move to the next global fetch stage
633-
advance_smem_write_stage(iterator_A, iterator_B, tile_dequanter_B);
621+
advance_smem_write_stage(iterator_A, iterator_B);
634622
advance_smem_read_stage();
635623

636624
// Disable global fetching when done with global fetch iterations
@@ -654,14 +642,12 @@ class Wint2xMmaMultistage :
654642

655643
/// Perform the specified number of threadblock mainloop iterations of matrix
656644
/// multiply-accumulate. Assumes prologue has been initiated.
657-
template <typename TileDequanterB>
658645
CUTLASS_DEVICE
659646
void gemm_iters(
660647
int gemm_k_iterations, ///< number of threadblock mainloop iterations
661648
FragmentC &accum, ///< [in|out] accumulator tile
662649
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
663-
IteratorB &iterator_B,
664-
TileDequanterB &tile_dequanter_B) ///< [in|out] iterator over B operand in global memory
650+
IteratorB &iterator_B)
665651
{
666652
PipeState pipe_state;
667653

@@ -701,7 +687,6 @@ class Wint2xMmaMultistage :
701687
accum,
702688
iterator_A,
703689
iterator_B,
704-
tile_dequanter_B,
705690
gemm_k_iterations,
706691
stage);
707692
stage += 1;
@@ -755,7 +740,6 @@ class Wint2xMmaMultistage :
755740
}
756741

757742
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
758-
template <typename TileDequanterB>
759743
CUTLASS_DEVICE
760744
void operator()(
761745
///< problem size of GEMM
@@ -766,13 +750,11 @@ class Wint2xMmaMultistage :
766750
IteratorA iterator_A,
767751
///< iterator over B operand in global memory
768752
IteratorB iterator_B,
769-
///< pre-load and dequantize B to shared memory
770-
TileDequanterB tile_dequanter_B,
771753
///< initial value of accumulator
772754
FragmentC const &src_accum) {
773755

774756
// Prologue (start fetching iterations of global fragments into shared memory)
775-
prologue(iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
757+
prologue(iterator_A, iterator_B, gemm_k_iterations);
776758

777759
// Wait until we have at least one completed global fetch stage
778760
gmem_wait();
@@ -781,7 +763,7 @@ class Wint2xMmaMultistage :
781763
accum = src_accum;
782764

783765
// Perform the MAC-iterations
784-
//gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B);
766+
//gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B);
785767
}
786768
};
787769

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h

Lines changed: 0 additions & 133 deletions
This file was deleted.

custom_ops/gpu_ops/cutlass_kernels/moe_gemm/fused_moe_cutlass_kernel.h

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
#include "cutlass/trace.h"
4444

4545
#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h"
46-
#include "cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
4746
#include "cutlass_extensions/tile_interleaved_layout.h"
4847

4948
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -844,9 +843,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
844843
kInterleave >= 1,
845844
"B must be row major/col major OR col major interleaved.");
846845

847-
// LayoutB should be RowMajor
848-
using TileDequanterB = cutlass::gemm::threadblock::TileDequanter<ElementA, ElementScale, ThreadblockShape::kK, ThreadblockShape::kN, kStages, kThreadCount, QuantMethod>;
849-
850846
//
851847
// Problem visitor.
852848
//
@@ -916,30 +912,13 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
916912
platform::is_same<layout::RowMajor, LayoutB>::value
917913
? gemm_n
918914
: gemm_k * kInterleave;
919-
//typename LayoutB::LongIndex ldm_B_shared = TileDequanterB::kColumns;
920915

921916
// the begin threadblock_offset of B, which holds the same column id with C
922917
cutlass::MatrixCoord tb_offset_B{0,
923918
threadblock_offset.n() / kInterleave};
924919

925920
cutlass::MatrixCoord extent_B{problem_size.k() * kInterleave, problem_size.n() / kInterleave};
926-
//cutlass::MatrixCoord extent_B_shared{TileDequanterB::kRows, TileDequanterB::kColumns};
927921

928-
/*MmaElementB* smem_unzip_B_ptr = nullptr;
929-
if constexpr (QuantMethod == WintQuantMethod::kWeightOnlyInt2) {
930-
smem_unzip_B_ptr = shared_storage.main_loop.operand_unzip_B_ptr();
931-
}
932-
QuantArguments quant_args = get_quant_args(params, problem_idx, gemm_k, gemm_n);
933-
TileDequanterB tile_dequanter_B(smem_unzip_B_ptr,
934-
byte_ptr_B,
935-
ldm_B,
936-
extent_B,
937-
tb_offset_B,
938-
weight_scale_ptr,
939-
tb_offset_scale,
940-
quant_args);
941-
MmaElementB* ptr_B = tile_dequanter_B.GetOutPtr();*/
942-
TileDequanterB tile_dequanter_B;
943922
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
944923

945924
// Compute position within threadblock
@@ -989,7 +968,6 @@ struct Wint2xMoeFCGemm : public MoeFCGemm<Mma_, Epilogue_, ThreadblockSwizzle_,
989968
accumulators,
990969
iterator_A,
991970
iterator_B,
992-
tile_dequanter_B,
993971
accumulators);
994972

995973
//

0 commit comments

Comments
 (0)