Skip to content

Commit 4c1c677

Browse files
committed
Check and correct the load and unpack of weights.
Change-Id: I9a95649b9f90fcf9300a4a10f266046e3adf1064
1 parent 567c4bf commit 4c1c677

File tree

3 files changed

+68
-234
lines changed

3 files changed

+68
-234
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 36 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ class Wint2xMmaMultistage :
332332
if (smem_read_stage_idx_ == Base::kStages) {
333333
// Wrap back around to the 'start' of the circular buffer in shared memory
334334
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
335-
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0});
335+
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0});
336336
smem_read_stage_idx_ = 0;
337337
}
338338
}
@@ -399,7 +399,7 @@ class Wint2xMmaMultistage :
399399
}
400400

401401
CUTLASS_DEVICE
402-
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0) {
402+
void copy_tiles_and_advance_B(IteratorB &iterator_B, int group_start_B = 0, int stage = 0) {
403403
iterator_B.set_iteration_index(group_start_B *
404404
IteratorB::kAccessesPerVector);
405405
this->smem_iterator_B_.set_iteration_index(group_start_B);
@@ -421,11 +421,10 @@ class Wint2xMmaMultistage :
421421
auto gmem_ptr = iterator_B.get();
422422

423423
#if 0
424-
if (group_start_B == 0 && j == 0 && v == 0) {
425-
CUTLASS_TRACE_DEVICE(" dst_ptr=%p, iterator_B.get()=%p, kAccessesPerGroupB=%d, kAccessesPerVector=%d, sizeof(AccessType)=%d",
426-
reinterpret_cast<void*>(dst_ptr), reinterpret_cast<void*>(gmem_ptr),
427-
static_cast<int>(Detail::kAccessesPerGroupB), static_cast<int>(IteratorB::kAccessesPerVector),
428-
static_cast<int>(sizeof(typename IteratorB::Element)));
424+
if (j == 0 && v == 0) {
425+
CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, %d bytes, group_start_B=%d, valid=%d",
426+
stage, reinterpret_cast<void*>(gmem_ptr), reinterpret_cast<void*>(dst_ptr),
427+
kSrcBytes, group_start_B, static_cast<int>(iterator_B.valid()));
429428
}
430429
#endif
431430

@@ -543,9 +542,9 @@ class Wint2xMmaMultistage :
543542

544543
uint8_t* gmem_uint8_ptr = reinterpret_cast<uint8_t*>(gmem_ptr);
545544

546-
CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};",
545+
CUTLASS_TRACE_DEVICE(" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d, valid=%d; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};",
547546
stage, reinterpret_cast<void*>(gmem_ptr), reinterpret_cast<void*>(dst_ptr), kSrcBytes,
548-
gmem_n, gmem_k, gmem_row, gmem_col,
547+
static_cast<int>(iterator_B.valid()), gmem_n, gmem_k, gmem_row, gmem_col,
549548
static_cast<int>(gmem_uint8_ptr[0]), static_cast<int>(gmem_uint8_ptr[1]),
550549
static_cast<int>(gmem_uint8_ptr[2]), static_cast<int>(gmem_uint8_ptr[3]),
551550
static_cast<int>(gmem_uint8_ptr[4]), static_cast<int>(gmem_uint8_ptr[5]),
@@ -689,53 +688,54 @@ class Wint2xMmaMultistage :
689688
++this->warp_tile_iterator_A_;
690689

691690
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
692-
int warp_mma_k_for_B = warp_mma_k / Base::kWarpGemmIterationsPerLoadForB;
693691

694-
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
692+
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
695693
// Load the next warp-tile's B fragment from shared memory
696-
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k_for_B + 1) % Base::kWarpGemmIterations);
694+
//this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k_for_B + 1) % Base::kWarpGemmIterations);
695+
this->warp_tile_iterator_B_.set_kgroup_index(0);
697696
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
698697
++this->warp_tile_iterator_B_;
699698

699+
if (PipeState::WarpLoadedFragmentB::kElements == 64) {
700+
uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_.data());
701+
CUTLASS_TRACE_DEVICE(" [stage=%d] warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
702+
stage - Base::kStages + 2,
703+
static_cast<int>(reg_uint8_ptr[0]), static_cast<int>(reg_uint8_ptr[1]),
704+
static_cast<int>(reg_uint8_ptr[2]), static_cast<int>(reg_uint8_ptr[3]),
705+
static_cast<int>(reg_uint8_ptr[4]), static_cast<int>(reg_uint8_ptr[5]),
706+
static_cast<int>(reg_uint8_ptr[6]), static_cast<int>(reg_uint8_ptr[7]),
707+
static_cast<int>(reg_uint8_ptr[8]), static_cast<int>(reg_uint8_ptr[9]),
708+
static_cast<int>(reg_uint8_ptr[10]), static_cast<int>(reg_uint8_ptr[11]),
709+
static_cast<int>(reg_uint8_ptr[12]), static_cast<int>(reg_uint8_ptr[13]),
710+
static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
711+
sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
712+
}
713+
700714
warp_dequantizer_.load(pipe_state.warp_frag_local_scale_);
701715
}
702716

703717
// Execute the current warp-tile of MMA operations
704-
705-
// CUTLASS_TRACE_DEVICE("ElementA %d", PipeState::WarpTransformedFragmentA::kElements);
706-
// CUTLASS_TRACE_DEVICE("ElementB %d", PipeState::WarpTransformedFragmentB::kElements);
707-
// CUTLASS_TRACE_DEVICE("kStagedAccumulation %d", Detail::kStagedAccumulation);
708-
709-
// uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_[warp_mma_k % 2].data());
710-
// CUTLASS_TRACE_DEVICE(" reg_uint8_ptr=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
711-
// static_cast<int>(reg_uint8_ptr[0]), static_cast<int>(reg_uint8_ptr[1]),
712-
// static_cast<int>(reg_uint8_ptr[2]), static_cast<int>(reg_uint8_ptr[3]),
713-
// static_cast<int>(reg_uint8_ptr[4]), static_cast<int>(reg_uint8_ptr[5]),
714-
// static_cast<int>(reg_uint8_ptr[6]), static_cast<int>(reg_uint8_ptr[7]),
715-
// static_cast<int>(reg_uint8_ptr[8]), static_cast<int>(reg_uint8_ptr[9]),
716-
// static_cast<int>(reg_uint8_ptr[10]), static_cast<int>(reg_uint8_ptr[11]),
717-
// static_cast<int>(reg_uint8_ptr[12]), static_cast<int>(reg_uint8_ptr[13]),
718-
// static_cast<int>(reg_uint8_ptr[14]), static_cast<int>(reg_uint8_ptr[15]),
719-
// sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8);
720-
721718
if (Detail::kStagedAccumulation) {
722719
//CUTLASS_TRACE_DEVICE(" [MMa-kStagedAccumulation][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
723720
warp_mma_(
724721
pipe_state.tmp_accum_,
725722
pipe_state.warp_frag_A_[warp_mma_k % 2],
726723
pipe_state.warp_frag_B_,
727-
// unpacked_frag_B,
728724
pipe_state.tmp_accum_,
729725
warp_k_compute_offset_B
730726
);
731727

728+
if (warp_mma_k == 0) {
729+
plus<FragmentC> plus_accum;
730+
accum = plus_accum(accum, pipe_state.tmp_accum_);
731+
pipe_state.tmp_accum_.clear();
732+
}
732733
} else {
733734
//CUTLASS_TRACE_DEVICE(" [MMa][stage=%d] warp_mma_k=%d, warp_k_compute_offset_B=%d", stage, warp_mma_k, warp_k_compute_offset_B);
734735
warp_mma_(
735736
accum,
736737
pipe_state.warp_frag_A_[warp_mma_k % 2],
737738
pipe_state.warp_frag_B_,
738-
// unpacked_frag_B,
739739
accum,
740740
warp_k_compute_offset_B
741741
);
@@ -761,33 +761,6 @@ class Wint2xMmaMultistage :
761761
static_cast<float>(accum[12]), static_cast<float>(accum[13]),
762762
static_cast<float>(accum[14]), static_cast<float>(accum[15]));
763763
}
764-
765-
// CUTLASS_TRACE_DEVICE_TID(" now1 warp_loaded_frag_A_[0:7]=[%f, %f, %f, %f, %f, %f, %f, %f]",
766-
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][0]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][1]),
767-
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][2]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][3]),
768-
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][4]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][5]),
769-
// static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][6]), static_cast<float>(pipe_state.warp_loaded_frag_A_[warp_mma_k % 2][7]));
770-
771-
// CUTLASS_TRACE_DEVICE_TID(" now1 unpacked_frag_B[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
772-
// static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
773-
// static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
774-
// static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
775-
// static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
776-
// static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
777-
// static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
778-
// static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
779-
// static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
780-
781-
// CUTLASS_TRACE_DEVICE_TID(" warp_k_compute_offset_B = %d, now1 tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
782-
// warp_k_compute_offset_B,
783-
// static_cast<float>(accum[0]), static_cast<float>(accum[1]),
784-
// static_cast<float>(accum[2]), static_cast<float>(accum[3]),
785-
// static_cast<float>(accum[4]), static_cast<float>(accum[5]),
786-
// static_cast<float>(accum[6]), static_cast<float>(accum[7]),
787-
// static_cast<float>(accum[8]), static_cast<float>(accum[9]),
788-
// static_cast<float>(accum[10]), static_cast<float>(accum[11]),
789-
// static_cast<float>(accum[12]), static_cast<float>(accum[13]),
790-
// static_cast<float>(accum[14]), static_cast<float>(accum[15]));
791764
#endif
792765
}
793766

@@ -798,7 +771,7 @@ class Wint2xMmaMultistage :
798771
int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
799772

800773
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
801-
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
774+
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B, stage);
802775
if (warp_mma_k == 0) {
803776
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
804777
}
@@ -813,7 +786,7 @@ class Wint2xMmaMultistage :
813786
int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
814787

815788
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
816-
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
789+
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B, stage);
817790

818791
// Inserts a memory fence between stages of cp.async instructions.
819792
cutlass::arch::cp_async_fence();
@@ -858,22 +831,11 @@ class Wint2xMmaMultistage :
858831
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
859832
QuantArguments &mma_quant_args)
860833
{
861-
#if 0
862-
CUTLASS_TRACE_DEVICE(" [PipeState] WarpLoadedFragmentA::kElements=%d, %d bytes",
863-
PipeState::WarpLoadedFragmentA::kElements, static_cast<int>(sizeof_bits<typename PipeState::WarpLoadedFragmentA>::value / 8));
864-
CUTLASS_TRACE_DEVICE(" [PipeState] WarpLoadedFragmentB::kElements=%d, %d bytes",
865-
PipeState::WarpLoadedFragmentB::kElements, static_cast<int>(sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8));
866-
CUTLASS_TRACE_DEVICE(" [PipeState] WarpTransformedFragmentA::kElements=%d, %d bytes",
867-
PipeState::WarpTransformedFragmentA::kElements, static_cast<int>(sizeof_bits<typename PipeState::WarpTransformedFragmentA>::value / 8));
868-
CUTLASS_TRACE_DEVICE(" [PipeState] WarpTransformedFragmentB::kElements=%d, %d bytes",
869-
PipeState::WarpTransformedFragmentB::kElements, static_cast<int>(sizeof_bits<typename PipeState::WarpTransformedFragmentB>::value / 8));
870-
#endif
871-
872834
PipeState pipe_state;
873835

874836
// Disable global fetching if done with global fetch iterations
875837
iterator_A.clear_mask(gemm_k_iterations == 0);
876-
iterator_B.clear_mask(gemm_k_iterations == (-Base::kStages + 1));
838+
iterator_B.clear_mask(gemm_k_iterations == 0);
877839

878840
// Load first warp-tile's A fragment from shared memory
879841
this->warp_tile_iterator_A_.set_kgroup_index(0);
@@ -896,10 +858,10 @@ class Wint2xMmaMultistage :
896858
sizeof_bits<typename PipeState::WarpLoadedFragmentA>::value / 8);
897859
}
898860
#endif
899-
#if 0
861+
#if 1
900862
if (PipeState::WarpLoadedFragmentB::kElements == 64) {
901863
uint8_t* reg_uint8_ptr = reinterpret_cast<uint8_t*>(pipe_state.warp_loaded_frag_B_.data());
902-
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
864+
CUTLASS_TRACE_DEVICE(" [stage=0] warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes",
903865
static_cast<int>(reg_uint8_ptr[0]), static_cast<int>(reg_uint8_ptr[1]),
904866
static_cast<int>(reg_uint8_ptr[2]), static_cast<int>(reg_uint8_ptr[3]),
905867
static_cast<int>(reg_uint8_ptr[4]), static_cast<int>(reg_uint8_ptr[5]),
@@ -936,69 +898,8 @@ class Wint2xMmaMultistage :
936898
pipe_state.warp_frag_B_,
937899
0);
938900

939-
#if 0
940-
if (TransformBAfterLDS::result_type::kElements == 64) {
941-
CUTLASS_TRACE_DEVICE(" TransformBAfterLDS::result_type::kElements: 64, %d bytes", sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8);
942-
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
943-
static_cast<float>(unpacked_frag_B[0]), static_cast<float>(unpacked_frag_B[1]),
944-
static_cast<float>(unpacked_frag_B[2]), static_cast<float>(unpacked_frag_B[3]),
945-
static_cast<float>(unpacked_frag_B[4]), static_cast<float>(unpacked_frag_B[5]),
946-
static_cast<float>(unpacked_frag_B[6]), static_cast<float>(unpacked_frag_B[7]),
947-
static_cast<float>(unpacked_frag_B[8]), static_cast<float>(unpacked_frag_B[9]),
948-
static_cast<float>(unpacked_frag_B[10]), static_cast<float>(unpacked_frag_B[11]),
949-
static_cast<float>(unpacked_frag_B[12]), static_cast<float>(unpacked_frag_B[13]),
950-
static_cast<float>(unpacked_frag_B[14]), static_cast<float>(unpacked_frag_B[15]));
951-
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
952-
static_cast<float>(unpacked_frag_B[16]), static_cast<float>(unpacked_frag_B[17]),
953-
static_cast<float>(unpacked_frag_B[18]), static_cast<float>(unpacked_frag_B[19]),
954-
static_cast<float>(unpacked_frag_B[20]), static_cast<float>(unpacked_frag_B[21]),
955-
static_cast<float>(unpacked_frag_B[22]), static_cast<float>(unpacked_frag_B[23]),
956-
static_cast<float>(unpacked_frag_B[24]), static_cast<float>(unpacked_frag_B[25]),
957-
static_cast<float>(unpacked_frag_B[26]), static_cast<float>(unpacked_frag_B[27]),
958-
static_cast<float>(unpacked_frag_B[28]), static_cast<float>(unpacked_frag_B[29]),
959-
static_cast<float>(unpacked_frag_B[30]), static_cast<float>(unpacked_frag_B[31]));
960-
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
961-
static_cast<float>(unpacked_frag_B[32]), static_cast<float>(unpacked_frag_B[33]),
962-
static_cast<float>(unpacked_frag_B[34]), static_cast<float>(unpacked_frag_B[35]),
963-
static_cast<float>(unpacked_frag_B[36]), static_cast<float>(unpacked_frag_B[37]),
964-
static_cast<float>(unpacked_frag_B[38]), static_cast<float>(unpacked_frag_B[39]),
965-
static_cast<float>(unpacked_frag_B[40]), static_cast<float>(unpacked_frag_B[41]),
966-
static_cast<float>(unpacked_frag_B[42]), static_cast<float>(unpacked_frag_B[43]),
967-
static_cast<float>(unpacked_frag_B[44]), static_cast<float>(unpacked_frag_B[45]),
968-
static_cast<float>(unpacked_frag_B[46]), static_cast<float>(unpacked_frag_B[47]));
969-
CUTLASS_TRACE_DEVICE(" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
970-
static_cast<float>(unpacked_frag_B[48]), static_cast<float>(unpacked_frag_B[49]),
971-
static_cast<float>(unpacked_frag_B[50]), static_cast<float>(unpacked_frag_B[51]),
972-
static_cast<float>(unpacked_frag_B[52]), static_cast<float>(unpacked_frag_B[53]),
973-
static_cast<float>(unpacked_frag_B[54]), static_cast<float>(unpacked_frag_B[55]),
974-
static_cast<float>(unpacked_frag_B[56]), static_cast<float>(unpacked_frag_B[57]),
975-
static_cast<float>(unpacked_frag_B[58]), static_cast<float>(unpacked_frag_B[59]),
976-
static_cast<float>(unpacked_frag_B[60]), static_cast<float>(unpacked_frag_B[61]),
977-
static_cast<float>(unpacked_frag_B[62]), static_cast<float>(unpacked_frag_B[63]));
978-
}
979-
#endif
980-
981901
if (Detail::kStagedAccumulation) {
982902
pipe_state.tmp_accum_.clear();
983-
CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
984-
static_cast<float>(pipe_state.tmp_accum_[0]), static_cast<float>(pipe_state.tmp_accum_[1]),
985-
static_cast<float>(pipe_state.tmp_accum_[2]), static_cast<float>(pipe_state.tmp_accum_[3]),
986-
static_cast<float>(pipe_state.tmp_accum_[4]), static_cast<float>(pipe_state.tmp_accum_[5]),
987-
static_cast<float>(pipe_state.tmp_accum_[6]), static_cast<float>(pipe_state.tmp_accum_[7]),
988-
static_cast<float>(pipe_state.tmp_accum_[8]), static_cast<float>(pipe_state.tmp_accum_[9]),
989-
static_cast<float>(pipe_state.tmp_accum_[10]), static_cast<float>(pipe_state.tmp_accum_[11]),
990-
static_cast<float>(pipe_state.tmp_accum_[12]), static_cast<float>(pipe_state.tmp_accum_[13]),
991-
static_cast<float>(pipe_state.tmp_accum_[14]), static_cast<float>(pipe_state.tmp_accum_[15]));
992-
} else {
993-
CUTLASS_TRACE_DEVICE(" before tmp_accum_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]",
994-
static_cast<float>(accum[0]), static_cast<float>(accum[1]),
995-
static_cast<float>(accum[2]), static_cast<float>(accum[3]),
996-
static_cast<float>(accum[4]), static_cast<float>(accum[5]),
997-
static_cast<float>(accum[6]), static_cast<float>(accum[7]),
998-
static_cast<float>(accum[8]), static_cast<float>(accum[9]),
999-
static_cast<float>(accum[10]), static_cast<float>(accum[11]),
1000-
static_cast<float>(accum[12]), static_cast<float>(accum[13]),
1001-
static_cast<float>(accum[14]), static_cast<float>(accum[15]));
1002903
}
1003904

1004905
int stage = Base::kStages - 1;

0 commit comments

Comments
 (0)