Skip to content

Commit 0fabdbc

Browse files
committed
Use async copy for local_scale.
Change-Id: Ib882ba41c3d2354bda4d25b40e2408ad3b2f7893
1 parent e86c13d commit 0fabdbc

File tree

3 files changed

+36
-32
lines changed

3 files changed

+36
-32
lines changed

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/default_wint2x_mma.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ struct DefaultQuantParamsIterators {
4747
MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor, 0,
4848
IteratorThreadMap, kAlignment>;
4949
using SmemIterator = Iterator;
50-
51-
//using AccessType = cutlass::Array<ElementT, kAlignment>;
52-
//using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
53-
// MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor,
54-
// 0, IteratorThreadMap, AccessType>;
5550
};
5651

5752
template <typename ThreadblockShape, int GroupSize>
@@ -70,10 +65,11 @@ struct DefaultQuantParamsIterators<ThreadblockShape, uint4b_t, GroupSize> {
7065
kColumns / kAlignment, kAlignment>;
7166

7267
public:
73-
using Iterator =
74-
cutlass::transform::threadblock::PredicatedTileIterator<
75-
cutlass::MatrixShape<kRows, kColumns>, uint4b_t,
76-
layout::RowMajor, 0, IteratorThreadMap, kAlignment>;
68+
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
69+
using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
70+
MatrixShape<kRows, kColumns>, uint4b_t, layout::RowMajor,
71+
0, IteratorThreadMap, AccessType>;
72+
7773
using SmemIterator = Iterator;
7874
};
7975

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_mma_multistage.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,6 @@ class Wint2xMmaMultistage :
572572

573573
++this->smem_iterator_B_;
574574
}
575-
__syncthreads();
576575
}
577576

578577
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching

custom_ops/gpu_ops/cutlass_extensions/gemm/threadblock/wint2x_params_accessor.h

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "cutlass/arch/memory_sm80.h"
1718
#include "cutlass/cutlass.h"
1819
#include "cutlass/gemm/gemm.h"
1920
#include "cutlass/matrix_shape.h"
@@ -67,7 +68,7 @@ class Wint2ParamsAccessor {
6768
using ElementSuperScale = typename IteratorSuperScale::Element;
6869
using LayoutSuperScale = typename IteratorSuperScale::Layout;
6970

70-
// local_scale uint4 and group-wise
71+
/// local_scale uint4 and group-wise
7172
using ElementLocalScale = typename IteratorLocalScale::Element;
7273
using LayoutLocalScale = typename IteratorLocalScale::Layout;
7374
static_assert(platform::is_same<ElementLocalScale, uint4b_t>::value,
@@ -76,7 +77,7 @@ class Wint2ParamsAccessor {
7677
using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element;
7778
using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout;
7879

79-
// 2 uint4b_t values are stored in a single uint8_t
80+
/// 2 uint4b_t values are stored in a single uint8_t
8081
constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK;
8182
constexpr static int kLocalScaleRows = IteratorLocalScale::Shape::kRow;
8283

@@ -249,29 +250,37 @@ class Wint2ParamsAccessor {
249250
if ((stage % kStagesPerLocalScaleLoad) == 0) {
250251
// Load group-wise local_scale to shared memory, which only needs to be done at each stage.
251252
// Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages.
252-
typename IteratorLocalScale::Fragment tb_frag_local_scale;
253-
tb_frag_local_scale.clear();
254-
quant_args.iterator_local_scale.load(tb_frag_local_scale);
255-
this->smem_iterator_local_scale_.store(tb_frag_local_scale);
253+
using AccessType = typename IteratorLocalScale::AccessType;
254+
cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits<AccessType>::value == 128)
255+
? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always;
256+
257+
quant_args.iterator_local_scale.set_iteration_index(0);
258+
this->smem_iterator_local_scale_.set_iteration_index(0);
259+
260+
// Async Copy for local_scale
261+
CUTLASS_PRAGMA_UNROLL
262+
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) {
263+
AccessType *dst_ptr =
264+
reinterpret_cast<AccessType *>(this->smem_iterator_local_scale_.get());
265+
266+
CUTLASS_PRAGMA_UNROLL
267+
for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) {
268+
auto gmem_ptr = quant_args.iterator_local_scale.get();
269+
270+
int const kSrcBytes =
271+
sizeof_bits<typename IteratorLocalScale::Element>::value *
272+
IteratorLocalScale::ThreadMap::kElementsPerAccess /
273+
IteratorLocalScale::kAccessesPerVector / 8;
274+
275+
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
276+
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
277+
}
278+
++quant_args.iterator_local_scale;
279+
}
280+
++this->smem_iterator_local_scale_;
256281

257282
//CUTLASS_TRACE_DEVICE(" [stage=%d][LocalScale] Shape: {%d, %d}",
258283
// stage, IteratorLocalScale::Shape::kRow, IteratorLocalScale::Shape::kColumn);
259-
#if 0
260-
__syncthreads();
261-
if (IteratorLocalScale::Fragment::kElements == 32) {
262-
uint8_t* local_scale_ptr = reinterpret_cast<uint8_t*>(tb_frag_local_scale.data());
263-
CUTLASS_TRACE_DEVICE(" [stage=%d][LocalScale] tb_frag_local_scale[0:15]=[%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d]",
264-
stage,
265-
static_cast<int>(local_scale_ptr[0]), static_cast<int>(local_scale_ptr[1]),
266-
static_cast<int>(local_scale_ptr[2]), static_cast<int>(local_scale_ptr[3]),
267-
static_cast<int>(local_scale_ptr[4]), static_cast<int>(local_scale_ptr[5]),
268-
static_cast<int>(local_scale_ptr[6]), static_cast<int>(local_scale_ptr[7]),
269-
static_cast<int>(local_scale_ptr[8]), static_cast<int>(local_scale_ptr[9]),
270-
static_cast<int>(local_scale_ptr[10]), static_cast<int>(local_scale_ptr[11]),
271-
static_cast<int>(local_scale_ptr[12]), static_cast<int>(local_scale_ptr[13]),
272-
static_cast<int>(local_scale_ptr[14]), static_cast<int>(local_scale_ptr[15]));
273-
}
274-
#endif
275284
}
276285
}
277286

0 commit comments

Comments
 (0)