Skip to content

Commit 0a34bcd

Browse files
Added more info about functionality
1 parent 9292a81 commit 0a34bcd

File tree

1 file changed

+30
-6
lines changed
  • clang/runtime/dpct-rt/include/dpct

1 file changed

+30
-6
lines changed

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,8 +2056,24 @@ class joint_matrix {
20562056
const size_t num_elements;
20572057
};
20582058

2059-
/// Stores 1 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2060-
/// Requires the sub-group size of kernel calling this function to be 32
2059+
/// Stores 1 8x8 b16 matrix from private memory to local memory per sub-group.
2060+
/// Requires the sub-group size of kernel calling this function to be 32.
2061+
/// Each of the first 8 work items contain the starting address of their
2062+
/// respective matrix row.
2063+
/// Each of the 32 work items store 32-bits (2 packed 16-bit data) for a total
2064+
/// of 128 bytes.
2065+
/// Row Major: Each row of the matrix is stored by a group of 4 work items
2066+
/// r0: t0 t1 t2 t3
2067+
/// r1: t4 t5 t6 t7
2068+
/// ...
2069+
/// r7: t24 t25 t26 t27
2070+
/// r7: t28 t29 t30 t31
2071+
/// Col Major: Each col of the matrix is stored by a group of 4 work items
2072+
/// r0: t0 t4 t8 ... t28
2073+
/// r1: t0 t4 t8 ... t28
2074+
/// ...
2075+
/// r6: t3 t7 t11 ... t31
2076+
/// r7: t3 t7 t11 ... t31
20612077
/// \tparam [in] T The type of matrix elements
20622078
/// \param [in] addr The address of the matrix in local memory
20632079
/// \param [in] m The private memory containing data of matrix
@@ -2110,8 +2126,12 @@ void stmatrix(uintptr_t addr, T m, bool trans = false, unsigned mat = 0) {
21102126
}
21112127
}
21122128

2113-
/// Stores 2 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2114-
/// Requires the sub-group size of kernel calling this function to be 32
2129+
/// Stores 2 8x8 b16 matrix from private memory to local memory per sub-group.
2130+
/// Requires the sub-group size of kernel calling this function to be 32.
2131+
/// Each of the first 16 work items contain the starting address of their
2132+
/// respective matrix row.
2133+
/// Each of the 32 work items store 64-bits (32-bit per matrix) for a total
2134+
/// of 256 bytes.
21152135
/// \tparam [in] T The type of matrix elements
21162136
/// \param [in] addr The address of the matrix in local memory
21172137
/// \param [in] m1 The private memory containing data of 1st matrix
@@ -2125,8 +2145,12 @@ void stmatrix(uintptr_t addr, T m1, T m2, bool trans = false) {
21252145
stmatrix(addr, m2, trans, 1);
21262146
}
21272147

2128-
/// Stores 4 8x8 b16 matrix from private memory to local memory (32-bits per wi)
2129-
/// Requires the sub-group size of kernel calling this function to be 32
2148+
/// Stores 4 8x8 b16 matrix from private memory to local memory per sub-group.
2149+
/// Requires the sub-group size of kernel calling this function to be 32.
2150+
/// Each of the 32 work items contain the starting address of their
2151+
/// respective matrix row.
2152+
/// Each of the 32 work items store 128-bits (32-bit per matrix) for a total
2153+
/// of 512 bytes.
21302154
/// \tparam [in] T The type of matrix elements
21312155
/// \param [in] addr The address of the matrix in local memory
21322156
/// \param [in] m1 The private memory containing data of 1st matrix

0 commit comments

Comments
 (0)