Skip to content

Commit 6009642

Browse files
PietroGhgmartygrant
authored andcommitted
[SYCL] [NATIVECPU] Implement urKernelSetArgLocal (#11101)
This PR adds support to `local_accessors` by implementing `urKernelSetArgLocal`.
1 parent e1fbecc commit 6009642

File tree

3 files changed

+57
-5
lines changed

3 files changed

+57
-5
lines changed

sycl/plugins/unified_runtime/ur/adapters/native_cpu/enqueue.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6464
// TODO: add proper event dep management
6565
sycl::detail::NDRDescT ndr =
6666
getNDRDesc(workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize);
67+
hKernel->handleLocalArgs();
6768

6869
__nativecpu_state state(ndr.GlobalSize[0], ndr.GlobalSize[1],
6970
ndr.GlobalSize[2], ndr.LocalSize[0], ndr.LocalSize[1],
@@ -90,6 +91,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
9091
// TODO: we should avoid calling clear here by avoiding using push_back
9192
// in setKernelArgs.
9293
hKernel->_args.clear();
94+
hKernel->_localArgInfo.clear();
9395
return UR_RESULT_SUCCESS;
9496
}
9597

sycl/plugins/unified_runtime/ur/adapters/native_cpu/kernel.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
5454
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgLocal(
5555
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
5656
const ur_kernel_arg_local_properties_t *pProperties) {
57-
std::ignore = hKernel;
58-
std::ignore = argIndex;
59-
std::ignore = argSize;
6057
std::ignore = pProperties;
61-
62-
DIE_NO_IMPLEMENTATION
58+
// emplace a placeholder kernel arg, gets replaced with a pointer to the
59+
// memory pool before enqueueing the kernel.
60+
hKernel->_args.emplace_back(nullptr);
61+
hKernel->_localArgInfo.emplace_back(argIndex, argSize);
62+
return UR_RESULT_SUCCESS;
6363
}
6464

6565
UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel,

sycl/plugins/unified_runtime/ur/adapters/native_cpu/kernel.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ using nativecpu_kernel_t = void(const sycl::detail::NativeCPUArgDesc *,
1717
using nativecpu_ptr_t = nativecpu_kernel_t *;
1818
using nativecpu_task_t = std::function<nativecpu_kernel_t>;
1919

20+
struct local_arg_info_t {
21+
uint32_t argIndex;
22+
size_t argSize;
23+
local_arg_info_t(uint32_t argIndex, size_t argSize)
24+
: argIndex(argIndex), argSize(argSize) {}
25+
};
26+
2027
struct ur_kernel_handle_t_ : RefCounted {
2128

2229
ur_kernel_handle_t_(const char *name, nativecpu_task_t subhandler)
@@ -25,4 +32,47 @@ struct ur_kernel_handle_t_ : RefCounted {
2532
const char *_name;
2633
nativecpu_task_t _subhandler;
2734
std::vector<sycl::detail::NativeCPUArgDesc> _args;
35+
std::vector<local_arg_info_t> _localArgInfo;
36+
37+
// To be called before enqueing the kernel.
38+
void handleLocalArgs() {
39+
updateMemPool();
40+
size_t offset = 0;
41+
for (auto &entry : _localArgInfo) {
42+
_args[entry.argIndex].MPtr =
43+
reinterpret_cast<char *>(_localMemPool) + offset;
44+
// update offset in the memory pool
45+
// Todo: update this offset computation when we have work-group
46+
// level parallelism.
47+
offset += entry.argSize;
48+
}
49+
}
50+
51+
~ur_kernel_handle_t_() {
52+
if (_localMemPool) {
53+
free(_localMemPool);
54+
}
55+
}
56+
57+
private:
58+
void updateMemPool() {
59+
// compute requested size.
60+
// Todo: currently we execute only one work-group at a time, so for each
61+
// local arg we can allocate just 1 * argSize local arg. When we implement
62+
// work-group level parallelism we should allocate N * argSize where N is
63+
// the number of work groups being executed in parallel (e.g. number of
64+
// threads in the thread pool).
65+
size_t reqSize = 0;
66+
for (auto &entry : _localArgInfo) {
67+
reqSize += entry.argSize;
68+
}
69+
if (reqSize == 0 || reqSize == _localMemPoolSize) {
70+
return;
71+
}
72+
// realloc handles nullptr case
73+
_localMemPool = realloc(_localMemPool, reqSize);
74+
_localMemPoolSize = reqSize;
75+
}
76+
void *_localMemPool = nullptr;
77+
size_t _localMemPoolSize = 0;
2878
};

0 commit comments

Comments
 (0)