@@ -332,7 +332,7 @@ class Wint2xMmaMultistage :
332
332
if (smem_read_stage_idx_ == Base::kStages ) {
333
333
// Wrap back around to the 'start' of the circular buffer in shared memory
334
334
this ->warp_tile_iterator_A_ .add_tile_offset ({0 , -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations });
335
- this ->warp_tile_iterator_B_ .add_tile_offset ({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations , 0 });
335
+ this ->warp_tile_iterator_B_ .add_tile_offset ({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB , 0 });
336
336
smem_read_stage_idx_ = 0 ;
337
337
}
338
338
}
@@ -399,7 +399,7 @@ class Wint2xMmaMultistage :
399
399
}
400
400
401
401
CUTLASS_DEVICE
402
- void copy_tiles_and_advance_B (IteratorB &iterator_B, int group_start_B = 0 ) {
402
+ void copy_tiles_and_advance_B (IteratorB &iterator_B, int group_start_B = 0 , int stage = 0 ) {
403
403
iterator_B.set_iteration_index (group_start_B *
404
404
IteratorB::kAccessesPerVector );
405
405
this ->smem_iterator_B_ .set_iteration_index (group_start_B);
@@ -421,11 +421,10 @@ class Wint2xMmaMultistage :
421
421
auto gmem_ptr = iterator_B.get ();
422
422
423
423
#if 0
424
- if (group_start_B == 0 && j == 0 && v == 0) {
425
- CUTLASS_TRACE_DEVICE(" dst_ptr=%p, iterator_B.get()=%p, kAccessesPerGroupB=%d, kAccessesPerVector=%d, sizeof(AccessType)=%d",
426
- reinterpret_cast<void*>(dst_ptr), reinterpret_cast<void*>(gmem_ptr),
427
- static_cast<int>(Detail::kAccessesPerGroupB), static_cast<int>(IteratorB::kAccessesPerVector),
428
- static_cast<int>(sizeof(typename IteratorB::Element)));
424
+ if (j == 0 && v == 0) {
425
+ CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, %d bytes, group_start_B=%d, valid=%d",
426
+ stage, reinterpret_cast<void*>(gmem_ptr), reinterpret_cast<void*>(dst_ptr),
427
+ kSrcBytes, group_start_B, static_cast<int>(iterator_B.valid()));
429
428
}
430
429
#endif
431
430
@@ -543,9 +542,9 @@ class Wint2xMmaMultistage :
543
542
544
543
uint8_t* gmem_uint8_ptr = reinterpret_cast<uint8_t*>(gmem_ptr);
545
544
546
- CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};",
545
+ CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d, valid=%d ; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};",
547
546
stage, reinterpret_cast<void*>(gmem_ptr), reinterpret_cast<void*>(dst_ptr), kSrcBytes,
548
- gmem_n, gmem_k, gmem_row, gmem_col,
547
+ static_cast<int>(iterator_B.valid()), gmem_n, gmem_k, gmem_row, gmem_col,
549
548
static_cast<int>(gmem_uint8_ptr[0]), static_cast<int>(gmem_uint8_ptr[1]),
550
549
static_cast<int>(gmem_uint8_ptr[2]), static_cast<int>(gmem_uint8_ptr[3]),
551
550
static_cast<int>(gmem_uint8_ptr[4]), static_cast<int>(gmem_uint8_ptr[5]),
@@ -689,53 +688,54 @@ class Wint2xMmaMultistage :
689
688
++this ->warp_tile_iterator_A_ ;
690
689
691
690
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB ;
692
- int warp_mma_k_for_B = warp_mma_k / Base::kWarpGemmIterationsPerLoadForB ;
693
691
694
- if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1 ) {
692
+ if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1 ) {
695
693
// Load the next warp-tile's B fragment from shared memory
696
- this ->warp_tile_iterator_B_ .set_kgroup_index ((warp_mma_k_for_B + 1 ) % Base::kWarpGemmIterations );
694
+ // this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k_for_B + 1) % Base::kWarpGemmIterations);
695
+ this ->warp_tile_iterator_B_ .set_kgroup_index (0 );
697
696
this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ );
698
697
++this ->warp_tile_iterator_B_ ;
699
698
699
+ if (PipeState::WarpLoadedFragmentB::kElements == 64 ) {
700
+ uint8_t * reg_uint8_ptr = reinterpret_cast <uint8_t *>(pipe_state.warp_loaded_frag_B_ .data ());
701
+ CUTLASS_TRACE_DEVICE (" [stage=%d] warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes" ,
702
+ stage - Base::kStages + 2 ,
703
+ static_cast <int >(reg_uint8_ptr[0 ]), static_cast <int >(reg_uint8_ptr[1 ]),
704
+ static_cast <int >(reg_uint8_ptr[2 ]), static_cast <int >(reg_uint8_ptr[3 ]),
705
+ static_cast <int >(reg_uint8_ptr[4 ]), static_cast <int >(reg_uint8_ptr[5 ]),
706
+ static_cast <int >(reg_uint8_ptr[6 ]), static_cast <int >(reg_uint8_ptr[7 ]),
707
+ static_cast <int >(reg_uint8_ptr[8 ]), static_cast <int >(reg_uint8_ptr[9 ]),
708
+ static_cast <int >(reg_uint8_ptr[10 ]), static_cast <int >(reg_uint8_ptr[11 ]),
709
+ static_cast <int >(reg_uint8_ptr[12 ]), static_cast <int >(reg_uint8_ptr[13 ]),
710
+ static_cast <int >(reg_uint8_ptr[14 ]), static_cast <int >(reg_uint8_ptr[15 ]),
711
+ sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8 );
712
+ }
713
+
700
714
warp_dequantizer_.load (pipe_state.warp_frag_local_scale_ );
701
715
}
702
716
703
717
// Execute the current warp-tile of MMA operations
704
-
705
- // CUTLASS_TRACE_DEVICE("ElementA %d", PipeState::WarpTransformedFragmentA::kElements);
706
- // CUTLASS_TRACE_DEVICE("ElementB %d", PipeState::WarpTransformedFragmentB::kElements);
707
- // CUTLASS_TRACE_DEVICE("kStagedAccumulation %d", Detail::kStagedAccumulation);
708
-
709
- // uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_[warp_mma_k % 2].data());
710
- // CUTLASS_TRACE_DEVICE(" reg_uint8_ptr=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
711
- // static_cast<int>(reg_uint8_ptr[0]), static_cast<int>(reg_uint8_ptr[1]),
712
- // static_cast<int>(reg_uint8_ptr[2]), static_cast<int>(reg_uint8_ptr[3]),
713
- // static_cast<int>(reg_uint8_ptr[4]), static_cast<int>(reg_uint8_ptr[5]),
714
- // static_cast<int>(reg_uint8_ptr[6]), static_cast<int>(reg_uint8_ptr[7]),
715
- // static_cast<int>(reg_uint8_ptr[8]), static_cast<int>(reg_uint8_ptr[9]),
716
- // static_cast<int>(reg_uint8_ptr[10]), static_cast<int>(reg_uint8_ptr[11]),
717
- // static_cast<int>(reg_uint8_ptr[12]), static_cast<int>(reg_uint8_ptr[13]),
718
- // static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
719
- // sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
720
-
721
718
if (Detail::kStagedAccumulation ) {
722
719
// CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
723
720
warp_mma_ (
724
721
pipe_state.tmp_accum_ ,
725
722
pipe_state.warp_frag_A_ [warp_mma_k % 2 ],
726
723
pipe_state.warp_frag_B_ ,
727
- // unpacked_frag_B,
728
724
pipe_state.tmp_accum_ ,
729
725
warp_k_compute_offset_B
730
726
);
731
727
728
+ if (warp_mma_k == 0 ) {
729
+ plus<FragmentC> plus_accum;
730
+ accum = plus_accum (accum, pipe_state.tmp_accum_ );
731
+ pipe_state.tmp_accum_ .clear ();
732
+ }
732
733
} else {
733
734
// CUTLASS_TRACE_DEVICE(" [MMa][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
734
735
warp_mma_ (
735
736
accum,
736
737
pipe_state.warp_frag_A_ [warp_mma_k % 2 ],
737
738
pipe_state.warp_frag_B_ ,
738
- // unpacked_frag_B,
739
739
accum,
740
740
warp_k_compute_offset_B
741
741
);
@@ -761,33 +761,6 @@ class Wint2xMmaMultistage :
761
761
static_cast<float>(accum[12]), static_cast<float>(accum[13]),
762
762
static_cast<float>(accum[14]), static_cast<float>(accum[15]));
763
763
}
764
-
765
- // CUTLASS_TRACE_DEVICE_TID(" now1 warp_loaded_frag_A_[0:7]=[%f, %f, %f, %f, %f, %f, %f, %f]",
766
- // static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][0]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][1]),
767
- // static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][2]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][3]),
768
- // static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][4]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][5]),
769
- // static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][6]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][7]));
770
-
771
- // CUTLASS_TRACE_DEVICE_TID(" now1 unpacked_frag_B[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
772
- // static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
773
- // static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
774
- // static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
775
- // static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
776
- // static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
777
- // static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
778
- // static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
779
- // static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
780
-
781
- // CUTLASS_TRACE_DEVICE_TID(" warp_k_compute_offset_B = %d, now1 tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
782
- // warp_k_compute_offset_B,
783
- // static_cast<float>(accum[0]), static_cast<float>(accum[1]),
784
- // static_cast<float>(accum[2]), static_cast<float>(accum[3]),
785
- // static_cast<float>(accum[4]), static_cast<float>(accum[5]),
786
- // static_cast<float>(accum[6]), static_cast<float>(accum[7]),
787
- // static_cast<float>(accum[8]), static_cast<float>(accum[9]),
788
- // static_cast<float>(accum[10]), static_cast<float>(accum[11]),
789
- // static_cast<float>(accum[12]), static_cast<float>(accum[13]),
790
- // static_cast<float>(accum[14]), static_cast<float>(accum[15]));
791
764
#endif
792
765
}
793
766
@@ -798,7 +771,7 @@ class Wint2xMmaMultistage :
798
771
int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB ;
799
772
800
773
copy_tiles_and_advance_A (iterator_A, group_start_iteration_A);
801
- copy_tiles_and_advance_B (iterator_B, group_start_iteration_B);
774
+ copy_tiles_and_advance_B (iterator_B, group_start_iteration_B, stage );
802
775
if (warp_mma_k == 0 ) {
803
776
quant_params_accessor_B_.copy_tiles_and_advance_per_stage <false >(mma_quant_args, stage);
804
777
}
@@ -813,7 +786,7 @@ class Wint2xMmaMultistage :
813
786
int group_start_iteration_B = (warp_mma_k + 1 ) * Detail::kAccessesPerGroupB ;
814
787
815
788
copy_tiles_and_advance_A (iterator_A, group_start_iteration_A);
816
- copy_tiles_and_advance_B (iterator_B, group_start_iteration_B);
789
+ copy_tiles_and_advance_B (iterator_B, group_start_iteration_B, stage );
817
790
818
791
// Inserts a memory fence between stages of cp.async instructions.
819
792
cutlass::arch::cp_async_fence ();
@@ -858,22 +831,11 @@ class Wint2xMmaMultistage :
858
831
IteratorB &iterator_B, // /< [in|out] iterator over B operand in global memory
859
832
QuantArguments &mma_quant_args)
860
833
{
861
- #if 0
862
- CUTLASS_TRACE_DEVICE(" [PipeState] WarpLoadedFragmentA::kElements=%d, %d bytes",
863
- PipeState::WarpLoadedFragmentA::kElements, static_cast<int>(sizeof_bits<typename PipeState::WarpLoadedFragmentA>::value / 8));
864
- CUTLASS_TRACE_DEVICE(" [PipeState] WarpLoadedFragmentB::kElements=%d, %d bytes",
865
- PipeState::WarpLoadedFragmentB::kElements, static_cast<int>(sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8));
866
- CUTLASS_TRACE_DEVICE(" [PipeState] WarpTransformedFragmentA::kElements=%d, %d bytes",
867
- PipeState::WarpTransformedFragmentA::kElements, static_cast<int>(sizeof_bits<typename PipeState::WarpTransformedFragmentA>::value / 8));
868
- CUTLASS_TRACE_DEVICE(" [PipeState] WarpTransformedFragmentB::kElements=%d, %d bytes",
869
- PipeState::WarpTransformedFragmentB::kElements, static_cast<int>(sizeof_bits<typename PipeState::WarpTransformedFragmentB>::value / 8));
870
- #endif
871
-
872
834
PipeState pipe_state;
873
835
874
836
// Disable global fetching if done with global fetch iterations
875
837
iterator_A.clear_mask (gemm_k_iterations == 0 );
876
- iterator_B.clear_mask (gemm_k_iterations == (-Base:: kStages + 1 ) );
838
+ iterator_B.clear_mask (gemm_k_iterations == 0 );
877
839
878
840
// Load first warp-tile's A fragment from shared memory
879
841
this ->warp_tile_iterator_A_ .set_kgroup_index (0 );
@@ -896,10 +858,10 @@ class Wint2xMmaMultistage :
896
858
sizeof_bits<typename PipeState::WarpLoadedFragmentA>::value / 8);
897
859
}
898
860
#endif
899
- #if 0
861
+ #if 1
900
862
if (PipeState::WarpLoadedFragmentB::kElements == 64 ) {
901
863
uint8_t * reg_uint8_ptr = reinterpret_cast <uint8_t *>(pipe_state.warp_loaded_frag_B_ .data ());
902
- CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
864
+ CUTLASS_TRACE_DEVICE (" [stage=0] warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes" ,
903
865
static_cast <int >(reg_uint8_ptr[0 ]), static_cast <int >(reg_uint8_ptr[1 ]),
904
866
static_cast <int >(reg_uint8_ptr[2 ]), static_cast <int >(reg_uint8_ptr[3 ]),
905
867
static_cast <int >(reg_uint8_ptr[4 ]), static_cast <int >(reg_uint8_ptr[5 ]),
@@ -936,69 +898,8 @@ class Wint2xMmaMultistage :
936
898
pipe_state.warp_frag_B_ ,
937
899
0 );
938
900
939
- #if 0
940
- if (TransformBAfterLDS::result_type::kElements == 64) {
941
- CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);
942
- CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
943
- static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
944
- static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
945
- static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
946
- static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
947
- static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
948
- static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
949
- static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
950
- static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
951
- CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
952
- static_cast<float>(unpacked_frag_B[16]), static_cast<float>(unpacked_frag_B[17]),
953
- static_cast<float>(unpacked_frag_B[18]), static_cast<float>(unpacked_frag_B[19]),
954
- static_cast<float>(unpacked_frag_B[20]), static_cast<float>(unpacked_frag_B[21]),
955
- static_cast<float>(unpacked_frag_B[22]), static_cast<float>(unpacked_frag_B[23]),
956
- static_cast<float>(unpacked_frag_B[24]), static_cast<float>(unpacked_frag_B[25]),
957
- static_cast<float>(unpacked_frag_B[26]), static_cast<float>(unpacked_frag_B[27]),
958
- static_cast<float>(unpacked_frag_B[28]), static_cast<float>(unpacked_frag_B[29]),
959
- static_cast<float>(unpacked_frag_B[30]), static_cast<float>(unpacked_frag_B[31]));
960
- CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
961
- static_cast<float>(unpacked_frag_B[32]), static_cast<float>(unpacked_frag_B[33]),
962
- static_cast<float>(unpacked_frag_B[34]), static_cast<float>(unpacked_frag_B[35]),
963
- static_cast<float>(unpacked_frag_B[36]), static_cast<float>(unpacked_frag_B[37]),
964
- static_cast<float>(unpacked_frag_B[38]), static_cast<float>(unpacked_frag_B[39]),
965
- static_cast<float>(unpacked_frag_B[40]), static_cast<float>(unpacked_frag_B[41]),
966
- static_cast<float>(unpacked_frag_B[42]), static_cast<float>(unpacked_frag_B[43]),
967
- static_cast<float>(unpacked_frag_B[44]), static_cast<float>(unpacked_frag_B[45]),
968
- static_cast<float>(unpacked_frag_B[46]), static_cast<float>(unpacked_frag_B[47]));
969
- CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
970
- static_cast<float>(unpacked_frag_B[48]), static_cast<float>(unpacked_frag_B[49]),
971
- static_cast<float>(unpacked_frag_B[50]), static_cast<float>(unpacked_frag_B[51]),
972
- static_cast<float>(unpacked_frag_B[52]), static_cast<float>(unpacked_frag_B[53]),
973
- static_cast<float>(unpacked_frag_B[54]), static_cast<float>(unpacked_frag_B[55]),
974
- static_cast<float>(unpacked_frag_B[56]), static_cast<float>(unpacked_frag_B[57]),
975
- static_cast<float>(unpacked_frag_B[58]), static_cast<float>(unpacked_frag_B[59]),
976
- static_cast<float>(unpacked_frag_B[60]), static_cast<float>(unpacked_frag_B[61]),
977
- static_cast<float>(unpacked_frag_B[62]), static_cast<float>(unpacked_frag_B[63]));
978
- }
979
- #endif
980
-
981
901
if (Detail::kStagedAccumulation ) {
982
902
pipe_state.tmp_accum_ .clear ();
983
- CUTLASS_TRACE_DEVICE (" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]" ,
984
- static_cast <float >(pipe_state.tmp_accum_ [0 ]), static_cast <float >(pipe_state.tmp_accum_ [1 ]),
985
- static_cast <float >(pipe_state.tmp_accum_ [2 ]), static_cast <float >(pipe_state.tmp_accum_ [3 ]),
986
- static_cast <float >(pipe_state.tmp_accum_ [4 ]), static_cast <float >(pipe_state.tmp_accum_ [5 ]),
987
- static_cast <float >(pipe_state.tmp_accum_ [6 ]), static_cast <float >(pipe_state.tmp_accum_ [7 ]),
988
- static_cast <float >(pipe_state.tmp_accum_ [8 ]), static_cast <float >(pipe_state.tmp_accum_ [9 ]),
989
- static_cast <float >(pipe_state.tmp_accum_ [10 ]), static_cast <float >(pipe_state.tmp_accum_ [11 ]),
990
- static_cast <float >(pipe_state.tmp_accum_ [12 ]), static_cast <float >(pipe_state.tmp_accum_ [13 ]),
991
- static_cast <float >(pipe_state.tmp_accum_ [14 ]), static_cast <float >(pipe_state.tmp_accum_ [15 ]));
992
- } else {
993
- CUTLASS_TRACE_DEVICE (" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]" ,
994
- static_cast <float >(accum[0 ]), static_cast <float >(accum[1 ]),
995
- static_cast <float >(accum[2 ]), static_cast <float >(accum[3 ]),
996
- static_cast <float >(accum[4 ]), static_cast <float >(accum[5 ]),
997
- static_cast <float >(accum[6 ]), static_cast <float >(accum[7 ]),
998
- static_cast <float >(accum[8 ]), static_cast <float >(accum[9 ]),
999
- static_cast <float >(accum[10 ]), static_cast <float >(accum[11 ]),
1000
- static_cast <float >(accum[12 ]), static_cast <float >(accum[13 ]),
1001
- static_cast <float >(accum[14 ]), static_cast <float >(accum[15 ]));
1002
903
}
1003
904
1004
905
int stage = Base::kStages - 1 ;
0 commit comments