Skip to content

Commit d3123b9

Browse files
authored
[SYCL][NATIVECPU] Reference counting native_cpu mem handle kernel arguments (#17558)
This PR fixes mem handle arguments for asynchronous kernel launches. This makes at least the following e2e tests pass on native cpu: ``` SpecConstants/2020/hierarchy.cpp SpecConstants/2020/marray_vec.cpp ``` The native_cpu kernel handle now reference-counts the mem handle arguments to prevent their potential release by the SYCL runtime before the enqueued kernel started using them (if the thread takes longer to start). These arguments are then freed when the kernel handle is released. The native_cpu mem handle class now reuses the same reference counting as the other native_cpu handles instead of having its own reference counting members, which were removed.
1 parent 32fc89a commit d3123b9

File tree

5 files changed

+31
-17
lines changed

5 files changed

+31
-17
lines changed

unified-runtime/source/adapters/native_cpu/common.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,6 @@ namespace ur {
5353
} // namespace ur
5454
} // namespace detail
5555

56-
// Base class to store common data
57-
struct _ur_object {
58-
ur_shared_mutex Mutex;
59-
};
60-
6156
// Todo: replace this with a common helper once it is available
6257
struct RefCounted {
6358
std::atomic_uint32_t _refCount;
@@ -67,6 +62,11 @@ struct RefCounted {
6762
uint32_t getReferenceCount() const { return _refCount; }
6863
};
6964

65+
// Base class to store common data
66+
struct _ur_object : RefCounted {
67+
ur_shared_mutex Mutex;
68+
};
69+
7070
template <typename T> inline void decrementOrDelete(T *refC) {
7171
if (refC->decrementReferenceCount() == 0)
7272
delete refC;

unified-runtime/source/adapters/native_cpu/kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
#include "common.hpp"
1515
#include "kernel.hpp"
16-
#include "memory.hpp"
1716
#include "program.hpp"
1817

1918
UR_APIEXPORT ur_result_t UR_APICALL
@@ -271,6 +270,7 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
271270
return UR_RESULT_SUCCESS;
272271
}
273272

273+
hKernel->addArgReference(hArgValue);
274274
hKernel->addPtrArg(hArgValue->_mem, argIndex);
275275
return UR_RESULT_SUCCESS;
276276
}

unified-runtime/source/adapters/native_cpu/kernel.hpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include "common.hpp"
12+
#include "memory.hpp"
1213
#include "nativecpu_state.hpp"
1314
#include "program.hpp"
1415
#include <cstring>
@@ -35,9 +36,14 @@ struct ur_kernel_handle_t_ : RefCounted {
3536
ur_kernel_handle_t_(const ur_kernel_handle_t_ &other)
3637
: Args(other.Args), hProgram(other.hProgram), _name(other._name),
3738
_subhandler(other._subhandler), _localArgInfo(other._localArgInfo),
38-
ReqdWGSize(other.ReqdWGSize) {}
39+
ReqdWGSize(other.ReqdWGSize) {
40+
takeArgReferences(other);
41+
}
3942

40-
~ur_kernel_handle_t_() { free(_localMemPool); }
43+
~ur_kernel_handle_t_() {
44+
removeArgReferences();
45+
free(_localMemPool);
46+
}
4147

4248
ur_kernel_handle_t_(ur_program_handle_t hProgram, const char *name,
4349
nativecpu_task_t subhandler,
@@ -186,10 +192,26 @@ struct ur_kernel_handle_t_ : RefCounted {
186192

187193
void addPtrArg(void *Ptr, size_t Index) { Args.addPtrArg(Index, Ptr); }
188194

195+
void addArgReference(ur_mem_handle_t Arg) {
196+
Arg->incrementReferenceCount();
197+
ReferencedArgs.push_back(Arg);
198+
}
199+
200+
private:
201+
void removeArgReferences() {
202+
for (auto arg : ReferencedArgs)
203+
decrementOrDelete(arg);
204+
}
205+
void takeArgReferences(const ur_kernel_handle_t_ &other) {
206+
for (auto arg : other.ReferencedArgs)
207+
addArgReference(arg);
208+
}
209+
189210
private:
190211
char *_localMemPool = nullptr;
191212
size_t _localMemPoolSize = 0;
192213
std::optional<native_cpu::WGSize_t> ReqdWGSize = std::nullopt;
193214
std::optional<native_cpu::WGSize_t> MaxWGSize = std::nullopt;
194215
std::optional<uint64_t> MaxLinearWGSize = std::nullopt;
216+
std::vector<ur_mem_handle_t> ReferencedArgs;
195217
};

unified-runtime/source/adapters/native_cpu/memory.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) {
7070

7171
UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
7272
UR_ASSERT(hMem, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
73+
decrementOrDelete(hMem);
7374

74-
hMem->decrementRefCount();
75-
if (hMem->_refCount > 0) {
76-
return UR_RESULT_SUCCESS;
77-
}
78-
79-
delete hMem;
8075
return UR_RESULT_SUCCESS;
8176
}
8277

unified-runtime/source/adapters/native_cpu/memory.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,11 @@ struct ur_mem_handle_t_ : _ur_object {
3838
}
3939
}
4040

41-
void decrementRefCount() noexcept { _refCount--; }
42-
4341
// Method to get type of the derived object (image or buffer)
4442
bool isImage() const { return this->IsImage; }
4543

4644
char *_mem;
4745
bool _ownsMem;
48-
std::atomic_uint32_t _refCount = {1};
4946

5047
private:
5148
const bool IsImage;

0 commit comments

Comments
 (0)