44
44
#include " cutlass/numeric_types.h"
45
45
46
46
#include " cutlass_extensions/arch/memory_copy_sm80.h"
47
- #include " cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
48
47
#include " cutlass_extensions/gemm/threadblock/wint2x_mma_base.h"
49
48
#include " cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h"
49
+ #include " cutlass_extensions/gemm/warp/mma_tensorop_wint2x_dequantizer.h"
50
50
51
51
// ///////////////////////////////////////////////////////////////////////////////////////////////
52
52
@@ -292,32 +292,32 @@ class Wint2xMmaMultistage :
292
292
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM ;
293
293
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM ;
294
294
295
- CUTLASS_TRACE_DEVICE (" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d" ,
296
- Shape::kM , Shape::kN , Shape::kK , IteratorB::Shape::kRow , IteratorB::Shape::kColumn , kInterleave );
297
- CUTLASS_TRACE_DEVICE (" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d" ,
298
- Policy::kPartitionsK , Base::kWarpGemmIterations ,
299
- Base::WarpCount::kM , Base::WarpCount::kN , warp_idx_m, warp_idx_n, warp_idx_k);
295
+ // CUTLASS_TRACE_DEVICE(" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d",
296
+ // Shape::kM, Shape::kN, Shape::kK, IteratorB::Shape::kRow, IteratorB::Shape::kColumn, kInterleave);
297
+ // CUTLASS_TRACE_DEVICE(" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d",
298
+ // Policy::kPartitionsK, Base::kWarpGemmIterations,
299
+ // Base::WarpCount::kM, Base::WarpCount::kN, warp_idx_m, warp_idx_n, warp_idx_k);
300
300
301
301
// Add per-warp offsets in units of warp-level tiles
302
302
this ->warp_tile_iterator_A_ .add_tile_offset (
303
303
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
304
304
this ->warp_tile_iterator_B_ .add_tile_offset (
305
305
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
306
306
307
- CUTLASS_TRACE_DEVICE (" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}" ,
308
- Policy::SmemPaddingA::kRow , Policy::SmemPaddingA::kColumn , Policy::SmemPaddingB::kRow , Policy::SmemPaddingB::kColumn );
309
- CUTLASS_TRACE_DEVICE (" operand_A_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageA=%d, kAccessesPerVectorA=%d" ,
310
- shared_storage.operand_A .data (),
311
- static_cast <int >(Base::SharedStorage::ShapeA::kRow ), static_cast <int >(Base::SharedStorage::ShapeA::kColumn ),
312
- static_cast <int >(sizeof (shared_storage.operand_A )),
313
- static_cast <int >(IteratorA::ThreadMap::kElementsPerAccess ), static_cast <int >(sizeof (typename IteratorA::AccessType)),
314
- static_cast <int >(Detail::AsyncCopyIterationsPerStageA), static_cast <int >(IteratorA::kAccessesPerVector ));
315
- CUTLASS_TRACE_DEVICE (" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVectorA=%d" ,
316
- shared_storage.operand_B .data (),
317
- static_cast <int >(Base::SharedStorage::ShapeB::kRow ), static_cast <int >(Base::SharedStorage::ShapeB::kColumn ),
318
- static_cast <int >(sizeof (shared_storage.operand_B )),
319
- static_cast <int >(IteratorB::ThreadMap::kElementsPerAccess ), static_cast <int >(sizeof (typename IteratorB::AccessType)),
320
- static_cast <int >(Detail::AsyncCopyIterationsPerStageB), static_cast <int >(IteratorB::kAccessesPerVector ));
307
+ // CUTLASS_TRACE_DEVICE(" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}",
308
+ // Policy::SmemPaddingA::kRow, Policy::SmemPaddingA::kColumn, Policy::SmemPaddingB::kRow, Policy::SmemPaddingB::kColumn);
309
+ // CUTLASS_TRACE_DEVICE(" operand_A_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageA=%d, kAccessesPerVectorA=%d",
310
+ // shared_storage.operand_A.data(),
311
+ // static_cast<int>(Base::SharedStorage::ShapeA::kRow), static_cast<int>(Base::SharedStorage::ShapeA::kColumn),
312
+ // static_cast<int>(sizeof(shared_storage.operand_A)),
313
+ // static_cast<int>(IteratorA::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorA::AccessType)),
314
+ // static_cast<int>(Detail::AsyncCopyIterationsPerStageA), static_cast<int>(IteratorA::kAccessesPerVector));
315
+ // CUTLASS_TRACE_DEVICE(" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVectorA=%d",
316
+ // shared_storage.operand_B.data(),
317
+ // static_cast<int>(Base::SharedStorage::ShapeB::kRow), static_cast<int>(Base::SharedStorage::ShapeB::kColumn),
318
+ // static_cast<int>(sizeof(shared_storage.operand_B)),
319
+ // static_cast<int>(IteratorB::ThreadMap::kElementsPerAccess), static_cast<int>(sizeof(typename IteratorB::AccessType)),
320
+ // static_cast<int>(Detail::AsyncCopyIterationsPerStageB), static_cast<int>(IteratorB::kAccessesPerVector));
321
321
322
322
smem_ptr_A_ = reinterpret_cast <ElementA*>(shared_storage.operand_A .data ());
323
323
smem_ptr_B_ = reinterpret_cast <uint8_t *>(shared_storage.operand_B .data ());
@@ -678,9 +678,11 @@ class Wint2xMmaMultistage :
678
678
int &gemm_k_iterations, // /< [in|out] number of threadblock mainloop iterations remaining
679
679
int stage)
680
680
{
681
+
681
682
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
682
683
CUTLASS_PRAGMA_UNROLL
683
684
for (int warp_mma_k = 0 ; warp_mma_k < Base::kWarpGemmIterations ; ++warp_mma_k) {
685
+
684
686
// Load the next warp-tile's A fragment from shared memory
685
687
this ->warp_tile_iterator_A_ .set_kgroup_index ((warp_mma_k + 1 ) % Base::kWarpGemmIterations );
686
688
this ->warp_tile_iterator_A_ .load (pipe_state.warp_frag_A_ [(warp_mma_k + 1 ) % 2 ]);
@@ -699,27 +701,41 @@ class Wint2xMmaMultistage :
699
701
}
700
702
701
703
// Execute the current warp-tile of MMA operations
704
+
705
+ // CUTLASS_TRACE_DEVICE("ElementA %d", PipeState::WarpTransformedFragmentA::kElements);
706
+ // CUTLASS_TRACE_DEVICE("ElementB %d", PipeState::WarpTransformedFragmentB::kElements);
707
+ // CUTLASS_TRACE_DEVICE("kStagedAccumulation %d", Detail::kStagedAccumulation);
708
+
709
+ // uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_[warp_mma_k % 2].data());
710
+ // CUTLASS_TRACE_DEVICE(" reg_uint8_ptr=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
711
+ // static_cast<int>(reg_uint8_ptr[0]), static_cast<int>(reg_uint8_ptr[1]),
712
+ // static_cast<int>(reg_uint8_ptr[2]), static_cast<int>(reg_uint8_ptr[3]),
713
+ // static_cast<int>(reg_uint8_ptr[4]), static_cast<int>(reg_uint8_ptr[5]),
714
+ // static_cast<int>(reg_uint8_ptr[6]), static_cast<int>(reg_uint8_ptr[7]),
715
+ // static_cast<int>(reg_uint8_ptr[8]), static_cast<int>(reg_uint8_ptr[9]),
716
+ // static_cast<int>(reg_uint8_ptr[10]), static_cast<int>(reg_uint8_ptr[11]),
717
+ // static_cast<int>(reg_uint8_ptr[12]), static_cast<int>(reg_uint8_ptr[13]),
718
+ // static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
719
+ // sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
720
+
702
721
if (Detail::kStagedAccumulation ) {
703
722
// CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
704
723
warp_mma_ (
705
724
pipe_state.tmp_accum_ ,
706
725
pipe_state.warp_frag_A_ [warp_mma_k % 2 ],
707
726
pipe_state.warp_frag_B_ ,
727
+ // unpacked_frag_B,
708
728
pipe_state.tmp_accum_ ,
709
729
warp_k_compute_offset_B
710
730
);
711
731
712
- if (warp_mma_k == 0 ) {
713
- plus<FragmentC> plus_accum;
714
- accum = plus_accum (accum, pipe_state.tmp_accum_ );
715
- pipe_state.tmp_accum_ .clear ();
716
- }
717
732
} else {
718
733
// CUTLASS_TRACE_DEVICE(" [MMa][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
719
734
warp_mma_ (
720
735
accum,
721
736
pipe_state.warp_frag_A_ [warp_mma_k % 2 ],
722
737
pipe_state.warp_frag_B_ ,
738
+ // unpacked_frag_B,
723
739
accum,
724
740
warp_k_compute_offset_B
725
741
);
@@ -736,6 +752,33 @@ class Wint2xMmaMultistage :
736
752
static_cast<float>(accum[14]), static_cast<float>(accum[15]));
737
753
}
738
754
#endif
755
+
756
+ // CUTLASS_TRACE_DEVICE_TID(" now1 warp_loaded_frag_A_[0:7]=[%f, %f, %f, %f, %f, %f, %f, %f]",
757
+ // static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][0]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][1]),
758
+ // static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][2]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][3]),
759
+ // static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][4]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][5]),
760
+ // static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][6]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][7]));
761
+
762
+ // CUTLASS_TRACE_DEVICE_TID(" now1 unpacked_frag_B[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
763
+ // static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
764
+ // static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
765
+ // static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
766
+ // static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
767
+ // static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
768
+ // static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
769
+ // static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
770
+ // static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
771
+
772
+ // CUTLASS_TRACE_DEVICE_TID(" warp_k_compute_offset_B = %d, now1 tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
773
+ // warp_k_compute_offset_B,
774
+ // static_cast<float>(accum[0]), static_cast<float>(accum[1]),
775
+ // static_cast<float>(accum[2]), static_cast<float>(accum[3]),
776
+ // static_cast<float>(accum[4]), static_cast<float>(accum[5]),
777
+ // static_cast<float>(accum[6]), static_cast<float>(accum[7]),
778
+ // static_cast<float>(accum[8]), static_cast<float>(accum[9]),
779
+ // static_cast<float>(accum[10]), static_cast<float>(accum[11]),
780
+ // static_cast<float>(accum[12]), static_cast<float>(accum[13]),
781
+ // static_cast<float>(accum[14]), static_cast<float>(accum[15]));
739
782
}
740
783
741
784
// Except for the last warp-tile, all warp-tiles issue their share of
@@ -832,6 +875,7 @@ class Wint2xMmaMultistage :
832
875
this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ );
833
876
++this ->warp_tile_iterator_B_ ;
834
877
878
+ #if 0
835
879
if (PipeState::WarpLoadedFragmentA::kElements == 8) {
836
880
ElementA* warp_frag_A_ptr = reinterpret_cast<ElementA*>(pipe_state.warp_frag_A_[0].data());
837
881
CUTLASS_TRACE_DEVICE(" warp_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes",
@@ -841,6 +885,7 @@ class Wint2xMmaMultistage :
841
885
static_cast<float>(warp_frag_A_ptr[6]), static_cast<float>(warp_frag_A_ptr[7]),
842
886
sizeof_bits<typename PipeState::WarpLoadedFragmentA>::value / 8);
843
887
}
888
+ #endif
844
889
#if 0
845
890
if (PipeState::WarpLoadedFragmentB::kElements == 64) {
846
891
uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_.data());
@@ -881,8 +926,69 @@ class Wint2xMmaMultistage :
881
926
pipe_state.warp_frag_B_ ,
882
927
0 );
883
928
929
+ #if 0
930
+ if (TransformBAfterLDS::result_type::kElements == 64) {
931
+ CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);
932
+ CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
933
+ static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
934
+ static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
935
+ static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
936
+ static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
937
+ static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
938
+ static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
939
+ static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
940
+ static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
941
+ CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
942
+ static_cast<float>(unpacked_frag_B[16]), static_cast<float>(unpacked_frag_B[17]),
943
+ static_cast<float>(unpacked_frag_B[18]), static_cast<float>(unpacked_frag_B[19]),
944
+ static_cast<float>(unpacked_frag_B[20]), static_cast<float>(unpacked_frag_B[21]),
945
+ static_cast<float>(unpacked_frag_B[22]), static_cast<float>(unpacked_frag_B[23]),
946
+ static_cast<float>(unpacked_frag_B[24]), static_cast<float>(unpacked_frag_B[25]),
947
+ static_cast<float>(unpacked_frag_B[26]), static_cast<float>(unpacked_frag_B[27]),
948
+ static_cast<float>(unpacked_frag_B[28]), static_cast<float>(unpacked_frag_B[29]),
949
+ static_cast<float>(unpacked_frag_B[30]), static_cast<float>(unpacked_frag_B[31]));
950
+ CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
951
+ static_cast<float>(unpacked_frag_B[32]), static_cast<float>(unpacked_frag_B[33]),
952
+ static_cast<float>(unpacked_frag_B[34]), static_cast<float>(unpacked_frag_B[35]),
953
+ static_cast<float>(unpacked_frag_B[36]), static_cast<float>(unpacked_frag_B[37]),
954
+ static_cast<float>(unpacked_frag_B[38]), static_cast<float>(unpacked_frag_B[39]),
955
+ static_cast<float>(unpacked_frag_B[40]), static_cast<float>(unpacked_frag_B[41]),
956
+ static_cast<float>(unpacked_frag_B[42]), static_cast<float>(unpacked_frag_B[43]),
957
+ static_cast<float>(unpacked_frag_B[44]), static_cast<float>(unpacked_frag_B[45]),
958
+ static_cast<float>(unpacked_frag_B[46]), static_cast<float>(unpacked_frag_B[47]));
959
+ CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
960
+ static_cast<float>(unpacked_frag_B[48]), static_cast<float>(unpacked_frag_B[49]),
961
+ static_cast<float>(unpacked_frag_B[50]), static_cast<float>(unpacked_frag_B[51]),
962
+ static_cast<float>(unpacked_frag_B[52]), static_cast<float>(unpacked_frag_B[53]),
963
+ static_cast<float>(unpacked_frag_B[54]), static_cast<float>(unpacked_frag_B[55]),
964
+ static_cast<float>(unpacked_frag_B[56]), static_cast<float>(unpacked_frag_B[57]),
965
+ static_cast<float>(unpacked_frag_B[58]), static_cast<float>(unpacked_frag_B[59]),
966
+ static_cast<float>(unpacked_frag_B[60]), static_cast<float>(unpacked_frag_B[61]),
967
+ static_cast<float>(unpacked_frag_B[62]), static_cast<float>(unpacked_frag_B[63]));
968
+ }
969
+ #endif
970
+
884
971
if (Detail::kStagedAccumulation ) {
885
972
pipe_state.tmp_accum_ .clear ();
973
+ CUTLASS_TRACE_DEVICE (" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]" ,
974
+ static_cast <float >(pipe_state.tmp_accum_ [0 ]), static_cast <float >(pipe_state.tmp_accum_ [1 ]),
975
+ static_cast <float >(pipe_state.tmp_accum_ [2 ]), static_cast <float >(pipe_state.tmp_accum_ [3 ]),
976
+ static_cast <float >(pipe_state.tmp_accum_ [4 ]), static_cast <float >(pipe_state.tmp_accum_ [5 ]),
977
+ static_cast <float >(pipe_state.tmp_accum_ [6 ]), static_cast <float >(pipe_state.tmp_accum_ [7 ]),
978
+ static_cast <float >(pipe_state.tmp_accum_ [8 ]), static_cast <float >(pipe_state.tmp_accum_ [9 ]),
979
+ static_cast <float >(pipe_state.tmp_accum_ [10 ]), static_cast <float >(pipe_state.tmp_accum_ [11 ]),
980
+ static_cast <float >(pipe_state.tmp_accum_ [12 ]), static_cast <float >(pipe_state.tmp_accum_ [13 ]),
981
+ static_cast <float >(pipe_state.tmp_accum_ [14 ]), static_cast <float >(pipe_state.tmp_accum_ [15 ]));
982
+ } else {
983
+ CUTLASS_TRACE_DEVICE (" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]" ,
984
+ static_cast <float >(accum[0 ]), static_cast <float >(accum[1 ]),
985
+ static_cast <float >(accum[2 ]), static_cast <float >(accum[3 ]),
986
+ static_cast <float >(accum[4 ]), static_cast <float >(accum[5 ]),
987
+ static_cast <float >(accum[6 ]), static_cast <float >(accum[7 ]),
988
+ static_cast <float >(accum[8 ]), static_cast <float >(accum[9 ]),
989
+ static_cast <float >(accum[10 ]), static_cast <float >(accum[11 ]),
990
+ static_cast <float >(accum[12 ]), static_cast <float >(accum[13 ]),
991
+ static_cast <float >(accum[14 ]), static_cast <float >(accum[15 ]));
886
992
}
887
993
888
994
int stage = Base::kStages - 1 ;
0 commit comments