@@ -86,10 +86,10 @@ template <
86
86
typename Policy_,
87
87
// / Number of stages,
88
88
int Stages,
89
+ // / Transform for input B applied in register after the LDS
90
+ typename TransformBAfterLDS_,
89
91
// / Use zfill or predicate for out-of-bound cp.async
90
- SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone ,
91
- // / Used for partial specialization
92
- typename Enable = bool >
92
+ SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone >
93
93
class Wint2xMmaMultistage :
94
94
public Wint2xMmaBase<Shape_, Policy_, Stages> {
95
95
public:
@@ -107,8 +107,10 @@ class Wint2xMmaMultistage :
107
107
using LayoutC = LayoutC_;
108
108
// /< Policy describing tuning details
109
109
using Policy = Policy_;
110
+ // / Transform for input B applied in register after the LDS
111
+ using TransformBAfterLDS = TransformBAfterLDS_;
110
112
111
- using ZippedShapeB = typename Base::SharedStorage::ZippedShapeB ;
113
+ static constexpr int kInterleave = IteratorB::Shape:: kRow / Shape:: kK ;
112
114
113
115
using SmemIteratorA = SmemIteratorA_;
114
116
using SmemIteratorB = SmemIteratorB_;
@@ -131,12 +133,11 @@ class Wint2xMmaMultistage :
131
133
132
134
using LayoutScale = cutlass::layout::ColumnMajor;
133
135
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
134
- using ElementB = typename WarpTransformedFragmentB::Element;
135
136
using Dequantizer =
136
137
warp::MmaTensorOpWin2xDequantizer<Operator,
137
138
typename Base::WarpGemm,
138
139
Operand::kB ,
139
- ElementB ,
140
+ typename WarpTransformedFragmentB::Element ,
140
141
cutlass::layout::ColumnMajor,
141
142
32 ,
142
143
WeightOnlyQuantOp::UNDEFINED>;
@@ -199,6 +200,14 @@ class Wint2xMmaMultistage :
199
200
WarpTransformedFragmentB warp_transformed_frag_B_[2 ];
200
201
};
201
202
203
+ using ElementA = typename IteratorA::Element;
204
+ using ElementB = typename IteratorB::Element;
205
+ using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
206
+
207
+ static constexpr bool IsTileInterleaveLayout =
208
+ layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
209
+ static_assert (!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
210
+ " Layout K must match threadblockK" );
202
211
203
212
private:
204
213
@@ -224,10 +233,11 @@ class Wint2xMmaMultistage :
224
233
// / Shared memory read stage index
225
234
int smem_read_stage_idx_;
226
235
227
- uint8_t * column_wise_smem_ptr_B_;
236
+ // / Transform for B in register
237
+ TransformBAfterLDS transform_B_;
228
238
229
- uint8_t * smem_zipped_ptr_B_ ;
230
- int smem_zipped_bytes_per_stage_B_ ;
239
+ uint8_t * smem_ptr_B_ ;
240
+ uint8_t * ptr_B_ ;
231
241
232
242
public:
233
243
@@ -261,16 +271,31 @@ class Wint2xMmaMultistage :
261
271
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM ;
262
272
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM ;
263
273
274
+ CUTLASS_TRACE_DEVICE (" Shape: {%d, %d, %d}, IteratorB::Shape: {%d, %d}, kInterleave: %d" ,
275
+ Shape::kM , Shape::kN , Shape::kK , IteratorB::Shape::kRow , IteratorB::Shape::kColumn , kInterleave );
276
+ CUTLASS_TRACE_DEVICE (" kPartitionsK=%d, kWarpGemmIterations=%d, WarpCount={%d, %d}, warp_idx_m=%d, warp_idx_n=%d, warp_idx_k=%d" ,
277
+ Policy::kPartitionsK , Base::kWarpGemmIterations ,
278
+ Base::WarpCount::kM , Base::WarpCount::kN , warp_idx_m, warp_idx_n, warp_idx_k);
279
+
264
280
// Add per-warp offsets in units of warp-level tiles
265
281
this ->warp_tile_iterator_A_ .add_tile_offset (
266
282
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
267
283
this ->warp_tile_iterator_B_ .add_tile_offset (
268
284
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
269
285
270
- column_wise_smem_ptr_B_ = shared_storage.operand_zipped_B_ptr ();
271
-
272
- smem_zipped_ptr_B_ = column_wise_smem_ptr_B_ + Base::SharedStorage::kColumnWiseParamsRows * ZippedShapeB::kColumn ;
273
- smem_zipped_bytes_per_stage_B_ = Base::SharedStorage::kZippedRowsPerStages * ZippedShapeB::kColumn ;
286
+ CUTLASS_TRACE_DEVICE (" Policy::SmemPaddingA: {%d, %d}; Policy::SmemPaddingB: {%d, %d}" ,
287
+ Policy::SmemPaddingA::kRow , Policy::SmemPaddingA::kColumn , Policy::SmemPaddingB::kRow , Policy::SmemPaddingB::kColumn );
288
+ CUTLASS_TRACE_DEVICE (" operand_A_ptr=%p, kRow=%d, kColumn=%d" ,
289
+ shared_storage.operand_A .data (), static_cast <int >(Base::SharedStorage::ShapeA::kRow ),
290
+ static_cast <int >(Base::SharedStorage::ShapeA::kColumn ));
291
+ CUTLASS_TRACE_DEVICE (" operand_B_ptr=%p, kRow=%d, kColumn=%d, %d bytes; kElementsPerAccess=%d, sizeof(AccessType)=%d, AsyncCopyIterationsPerStageB=%d, kAccessesPerVector=%d" ,
292
+ shared_storage.operand_B .data (),
293
+ static_cast <int >(Base::SharedStorage::ShapeB::kRow ), static_cast <int >(Base::SharedStorage::ShapeB::kColumn ),
294
+ static_cast <int >(sizeof (shared_storage.operand_B )),
295
+ static_cast <int >(IteratorB::ThreadMap::kElementsPerAccess ), static_cast <int >(sizeof (typename IteratorB::AccessType)),
296
+ static_cast <int >(Detail::AsyncCopyIterationsPerStageB), static_cast <int >(IteratorB::kAccessesPerVector ));
297
+
298
+ smem_ptr_B_ = reinterpret_cast <uint8_t *>(shared_storage.operand_B .data ());
274
299
}
275
300
276
301
// / Advance shared memory read-iterators to the next stage
@@ -371,6 +396,13 @@ class Wint2xMmaMultistage :
371
396
for (int v = 0 ; v < IteratorB::kAccessesPerVector ; ++v) {
372
397
auto gmem_ptr = iterator_B.get ();
373
398
399
+ if (group_start_B == 0 && j == 0 && v == 0 ) {
400
+ CUTLASS_TRACE_DEVICE (" dst_ptr=%p, iterator_B.get()=%p, kAccessesPerGroupB=%d, kAccessesPerVector=%d, sizeof(AccessType)=%d" ,
401
+ reinterpret_cast <void *>(dst_ptr), reinterpret_cast <void *>(gmem_ptr),
402
+ static_cast <int >(Detail::kAccessesPerGroupB ), static_cast <int >(IteratorB::kAccessesPerVector ),
403
+ static_cast <int >(sizeof (typename IteratorB::Element)));
404
+ }
405
+
374
406
if (SharedMemoryClear == SharedMemoryClearOption::kZfill ) {
375
407
cutlass::arch::copy_zfill<kSrcBytes , kCacheOpB , GlobalToSharedB>(
376
408
dst_ptr + v, gmem_ptr, iterator_B.valid ());
@@ -423,7 +455,7 @@ class Wint2xMmaMultistage :
423
455
424
456
template <bool GlobalToSharedB, bool InitStage>
425
457
CUTLASS_DEVICE
426
- void copy_tiles_and_advance_per_stage_B (IteratorB &iterator_B) {
458
+ void copy_tiles_and_advance_per_stage_B (IteratorB &iterator_B, int stage ) {
427
459
iterator_B.set_iteration_index (0 );
428
460
this ->smem_iterator_B_ .set_iteration_index (0 );
429
461
@@ -443,6 +475,31 @@ class Wint2xMmaMultistage :
443
475
IteratorB::ThreadMap::kElementsPerAccess /
444
476
IteratorB::kAccessesPerVector / 8 ;
445
477
478
+ if (v == 0 ) {
479
+ int gmem_offset = reinterpret_cast <int >(gmem_ptr) - reinterpret_cast <int >(ptr_B_);
480
+ int gmem_k = 8192 * kInterleave / 4 ;
481
+ int gmem_n = 1792 / kInterleave ;
482
+ int gmem_row = gmem_offset / gmem_k;
483
+ int gmem_col = gmem_offset % gmem_k;
484
+
485
+ int smem_offset = reinterpret_cast <int >(dst_ptr) - reinterpret_cast <int >(smem_ptr_B_);
486
+ int smem_k = Shape::kK * kInterleave / 4 ;
487
+ int smem_n = Shape::kN / kInterleave ;
488
+ int smem_row = smem_offset / smem_k;
489
+ int smem_col = smem_offset % smem_k;
490
+
491
+ uint8_t * gmem_uint8_ptr = reinterpret_cast <uint8_t *>(gmem_ptr);
492
+
493
+ CUTLASS_TRACE_DEVICE (" [stage=%d] gmem_ptr=%p, smem_ptr=%p, bytes=%d; gmem: %dx%d, {%d, %d}, [%d, %d, %d, %d, %d, %d, %d, %d]; smem: {%d, %d};" ,
494
+ stage, reinterpret_cast <void *>(gmem_ptr), reinterpret_cast <void *>(dst_ptr), kSrcBytes ,
495
+ gmem_n, gmem_k, gmem_row, gmem_col,
496
+ static_cast <int >(gmem_uint8_ptr[0 ]), static_cast <int >(gmem_uint8_ptr[1 ]),
497
+ static_cast <int >(gmem_uint8_ptr[2 ]), static_cast <int >(gmem_uint8_ptr[3 ]),
498
+ static_cast <int >(gmem_uint8_ptr[4 ]), static_cast <int >(gmem_uint8_ptr[5 ]),
499
+ static_cast <int >(gmem_uint8_ptr[6 ]), static_cast <int >(gmem_uint8_ptr[7 ]),
500
+ smem_row, smem_col);
501
+ }
502
+
446
503
if (InitStage) {
447
504
cutlass::arch::copy_zfill<kSrcBytes , kCacheOpB , GlobalToSharedB>(
448
505
dst_ptr + v, iterator_B.get (), iterator_B.valid ());
@@ -484,7 +541,7 @@ class Wint2xMmaMultistage :
484
541
copy_tiles_and_advance_per_stage_A (iterator_A);
485
542
486
543
// Async copy zipped B to shared memory.
487
- copy_tiles_and_advance_per_stage_B<true , true >(iterator_B);
544
+ copy_tiles_and_advance_per_stage_B<true , true >(iterator_B, stage );
488
545
489
546
// TODO: Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
490
547
@@ -666,6 +723,18 @@ class Wint2xMmaMultistage :
666
723
IteratorA &iterator_A, // /< [in|out] iterator over A operand in global memory
667
724
IteratorB &iterator_B)
668
725
{
726
+ #if 0
727
+ int smem_k = Shape::kK * kInterleave / 4;
728
+ int smem_n = Shape::kN / kInterleave;
729
+ for (int i = 0; i < 3 * smem_n; ++i) {
730
+ for (int j = 0; j < smem_k; ++j) {
731
+ if (i % 3 == 0) {
732
+ CUTLASS_TRACE_DEVICE(" [i=%d, j=%d, %dx%d] %d", i, j, smem_n, smem_k, static_cast<int>(smem_ptr_B_[i * smem_k + j]));
733
+ }
734
+ }
735
+ }
736
+ #endif
737
+
669
738
PipeState pipe_state;
670
739
671
740
// Disable global fetching if done with global fetch iterations
@@ -682,6 +751,70 @@ class Wint2xMmaMultistage :
682
751
this ->warp_tile_iterator_B_ .load (pipe_state.warp_loaded_frag_B_ [0 ]);
683
752
++this ->warp_tile_iterator_B_ ;
684
753
754
+ if (PipeState::WarpLoadedFragmentA::kElements == 8 ) {
755
+ ElementA* warp_frag_A_ptr = reinterpret_cast <ElementA*>(pipe_state.warp_loaded_frag_A_ [0 ].data ());
756
+ CUTLASS_TRACE_DEVICE (" warp_loaded_frag_A_=[%f, %f, %f, %f, %f, %f, %f, %f], %d bytes" ,
757
+ static_cast <float >(warp_frag_A_ptr[0 ]), static_cast <float >(warp_frag_A_ptr[1 ]),
758
+ static_cast <float >(warp_frag_A_ptr[2 ]), static_cast <float >(warp_frag_A_ptr[3 ]),
759
+ static_cast <float >(warp_frag_A_ptr[4 ]), static_cast <float >(warp_frag_A_ptr[5 ]),
760
+ static_cast <float >(warp_frag_A_ptr[6 ]), static_cast <float >(warp_frag_A_ptr[7 ]),
761
+ sizeof_bits<typename PipeState::WarpLoadedFragmentA>::value / 8 );
762
+ }
763
+ if (PipeState::WarpLoadedFragmentB::kElements == 64 ) {
764
+ uint8_t * reg_uint8_ptr = reinterpret_cast <uint8_t *>(pipe_state.warp_loaded_frag_B_ [0 ].data ());
765
+ CUTLASS_TRACE_DEVICE (" warp_loaded_frag_B_=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d], %d bytes" ,
766
+ static_cast <int >(reg_uint8_ptr[0 ]), static_cast <int >(reg_uint8_ptr[1 ]),
767
+ static_cast <int >(reg_uint8_ptr[2 ]), static_cast <int >(reg_uint8_ptr[3 ]),
768
+ static_cast <int >(reg_uint8_ptr[4 ]), static_cast <int >(reg_uint8_ptr[5 ]),
769
+ static_cast <int >(reg_uint8_ptr[6 ]), static_cast <int >(reg_uint8_ptr[7 ]),
770
+ static_cast <int >(reg_uint8_ptr[8 ]), static_cast <int >(reg_uint8_ptr[9 ]),
771
+ static_cast <int >(reg_uint8_ptr[10 ]), static_cast <int >(reg_uint8_ptr[11 ]),
772
+ static_cast <int >(reg_uint8_ptr[12 ]), static_cast <int >(reg_uint8_ptr[13 ]),
773
+ static_cast <int >(reg_uint8_ptr[14 ]), static_cast <int >(reg_uint8_ptr[15 ]),
774
+ sizeof_bits<typename PipeState::WarpLoadedFragmentB>::value / 8 );
775
+ }
776
+
777
+ typename TransformBAfterLDS::result_type unpacked_frag_B = transform_B_ (pipe_state.warp_loaded_frag_B_ [0 ]);
778
+ if (TransformBAfterLDS::result_type::kElements == 64 ) {
779
+ CUTLASS_TRACE_DEVICE (" TransformBAfterLDS::result_type::kElements: 64, %d bytes" , sizeof_bits<typename TransformBAfterLDS::result_type>::value / 8 );
780
+ CUTLASS_TRACE_DEVICE (" warp_loaded_frag_B_[0:15]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]" ,
781
+ static_cast <float >(unpacked_frag_B[0 ]), static_cast <float >(unpacked_frag_B[1 ]),
782
+ static_cast <float >(unpacked_frag_B[2 ]), static_cast <float >(unpacked_frag_B[3 ]),
783
+ static_cast <float >(unpacked_frag_B[4 ]), static_cast <float >(unpacked_frag_B[5 ]),
784
+ static_cast <float >(unpacked_frag_B[6 ]), static_cast <float >(unpacked_frag_B[7 ]),
785
+ static_cast <float >(unpacked_frag_B[8 ]), static_cast <float >(unpacked_frag_B[9 ]),
786
+ static_cast <float >(unpacked_frag_B[10 ]), static_cast <float >(unpacked_frag_B[11 ]),
787
+ static_cast <float >(unpacked_frag_B[12 ]), static_cast <float >(unpacked_frag_B[13 ]),
788
+ static_cast <float >(unpacked_frag_B[14 ]), static_cast <float >(unpacked_frag_B[15 ]));
789
+ CUTLASS_TRACE_DEVICE (" warp_loaded_frag_B_[16:31]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]" ,
790
+ static_cast <float >(unpacked_frag_B[16 ]), static_cast <float >(unpacked_frag_B[17 ]),
791
+ static_cast <float >(unpacked_frag_B[18 ]), static_cast <float >(unpacked_frag_B[19 ]),
792
+ static_cast <float >(unpacked_frag_B[20 ]), static_cast <float >(unpacked_frag_B[21 ]),
793
+ static_cast <float >(unpacked_frag_B[22 ]), static_cast <float >(unpacked_frag_B[23 ]),
794
+ static_cast <float >(unpacked_frag_B[24 ]), static_cast <float >(unpacked_frag_B[25 ]),
795
+ static_cast <float >(unpacked_frag_B[26 ]), static_cast <float >(unpacked_frag_B[27 ]),
796
+ static_cast <float >(unpacked_frag_B[28 ]), static_cast <float >(unpacked_frag_B[29 ]),
797
+ static_cast <float >(unpacked_frag_B[30 ]), static_cast <float >(unpacked_frag_B[31 ]));
798
+ CUTLASS_TRACE_DEVICE (" warp_loaded_frag_B_[32:47]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]" ,
799
+ static_cast <float >(unpacked_frag_B[32 ]), static_cast <float >(unpacked_frag_B[33 ]),
800
+ static_cast <float >(unpacked_frag_B[34 ]), static_cast <float >(unpacked_frag_B[35 ]),
801
+ static_cast <float >(unpacked_frag_B[36 ]), static_cast <float >(unpacked_frag_B[37 ]),
802
+ static_cast <float >(unpacked_frag_B[38 ]), static_cast <float >(unpacked_frag_B[39 ]),
803
+ static_cast <float >(unpacked_frag_B[40 ]), static_cast <float >(unpacked_frag_B[41 ]),
804
+ static_cast <float >(unpacked_frag_B[42 ]), static_cast <float >(unpacked_frag_B[43 ]),
805
+ static_cast <float >(unpacked_frag_B[44 ]), static_cast <float >(unpacked_frag_B[45 ]),
806
+ static_cast <float >(unpacked_frag_B[46 ]), static_cast <float >(unpacked_frag_B[47 ]));
807
+ CUTLASS_TRACE_DEVICE (" warp_loaded_frag_B_[48:63]=[%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f]" ,
808
+ static_cast <float >(unpacked_frag_B[48 ]), static_cast <float >(unpacked_frag_B[49 ]),
809
+ static_cast <float >(unpacked_frag_B[50 ]), static_cast <float >(unpacked_frag_B[51 ]),
810
+ static_cast <float >(unpacked_frag_B[52 ]), static_cast <float >(unpacked_frag_B[53 ]),
811
+ static_cast <float >(unpacked_frag_B[54 ]), static_cast <float >(unpacked_frag_B[55 ]),
812
+ static_cast <float >(unpacked_frag_B[56 ]), static_cast <float >(unpacked_frag_B[57 ]),
813
+ static_cast <float >(unpacked_frag_B[58 ]), static_cast <float >(unpacked_frag_B[59 ]),
814
+ static_cast <float >(unpacked_frag_B[60 ]), static_cast <float >(unpacked_frag_B[61 ]),
815
+ static_cast <float >(unpacked_frag_B[62 ]), static_cast <float >(unpacked_frag_B[63 ]));
816
+ }
817
+
685
818
typename Dequantizer::FragmentLocalScale warp_frag_local_scale;
686
819
typename Dequantizer::FragmentCodeScale warp_frag_code_scale;
687
820
typename Dequantizer::FragmentCodeZp warp_frag_code_zp;
@@ -702,6 +835,7 @@ class Wint2xMmaMultistage :
702
835
warp_frag_code_zp,
703
836
warp_frag_super_scale);
704
837
838
+ #if 0
705
839
// Transform, if necessary, the first warp-tile's shared memory fragments
706
840
warp_mma_.transform(
707
841
pipe_state.warp_transformed_frag_A_[0],
@@ -713,7 +847,6 @@ class Wint2xMmaMultistage :
713
847
pipe_state.tmp_accum_.clear();
714
848
}
715
849
716
- #if 0
717
850
int stage = Base::kStages - 1;
718
851
719
852
// Mainloop
@@ -790,6 +923,8 @@ class Wint2xMmaMultistage :
790
923
// /< initial value of accumulator
791
924
FragmentC const &src_accum) {
792
925
926
+ ptr_B_ = reinterpret_cast <uint8_t *>(iterator_B.get_origin_pointer ());
927
+
793
928
// Prologue (start fetching iterations of global fragments into shared memory)
794
929
prologue (iterator_A, iterator_B, gemm_k_iterations);
795
930
@@ -800,7 +935,7 @@ class Wint2xMmaMultistage :
800
935
accum = src_accum;
801
936
802
937
// Perform the MAC-iterations
803
- // gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B);
938
+ gemm_iters (gemm_k_iterations, accum, iterator_A, iterator_B);
804
939
}
805
940
};
806
941
0 commit comments