Skip to content

Commit 8047a00

Browse files
Removed item_ct1 in favor of free functions
1 parent d46c138 commit 8047a00

File tree

3 files changed

+27
-34
lines changed

3 files changed

+27
-34
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,15 +1415,13 @@ class SYCLGen : public SYCLGenBase {
14151415
if (emitStmt(Dst)) {
14161416
return SYCLGenError();
14171417
}
1418-
OS() << ", ";
14191418
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
14201419
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
14211420
continue;
1421+
OS() << ", ";
14221422
if (emitStmt(VE->getElement(Inst)))
14231423
return SYCLGenError();
1424-
OS() << ", ";
14251424
}
1426-
OS() << DpctGlobalInfo::getItem(GAS);
14271425
if (Inst->hasAttr(InstAttr::trans))
14281426
OS() << ", true";
14291427
OS() << ");";

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,13 +2223,12 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22232223
/// \tparam [in] T The type of matrix elements
22242224
/// \param [in] addr The address of the matrix in local memory
22252225
/// \param [in] m The private memory containing data of matrix
2226-
/// \param [in] item The sycl::nd_item index space class
22272226
/// \param [in] trans Indicates whether the matrix to be stored transposed
22282227
/// \param [in] mat The matrix index to be stored
2229-
template <typename T, typename ItemT>
2230-
void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
2231-
unsigned mat = 0) {
2232-
int lane = item.get_sub_group().get_local_linear_id();
2228+
template <typename T>
2229+
void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
2230+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
2231+
int lane = sg.get_local_linear_id();
22332232

22342233
int lane_group8_row = lane / 8;
22352234
int lane_group8_col = lane % 8;
@@ -2241,8 +2240,8 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
22412240
src_lane += 1;
22422241

22432242
// Broadcast the address from the source lane
2244-
auto recv_addr_uintp = dpct::select_from_sub_group(
2245-
item.get_sub_group(), addr, mat * 8 + src_lane);
2243+
auto recv_addr_uintp =
2244+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
22462245

22472246
// Cast the received address from uintptr_t to the type of 'm'
22482247
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
@@ -2254,10 +2253,10 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
22542253
int src_lane = (lane % 4) * 2;
22552254

22562255
// Broadcast the address from the source lane
2257-
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
2258-
item.get_sub_group(), addr, mat * 8 + src_lane);
2259-
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
2260-
item.get_sub_group(), addr, mat * 8 + src_lane + 1);
2256+
auto recv_addr_uintp_1 =
2257+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
2258+
auto recv_addr_uintp_2 =
2259+
dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
22612260

22622261
// Cast the received address from uintptr_t to 'half *'
22632262
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
@@ -2279,15 +2278,13 @@ void stmatrix(uintptr_t addr, T m, const ItemT &item, bool trans = false,
22792278
/// \param [in] addr The address of the matrix in local memory
22802279
/// \param [in] m1 The private memory containing data of 1st matrix
22812280
/// \param [in] m2 The private memory containing data of 2nd matrix
2282-
/// \param [in] item The sycl::nd_item index space class
22832281
/// \param [in] trans Indicates whether the matrix to be stored transposed
2284-
template <typename T, typename ItemT>
2285-
void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
2286-
bool trans = false) {
2282+
template <typename T>
2283+
void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
22872284
// Store 1st matrix
2288-
stmatrix(addr, m1, item, trans, 0);
2285+
stmatrix(addr, m1, trans, 0);
22892286
// Store 2nd matrix
2290-
stmatrix(addr, m2, item, trans, 1);
2287+
stmatrix(addr, m2, trans, 1);
22912288
}
22922289

22932290
/// Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
@@ -2298,19 +2295,17 @@ void stmatrix(uintptr_t addr, T m1, T m2, const ItemT &item,
22982295
/// \param [in] m2 The private memory containing data of 2nd matrix
22992296
/// \param [in] m3 The private memory containing data of 3rd matrix
23002297
/// \param [in] m4 The private memory containing data of 4th matrix
2301-
/// \param [in] item The sycl::nd_item index space class
23022298
/// \param [in] trans Indicates whether the matrix to be stored transposed
2303-
template <typename T, typename ItemT>
2304-
void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, const ItemT &item,
2305-
bool trans = false) {
2299+
template <typename T>
2300+
void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false) {
23062301
// Store 1st matrix
2307-
stmatrix(addr, m1, item, trans, 0);
2302+
stmatrix(addr, m1, trans, 0);
23082303
// Store 2nd matrix
2309-
stmatrix(addr, m2, item, trans, 1);
2304+
stmatrix(addr, m2, trans, 1);
23102305
// Store 3rd matrix
2311-
stmatrix(addr, m3, item, trans, 2);
2306+
stmatrix(addr, m3, trans, 2);
23122307
// Store 4th matrix
2313-
stmatrix(addr, m4, item, trans, 3);
2308+
stmatrix(addr, m4, trans, 3);
23142309
}
23152310

23162311
/// A helper struct that defines the pack type for the input matrix fragments

clang/test/dpct/asm/stmatrix.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ __device__ void store_matrix_x1(void *sh_r_addr, int *r) {
2222
// CHECK: auto addr = sh_r_addr;
2323
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
2424

25-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], item_ct1);
25+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0]);
2626
asm volatile("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n"
2727
:
2828
: "r"(addr), "r"(r[0]));
@@ -32,7 +32,7 @@ __device__ void store_matrix_x2(void *sh_r_addr, int *r) {
3232
// CHECK: auto addr = sh_r_addr;
3333
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
3434

35-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], item_ct1);
35+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1]);
3636
asm volatile("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n"
3737
:
3838
: "r"(addr), "r"(r[0]), "r"(r[1]));
@@ -42,7 +42,7 @@ __device__ void store_matrix_x4(void *sh_r_addr, int *r) {
4242
// CHECK: auto addr = sh_r_addr;
4343
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
4444

45-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3], item_ct1);
45+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3]);
4646
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n"
4747
:
4848
: "r"(addr), "r"(r[0]), "r"(r[1]), "r"(r[2]), "r"(r[3]));
@@ -52,7 +52,7 @@ __device__ void store_matrix_x1_trans(void *sh_r_addr, int *r) {
5252
// CHECK: auto addr = sh_r_addr;
5353
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
5454

55-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], item_ct1, true);
55+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], true);
5656
asm volatile("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n"
5757
:
5858
: "r"(addr), "r"(r[0]));
@@ -62,7 +62,7 @@ __device__ void store_matrix_x2_trans(void *sh_r_addr, int *r) {
6262
// CHECK: auto addr = sh_r_addr;
6363
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
6464

65-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], item_ct1, true);
65+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], true);
6666
asm volatile("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n"
6767
:
6868
: "r"(addr), "r"(r[0]), "r"(r[1]));
@@ -72,7 +72,7 @@ __device__ void store_matrix_x4_trans(void *sh_r_addr, int *r) {
7272
// CHECK: auto addr = sh_r_addr;
7373
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
7474

75-
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3], item_ct1, true);
75+
// CHECK: dpct::experimental::matrix::stmatrix((uintptr_t)addr, r[0], r[1], r[2], r[3], true);
7676
asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n"
7777
:
7878
: "r"(addr), "r"(r[0]), "r"(r[1]), "r"(r[2]), "r"(r[3]));

0 commit comments

Comments
 (0)