@@ -17,6 +17,13 @@ using nativecpu_kernel_t = void(const sycl::detail::NativeCPUArgDesc *,
17
17
using nativecpu_ptr_t = nativecpu_kernel_t *;
18
18
using nativecpu_task_t = std::function<nativecpu_kernel_t >;
19
19
20
+ struct local_arg_info_t {
21
+ uint32_t argIndex;
22
+ size_t argSize;
23
+ local_arg_info_t (uint32_t argIndex, size_t argSize)
24
+ : argIndex(argIndex), argSize(argSize) {}
25
+ };
26
+
20
27
struct ur_kernel_handle_t_ : RefCounted {
21
28
22
29
ur_kernel_handle_t_ (const char *name, nativecpu_task_t subhandler)
@@ -25,4 +32,47 @@ struct ur_kernel_handle_t_ : RefCounted {
25
32
const char *_name;
26
33
nativecpu_task_t _subhandler;
27
34
std::vector<sycl::detail::NativeCPUArgDesc> _args;
35
+ std::vector<local_arg_info_t > _localArgInfo;
36
+
37
+ // To be called before enqueing the kernel.
38
+ void handleLocalArgs () {
39
+ updateMemPool ();
40
+ size_t offset = 0 ;
41
+ for (auto &entry : _localArgInfo) {
42
+ _args[entry.argIndex ].MPtr =
43
+ reinterpret_cast <char *>(_localMemPool) + offset;
44
+ // update offset in the memory pool
45
+ // Todo: update this offset computation when we have work-group
46
+ // level parallelism.
47
+ offset += entry.argSize ;
48
+ }
49
+ }
50
+
51
+ ~ur_kernel_handle_t_ () {
52
+ if (_localMemPool) {
53
+ free (_localMemPool);
54
+ }
55
+ }
56
+
57
+ private:
58
+ void updateMemPool () {
59
+ // compute requested size.
60
+ // Todo: currently we execute only one work-group at a time, so for each
61
+ // local arg we can allocate just 1 * argSize local arg. When we implement
62
+ // work-group level parallelism we should allocate N * argSize where N is
63
+ // the number of work groups being executed in parallel (e.g. number of
64
+ // threads in the thread pool).
65
+ size_t reqSize = 0 ;
66
+ for (auto &entry : _localArgInfo) {
67
+ reqSize += entry.argSize ;
68
+ }
69
+ if (reqSize == 0 || reqSize == _localMemPoolSize) {
70
+ return ;
71
+ }
72
+ // realloc handles nullptr case
73
+ _localMemPool = realloc (_localMemPool, reqSize);
74
+ _localMemPoolSize = reqSize;
75
+ }
76
+ void *_localMemPool = nullptr ;
77
+ size_t _localMemPoolSize = 0 ;
28
78
};
0 commit comments