Skip to content

Commit 9e64a92

Browse files
authored
moved reference event to a static object in platform (#6156)
Instead of having a separate reference event for each context, this PR introduces a common one. Closes #6155. I don't think adding tests to the test suite for this change makes sense, but I can add some if the reviewer disagrees.
1 parent caa696f commit 9e64a92

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,23 +409,25 @@ pi_uint64 _pi_event::get_queued_time() const {
409409
assert(is_started());
410410

411411
PI_CHECK_ERROR(
412-
cuEventElapsedTime(&miliSeconds, context_->evBase_, evQueued_));
412+
cuEventElapsedTime(&miliSeconds, _pi_platform::evBase_, evQueued_));
413413
return static_cast<pi_uint64>(miliSeconds * 1.0e6);
414414
}
415415

416416
pi_uint64 _pi_event::get_start_time() const {
417417
float miliSeconds = 0.0f;
418418
assert(is_started());
419419

420-
PI_CHECK_ERROR(cuEventElapsedTime(&miliSeconds, context_->evBase_, evStart_));
420+
PI_CHECK_ERROR(
421+
cuEventElapsedTime(&miliSeconds, _pi_platform::evBase_, evStart_));
421422
return static_cast<pi_uint64>(miliSeconds * 1.0e6);
422423
}
423424

424425
pi_uint64 _pi_event::get_end_time() const {
425426
float miliSeconds = 0.0f;
426427
assert(is_started() && is_recorded());
427428

428-
PI_CHECK_ERROR(cuEventElapsedTime(&miliSeconds, context_->evBase_, evEnd_));
429+
PI_CHECK_ERROR(
430+
cuEventElapsedTime(&miliSeconds, _pi_platform::evBase_, evEnd_));
429431
return static_cast<pi_uint64>(miliSeconds * 1.0e6);
430432
}
431433

@@ -1881,9 +1883,16 @@ pi_result cuda_piContextCreate(const pi_context_properties *properties,
18811883
_pi_context::kind::user_defined, newContext, *devices});
18821884
}
18831885

1884-
// Use default stream to record base event counter
1885-
PI_CHECK_ERROR(cuEventCreate(&piContextPtr->evBase_, CU_EVENT_DEFAULT));
1886-
PI_CHECK_ERROR(cuEventRecord(piContextPtr->evBase_, 0));
1886+
static std::once_flag initFlag;
1887+
std::call_once(
1888+
initFlag,
1889+
[](pi_result &err) {
1890+
// Use default stream to record base event counter
1891+
PI_CHECK_ERROR(
1892+
cuEventCreate(&_pi_platform::evBase_, CU_EVENT_DEFAULT));
1893+
PI_CHECK_ERROR(cuEventRecord(_pi_platform::evBase_, 0));
1894+
},
1895+
errcode_ret);
18871896

18881897
// For non-primary scoped contexts keep the last active on top of the stack
18891898
// as `cuCtxCreate` replaces it implicitly otherwise.
@@ -1913,8 +1922,6 @@ pi_result cuda_piContextRelease(pi_context ctxt) {
19131922

19141923
std::unique_ptr<_pi_context> context{ctxt};
19151924

1916-
PI_CHECK_ERROR(cuEventDestroy(context->evBase_));
1917-
19181925
if (!ctxt->is_primary()) {
19191926
CUcontext cuCtxt = ctxt->get();
19201927
CUcontext current = nullptr;
@@ -5137,3 +5144,5 @@ pi_result piPluginInit(pi_plugin *PluginInit) {
51375144
}
51385145

51395146
} // extern "C"
5147+
5148+
CUevent _pi_platform::evBase_{nullptr};

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ pi_result cuda_piKernelGetGroupInfo(pi_kernel kernel, pi_device device,
6161
/// when devices are used.
6262
///
6363
struct _pi_platform {
64+
static CUevent evBase_; // CUDA event used as base counter
6465
std::vector<std::unique_ptr<_pi_device>> devices_;
6566
};
6667

@@ -162,11 +163,8 @@ struct _pi_context {
162163
_pi_device *deviceId_;
163164
std::atomic_uint32_t refCount_;
164165

165-
CUevent evBase_; // CUDA event used as base counter
166-
167166
_pi_context(kind k, CUcontext ctxt, _pi_device *devId)
168-
: kind_{k}, cuContext_{ctxt}, deviceId_{devId}, refCount_{1},
169-
evBase_(nullptr) {
167+
: kind_{k}, cuContext_{ctxt}, deviceId_{devId}, refCount_{1} {
170168
cuda_piDeviceRetain(deviceId_);
171169
};
172170

0 commit comments

Comments
 (0)