Skip to content

Commit e0b366f

Browse files
committed
add pingpong buffer for loaded_b_frag
1 parent 567c4bf commit e0b366f

File tree

1 file changed

+14
-43
lines changed

1 file changed

+14
-43
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class Wint2xMmaMultistage :
209209
WarpTransformedFragmentA warp_frag_A_[2];
210210

211211
/// Pair of B fragments used to overlap shared memory loads and math instructions
212-
WarpLoadedFragmentB warp_loaded_frag_B_;
212+
WarpLoadedFragmentB warp_loaded_frag_B_[2];
213213
WarpTransformedFragmentB warp_frag_B_;
214214
};
215215

@@ -691,10 +691,10 @@ class Wint2xMmaMultistage :
691691
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
692692
int warp_mma_k_for_B = warp_mma_k / Base::kWarpGemmIterationsPerLoadForB;
693693

694-
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
694+
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
695695
// Load the next warp-tile's B fragment from shared memory
696696
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k_for_B + 1) % Base::kWarpGemmIterations);
697-
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
697+
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[(warp_mma_k_for_B + 1) % 2]);
698698
++this->warp_tile_iterator_B_;
699699

700700
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
@@ -718,6 +718,16 @@ class Wint2xMmaMultistage :
718718
// static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
719719
// sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
720720

721+
if (warp_k_compute_offset_B == 0) {
722+
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
723+
pipe_state.warp_frag_code_scale_,
724+
pipe_state.warp_frag_code_zp_,
725+
pipe_state.warp_frag_super_scale_,
726+
pipe_state.warp_loaded_frag_B_[warp_mma_k_for_B % 2],
727+
pipe_state.warp_frag_B_,
728+
(stage - Base::kStages + 2) * Shape::kK);
729+
}
730+
721731
if (Detail::kStagedAccumulation) {
722732
//CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
723733
warp_mma_(
@@ -767,27 +777,6 @@ class Wint2xMmaMultistage :
767777
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][2]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][3]),
768778
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][4]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][5]),
769779
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][6]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][7]));
770-
771-
// CUTLASS_TRACE_DEVICE_TID(" now1 unpacked_frag_B[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
772-
// static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
773-
// static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
774-
// static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
775-
// static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
776-
// static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
777-
// static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
778-
// static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
779-
// static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
780-
781-
// CUTLASS_TRACE_DEVICE_TID(" warp_k_compute_offset_B = %d, now1 tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
782-
// warp_k_compute_offset_B,
783-
// static_cast<float>(accum[0]), static_cast<float>(accum[1]),
784-
// static_cast<float>(accum[2]), static_cast<float>(accum[3]),
785-
// static_cast<float>(accum[4]), static_cast<float>(accum[5]),
786-
// static_cast<float>(accum[6]), static_cast<float>(accum[7]),
787-
// static_cast<float>(accum[8]), static_cast<float>(accum[9]),
788-
// static_cast<float>(accum[10]), static_cast<float>(accum[11]),
789-
// static_cast<float>(accum[12]), static_cast<float>(accum[13]),
790-
// static_cast<float>(accum[14]), static_cast<float>(accum[15]));
791780
#endif
792781
}
793782

@@ -835,16 +824,6 @@ class Wint2xMmaMultistage :
835824
iterator_B.clear_mask(gemm_k_iterations == 0);
836825
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
837826
}
838-
839-
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
840-
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
841-
pipe_state.warp_frag_code_scale_,
842-
pipe_state.warp_frag_code_zp_,
843-
pipe_state.warp_frag_super_scale_,
844-
pipe_state.warp_loaded_frag_B_,
845-
pipe_state.warp_frag_B_,
846-
(stage - Base::kStages + 2) * Shape::kK);
847-
}
848827
}
849828
}
850829

@@ -882,7 +861,7 @@ class Wint2xMmaMultistage :
882861

883862
// Load first warp-tile's B fragment from shared memory
884863
this->warp_tile_iterator_B_.set_kgroup_index(0);
885-
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
864+
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_[0]);
886865
++this->warp_tile_iterator_B_;
887866

888867
#if 0
@@ -928,14 +907,6 @@ class Wint2xMmaMultistage :
928907
}
929908
#endif
930909

931-
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
932-
pipe_state.warp_frag_code_scale_,
933-
pipe_state.warp_frag_code_zp_,
934-
pipe_state.warp_frag_super_scale_,
935-
pipe_state.warp_loaded_frag_B_,
936-
pipe_state.warp_frag_B_,
937-
0);
938-
939910
#if 0
940911
if (TransformBAfterLDS::result_type::kElements == 64) {
941912
CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);

0 commit comments

Comments
 (0)