@@ -209,7 +209,7 @@ class Wint2xMmaMultistage :
209
209
WarpTransformedFragmentA warp_frag_A_[2 ];
210
210
211
211
// / Pair of B fragments used to overlap shared memory loads and math instructions
212
- WarpLoadedFragmentB warp_loaded_frag_B_;
212
+ WarpLoadedFragmentB warp_loaded_frag_B_[ 2 ] ;
213
213
WarpTransformedFragmentB warp_frag_B_;
214
214
};
215
215
@@ -691,10 +691,10 @@ class Wint2xMmaMultistage :
691
691
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB ;
692
692
int warp_mma_k_for_B = warp_mma_k / Base::kWarpGemmIterationsPerLoadForB ;
693
693
694
- if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1 ) {
694
+ if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1 ) {
695
695
// Load the next warp-tile's B fragment from shared memory
696
696
this ->warp_tile_iterator_B_ .set_kgroup_index ((warp_mma_k_for_B + 1 ) % Base::kWarpGemmIterations );
697
- this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ );
697
+ this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ [(warp_mma_k_for_B + 1 ) % 2 ] );
698
698
++this ->warp_tile_iterator_B_ ;
699
699
700
700
warp_dequantizer_.load (pipe_state.warp_frag_local_scale_ );
@@ -718,6 +718,16 @@ class Wint2xMmaMultistage :
718
718
// static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
719
719
// sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
720
720
721
+ if (warp_k_compute_offset_B == 0 ) {
722
+ warp_dequantizer_.dequantize (pipe_state.warp_frag_local_scale_ ,
723
+ pipe_state.warp_frag_code_scale_ ,
724
+ pipe_state.warp_frag_code_zp_ ,
725
+ pipe_state.warp_frag_super_scale_ ,
726
+ pipe_state.warp_loaded_frag_B_ [warp_mma_k_for_B % 2 ],
727
+ pipe_state.warp_frag_B_ ,
728
+ (stage - Base::kStages + 2 ) * Shape::kK );
729
+ }
730
+
721
731
if (Detail::kStagedAccumulation ) {
722
732
// CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
723
733
warp_mma_ (
@@ -767,27 +777,6 @@ class Wint2xMmaMultistage :
767
777
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][2]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][3]),
768
778
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][4]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][5]),
769
779
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][6]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][7]));
770
-
771
- // CUTLASS_TRACE_DEVICE_TID(" now1 unpacked_frag_B[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
772
- // static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
773
- // static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
774
- // static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
775
- // static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
776
- // static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
777
- // static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
778
- // static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
779
- // static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
780
-
781
- // CUTLASS_TRACE_DEVICE_TID(" warp_k_compute_offset_B = %d, now1 tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
782
- // warp_k_compute_offset_B,
783
- // static_cast<float>(accum[0]), static_cast<float>(accum[1]),
784
- // static_cast<float>(accum[2]), static_cast<float>(accum[3]),
785
- // static_cast<float>(accum[4]), static_cast<float>(accum[5]),
786
- // static_cast<float>(accum[6]), static_cast<float>(accum[7]),
787
- // static_cast<float>(accum[8]), static_cast<float>(accum[9]),
788
- // static_cast<float>(accum[10]), static_cast<float>(accum[11]),
789
- // static_cast<float>(accum[12]), static_cast<float>(accum[13]),
790
- // static_cast<float>(accum[14]), static_cast<float>(accum[15]));
791
780
#endif
792
781
}
793
782
@@ -835,16 +824,6 @@ class Wint2xMmaMultistage :
835
824
iterator_B.clear_mask (gemm_k_iterations == 0 );
836
825
quant_params_accessor_B_.clear_mask (mma_quant_args, gemm_k_iterations == 0 );
837
826
}
838
-
839
- if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1 ) {
840
- warp_dequantizer_.dequantize (pipe_state.warp_frag_local_scale_ ,
841
- pipe_state.warp_frag_code_scale_ ,
842
- pipe_state.warp_frag_code_zp_ ,
843
- pipe_state.warp_frag_super_scale_ ,
844
- pipe_state.warp_loaded_frag_B_ ,
845
- pipe_state.warp_frag_B_ ,
846
- (stage - Base::kStages + 2 ) * Shape::kK );
847
- }
848
827
}
849
828
}
850
829
@@ -882,7 +861,7 @@ class Wint2xMmaMultistage :
882
861
883
862
// Load first warp-tile's B fragment from shared memory
884
863
this ->warp_tile_iterator_B_ .set_kgroup_index (0 );
885
- this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ );
864
+ this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ [ 0 ] );
886
865
++this ->warp_tile_iterator_B_ ;
887
866
888
867
#if 0
@@ -928,14 +907,6 @@ class Wint2xMmaMultistage :
928
907
}
929
908
#endif
930
909
931
- warp_dequantizer_.dequantize (pipe_state.warp_frag_local_scale_ ,
932
- pipe_state.warp_frag_code_scale_ ,
933
- pipe_state.warp_frag_code_zp_ ,
934
- pipe_state.warp_frag_super_scale_ ,
935
- pipe_state.warp_loaded_frag_B_ ,
936
- pipe_state.warp_frag_B_ ,
937
- 0 );
938
-
939
910
#if 0
940
911
if (TransformBAfterLDS::result_type::kElements == 64) {
941
912
CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);
0 commit comments