|
8 | 8 | //
|
9 | 9 | //===----------------------------------------------------------------------===//
|
10 | 10 | #include "command_buffer.hpp"
|
| 11 | +#include "helpers/kernel_helpers.hpp" |
11 | 12 | #include "logger/ur_logger.hpp"
|
12 | 13 | #include "ur_level_zero.hpp"
|
13 | 14 |
|
@@ -78,130 +79,6 @@ preferCopyEngineForFill(ur_exp_command_buffer_handle_t CommandBuffer,
|
78 | 79 | return UR_RESULT_SUCCESS;
|
79 | 80 | }
|
80 | 81 |
|
81 |
| -/** |
82 |
| - * Calculates a work group size for the kernel based on the GlobalWorkSize or |
83 |
| - * the LocalWorkSize if provided. |
84 |
| - * @param[in][optional] Kernel The Kernel. Used when LocalWorkSize is not |
85 |
| - * provided. |
86 |
| - * @param[in][optional] Device The device associated with the kernel. Used when |
87 |
| - * LocalWorkSize is not provided. |
88 |
| - * @param[out] ZeThreadGroupDimensions Number of work groups in each dimension. |
89 |
| - * @param[out] WG The work group size for each dimension. |
90 |
| - * @param[in] WorkDim The number of dimensions in the kernel. |
91 |
| - * @param[in] GlobalWorkSize The global work size. |
92 |
| - * @param[in][optional] LocalWorkSize The local work size. |
93 |
| - * @return UR_RESULT_SUCCESS or an error code on failure. |
94 |
| - */ |
95 |
| -ur_result_t calculateKernelWorkDimensions( |
96 |
| - ur_kernel_handle_t Kernel, ur_device_handle_t Device, |
97 |
| - ze_group_count_t &ZeThreadGroupDimensions, uint32_t (&WG)[3], |
98 |
| - uint32_t WorkDim, const size_t *GlobalWorkSize, |
99 |
| - const size_t *LocalWorkSize) { |
100 |
| - |
101 |
| - UR_ASSERT(GlobalWorkSize, UR_RESULT_ERROR_INVALID_VALUE); |
102 |
| - // If LocalWorkSize is not provided then Kernel must be provided to query |
103 |
| - // suggested group size. |
104 |
| - UR_ASSERT(LocalWorkSize || Kernel, UR_RESULT_ERROR_INVALID_VALUE); |
105 |
| - |
106 |
| - // New variable needed because GlobalWorkSize parameter might not be of size |
107 |
| - // 3 |
108 |
| - size_t GlobalWorkSize3D[3]{1, 1, 1}; |
109 |
| - std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D); |
110 |
| - |
111 |
| - if (LocalWorkSize) { |
112 |
| - WG[0] = ur_cast<uint32_t>(LocalWorkSize[0]); |
113 |
| - WG[1] = WorkDim >= 2 ? ur_cast<uint32_t>(LocalWorkSize[1]) : 1; |
114 |
| - WG[2] = WorkDim == 3 ? ur_cast<uint32_t>(LocalWorkSize[2]) : 1; |
115 |
| - } else { |
116 |
| - // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize3D |
117 |
| - // values do not fit to 32-bit that the API only supports currently. |
118 |
| - bool SuggestGroupSize = true; |
119 |
| - for (int I : {0, 1, 2}) { |
120 |
| - if (GlobalWorkSize3D[I] > UINT32_MAX) { |
121 |
| - SuggestGroupSize = false; |
122 |
| - } |
123 |
| - } |
124 |
| - if (SuggestGroupSize) { |
125 |
| - ZE2UR_CALL(zeKernelSuggestGroupSize, |
126 |
| - (Kernel->ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1], |
127 |
| - GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2])); |
128 |
| - } else { |
129 |
| - for (int I : {0, 1, 2}) { |
130 |
| - // Try to find a I-dimension WG size that the GlobalWorkSize3D[I] is |
131 |
| - // fully divisable with. Start with the max possible size in |
132 |
| - // each dimension. |
133 |
| - uint32_t GroupSize[] = { |
134 |
| - Device->ZeDeviceComputeProperties->maxGroupSizeX, |
135 |
| - Device->ZeDeviceComputeProperties->maxGroupSizeY, |
136 |
| - Device->ZeDeviceComputeProperties->maxGroupSizeZ}; |
137 |
| - GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]); |
138 |
| - while (GlobalWorkSize3D[I] % GroupSize[I]) { |
139 |
| - --GroupSize[I]; |
140 |
| - } |
141 |
| - if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) { |
142 |
| - logger::debug("calculateKernelWorkDimensions: can't find a WG size " |
143 |
| - "suitable for global work size > UINT32_MAX"); |
144 |
| - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; |
145 |
| - } |
146 |
| - WG[I] = GroupSize[I]; |
147 |
| - } |
148 |
| - logger::debug("calculateKernelWorkDimensions: using computed WG " |
149 |
| - "size = {{{}, {}, {}}}", |
150 |
| - WG[0], WG[1], WG[2]); |
151 |
| - } |
152 |
| - } |
153 |
| - |
154 |
| - // TODO: assert if sizes do not fit into 32-bit? |
155 |
| - switch (WorkDim) { |
156 |
| - case 3: |
157 |
| - ZeThreadGroupDimensions.groupCountX = |
158 |
| - ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]); |
159 |
| - ZeThreadGroupDimensions.groupCountY = |
160 |
| - ur_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]); |
161 |
| - ZeThreadGroupDimensions.groupCountZ = |
162 |
| - ur_cast<uint32_t>(GlobalWorkSize3D[2] / WG[2]); |
163 |
| - break; |
164 |
| - case 2: |
165 |
| - ZeThreadGroupDimensions.groupCountX = |
166 |
| - ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]); |
167 |
| - ZeThreadGroupDimensions.groupCountY = |
168 |
| - ur_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]); |
169 |
| - WG[2] = 1; |
170 |
| - break; |
171 |
| - case 1: |
172 |
| - ZeThreadGroupDimensions.groupCountX = |
173 |
| - ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]); |
174 |
| - WG[1] = WG[2] = 1; |
175 |
| - break; |
176 |
| - |
177 |
| - default: |
178 |
| - logger::error("calculateKernelWorkDimensions: unsupported work_dim"); |
179 |
| - return UR_RESULT_ERROR_INVALID_VALUE; |
180 |
| - } |
181 |
| - |
182 |
| - // Error handling for non-uniform group size case |
183 |
| - if (GlobalWorkSize3D[0] != |
184 |
| - size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) { |
185 |
| - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " |
186 |
| - "is not a multiple of the group size in the 1st dimension"); |
187 |
| - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; |
188 |
| - } |
189 |
| - if (GlobalWorkSize3D[1] != |
190 |
| - size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) { |
191 |
| - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " |
192 |
| - "is not a multiple of the group size in the 2nd dimension"); |
193 |
| - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; |
194 |
| - } |
195 |
| - if (GlobalWorkSize3D[2] != |
196 |
| - size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) { |
197 |
| - logger::error("calculateKernelWorkDimensions: invalid work_dim. The range " |
198 |
| - "is not a multiple of the group size in the 3rd dimension"); |
199 |
| - return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE; |
200 |
| - } |
201 |
| - |
202 |
| - return UR_RESULT_SUCCESS; |
203 |
| -} |
204 |
| - |
205 | 82 | /**
|
206 | 83 | * Helper function for finding the Level Zero events associated with the
|
207 | 84 | * commands in a command-buffer, each event is pointed to by a sync-point in the
|
@@ -880,7 +757,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
|
880 | 757 |
|
881 | 758 | ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
|
882 | 759 | uint32_t WG[3];
|
883 |
| - UR_CALL(calculateKernelWorkDimensions(Kernel, CommandBuffer->Device, |
| 760 | + UR_CALL(calculateKernelWorkDimensions(Kernel->ZeKernel, CommandBuffer->Device, |
884 | 761 | ZeThreadGroupDimensions, WG, WorkDim,
|
885 | 762 | GlobalWorkSize, LocalWorkSize));
|
886 | 763 |
|
@@ -1584,8 +1461,8 @@ ur_result_t updateKernelCommand(
|
1584 | 1461 |
|
1585 | 1462 | uint32_t WG[3];
|
1586 | 1463 | UR_CALL(calculateKernelWorkDimensions(
|
1587 |
| - Command->Kernel, CommandBuffer->Device, ZeThreadGroupDimensions, WG, |
1588 |
| - Dim, NewGlobalWorkSize, NewLocalWorkSize)); |
| 1464 | + Command->Kernel->ZeKernel, CommandBuffer->Device, |
| 1465 | + ZeThreadGroupDimensions, WG, Dim, NewGlobalWorkSize, NewLocalWorkSize)); |
1589 | 1466 |
|
1590 | 1467 | auto MutableGroupCountDesc =
|
1591 | 1468 | std::make_unique<ZeStruct<ze_mutable_group_count_exp_desc_t>>();
|
|
0 commit comments