45
45
46
46
#include " cutlass_extensions/arch/memory_copy_sm80.h"
47
47
#include " cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
48
- #include " cutlass_extensions/gemm/threadblock/wint2x_tile_dequanter.h"
49
48
50
49
// ///////////////////////////////////////////////////////////////////////////////////////////////
51
50
@@ -272,17 +271,12 @@ class Wint2xMmaMultistage :
272
271
}
273
272
274
273
// / Advance global memory read-iterators and shared memory write-iterators to the stage
275
- template <typename TileDequanterB>
276
274
CUTLASS_DEVICE
277
- void advance_smem_write_stage (
278
- IteratorA &iterator_A,
279
- IteratorB &iterator_B,
280
- TileDequanterB &tile_dequanter_B)
275
+ void advance_smem_write_stage (IteratorA &iterator_A, IteratorB &iterator_B)
281
276
{
282
277
// Advance global iterators
283
278
iterator_A.add_tile_offset ({0 , 1 });
284
279
iterator_B.add_tile_offset ({1 , 0 });
285
- // tile_dequanter_B.AddTileOffset({1, 0});
286
280
287
281
// Advance shared iterators
288
282
smem_iterator_A_.add_tile_offset ({0 , 1 });
@@ -455,12 +449,10 @@ class Wint2xMmaMultistage :
455
449
456
450
// / GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
457
451
// / the global fragments needed by the first kStages-1 threadblock mainloop iterations
458
- template <typename TileDequanterB>
459
452
CUTLASS_DEVICE
460
453
void prologue (
461
454
IteratorA &iterator_A, // /< [in|out] iterator over A operand in global memory
462
455
IteratorB &iterator_B, // /< [in|out] iterator over B operand in global memory
463
- TileDequanterB &tile_dequanter_B,
464
456
int &gemm_k_iterations) // /< [in|out] number of threadblock mainloop iterations remaining
465
457
{
466
458
// Issue several complete stages
@@ -478,11 +470,9 @@ class Wint2xMmaMultistage :
478
470
copy_tiles_and_advance_per_stage_B<true , true >(iterator_B);
479
471
480
472
// TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
481
- // tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
482
- // column_wise_smem_ptr_B_, stage);
483
473
484
474
// Move to the next write stage
485
- advance_smem_write_stage (iterator_A, iterator_B, tile_dequanter_B );
475
+ advance_smem_write_stage (iterator_A, iterator_B);
486
476
487
477
// Defines the boundary of a stage of cp.async.
488
478
cutlass::arch::cp_async_fence ();
@@ -544,14 +534,12 @@ class Wint2xMmaMultistage :
544
534
}
545
535
546
536
// / Perform a threadblock mainloop iteration of matrix multiply-accumulate
547
- template <typename TileDequanterB>
548
537
CUTLASS_DEVICE
549
538
void mac_loop_iter (
550
539
PipeState &pipe_state, // /< [in|out] loop-carried pipeline state
551
540
FragmentC &accum, // /< [in|out] destination accumulator tile
552
541
IteratorA &iterator_A, // /< [in|out] iterator over A operand in global memory
553
542
IteratorB &iterator_B, // /< [in|out] iterator over B operand in global memory
554
- TileDequanterB &tile_dequanter_B, // /< [in|out] tile dequantizer for B operand
555
543
int &gemm_k_iterations, // /< [in|out] number of threadblock mainloop iterations remaining
556
544
int stage)
557
545
{
@@ -630,7 +618,7 @@ class Wint2xMmaMultistage :
630
618
gmem_wait ();
631
619
632
620
// Move to the next global fetch stage
633
- advance_smem_write_stage (iterator_A, iterator_B, tile_dequanter_B );
621
+ advance_smem_write_stage (iterator_A, iterator_B);
634
622
advance_smem_read_stage ();
635
623
636
624
// Disable global fetching when done with global fetch iterations
@@ -654,14 +642,12 @@ class Wint2xMmaMultistage :
654
642
655
643
// / Perform the specified number of threadblock mainloop iterations of matrix
656
644
// / multiply-accumulate. Assumes prologue has been initiated.
657
- template <typename TileDequanterB>
658
645
CUTLASS_DEVICE
659
646
void gemm_iters (
660
647
int gemm_k_iterations, // /< number of threadblock mainloop iterations
661
648
FragmentC &accum, // /< [in|out] accumulator tile
662
649
IteratorA &iterator_A, // /< [in|out] iterator over A operand in global memory
663
- IteratorB &iterator_B,
664
- TileDequanterB &tile_dequanter_B) // /< [in|out] iterator over B operand in global memory
650
+ IteratorB &iterator_B)
665
651
{
666
652
PipeState pipe_state;
667
653
@@ -701,7 +687,6 @@ class Wint2xMmaMultistage :
701
687
accum,
702
688
iterator_A,
703
689
iterator_B,
704
- tile_dequanter_B,
705
690
gemm_k_iterations,
706
691
stage);
707
692
stage += 1;
@@ -755,7 +740,6 @@ class Wint2xMmaMultistage :
755
740
}
756
741
757
742
// / Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
758
- template <typename TileDequanterB>
759
743
CUTLASS_DEVICE
760
744
void operator ()(
761
745
// /< problem size of GEMM
@@ -766,13 +750,11 @@ class Wint2xMmaMultistage :
766
750
IteratorA iterator_A,
767
751
// /< iterator over B operand in global memory
768
752
IteratorB iterator_B,
769
- // /< pre-load and dequantize B to shared memory
770
- TileDequanterB tile_dequanter_B,
771
753
// /< initial value of accumulator
772
754
FragmentC const &src_accum) {
773
755
774
756
// Prologue (start fetching iterations of global fragments into shared memory)
775
- prologue (iterator_A, iterator_B, tile_dequanter_B, gemm_k_iterations);
757
+ prologue (iterator_A, iterator_B, gemm_k_iterations);
776
758
777
759
// Wait until we have at least one completed global fetch stage
778
760
gmem_wait ();
@@ -781,7 +763,7 @@ class Wint2xMmaMultistage :
781
763
accum = src_accum;
782
764
783
765
// Perform the MAC-iterations
784
- // gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, tile_dequanter_B );
766
+ // gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B);
785
767
}
786
768
};
787
769
0 commit comments