14
14
15
15
#pragma once
16
16
17
+ #include " cutlass/arch/memory_sm80.h"
17
18
#include " cutlass/cutlass.h"
18
19
#include " cutlass/gemm/gemm.h"
19
20
#include " cutlass/matrix_shape.h"
@@ -67,7 +68,7 @@ class Wint2ParamsAccessor {
67
68
using ElementSuperScale = typename IteratorSuperScale::Element;
68
69
using LayoutSuperScale = typename IteratorSuperScale::Layout;
69
70
70
- // local_scale uint4 and group-wise
71
+ // / local_scale uint4 and group-wise
71
72
using ElementLocalScale = typename IteratorLocalScale::Element;
72
73
using LayoutLocalScale = typename IteratorLocalScale::Layout;
73
74
static_assert (platform::is_same<ElementLocalScale, uint4b_t >::value,
@@ -76,7 +77,7 @@ class Wint2ParamsAccessor {
76
77
using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element;
77
78
using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout;
78
79
79
- // 2 uint4b_t values are stored in a single uint8_t
80
+ // / 2 uint4b_t values are stored in a single uint8_t
80
81
constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK ;
81
82
constexpr static int kLocalScaleRows = IteratorLocalScale::Shape::kRow ;
82
83
@@ -249,29 +250,37 @@ class Wint2ParamsAccessor {
249
250
if ((stage % kStagesPerLocalScaleLoad ) == 0 ) {
250
251
// Load group-wise local_scale to shared memory, which only needs to be done at each stage.
251
252
// Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages.
252
- typename IteratorLocalScale::Fragment tb_frag_local_scale;
253
- tb_frag_local_scale.clear ();
254
- quant_args.iterator_local_scale .load (tb_frag_local_scale);
255
- this ->smem_iterator_local_scale_ .store (tb_frag_local_scale);
253
+ using AccessType = typename IteratorLocalScale::AccessType;
254
+ cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits<AccessType>::value == 128 )
255
+ ? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always;
256
+
257
+ quant_args.iterator_local_scale .set_iteration_index (0 );
258
+ this ->smem_iterator_local_scale_ .set_iteration_index (0 );
259
+
260
+ // Async Copy for local_scale
261
+ CUTLASS_PRAGMA_UNROLL
262
+ for (int j = 0 ; j < IteratorLocalScale::ThreadMap::Iterations::kCount ; ++j) {
263
+ AccessType *dst_ptr =
264
+ reinterpret_cast <AccessType *>(this ->smem_iterator_local_scale_ .get ());
265
+
266
+ CUTLASS_PRAGMA_UNROLL
267
+ for (int v = 0 ; v < IteratorLocalScale::kAccessesPerVector ; ++v) {
268
+ auto gmem_ptr = quant_args.iterator_local_scale .get ();
269
+
270
+ int const kSrcBytes =
271
+ sizeof_bits<typename IteratorLocalScale::Element>::value *
272
+ IteratorLocalScale::ThreadMap::kElementsPerAccess /
273
+ IteratorLocalScale::kAccessesPerVector / 8 ;
274
+
275
+ cutlass::arch::cp_async<kSrcBytes , kCacheOp >(
276
+ dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale .valid ());
277
+ }
278
+ ++quant_args.iterator_local_scale ;
279
+ }
280
+ ++this ->smem_iterator_local_scale_ ;
256
281
257
282
// CUTLASS_TRACE_DEVICE(" [stage=%d][LocalScale] Shape: {%d, %d}",
258
283
// stage, IteratorLocalScale::Shape::kRow, IteratorLocalScale::Shape::kColumn);
259
- #if 0
260
- __syncthreads();
261
- if (IteratorLocalScale::Fragment::kElements == 32) {
262
- uint8_t* local_scale_ptr = reinterpret_cast<uint8_t*>(tb_frag_local_scale.data());
263
- CUTLASS_TRACE_DEVICE(" [stage=%d][LocalScale] tb_frag_local_scale[0:15]=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d]",
264
- stage,
265
- static_cast<int>(local_scale_ptr[0]), static_cast<int>(local_scale_ptr[1]),
266
- static_cast<int>(local_scale_ptr[2]), static_cast<int>(local_scale_ptr[3]),
267
- static_cast<int>(local_scale_ptr[4]), static_cast<int>(local_scale_ptr[5]),
268
- static_cast<int>(local_scale_ptr[6]), static_cast<int>(local_scale_ptr[7]),
269
- static_cast<int>(local_scale_ptr[8]), static_cast<int>(local_scale_ptr[9]),
270
- static_cast<int>(local_scale_ptr[10]), static_cast<int>(local_scale_ptr[11]),
271
- static_cast<int>(local_scale_ptr[12]), static_cast<int>(local_scale_ptr[13]),
272
- static_cast<int>(local_scale_ptr[14]), static_cast<int>(local_scale_ptr[15]));
273
- }
274
- #endif
275
284
}
276
285
}
277
286
0 commit comments