Skip to content

Commit db0271c

Browse files
Added b16 limitation and comments
1 parent 0a34bcd commit db0271c

File tree

2 files changed

+88
-37
lines changed

2 files changed

+88
-37
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,6 +1316,11 @@ class SYCLGen : public SYCLGenBase {
13161316
if (Inst->getNumInputOperands() != 1)
13171317
return SYCLGenError();
13181318

1319+
const auto *Type = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
1320+
1321+
if (!Type || Type->getKind() != InlineAsmBuiltinType::b16)
1322+
return SYCLGenError();
1323+
13191324
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
13201325
CurrInst = Inst;
13211326
const auto *Dst =

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,27 +2056,33 @@ class joint_matrix {
20562056
const size_t num_elements;
20572057
};
20582058

2059-
/// Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2059+
/// Collectively stores 1 8x8 b16 (128 bytes) matrix from private memory to
2060+
/// local memory per sub-group.
20602061
/// Requires the sub-group size of kernel calling this function to be 32.
2061-
/// Each of the first 8 work items contain the starting address of their
2062-
/// respective matrix row.
2063-
/// Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2064-
/// of 128 bytes.
2065-
/// Row Major: Each row of the matrix is stored by a group of 4 work items
2066-
/// r0: t0 t1 t2 t3
2067-
/// r1: t4 t5 t6 t7
2062+
/// 'mat' specifies the matrix index to be stored. The first '(mat + 1) * 8'
2063+
/// work items of sub-group contain the starting address of their respective
2064+
/// matrix row in 'addr'.
2065+
/// After distributing addresses to other work items, each of the 32 work items
2066+
/// store 32-bits (2 packed 16-bit data) into 'm' for a total of 128 bytes.
2067+
/// 'trans' specifies to perform a transposed/non-transposed store by each work
2068+
/// item like below
2069+
/// Row Major: Each row of the matrix is stored by a group of 4 work items(wi)
2070+
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2071+
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
20682072
/// ...
2069-
/// r7: t24 t25 t26 t27
2070-
/// r7: t28 t29 t30 t31
2071-
/// Col Major: Each col of the matrix is stored by a group of 4 work items
2072-
/// r0: t0 t4 t8 ... t28
2073-
/// r1: t0 t4 t8 ... t28
2073+
/// row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2074+
/// row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2075+
/// Col Major: Each col of the matrix is stored by a group of 4 work items(wi)
2076+
/// row-0: wi0 wi4 wi8 ... wi28
2077+
/// row-1: wi0 wi4 wi8 ... wi28
20742078
/// ...
2075-
/// r6: t3 t7 t11 ... t31
2076-
/// r7: t3 t7 t11 ... t31
2077-
/// \tparam [in] T The type of matrix elements
2078-
/// \param [in] addr The address of the matrix in local memory
2079-
/// \param [in] m The private memory containing data of matrix
2079+
/// row-6: wi3 wi7 wi11 ... wi31
2080+
/// row-7: wi3 wi7 wi11 ... wi31
2081+
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
2082+
/// \param [in] addr The starting address of corresponding matrix row for a work
2083+
/// item in local memory
2084+
/// \param [in] m The private memory to store the matrix. It points to 2 b16
2085+
/// type elements.
20802086
/// \param [in] trans Indicates whether the matrix to be stored transposed
20812087
/// \param [in] mat The matrix index to be stored
20822088
template <typename T>
@@ -2126,16 +2132,35 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
21262132
}
21272133
}
21282134

2129-
/// Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2135+
/// Collectively stores 2 8x8 b16 (256 bytes) matrix from private memory to
2136+
/// local memory per sub-group.
21302137
/// Requires the sub-group size of kernel calling this function to be 32.
2131-
/// Each of the first 16 work items contain the starting address of their
2132-
/// respective matrix row.
2133-
/// Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2134-
/// of 256 bytes.
2135-
/// \tparam [in] T The type of matrix elements
2136-
/// \param [in] addr The address of the matrix in local memory
2137-
/// \param [in] m1 The private memory containing data of 1st matrix
2138-
/// \param [in] m2 The private memory containing data of 2nd matrix
2138+
/// The first 16 work items of sub-group contain the starting address of their
2139+
/// respective matrix row in 'addr'.
2140+
/// After distributing addresses to other work items, each of the 32 work items
2141+
/// store 64-bits (32-bits per matrix) into 'm1' & 'm2' for a total of 256
2142+
/// bytes.
2143+
/// 'trans' specifies to perform a transposed/non-transposed store by each work
2144+
/// item like below
2145+
/// Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2146+
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2147+
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2148+
/// ...
2149+
/// row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2150+
/// row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2151+
/// Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2152+
/// row-0: wi0 wi4 wi8 ... wi28
2153+
/// row-1: wi0 wi4 wi8 ... wi28
2154+
/// ...
2155+
/// row-6: wi3 wi7 wi11 ... wi31
2156+
/// row-7: wi3 wi7 wi11 ... wi31
2157+
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
2158+
/// \param [in] addr The starting address of corresponding matrix row for a work
2159+
/// item in local memory
2160+
/// \param [in] m1 The private memory to store the data of 1st matrix. It points
2161+
/// to 2 b16 type elements.
2162+
/// \param [in] m2 The private memory to store the data of 2nd matrix. It points
2163+
/// to 2 b16 type elements.
21392164
/// \param [in] trans Indicates whether the matrix to be stored transposed
21402165
template <typename T>
21412166
void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
@@ -2145,18 +2170,39 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
21452170
stmatrix(addr, m2, trans, 1);
21462171
}
21472172

2148-
/// Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2173+
/// Collectively stores 4 8x8 b16 (512 bytes) matrix from private memory to
2174+
/// local memory per sub-group.
21492175
/// Requires the sub-group size of kernel calling this function to be 32.
2150-
/// Each of the 32 work items contain the starting address of their
2151-
/// respective matrix row.
2152-
/// Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2176+
/// Each work item of sub-group contains the starting address of their
2177+
/// respective matrix row in 'addr'.
2178+
/// After distributing addresses to other work items, each of the 32 work items
2179+
/// store 128-bits (32-bits per matrix) into 'm1', 'm2', 'm3' & 'm4' for a total
21532180
/// of 512 bytes.
2154-
/// \tparam [in] T The type of matrix elements
2155-
/// \param [in] addr The address of the matrix in local memory
2156-
/// \param [in] m1 The private memory containing data of 1st matrix
2157-
/// \param [in] m2 The private memory containing data of 2nd matrix
2158-
/// \param [in] m3 The private memory containing data of 3rd matrix
2159-
/// \param [in] m4 The private memory containing data of 4th matrix
2181+
/// 'trans' specifies to perform a transposed/non-transposed store by each work
2182+
/// item like below
2183+
/// Row Major: Each row of the matrices is stored by a group of 4 work items(wi)
2184+
/// row-0: wi0 wi0 wi1 wi1 ... wi3 wi3
2185+
/// row-1: wi4 wi4 wi5 wi5 ... wi7 wi7
2186+
/// ...
2187+
/// row-6: wi24 wi24 wi25 wi25 ... wi27 wi27
2188+
/// row-7: wi28 wi28 wi29 wi29 ... wi31 wi31
2189+
/// Col Major: Each col of the matrices is stored by a group of 4 work items(wi)
2190+
/// row-0: wi0 wi4 wi8 ... wi28
2191+
/// row-1: wi0 wi4 wi8 ... wi28
2192+
/// ...
2193+
/// row-6: wi3 wi7 wi11 ... wi31
2194+
/// row-7: wi3 wi7 wi11 ... wi31
2195+
/// \tparam [in] T Type of result variable (currently only supports 16-bit type)
2196+
/// \param [in] addr The starting address of corresponding matrix row for a work
2197+
/// item in local memory
2198+
/// \param [in] m1 The private memory to store the data of 1st matrix. It points
2199+
/// to 2 b16 type elements.
2200+
/// \param [in] m2 The private memory to store the data of 2nd matrix. It points
2201+
/// to 2 b16 type elements.
2202+
/// \param [in] m3 The private memory to store the data of 3rd matrix. It points
2203+
/// to 2 b16 type elements.
2204+
/// \param [in] m4 The private memory to store the data of 4th matrix. It points
2205+
/// to 2 b16 type elements.
21602206
/// \param [in] trans Indicates whether the matrix to be stored transposed
21612207
template <typename T>
21622208
void stmatrix(uintptr_t addr, T m1, T m2, T m3, T m4, bool trans = false) {

0 commit comments

Comments
 (0)