Skip to content

Commit d43c8fe

Browse files
Added more info about functionality
1 parent 8047a00 commit d43c8fe

File tree

1 file changed

+30
-6
lines changed
  • clang/runtime/dpct-rt/include/dpct

1 file changed

+30
-6
lines changed

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,8 +2218,24 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22182218
ldmatrix(addr, m4, trans, 3);
22192219
}
22202220

2221-
/// Stores 1 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2222-
/// Requires the sub-group size of kernel calling this function to be 32
2221+
/// Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2222+
/// Requires the sub-group size of kernel calling this function to be 32.
2223+
/// Each of the first 8 work items contain the starting address of their
2224+
/// respective matrix row.
2225+
/// Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2226+
/// of 128 bytes.
2227+
/// Row Major: Each row of the matrix is stored by a group of 4 work items
2228+
/// r0: t0 t1 t2 t3
2229+
/// r1: t4 t5 t6 t7
2230+
/// ...
2231+
/// r7: t24 t25 t26 t27
2232+
/// r7: t28 t29 t30 t31
2233+
/// Col Major: Each col of the matrix is stored by a group of 4 work items
2234+
/// r0: t0 t4 t8 ... t28
2235+
/// r1: t0 t4 t8 ... t28
2236+
/// ...
2237+
/// r6: t3 t7 t11 ... t31
2238+
/// r7: t3 t7 t11 ... t31
22232239
/// \tparam [in] T The type of matrix elements
22242240
/// \param [in] addr The address of the matrix in local memory
22252241
/// \param [in] m The private memory containing data of matrix
@@ -2272,8 +2288,12 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
22722288
}
22732289
}
22742290

2275-
/// Stores 2 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2276-
/// Requires the sub-group size of kernel calling this function to be 32
2291+
/// Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2292+
/// Requires the sub-group size of kernel calling this function to be 32.
2293+
/// Each of the first 16 work items contain the starting address of their
2294+
/// respective matrix row.
2295+
/// Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2296+
/// of 256 bytes.
22772297
/// \tparam [in] T The type of matrix elements
22782298
/// \param [in] addr The address of the matrix in local memory
22792299
/// \param [in] m1 The private memory containing data of 1st matrix
@@ -2287,8 +2307,12 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
22872307
stmatrix(addr, m2, trans, 1);
22882308
}
22892309

2290-
/// Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2291-
/// Requires the sub-group size of kernel calling this function to be 32
2310+
/// Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2311+
/// Requires the sub-group size of kernel calling this function to be 32.
2312+
/// Each of the 32 work items contain the starting address of their
2313+
/// respective matrix row.
2314+
/// Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2315+
/// of 512 bytes.
22922316
/// \tparam [in] T The type of matrix elements
22932317
/// \param [in] addr The address of the matrix in local memory
22942318
/// \param [in] m1 The private memory containing data of 1st matrix

0 commit comments

Comments
 (0)