Skip to content

Commit 5e5b7ac

Browse files
committed
Support reqd_work_group_size on native cpu
1 parent 4347e0c commit 5e5b7ac

File tree

6 files changed

+109
-19
lines changed

6 files changed

+109
-19
lines changed

source/adapters/native_cpu/device.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
112112
// '0x8086' : 'Intel HD graphics vendor ID'
113113
return ReturnValue(uint32_t{0x8086});
114114
case UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE:
115-
return ReturnValue(size_t{256});
115+
// TODO: provide a mechanism to estimate/configure this.
116+
return ReturnValue(size_t{2048});
116117
case UR_DEVICE_INFO_MEM_BASE_ADDR_ALIGN:
117118
// Imported from level_zero
118119
return ReturnValue(uint32_t{8});
@@ -151,7 +152,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
151152
case UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES: {
152153
struct {
153154
size_t Arr[3];
154-
} MaxGroupSize = {{256, 256, 1}};
155+
} MaxGroupSize = {{256, 256, 256}};
155156
return ReturnValue(MaxGroupSize);
156157
}
157158
case UR_DEVICE_INFO_PREFERRED_VECTOR_WIDTH_CHAR:

source/adapters/native_cpu/enqueue.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct NDRDescT {
3131
for (uint32_t I = 0; I < WorkDim; I++) {
3232
GlobalOffset[I] = GlobalWorkOffset[I];
3333
GlobalSize[I] = GlobalWorkSize[I];
34-
LocalSize[I] = LocalWorkSize[I];
34+
LocalSize[I] = LocalWorkSize ? LocalWorkSize[I] : 1;
3535
}
3636
for (uint32_t I = WorkDim; I < 3; I++) {
3737
GlobalSize[I] = 1;
@@ -81,6 +81,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
8181
DIE_NO_IMPLEMENTATION;
8282
}
8383

84+
// Check reqd_work_group_size
85+
if (hKernel->hasReqdWGSize() && pLocalWorkSize != nullptr) {
86+
const auto &Reqd = hKernel->getReqdWGSize();
87+
for (uint32_t Dim = 0; Dim < workDim; Dim++) {
88+
if (pLocalWorkSize[Dim] != Reqd[Dim]) {
89+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
90+
}
91+
}
92+
}
93+
8494
// TODO: add proper error checking
8595
// TODO: add proper event dep management
8696
native_cpu::NDRDescT ndr(workDim, pGlobalWorkOffset, pGlobalWorkSize,

source/adapters/native_cpu/kernel.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,16 @@ urKernelCreate(ur_program_handle_t hProgram, const char *pKernelName,
2828

2929
auto f = reinterpret_cast<nativecpu_ptr_t>(
3030
const_cast<unsigned char *>(kernelEntry->second));
31-
auto kernel = new ur_kernel_handle_t_(pKernelName, *f);
31+
ur_kernel_handle_t_ *kernel;
32+
33+
// Set reqd_work_group_size for kernel if needed
34+
const auto &ReqdMap = hProgram->KernelReqdWorkGroupSizeMD;
35+
auto ReqdIt = ReqdMap.find(pKernelName);
36+
if (ReqdIt != ReqdMap.end()) {
37+
kernel = new ur_kernel_handle_t_(hProgram, pKernelName, *f, ReqdIt->second);
38+
} else {
39+
kernel = new ur_kernel_handle_t_(hProgram, pKernelName, *f);
40+
}
3241

3342
*phKernel = kernel;
3443

@@ -84,13 +93,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel,
8493
// case UR_KERNEL_INFO_PROGRAM:
8594
// return ReturnValue(ur_program_handle_t{ Kernel->Program });
8695
case UR_KERNEL_INFO_FUNCTION_NAME:
87-
if (hKernel->_name) {
88-
return ReturnValue(hKernel->_name);
89-
}
90-
return UR_RESULT_ERROR_INVALID_FUNCTION_NAME;
91-
// case UR_KERNEL_INFO_NUM_ARGS:
92-
// return ReturnValue(uint32_t{ Kernel->ZeKernelProperties->numKernelArgs
93-
// });
96+
return ReturnValue(hKernel->_name);
9497
case UR_KERNEL_INFO_REFERENCE_COUNT:
9598
return ReturnValue(uint32_t{hKernel->getReferenceCount()});
9699
case UR_KERNEL_INFO_ATTRIBUTES:
@@ -121,8 +124,16 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
121124
return returnValue(max_threads);
122125
}
123126
case UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE: {
124-
size_t group_size[3] = {1, 1, 1};
125-
return returnValue(group_size, 3);
127+
size_t GroupSize[3] = {0, 0, 0};
128+
const auto &ReqdWGSizeMDMap = hKernel->hProgram->KernelReqdWorkGroupSizeMD;
129+
const auto ReqdWGSizeMD = ReqdWGSizeMDMap.find(hKernel->_name);
130+
if (ReqdWGSizeMD != ReqdWGSizeMDMap.end()) {
131+
const auto ReqdWGSize = ReqdWGSizeMD->second;
132+
GroupSize[0] = std::get<0>(ReqdWGSize);
133+
GroupSize[1] = std::get<1>(ReqdWGSize);
134+
GroupSize[2] = std::get<2>(ReqdWGSize);
135+
}
136+
return returnValue(GroupSize, 3);
126137
}
127138
case UR_KERNEL_GROUP_INFO_LOCAL_MEM_SIZE: {
128139
int bytes = 0;

source/adapters/native_cpu/kernel.hpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include "common.hpp"
1212
#include "nativecpu_state.hpp"
13+
#include "program.hpp"
14+
#include <array>
1315
#include <ur_api.h>
1416
#include <utility>
1517

@@ -37,13 +39,16 @@ struct local_arg_info_t {
3739

3840
struct ur_kernel_handle_t_ : RefCounted {
3941

40-
ur_kernel_handle_t_(const char *name, nativecpu_task_t subhandler)
41-
: _name{name}, _subhandler{std::move(subhandler)} {}
42+
ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name,
43+
nativecpu_task_t subhandler)
44+
: hProgram(hProgram), _name{name}, _subhandler{std::move(subhandler)},
45+
HasReqdWGSize(false) {}
4246

4347
ur_kernel_handle_t_(const ur_kernel_handle_t_ &other)
4448
: _name(other._name), _subhandler(other._subhandler), _args(other._args),
4549
_localArgInfo(other._localArgInfo), _localMemPool(other._localMemPool),
46-
_localMemPoolSize(other._localMemPoolSize) {
50+
_localMemPoolSize(other._localMemPoolSize),
51+
HasReqdWGSize(other.HasReqdWGSize), ReqdWGSize(other.ReqdWGSize) {
4752
incrementReferenceCount();
4853
}
4954

@@ -52,13 +57,22 @@ struct ur_kernel_handle_t_ : RefCounted {
5257
free(_localMemPool);
5358
}
5459
}
55-
56-
const char *_name;
60+
ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name,
61+
nativecpu_task_t subhandler,
62+
const native_cpu::ReqdWGSize_t &ReqdWGSize)
63+
: hProgram(hProgram), _name{name}, _subhandler{std::move(subhandler)},
64+
HasReqdWGSize(true), ReqdWGSize(ReqdWGSize) {}
65+
66+
ur_program_handle_t hProgram;
67+
std::string _name;
5768
nativecpu_task_t _subhandler;
5869
std::vector<native_cpu::NativeCPUArgDesc> _args;
5970
std::vector<local_arg_info_t> _localArgInfo;
6071

61-
// To be called before enqueueing the kernel.
72+
bool hasReqdWGSize() const { return HasReqdWGSize; }
73+
74+
const native_cpu::ReqdWGSize_t &getReqdWGSize() const { return ReqdWGSize; }
75+
6276
void updateMemPool(size_t numParallelThreads) {
6377
// compute requested size.
6478
size_t reqSize = 0;
@@ -88,4 +102,6 @@ struct ur_kernel_handle_t_ : RefCounted {
88102
private:
89103
char *_localMemPool = nullptr;
90104
size_t _localMemPoolSize = 0;
105+
bool HasReqdWGSize;
106+
native_cpu::ReqdWGSize_t ReqdWGSize;
91107
};

source/adapters/native_cpu/program.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "common.hpp"
1414
#include "program.hpp"
15+
#include <cstdint>
1516

1617
UR_APIEXPORT ur_result_t UR_APICALL
1718
urProgramCreateWithIL(ur_context_handle_t hContext, const void *pIL,
@@ -26,6 +27,39 @@ urProgramCreateWithIL(ur_context_handle_t hContext, const void *pIL,
2627
DIE_NO_IMPLEMENTATION
2728
}
2829

30+
// TODO: taken from CUDA adapter, move this to a common header?
31+
static std::pair<std::string, std::string>
32+
splitMetadataName(const std::string &metadataName) {
33+
size_t splitPos = metadataName.rfind('@');
34+
if (splitPos == std::string::npos)
35+
return std::make_pair(metadataName, std::string{});
36+
return std::make_pair(metadataName.substr(0, splitPos),
37+
metadataName.substr(splitPos, metadataName.length()));
38+
}
39+
40+
static ur_result_t getReqdWGSize(const ur_program_metadata_t &MetadataElement,
41+
native_cpu::ReqdWGSize_t &res) {
42+
size_t MDElemsSize = MetadataElement.size - sizeof(std::uint64_t);
43+
44+
// Expect between 1 and 3 32-bit integer values.
45+
UR_ASSERT(MDElemsSize == sizeof(std::uint32_t) ||
46+
MDElemsSize == sizeof(std::uint32_t) * 2 ||
47+
MDElemsSize == sizeof(std::uint32_t) * 3,
48+
UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
49+
50+
// Get pointer to data, skipping 64-bit size at the start of the data.
51+
const char *ValuePtr =
52+
reinterpret_cast<const char *>(MetadataElement.value.pData) +
53+
sizeof(std::uint64_t);
54+
// Read values and pad with 1's for values not present.
55+
std::uint32_t ReqdWorkGroupElements[] = {1, 1, 1};
56+
std::memcpy(ReqdWorkGroupElements, ValuePtr, MDElemsSize);
57+
std::get<0>(res) = ReqdWorkGroupElements[0];
58+
std::get<1>(res) = ReqdWorkGroupElements[1];
59+
std::get<2>(res) = ReqdWorkGroupElements[2];
60+
return UR_RESULT_SUCCESS;
61+
}
62+
2963
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
3064
ur_context_handle_t hContext, ur_device_handle_t hDevice, size_t size,
3165
const uint8_t *pBinary, const ur_program_properties_t *pProperties,
@@ -40,6 +74,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
4074

4175
auto hProgram = new ur_program_handle_t_(
4276
hContext, reinterpret_cast<const unsigned char *>(pBinary));
77+
if (pProperties != nullptr) {
78+
for (uint32_t i = 0; i < pProperties->count; i++) {
79+
auto mdNode = pProperties->pMetadatas[i];
80+
std::string mdName(mdNode.pName);
81+
auto [Prefix, Tag] = splitMetadataName(mdName);
82+
if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE) {
83+
native_cpu::ReqdWGSize_t reqdWGSize;
84+
getReqdWGSize(mdNode, reqdWGSize);
85+
hProgram->KernelReqdWorkGroupSizeMD[Prefix] = std::move(reqdWGSize);
86+
}
87+
}
88+
}
4389

4490
const nativecpu_entry *nativecpu_it =
4591
reinterpret_cast<const nativecpu_entry *>(pBinary);

source/adapters/native_cpu/program.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
#include "context.hpp"
1616
#include <map>
1717

18+
namespace native_cpu {
19+
using ReqdWGSize_t = std::array<uint32_t, 3>;
20+
}
21+
1822
struct ur_program_handle_t_ : RefCounted {
1923
ur_program_handle_t_(ur_context_handle_t ctx, const unsigned char *pBinary)
2024
: _ctx{ctx}, _ptr{pBinary} {}
@@ -30,6 +34,8 @@ struct ur_program_handle_t_ : RefCounted {
3034
};
3135

3236
std::map<const char *, const unsigned char *, _compare> _kernels;
37+
std::unordered_map<std::string, native_cpu::ReqdWGSize_t>
38+
KernelReqdWorkGroupSizeMD;
3339
};
3440

3541
// The nativecpu_entry struct is also defined as LLVM-IR in the

0 commit comments

Comments
 (0)