@@ -475,7 +475,7 @@ class Wint2xMmaMultistage :
475
475
copy_tiles_and_advance_per_stage_A (iterator_A);
476
476
477
477
// Async copy zipped B to shared memory.
478
- copy_tiles_and_advance_per_stage_B<false , true >(iterator_B);
478
+ copy_tiles_and_advance_per_stage_B<true , true >(iterator_B);
479
479
480
480
// TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
481
481
// tile_dequanter_B.Load(smem_zipped_ptr_B_ + (stage % Base::kStages) * smem_zipped_bytes_per_stage_B_,
@@ -609,7 +609,7 @@ class Wint2xMmaMultistage :
609
609
int group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB ;
610
610
611
611
copy_tiles_and_advance_A (iterator_A, group_start_iteration_A);
612
- copy_tiles_and_advance_B<false >(iterator_B, group_start_iteration_B);
612
+ copy_tiles_and_advance_B<true >(iterator_B, group_start_iteration_B);
613
613
}
614
614
615
615
// The second-to-last warp-tile also:
@@ -621,7 +621,7 @@ class Wint2xMmaMultistage :
621
621
int group_start_iteration_B = (warp_mma_k + 1 ) * Detail::kAccessesPerGroupB ;
622
622
623
623
copy_tiles_and_advance_A (iterator_A, group_start_iteration_A);
624
- copy_tiles_and_advance_B<false >(iterator_B, group_start_iteration_B);
624
+ copy_tiles_and_advance_B<true >(iterator_B, group_start_iteration_B);
625
625
626
626
// Inserts a memory fence between stages of cp.async instructions.
627
627
cutlass::arch::cp_async_fence ();
0 commit comments