Skip to content

Commit eb084bc

Browse files
authored
[UR] Add raii handle for UR context and device (#19062)
and fix context lifetime management in command_list_manager. There were two issues: - move ctor and assignment operators were declared as default but that was incorrect - they should set context on the moved-from command_list_manager to nullptr to avoid double free. - urContextRelease() was called in the desturctor, which means its context could have been released before other members. This was a problem since zeCommandList was trying to return command_list to the context's cache in it's dtor.
1 parent df48d76 commit eb084bc

File tree

3 files changed

+83
-35
lines changed

3 files changed

+83
-35
lines changed

unified-runtime/source/adapters/level_zero/v2/command_list_manager.cpp

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,7 @@ ur_command_list_manager::ur_command_list_manager(
2121
ur_context_handle_t context, ur_device_handle_t device,
2222
v2::raii::command_list_unique_handle &&commandList)
2323
: hContext(context), hDevice(device),
24-
zeCommandList(std::move(commandList)) {
25-
UR_CALL_THROWS(ur::level_zero::urContextRetain(context));
26-
UR_CALL_THROWS(ur::level_zero::urDeviceRetain(device));
27-
}
28-
29-
ur_command_list_manager::~ur_command_list_manager() {
30-
ur::level_zero::urContextRelease(hContext);
31-
ur::level_zero::urDeviceRelease(hDevice);
32-
}
24+
zeCommandList(std::move(commandList)) {}
3325

3426
ur_result_t ur_command_list_manager::appendGenericFillUnlocked(
3527
ur_mem_buffer_t *dst, size_t offset, size_t patternSize,
@@ -41,8 +33,8 @@ ur_result_t ur_command_list_manager::appendGenericFillUnlocked(
4133
auto waitListView = getWaitListView(phEventWaitList, numEventsInWaitList);
4234

4335
auto pDst = ur_cast<char *>(dst->getDevicePtr(
44-
hDevice, ur_mem_buffer_t::device_access_mode_t::read_only, offset, size,
45-
zeCommandList.get(), waitListView));
36+
hDevice.get(), ur_mem_buffer_t::device_access_mode_t::read_only, offset,
37+
size, zeCommandList.get(), waitListView));
4638

4739
// PatternSize must be a power of two for zeCommandListAppendMemoryFill.
4840
// When it's not, the fill is emulated with zeCommandListAppendMemoryCopy.
@@ -78,12 +70,12 @@ ur_result_t ur_command_list_manager::appendGenericCopyUnlocked(
7870
auto waitListView = getWaitListView(phEventWaitList, numEventsInWaitList);
7971

8072
auto pSrc = ur_cast<char *>(src->getDevicePtr(
81-
hDevice, ur_mem_buffer_t::device_access_mode_t::read_only, srcOffset,
82-
size, zeCommandList.get(), waitListView));
73+
hDevice.get(), ur_mem_buffer_t::device_access_mode_t::read_only,
74+
srcOffset, size, zeCommandList.get(), waitListView));
8375

8476
auto pDst = ur_cast<char *>(dst->getDevicePtr(
85-
hDevice, ur_mem_buffer_t::device_access_mode_t::write_only, dstOffset,
86-
size, zeCommandList.get(), waitListView));
77+
hDevice.get(), ur_mem_buffer_t::device_access_mode_t::write_only,
78+
dstOffset, size, zeCommandList.get(), waitListView));
8779

8880
ZE2UR_CALL(zeCommandListAppendMemoryCopy,
8981
(zeCommandList.get(), pDst, pSrc, size, zeSignalEvent,
@@ -110,10 +102,10 @@ ur_result_t ur_command_list_manager::appendRegionCopyUnlocked(
110102
auto waitListView = getWaitListView(phEventWaitList, numEventsInWaitList);
111103

112104
auto pSrc = ur_cast<char *>(src->getDevicePtr(
113-
hDevice, ur_mem_buffer_t::device_access_mode_t::read_only, 0,
105+
hDevice.get(), ur_mem_buffer_t::device_access_mode_t::read_only, 0,
114106
src->getSize(), zeCommandList.get(), waitListView));
115107
auto pDst = ur_cast<char *>(dst->getDevicePtr(
116-
hDevice, ur_mem_buffer_t::device_access_mode_t::write_only, 0,
108+
hDevice.get(), ur_mem_buffer_t::device_access_mode_t::write_only, 0,
117109
dst->getSize(), zeCommandList.get(), waitListView));
118110

119111
ZE2UR_CALL(zeCommandListAppendMemoryCopyRegion,
@@ -168,22 +160,22 @@ ur_result_t ur_command_list_manager::appendKernelLaunchUnlocked(
168160
UR_ASSERT(workDim > 0, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
169161
UR_ASSERT(workDim < 4, UR_RESULT_ERROR_INVALID_WORK_DIMENSION);
170162

171-
ze_kernel_handle_t hZeKernel = hKernel->getZeHandle(hDevice);
163+
ze_kernel_handle_t hZeKernel = hKernel->getZeHandle(hDevice.get());
172164

173165
std::scoped_lock<ur_shared_mutex> Lock(hKernel->Mutex);
174166

175167
ze_group_count_t zeThreadGroupDimensions{1, 1, 1};
176168
uint32_t WG[3]{};
177-
UR_CALL(calculateKernelWorkDimensions(hZeKernel, hDevice,
169+
UR_CALL(calculateKernelWorkDimensions(hZeKernel, hDevice.get(),
178170
zeThreadGroupDimensions, WG, workDim,
179171
pGlobalWorkSize, pLocalWorkSize));
180172

181173
auto zeSignalEvent = getSignalEvent(phEvent, UR_COMMAND_KERNEL_LAUNCH);
182174
auto waitListView = getWaitListView(phEventWaitList, numEventsInWaitList);
183175

184-
UR_CALL(hKernel->prepareForSubmission(hContext, hDevice, pGlobalWorkOffset,
185-
workDim, WG[0], WG[1], WG[2],
186-
getZeCommandList(), waitListView));
176+
UR_CALL(hKernel->prepareForSubmission(
177+
hContext.get(), hDevice.get(), pGlobalWorkOffset, workDim, WG[0], WG[1],
178+
WG[2], getZeCommandList(), waitListView));
187179

188180
if (cooperative) {
189181
TRACK_SCOPE_LATENCY("ur_command_list_manager::"
@@ -284,7 +276,7 @@ ur_result_t ur_command_list_manager::appendUSMFill(
284276
ur_event_handle_t phEvent) {
285277
TRACK_SCOPE_LATENCY("ur_command_list_manager::appendUSMFill");
286278

287-
ur_usm_handle_t dstHandle(hContext, size, pMem);
279+
ur_usm_handle_t dstHandle(hContext.get(), size, pMem);
288280
return appendGenericFillUnlocked(&dstHandle, 0, patternSize, pPattern, size,
289281
numEventsInWaitList, phEventWaitList,
290282
phEvent, UR_COMMAND_USM_FILL);
@@ -351,7 +343,7 @@ ur_result_t ur_command_list_manager::appendMemBufferRead(
351343
auto hBuffer = hMem->getBuffer();
352344
UR_ASSERT(offset + size <= hBuffer->getSize(), UR_RESULT_ERROR_INVALID_SIZE);
353345

354-
ur_usm_handle_t dstHandle(hContext, size, pDst);
346+
ur_usm_handle_t dstHandle(hContext.get(), size, pDst);
355347

356348
std::scoped_lock<ur_shared_mutex> lock(hBuffer->getMutex());
357349

@@ -369,7 +361,7 @@ ur_result_t ur_command_list_manager::appendMemBufferWrite(
369361
auto hBuffer = hMem->getBuffer();
370362
UR_ASSERT(offset + size <= hBuffer->getSize(), UR_RESULT_ERROR_INVALID_SIZE);
371363

372-
ur_usm_handle_t srcHandle(hContext, size, pSrc);
364+
ur_usm_handle_t srcHandle(hContext.get(), size, pSrc);
373365

374366
std::scoped_lock<ur_shared_mutex> lock(hBuffer->getMutex());
375367

@@ -410,7 +402,7 @@ ur_result_t ur_command_list_manager::appendMemBufferReadRect(
410402
TRACK_SCOPE_LATENCY("ur_command_list_manager::appendMemBufferReadRect");
411403

412404
auto hBuffer = hMem->getBuffer();
413-
ur_usm_handle_t dstHandle(hContext, 0, pDst);
405+
ur_usm_handle_t dstHandle(hContext.get(), 0, pDst);
414406

415407
std::scoped_lock<ur_shared_mutex> lock(hBuffer->getMutex());
416408

@@ -430,7 +422,7 @@ ur_result_t ur_command_list_manager::appendMemBufferWriteRect(
430422
TRACK_SCOPE_LATENCY("ur_command_list_manager::appendMemBufferWriteRect");
431423

432424
auto hBuffer = hMem->getBuffer();
433-
ur_usm_handle_t srcHandle(hContext, 0, pSrc);
425+
ur_usm_handle_t srcHandle(hContext.get(), 0, pSrc);
434426

435427
std::scoped_lock<ur_shared_mutex> lock(hBuffer->getMutex());
436428

@@ -470,8 +462,8 @@ ur_result_t ur_command_list_manager::appendUSMMemcpy2D(
470462
ur_rect_offset_t zeroOffset{0, 0, 0};
471463
ur_rect_region_t region{width, height, 0};
472464

473-
ur_usm_handle_t srcHandle(hContext, 0, pSrc);
474-
ur_usm_handle_t dstHandle(hContext, 0, pDst);
465+
ur_usm_handle_t srcHandle(hContext.get(), 0, pSrc);
466+
ur_usm_handle_t dstHandle(hContext.get(), 0, pDst);
475467

476468
return appendRegionCopyUnlocked(&srcHandle, &dstHandle, blocking, zeroOffset,
477469
zeroOffset, region, srcPitch, 0, dstPitch, 0,
@@ -784,13 +776,14 @@ ur_result_t ur_command_list_manager::appendUSMAllocHelper(
784776
pPool = hContext->getAsyncPool();
785777
}
786778

787-
auto device = (type == UR_USM_TYPE_HOST) ? nullptr : hDevice;
779+
auto device = (type == UR_USM_TYPE_HOST) ? nullptr : hDevice.get();
788780

789781
ur_event_handle_t originAllocEvent = nullptr;
790-
auto asyncAlloc = pPool->allocateEnqueued(hContext, Queue, true, device,
782+
auto asyncAlloc = pPool->allocateEnqueued(hContext.get(), Queue, true, device,
791783
nullptr, type, size);
792784
if (!asyncAlloc) {
793-
auto Ret = pPool->allocate(hContext, device, nullptr, type, size, ppMem);
785+
auto Ret =
786+
pPool->allocate(hContext.get(), device, nullptr, type, size, ppMem);
794787
if (Ret) {
795788
return Ret;
796789
}

unified-runtime/source/adapters/level_zero/v2/command_list_manager.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct ur_command_list_manager {
4747
operator=(const ur_command_list_manager &src) = delete;
4848
ur_command_list_manager &operator=(ur_command_list_manager &&src) = default;
4949

50-
~ur_command_list_manager();
50+
~ur_command_list_manager() = default;
5151

5252
ze_command_list_handle_t getZeCommandList();
5353

@@ -273,8 +273,10 @@ struct ur_command_list_manager {
273273
const ur_event_handle_t *phEventWaitList, ur_event_handle_t phEvent,
274274
ur_command_t commandType);
275275

276-
ur_context_handle_t hContext;
277-
ur_device_handle_t hDevice;
276+
// Context needs to be a first member - it needs to be alive
277+
// until all other members are destroyed.
278+
v2::raii::ur_context_handle_t hContext;
279+
v2::raii::ur_device_handle_t hDevice;
278280

279281
std::vector<ur_kernel_handle_t> submittedKernels;
280282
v2::raii::command_list_unique_handle zeCommandList;

unified-runtime/source/adapters/level_zero/v2/common.hpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,58 @@ HANDLE_WRAPPER_TYPE(ze_context_handle_t, zeContextDestroy)
106106
HANDLE_WRAPPER_TYPE(ze_command_list_handle_t, zeCommandListDestroy)
107107
HANDLE_WRAPPER_TYPE(ze_image_handle_t, zeImageDestroy)
108108

109+
template <typename RawHandle, ur_result_t (*retain)(RawHandle),
110+
ur_result_t (*release)(RawHandle)>
111+
struct ur_handle {
112+
ur_handle(RawHandle handle = nullptr) : handle(handle) {
113+
if (handle) {
114+
retain(handle);
115+
}
116+
}
117+
118+
ur_handle(const ur_handle &) = delete;
119+
ur_handle &operator=(const ur_handle &) = delete;
120+
121+
ur_handle(ur_handle &&rhs) {
122+
this->handle = rhs.handle;
123+
rhs.handle = nullptr;
124+
}
125+
126+
ur_handle &operator=(ur_handle &&rhs) {
127+
if (this == &rhs) {
128+
return *this;
129+
}
130+
131+
if (this->handle) {
132+
release(this->handle);
133+
}
134+
135+
this->handle = rhs.handle;
136+
rhs.handle = nullptr;
137+
138+
return *this;
139+
}
140+
141+
~ur_handle() {
142+
if (handle) {
143+
release(handle);
144+
}
145+
}
146+
147+
RawHandle get() const { return handle; }
148+
149+
RawHandle operator->() const { return get(); }
150+
151+
private:
152+
RawHandle handle;
153+
};
154+
155+
using ur_context_handle_t =
156+
ur_handle<::ur_context_handle_t, ur::level_zero::urContextRetain,
157+
ur::level_zero::urContextRelease>;
158+
using ur_device_handle_t =
159+
ur_handle<::ur_device_handle_t, ur::level_zero::urDeviceRetain,
160+
ur::level_zero::urDeviceRelease>;
161+
109162
} // namespace raii
110163
} // namespace v2

0 commit comments

Comments
 (0)