Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,36 +96,37 @@ ur_result_t calculateKernelWorkDimensions(
// New variable needed because GlobalWorkSize parameter might not be of size
// 3
size_t GlobalWorkSize3D[3]{1, 1, 1};
std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);

if (LocalWorkSize) {
WG[0] = ur_cast<uint32_t>(LocalWorkSize[0]);
WG[1] = WorkDim >= 2 ? ur_cast<uint32_t>(LocalWorkSize[1]) : 1;
WG[2] = WorkDim == 3 ? ur_cast<uint32_t>(LocalWorkSize[2]) : 1;
} else {
std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);
UR_CALL(getSuggestedLocalWorkSize(Device, Kernel, GlobalWorkSize3D, WG));
}
const size_t *GlobalWorkSizePtr = LocalWorkSize ? GlobalWorkSize : GlobalWorkSize3D;

// TODO: assert if sizes do not fit into 32-bit?
switch (WorkDim) {
case 3:
ZeThreadGroupDimensions.groupCountX =
ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
ur_cast<uint32_t>(GlobalWorkSizePtr[0] / WG[0]);
ZeThreadGroupDimensions.groupCountY =
ur_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
ur_cast<uint32_t>(GlobalWorkSizePtr[1] / WG[1]);
ZeThreadGroupDimensions.groupCountZ =
ur_cast<uint32_t>(GlobalWorkSize3D[2] / WG[2]);
ur_cast<uint32_t>(GlobalWorkSizePtr[2] / WG[2]);
break;
case 2:
ZeThreadGroupDimensions.groupCountX =
ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
ur_cast<uint32_t>(GlobalWorkSizePtr[0] / WG[0]);
ZeThreadGroupDimensions.groupCountY =
ur_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
ur_cast<uint32_t>(GlobalWorkSizePtr[1] / WG[1]);
WG[2] = 1;
break;
case 1:
ZeThreadGroupDimensions.groupCountX =
ur_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
ur_cast<uint32_t>(GlobalWorkSizePtr[0] / WG[0]);
WG[1] = WG[2] = 1;
break;

Expand All @@ -135,19 +136,19 @@ ur_result_t calculateKernelWorkDimensions(
}

// Error handling for non-uniform group size case
if (GlobalWorkSize3D[0] !=
if (GlobalWorkSizePtr[0] !=
size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) {
UR_LOG(ERR, "calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a multiple of the group size in the 1st dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
if (GlobalWorkSize3D[1] !=
if (WorkDim >= 2 && GlobalWorkSizePtr[1] !=
size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) {
UR_LOG(ERR, "calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a multiple of the group size in the 2nd dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
if (GlobalWorkSize3D[2] !=
if (WorkDim == 3 && GlobalWorkSizePtr[2] !=
size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) {
UR_LOG(ERR, "calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a multiple of the group size in the 3rd dimension");
Expand Down
Loading