Skip to content

Commit 03f180c

Browse files
Added support for m8n8k4 shape
1 parent 1bdce94 commit 03f180c

File tree

3 files changed

+287
-28
lines changed

3 files changed

+287
-28
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 86 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,42 +1309,106 @@ class SYCLGen : public SYCLGenBase {
13091309
if (Inst->getNumInputOperands() != 3)
13101310
return SYCLGenError();
13111311

1312-
if (!Inst->hasAttr(InstAttr::m16n8k16))
1312+
const InlineAsmVectorExpr *DMatVE =
1313+
dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1314+
if (!DMatVE)
13131315
return SYCLGenError();
13141316

13151317
// Only row Layout is supported for of A matrix and
13161318
// only col Layout is supported for of B matrix
1317-
if (Inst->getAttr(3) != InstAttr::row ||
1318-
Inst->getAttr(4) != InstAttr::col) {
1319+
if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col)
13191320
return SYCLGenError();
1320-
}
13211321

13221322
// Only f16 type is supported for A and B matrix data
1323+
const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
13231324
const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(1));
13241325
const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(2));
1326+
const auto *CType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(3));
13251327

1326-
std::string TypeStr;
1327-
if (!AType || !BType ||
1328-
(AType->getKind() != InlineAsmBuiltinType::f16 ||
1329-
BType->getKind() != InlineAsmBuiltinType::f16)) {
1328+
if (!(AType && BType && CType && DType))
13301329
return SYCLGenError();
1331-
} else {
1332-
if (tryEmitType(TypeStr, AType))
1333-
return SYCLGenError();
1334-
}
13351330

1336-
const InlineAsmVectorExpr *VE =
1337-
dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1338-
if (VE && VE->getNumElements() != 4) {
1331+
// Data types of matrix elements for A&B and C&D matrices should be same
1332+
if ((AType->getKind() != BType->getKind()) ||
1333+
(CType->getKind() != DType->getKind()))
1334+
return SYCLGenError();
1335+
1336+
// Check the validity of AB & CD types
1337+
std::string ABType, CDType;
1338+
if (tryEmitType(ABType, AType))
1339+
return SYCLGenError();
1340+
1341+
if (tryEmitType(CDType, CType))
1342+
return SYCLGenError();
1343+
1344+
// Register sizes for vector elements of A, B, C & D matrices
1345+
int NumVecElements[4] = {0};
1346+
1347+
// Data type used to multiply A & B matrices
1348+
std::string MulType;
1349+
if (Inst->hasAttr(InstAttr::m16n8k16)) {
1350+
// Only f16 type is supported for A and B matrix data for m16n8k16
1351+
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1352+
// If A matrix type is f16, then C&D matrix types can only be f16
1353+
if (CType->getKind() == AType->getKind()) {
1354+
NumVecElements[0] = 4; // A
1355+
NumVecElements[1] = 2; // B
1356+
NumVecElements[2] = 4; // C
1357+
NumVecElements[3] = 4; // D
1358+
} else
1359+
return SYCLGenError();
1360+
} else
1361+
return SYCLGenError();
1362+
} else if (Inst->hasAttr(InstAttr::m8n8k4)) {
1363+
// f16 & f64 types are supported for A and B matrix data for m8n8k4
1364+
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1365+
// If A matrix type is f16, then C&D matrix types can only be f16/f32
1366+
if (CType->getKind() == AType->getKind()) {
1367+
NumVecElements[0] = 4; // A
1368+
NumVecElements[1] = 2; // B
1369+
NumVecElements[2] = 2; // C
1370+
NumVecElements[3] = 4; // D
1371+
} else if (CType->getKind() == InlineAsmBuiltinType::f32) {
1372+
NumVecElements[0] = 8; // A
1373+
NumVecElements[1] = 2; // B
1374+
NumVecElements[2] = 2; // C
1375+
NumVecElements[3] = 8; // D
1376+
} else
1377+
return SYCLGenError();
1378+
} else if (AType->getKind() == InlineAsmBuiltinType::f64) {
1379+
// If A matrix type is f64, then C&D matrix types can only be f64
1380+
if (CType->getKind() == AType->getKind()) {
1381+
NumVecElements[0] = 2; // A
1382+
NumVecElements[1] = 1; // B
1383+
NumVecElements[2] = 1; // C
1384+
NumVecElements[3] = 2; // D
1385+
} else
1386+
return SYCLGenError();
1387+
} else
1388+
return SYCLGenError();
1389+
} else
13391390
return SYCLGenError();
1391+
1392+
// Check the register sizes for vector elements of A, B, C & D matrices
1393+
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
1394+
InputOp++) {
1395+
if (auto VE =
1396+
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1397+
if (VE->getNumElements() != NumVecElements[InputOp])
1398+
return SYCLGenError();
1399+
} else
1400+
return SYCLGenError();
13401401
}
1402+
if (DMatVE->getNumElements() != NumVecElements[3])
1403+
return SYCLGenError();
13411404

1405+
MulType = ABType;
13421406
OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
1343-
OS() << "<" << TypeStr << ">(";
1407+
OS() << "<" << MulType << ">(";
13441408

13451409
// Add D matrix address values to store the MAD result
1346-
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
1347-
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1410+
for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) {
1411+
if (isa<InlineAsmDiscardExpr>(DMatVE->getElement(Inst)))
13481412
continue;
13491413
OS() << "&";
13501414
if (emitStmt(VE->getElement(Inst)))
@@ -2607,11 +2671,10 @@ class SYCLGen : public SYCLGenBase {
26072671
Op = std::move(NewOp);
26082672
}
26092673

2610-
bool HasHalfOrBfloat16 =
2611-
SrcType->getKind() == InlineAsmBuiltinType::f16 ||
2612-
DesType->getKind() == InlineAsmBuiltinType::f16 ||
2613-
SrcType->getKind() == InlineAsmBuiltinType::bf16 ||
2614-
DesType->getKind() == InlineAsmBuiltinType::bf16;
2674+
bool HasHalfOrBfloat16 = SrcType->getKind() == InlineAsmBuiltinType::f16 ||
2675+
DesType->getKind() == InlineAsmBuiltinType::f16 ||
2676+
SrcType->getKind() == InlineAsmBuiltinType::bf16 ||
2677+
DesType->getKind() == InlineAsmBuiltinType::bf16;
26152678
if (DpctGlobalInfo::useIntelDeviceMath() && HasHalfOrBfloat16) {
26162679
insertHeader(HeaderType::HT_SYCL_Math);
26172680
if (SrcNeedBitCast)

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ MODIFIER(col, ".col")
280280

281281
// Matrix shape
282282
MODIFIER(m16n8k16, ".m16n8k16")
283+
MODIFIER(m8n8k4, ".m8n8k4")
283284

284285
STATE_SPACE(reg, ".reg")
285286
STATE_SPACE(sreg, ".sreg")

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 200 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#ifndef __DPCT_MATH_HPP__
1010
#define __DPCT_MATH_HPP__
1111

12-
#include <limits>
1312
#include <climits>
13+
#include <limits>
1414
#include <sycl/sycl.hpp>
1515
#include <type_traits>
1616

@@ -1636,7 +1636,8 @@ inline constexpr unsigned extend_vcompare2_add(AT a, BT b, unsigned c,
16361636
/// \returns The extend vectorized average of the two values
16371637
template <typename RetT, typename AT, typename BT>
16381638
inline constexpr RetT extend_vavrg2(AT a, BT b, RetT c) {
1639-
return detail::extend_vbinary2<RetT, false, false>(a, b, c, detail::average());
1639+
return detail::extend_vbinary2<RetT, false, false>(a, b, c,
1640+
detail::average());
16401641
}
16411642

16421643
/// Compute vectorized average of \p a and \p b, with each value treated as a 2
@@ -1933,7 +1934,8 @@ inline constexpr unsigned extend_vcompare4_add(AT a, BT b, unsigned c,
19331934
/// \returns The extend vectorized average of the two values
19341935
template <typename RetT, typename AT, typename BT>
19351936
inline constexpr RetT extend_vavrg4(AT a, BT b, RetT c) {
1936-
return detail::extend_vbinary4<RetT, false, false>(a, b, c, detail::average());
1937+
return detail::extend_vbinary4<RetT, false, false>(a, b, c,
1938+
detail::average());
19371939
}
19381940

19391941
/// Compute vectorized average of \p a and \p b, with each value treated as a 4
@@ -2056,6 +2058,198 @@ class joint_matrix {
20562058
const size_t num_elements;
20572059
};
20582060

2061+
/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b16
2062+
/// matrix
2063+
/// Requires the sub-group size of kernel calling this function to be 32
2064+
/// \tparam [in] MulType The type of the multiplication result
2065+
/// \tparam [in] ABType The type of the input matrices
2066+
/// \tparam [in] CDType The type of the output matrix
2067+
/// \tparam [in] ItemT The type of the sycl::nd_item index space class
2068+
/// \param [in] d0 The 1st element to be written to the output D matrix
2069+
/// \param [in] d1 The 2nd element to be written to the output D matrix
2070+
/// \param [in] d2 The 3rd element to be written to the output D matrix
2071+
/// \param [in] d3 The 4th element to be written to the output D matrix
2072+
/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix
2073+
/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix
2074+
/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix
2075+
/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix
2076+
/// \param [in] c0 The 1st element from C matrix to be added with d0
2077+
/// \param [in] c1 The 2nd element from C matrix to be added with d1
2078+
/// \param [in] c2 The 3rd element from C matrix to be added with d2
2079+
/// \param [in] c3 The 4th element from C matrix to be added with d3
2080+
/// \param [in] item The sycl::nd_item index space class
2081+
template <typename MulType, typename ABType, typename CDType, typename ItemT>
2082+
void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
2083+
ABType b0, ABType b1, CDType c0, CDType c1, CDType c2, CDType c3,
2084+
const ItemT &item) {
2085+
int lane = item.get_sub_group().get_local_linear_id();
2086+
2087+
short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4);
2088+
2089+
ABType recv_a[2] recv_a[0] = a0;
2090+
recv_a[1] = a1;
2091+
2092+
MulType *ra = reinterpret_cast<MulType *>(recv_a);
2093+
MulType *c_h[4];
2094+
c_h[0] = reinterpret_cast<MulType *>(&c0);
2095+
c_h[1] = reinterpret_cast<MulType *>(&c1);
2096+
c_h[2] = reinterpret_cast<MulType *>(&c2);
2097+
c_h[3] = reinterpret_cast<MulType *>(&c3);
2098+
for (int i = 0; i < 4; i++) {
2099+
ABType recv_b[4];
2100+
2101+
recv_b[0] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2102+
COL_LOAD_OFFSET + i);
2103+
recv_b[1] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2104+
COL_LOAD_OFFSET + i);
2105+
recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2106+
COL_LOAD_OFFSET + 16 + i);
2107+
recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2108+
COL_LOAD_OFFSET + 16 + i);
2109+
2110+
MulType *rb = reinterpret_cast<MulType *>(recv_b);
2111+
// Iterate for k times
2112+
for (int j = 0; j < 4; j++) {
2113+
c_h[(i >> 1)][i % 2] +=
2114+
static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
2115+
c_h[2 + (i >> 1)][i % 2] +=
2116+
static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[4 + j]);
2117+
}
2118+
}
2119+
2120+
d0 = c0;
2121+
d1 = c1;
2122+
d2 = c2;
2123+
d3 = c3;
2124+
}
2125+
2126+
/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32
2127+
/// matrix
2128+
/// Requires the sub-group size of kernel calling this function to be 32
2129+
/// \tparam [in] MulType The type of the multiplication result
2130+
/// \tparam [in] ABType The type of the input matrices
2131+
/// \tparam [in] CDType The type of the output matrix
2132+
/// \tparam [in] ItemT The type of the sycl::nd_item index space class
2133+
/// \param [in] d0 The 1st element to be written to the output D matrix
2134+
/// \param [in] d1 The 2nd element to be written to the output D matrix
2135+
/// \param [in] d2 The 3rd element to be written to the output D matrix
2136+
/// \param [in] d3 The 4th element to be written to the output D matrix
2137+
/// \param [in] d4 The 5th element to be written to the output D matrix
2138+
/// \param [in] d5 The 6th element to be written to the output D matrix
2139+
/// \param [in] d6 The 7th element to be written to the output D matrix
2140+
/// \param [in] d7 The 8th element to be written to the output D matrix
2141+
/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix
2142+
/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix
2143+
/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix
2144+
/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix
2145+
/// \param [in] c0 The 1st element from C matrix to be added with d0
2146+
/// \param [in] c1 The 2nd element from C matrix to be added with d1
2147+
/// \param [in] c2 The 3rd element from C matrix to be added with d2
2148+
/// \param [in] c3 The 4th element from C matrix to be added with d3
2149+
/// \param [in] c4 The 5th element from C matrix to be added with d4
2150+
/// \param [in] c5 The 6th element from C matrix to be added with d5
2151+
/// \param [in] c6 The 7th element from C matrix to be added with d6
2152+
/// \param [in] c7 The 8th element from C matrix to be added with d7
2153+
/// \param [in] item The sycl::nd_item index space class
2154+
template <typename MulType, typename ABType, typename CDType, typename ItemT>
2155+
void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5,
2156+
CDType *d6, CDType *d7, ABType a0, ABType a1, ABType b0, ABType b1,
2157+
CDType c0, CDType c1, CDType c2, CDType c3, CDType c4, CDType c5,
2158+
CDType c6, CDType c7, const ItemT &item) {
2159+
int lane = item.get_sub_group().get_local_linear_id();
2160+
2161+
short ROW_LOAD_OFFSET = 4 * (lane / 4) + (lane % 2);
2162+
short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4) + 2 * ((lane / 2) % 2);
2163+
2164+
ABType recv_a[2 * 2], recv_b[4 * 2];
2165+
2166+
recv_a[0] =
2167+
dpct::select_from_sub_group(item.get_sub_group(), a0, ROW_LOAD_OFFSET);
2168+
recv_a[1] =
2169+
dpct::select_from_sub_group(item.get_sub_group(), a1, ROW_LOAD_OFFSET);
2170+
recv_a[2] = dpct::select_from_sub_group(item.get_sub_group(), a0,
2171+
ROW_LOAD_OFFSET + 2);
2172+
recv_a[3] = dpct::select_from_sub_group(item.get_sub_group(), a1,
2173+
ROW_LOAD_OFFSET + 2);
2174+
2175+
recv_b[0] =
2176+
dpct::select_from_sub_group(item.get_sub_group(), b0, COL_LOAD_OFFSET);
2177+
recv_b[1] =
2178+
dpct::select_from_sub_group(item.get_sub_group(), b1, COL_LOAD_OFFSET);
2179+
recv_b[2] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2180+
COL_LOAD_OFFSET + 1);
2181+
recv_b[3] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2182+
COL_LOAD_OFFSET + 1);
2183+
recv_b[4] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2184+
COL_LOAD_OFFSET + 16);
2185+
recv_b[5] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2186+
COL_LOAD_OFFSET + 16);
2187+
recv_b[6] = dpct::select_from_sub_group(item.get_sub_group(), b0,
2188+
COL_LOAD_OFFSET + 17);
2189+
recv_b[7] = dpct::select_from_sub_group(item.get_sub_group(), b1,
2190+
COL_LOAD_OFFSET + 17);
2191+
2192+
MulType *ra = reinterpret_cast<MulType *>(recv_a);
2193+
MulType *rb = reinterpret_cast<MulType *>(recv_b);
2194+
for (int i = 0; i < 4 /*k*/; i++) {
2195+
c0 += static_cast<CDType>(ra[i]) * static_cast<CDType>(rb[i]);
2196+
c1 += static_cast<CDType>(ra[i]) * static_cast<CDType>(rb[i + 4]);
2197+
c2 += static_cast<CDType>(ra[i + 4]) * static_cast<CDType>(rb[i]);
2198+
c3 += static_cast<CDType>(ra[i + 4]) * static_cast<CDType>(rb[i + 4]);
2199+
c4 += static_cast<CDType>(ra[i]) * static_cast<CDType>(rb[i + 8]);
2200+
c5 += static_cast<CDType>(ra[i]) * static_cast<CDType>(rb[i + 12]);
2201+
c6 += static_cast<CDType>(ra[i + 4]) * static_cast<CDType>(rb[i + 8]);
2202+
c7 += static_cast<CDType>(ra[i + 4]) * static_cast<CDType>(rb[i + 12]);
2203+
}
2204+
2205+
d0 = c0;
2206+
d1 = c1;
2207+
d2 = c2;
2208+
d3 = c3;
2209+
d4 = c4;
2210+
d5 = c5;
2211+
d6 = c6;
2212+
d7 = c7;
2213+
}
2214+
2215+
/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32
2216+
/// matrix
2217+
/// Requires the sub-group size of kernel calling this function to be 32
2218+
/// \tparam [in] MulType The type of the multiplication result
2219+
/// \tparam [in] ABType The type of the input matrices
2220+
/// \tparam [in] CDType The type of the output matrix
2221+
/// \tparam [in] ItemT The type of the sycl::nd_item index space class
2222+
/// \param [in] d0 The 1st element to be written to the output D matrix
2223+
/// \param [in] d1 The 2nd element to be written to the output D matrix
2224+
/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix
2225+
/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix
2226+
/// \param [in] c0 The 1st element from C matrix to be added with d0
2227+
/// \param [in] c1 The 2nd element from C matrix to be added with d1
2228+
/// \param [in] item The sycl::nd_item index space class
2229+
template <typename MulType, typename ABType, typename CDType, typename ItemT>
2230+
void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1,
2231+
const ItemT &item) {
2232+
int lane = item.get_sub_group().get_local_linear_id();
2233+
2234+
short ROW_LOAD_OFFSET = 4 * (lane / 4);
2235+
short COL_LOAD_OFFSET = 8 * (lane % 4);
2236+
2237+
for (int i = 0; i < 4; i++) {
2238+
ABType recv_a = dpct::select_from_sub_group(item.get_sub_group(), a0,
2239+
ROW_LOAD_OFFSET + i);
2240+
ABType recv_b = dpct::select_from_sub_group(item.get_sub_group(), b0,
2241+
COL_LOAD_OFFSET + i);
2242+
c0 += recv_a * recv_b;
2243+
2244+
recv_b = dpct::select_from_sub_group(item.get_sub_group(), b0,
2245+
COL_LOAD_OFFSET + i + 4);
2246+
c1 += recv_a * recv_b;
2247+
}
2248+
2249+
d0 = c0;
2250+
d1 = c1;
2251+
}
2252+
20592253
/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32
20602254
/// matrix
20612255
/// Requires the sub-group size of kernel calling this function to be 32
@@ -2084,7 +2278,7 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
20842278
CDType c2, CDType c3, const ItemT &item) {
20852279
int lane = item.get_sub_group().get_local_linear_id();
20862280

2087-
short ROW_LOAD_OFFSET = 4 * (lane / 4);
2281+
short ROW_LOAD_OFFSET = 4 * (lane >> 2);
20882282
short COL_LOAD_OFFSET = 8 * (lane % 4);
20892283

20902284
for (int i = 0; i < 4; i++) {
@@ -2113,7 +2307,8 @@ void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1,
21132307
auto *rb0 = reinterpret_cast<MulType *>(recv_b);
21142308
auto *rb1 = reinterpret_cast<MulType *>(recv_b + 2);
21152309

2116-
for (int j = 0; j < 2 * 2; j++) {
2310+
// Iterate for k (i * j) times
2311+
for (int j = 0; j < 4; j++) {
21172312
auto a0 = static_cast<CDType>(ra0[j]);
21182313
auto a1 = static_cast<CDType>(ra1[j]);
21192314
auto b0 = static_cast<CDType>(rb0[j]);

0 commit comments

Comments
 (0)