Skip to content

Commit 73d85ef

Browse files
authored
Merge pull request #1289 from nrspruit/fix_multiDevice
[L0] Fix native kernel usage, multi device kernel pointer and WorkSize
2 parents e46dc35 + 7985d3e commit 73d85ef

File tree

2 files changed

+36
-31
lines changed

2 files changed

+36
-31
lines changed

source/adapters/level_zero/kernel.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
8686
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
8787
uint32_t WG[3]{};
8888

89-
// global_work_size of unused dimensions must be set to 1
90-
if (WorkDim >= 2) {
91-
UR_ASSERT(WorkDim >= 2 || GlobalWorkSize[1] == 1,
92-
UR_RESULT_ERROR_INVALID_VALUE);
93-
if (WorkDim == 3) {
94-
UR_ASSERT(WorkDim == 3 || GlobalWorkSize[2] == 1,
95-
UR_RESULT_ERROR_INVALID_VALUE);
96-
}
97-
}
89+
// New variable needed because GlobalWorkSize parameter might not be of size 3
90+
size_t GlobalWorkSize3D[3]{1, 1, 1};
91+
std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);
92+
9893
if (LocalWorkSize) {
9994
// L0
10095
UR_ASSERT(LocalWorkSize[0] < (std::numeric_limits<uint32_t>::max)(),
@@ -111,14 +106,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
111106
// values do not fit to 32-bit that the API only supports currently.
112107
bool SuggestGroupSize = true;
113108
for (int I : {0, 1, 2}) {
114-
if (GlobalWorkSize[I] > UINT32_MAX) {
109+
if (GlobalWorkSize3D[I] > UINT32_MAX) {
115110
SuggestGroupSize = false;
116111
}
117112
}
118113
if (SuggestGroupSize) {
119114
ZE2UR_CALL(zeKernelSuggestGroupSize,
120-
(ZeKernel, GlobalWorkSize[0], GlobalWorkSize[1],
121-
GlobalWorkSize[2], &WG[0], &WG[1], &WG[2]));
115+
(ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1],
116+
GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2]));
122117
} else {
123118
for (int I : {0, 1, 2}) {
124119
// Try to find a I-dimension WG size that the GlobalWorkSize[I] is
@@ -128,11 +123,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
128123
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeX,
129124
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeY,
130125
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeZ};
131-
GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize[I]);
132-
while (GlobalWorkSize[I] % GroupSize[I]) {
126+
GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]);
127+
while (GlobalWorkSize3D[I] % GroupSize[I]) {
133128
--GroupSize[I];
134129
}
135-
if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) {
130+
if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) {
136131
urPrint("urEnqueueKernelLaunch: can't find a WG size "
137132
"suitable for global work size > UINT32_MAX\n");
138133
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
@@ -149,22 +144,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
149144
switch (WorkDim) {
150145
case 3:
151146
ZeThreadGroupDimensions.groupCountX =
152-
static_cast<uint32_t>(GlobalWorkSize[0] / WG[0]);
147+
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
153148
ZeThreadGroupDimensions.groupCountY =
154-
static_cast<uint32_t>(GlobalWorkSize[1] / WG[1]);
149+
static_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
155150
ZeThreadGroupDimensions.groupCountZ =
156-
static_cast<uint32_t>(GlobalWorkSize[2] / WG[2]);
151+
static_cast<uint32_t>(GlobalWorkSize3D[2] / WG[2]);
157152
break;
158153
case 2:
159154
ZeThreadGroupDimensions.groupCountX =
160-
static_cast<uint32_t>(GlobalWorkSize[0] / WG[0]);
155+
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
161156
ZeThreadGroupDimensions.groupCountY =
162-
static_cast<uint32_t>(GlobalWorkSize[1] / WG[1]);
157+
static_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
163158
WG[2] = 1;
164159
break;
165160
case 1:
166161
ZeThreadGroupDimensions.groupCountX =
167-
static_cast<uint32_t>(GlobalWorkSize[0] / WG[0]);
162+
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
168163
WG[1] = WG[2] = 1;
169164
break;
170165

@@ -174,19 +169,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
174169
}
175170

176171
// Error handling for non-uniform group size case
177-
if (GlobalWorkSize[0] !=
172+
if (GlobalWorkSize3D[0] !=
178173
size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) {
179174
urPrint("urEnqueueKernelLaunch: invalid work_dim. The range is not a "
180175
"multiple of the group size in the 1st dimension\n");
181176
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
182177
}
183-
if (GlobalWorkSize[1] !=
178+
if (GlobalWorkSize3D[1] !=
184179
size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) {
185180
urPrint("urEnqueueKernelLaunch: invalid work_dim. The range is not a "
186181
"multiple of the group size in the 2nd dimension\n");
187182
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
188183
}
189-
if (GlobalWorkSize[2] !=
184+
if (GlobalWorkSize3D[2] !=
190185
size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) {
191186
urPrint("urEnqueueKernelLaunch: invalid work_dim. The range is not a "
192187
"multiple of the group size in the 3rd dimension\n");
@@ -450,10 +445,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
450445
}
451446

452447
std::scoped_lock<ur_shared_mutex> Guard(Kernel->Mutex);
453-
for (auto It : Kernel->ZeKernelMap) {
454-
auto ZeKernel = It.second;
448+
if (Kernel->ZeKernelMap.empty()) {
449+
auto ZeKernel = Kernel->ZeKernel;
455450
ZE2UR_CALL(zeKernelSetArgumentValue,
456451
(ZeKernel, ArgIndex, ArgSize, PArgValue));
452+
} else {
453+
for (auto It : Kernel->ZeKernelMap) {
454+
auto ZeKernel = It.second;
455+
ZE2UR_CALL(zeKernelSetArgumentValue,
456+
(ZeKernel, ArgIndex, ArgSize, PArgValue));
457+
}
457458
}
458459

459460
return UR_RESULT_SUCCESS;

source/adapters/level_zero/program.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -529,16 +529,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramGetFunctionPointer(
529529
void **FunctionPointerRet ///< [out] Returns the pointer to the function if
530530
///< it is found in the program.
531531
) {
532-
std::ignore = Device;
533-
534532
std::shared_lock<ur_shared_mutex> Guard(Program->Mutex);
535533
if (Program->State != ur_program_handle_t_::Exe) {
536534
return UR_RESULT_ERROR_INVALID_PROGRAM_EXECUTABLE;
537535
}
538536

539-
ze_result_t ZeResult =
540-
ZE_CALL_NOCHECK(zeModuleGetFunctionPointer,
541-
(Program->ZeModule, FunctionName, FunctionPointerRet));
537+
ze_module_handle_t ZeModule{};
538+
auto It = Program->ZeModuleMap.find(Device->ZeDevice);
539+
if (It != Program->ZeModuleMap.end()) {
540+
ZeModule = It->second;
541+
} else {
542+
ZeModule = Program->ZeModule;
543+
}
544+
ze_result_t ZeResult = ZE_CALL_NOCHECK(
545+
zeModuleGetFunctionPointer, (ZeModule, FunctionName, FunctionPointerRet));
542546

543547
// zeModuleGetFunctionPointer currently fails for all
544548
// kernels regardless of if the kernel exist or not

0 commit comments

Comments
 (0)