Skip to content

Commit ddfba84

Browse files
committed
[SYCL][NFC] Make kernel_impl::getAdapter() return by reference
1 parent 4aef322 commit ddfba84

18 files changed

+140
-141
lines changed

sycl/source/detail/buffer_impl.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,12 @@ buffer_impl::getNativeVector(backend BackendName) const {
8383
if (Platform.getBackend() != BackendName)
8484
continue;
8585

86-
auto Adapter = Platform.getAdapter();
87-
86+
adapter_impl &Adapter = Platform.getAdapter();
8887
ur_native_handle_t Handle = 0;
8988
// When doing buffer interop we don't know what device the memory should be
9089
// resident on, so pass nullptr for Device param. Buffer interop may not be
9190
// supported by all backends.
92-
Adapter->call<UrApiKind::urMemGetNativeHandle>(NativeMem, /*Dev*/ nullptr,
91+
Adapter.call<UrApiKind::urMemGetNativeHandle>(NativeMem, /*Dev*/ nullptr,
9392
&Handle);
9493
Handles.push_back(Handle);
9594

sycl/source/detail/context_impl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
9494
const async_handler &get_async_handler() const;
9595

9696
/// \return the Adapter associated with the platform of this context.
97-
const AdapterPtr &getAdapter() const { return MPlatform->getAdapter(); }
97+
const AdapterPtr &getAdapter() const { return &MPlatform->getAdapter(); }
9898

9999
/// \return the PlatformImpl associated with this context.
100100
platform_impl &getPlatformImpl() const { return *MPlatform; }
@@ -382,7 +382,7 @@ inline auto get_ur_handles(const sycl::device &syclDevice,
382382
inline auto get_ur_handles(const sycl::device &syclDevice) {
383383
auto &implDevice = *sycl::detail::getSyclObjImpl(syclDevice);
384384
ur_device_handle_t urDevice = implDevice.getHandleRef();
385-
return std::tuple{urDevice, implDevice.getAdapter()};
385+
return std::tuple{urDevice, &implDevice.getAdapter()};
386386
}
387387

388388
} // namespace _V1

sycl/source/detail/device_impl.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ device_impl::device_impl(ur_device_handle_t Device, platform_impl &Platform,
3232
MCache{*this} {
3333
// Interoperability Constructor already calls DeviceRetain in
3434
// urDeviceCreateWithNativeHandle.
35-
getAdapter()->call<UrApiKind::urDeviceRetain>(MDevice);
35+
getAdapter().call<UrApiKind::urDeviceRetain>(MDevice);
3636
}
3737

3838
device_impl::~device_impl() {
3939
try {
4040
// TODO catch an exception and put it to list of asynchronous exceptions
41-
const AdapterPtr &Adapter = getAdapter();
41+
adapter_impl &Adapter = getAdapter();
4242
ur_result_t Err =
43-
Adapter->call_nocheck<UrApiKind::urDeviceRelease>(MDevice);
44-
__SYCL_CHECK_UR_CODE_NO_EXC(Err, Adapter->getBackend());
43+
Adapter.call_nocheck<UrApiKind::urDeviceRelease>(MDevice);
44+
__SYCL_CHECK_UR_CODE_NO_EXC(Err, Adapter.getBackend());
4545
} catch (std::exception &e) {
4646
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~device_impl", e);
4747
}
@@ -123,8 +123,8 @@ std::vector<device> device_impl::create_sub_devices(
123123
size_t SubDevicesCount) const {
124124
std::vector<ur_device_handle_t> SubDevices(SubDevicesCount);
125125
uint32_t ReturnedSubDevices = 0;
126-
const AdapterPtr &Adapter = getAdapter();
127-
Adapter->call<sycl::errc::invalid, UrApiKind::urDevicePartition>(
126+
adapter_impl &Adapter = getAdapter();
127+
Adapter.call<sycl::errc::invalid, UrApiKind::urDevicePartition>(
128128
MDevice, Properties, SubDevicesCount, SubDevices.data(),
129129
&ReturnedSubDevices);
130130
if (ReturnedSubDevices != SubDevicesCount) {
@@ -270,8 +270,8 @@ std::vector<device> device_impl::create_sub_devices(
270270
Properties.pProperties = &Prop;
271271

272272
uint32_t SubDevicesCount = 0;
273-
const AdapterPtr &Adapter = getAdapter();
274-
Adapter->call<sycl::errc::invalid, UrApiKind::urDevicePartition>(
273+
adapter_impl &Adapter = getAdapter();
274+
Adapter.call<sycl::errc::invalid, UrApiKind::urDevicePartition>(
275275
MDevice, &Properties, 0, nullptr, &SubDevicesCount);
276276

277277
return create_sub_devices(&Properties, SubDevicesCount);
@@ -295,17 +295,17 @@ std::vector<device> device_impl::create_sub_devices() const {
295295
Properties.PropCount = 1;
296296

297297
uint32_t SubDevicesCount = 0;
298-
const AdapterPtr &Adapter = getAdapter();
299-
Adapter->call<UrApiKind::urDevicePartition>(MDevice, &Properties, 0, nullptr,
298+
adapter_impl &Adapter = getAdapter();
299+
Adapter.call<UrApiKind::urDevicePartition>(MDevice, &Properties, 0, nullptr,
300300
&SubDevicesCount);
301301

302302
return create_sub_devices(&Properties, SubDevicesCount);
303303
}
304304

305305
ur_native_handle_t device_impl::getNative() const {
306-
auto Adapter = getAdapter();
306+
adapter_impl &Adapter = getAdapter();
307307
ur_native_handle_t Handle;
308-
Adapter->call<UrApiKind::urDeviceGetNativeHandle>(getHandleRef(), &Handle);
308+
Adapter.call<UrApiKind::urDeviceGetNativeHandle>(getHandleRef(), &Handle);
309309
if (getBackend() == backend::opencl) {
310310
__SYCL_OCL_CALL(clRetainDevice, ur::cast<cl_device_id>(Handle));
311311
}
@@ -327,7 +327,7 @@ uint64_t device_impl::getCurrentDeviceTime() {
327327
auto GetGlobalTimestamps = [this](ur_device_handle_t Device,
328328
uint64_t *DeviceTime, uint64_t *HostTime) {
329329
auto Result =
330-
getAdapter()->call_nocheck<UrApiKind::urDeviceGetGlobalTimestamps>(
330+
getAdapter().call_nocheck<UrApiKind::urDeviceGetGlobalTimestamps>(
331331
Device, DeviceTime, HostTime);
332332
if (Result == UR_RESULT_ERROR_INVALID_OPERATION) {
333333
// NOTE(UR port): Removed the call to GetLastError because we shouldn't
@@ -339,7 +339,7 @@ uint64_t device_impl::getCurrentDeviceTime() {
339339
"Device and/or backend does not support querying timestamp."),
340340
UR_RESULT_ERROR_INVALID_OPERATION);
341341
} else {
342-
getAdapter()->checkUrResult<errc::feature_not_supported>(Result);
342+
getAdapter().checkUrResult<errc::feature_not_supported>(Result);
343343
}
344344
};
345345

sycl/source/detail/device_impl.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
113113

114114
bool has_info_desc(ur_device_info_t Desc) const {
115115
size_t return_size = 0;
116-
return getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
116+
return getAdapter().call_nocheck<UrApiKind::urDeviceGetInfo>(
117117
MDevice, Desc, 0, nullptr, &return_size) == UR_RESULT_SUCCESS;
118118
}
119119

@@ -153,23 +153,23 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
153153
!check_type_in_v<typename ur_ret_t::value_type, bool, std::string>);
154154
size_t ResultSize = 0;
155155
ur_result_t Error =
156-
getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
156+
getAdapter().call_nocheck<UrApiKind::urDeviceGetInfo>(
157157
getHandleRef(), Desc, 0, nullptr, &ResultSize);
158158
if (Error != UR_RESULT_SUCCESS)
159159
return {Error};
160160
if (ResultSize == 0)
161161
return {ur_ret_t{}};
162162

163163
ur_ret_t Result(ResultSize / sizeof(typename ur_ret_t::value_type));
164-
Error = getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
164+
Error = getAdapter().call_nocheck<UrApiKind::urDeviceGetInfo>(
165165
getHandleRef(), Desc, ResultSize, Result.data(), nullptr);
166166
if (Error != UR_RESULT_SUCCESS)
167167
return {Error};
168168
return {Result};
169169
} else {
170170
ur_ret_t Result;
171171
ur_result_t Error =
172-
getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
172+
getAdapter().call_nocheck<UrApiKind::urDeviceGetInfo>(
173173
getHandleRef(), Desc, sizeof(Result), &Result, nullptr);
174174
if (Error == UR_RESULT_SUCCESS)
175175
return {Result};
@@ -188,18 +188,18 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
188188
return urGetInfoString<UrApiKind::urDeviceGetInfo>(*this, Desc);
189189
} else if constexpr (is_std_vector_v<ur_ret_t>) {
190190
size_t ResultSize = 0;
191-
getAdapter()->call<UrApiKind::urDeviceGetInfo>(getHandleRef(), Desc, 0,
191+
getAdapter().call<UrApiKind::urDeviceGetInfo>(getHandleRef(), Desc, 0,
192192
nullptr, &ResultSize);
193193
if (ResultSize == 0)
194194
return ur_ret_t{};
195195

196196
ur_ret_t Result(ResultSize / sizeof(typename ur_ret_t::value_type));
197-
getAdapter()->call<UrApiKind::urDeviceGetInfo>(
197+
getAdapter().call<UrApiKind::urDeviceGetInfo>(
198198
getHandleRef(), Desc, ResultSize, Result.data(), nullptr);
199199
return Result;
200200
} else {
201201
ur_ret_t Result;
202-
getAdapter()->call<UrApiKind::urDeviceGetInfo>(
202+
getAdapter().call<UrApiKind::urDeviceGetInfo>(
203203
getHandleRef(), Desc, sizeof(Result), &Result, nullptr);
204204
return Result;
205205
}
@@ -468,7 +468,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
468468
platform get_platform() const;
469469

470470
/// \return the associated adapter with this device.
471-
const AdapterPtr &getAdapter() const { return MPlatform->getAdapter(); }
471+
adapter_impl &getAdapter() const { return MPlatform->getAdapter(); }
472472

473473
/// Check SYCL extension support by device
474474
///
@@ -724,7 +724,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
724724
CASE(info::device::platform) {
725725
return createSyclObjFromImpl<platform>(
726726
platform_impl::getOrMakePlatformImpl(
727-
get_info_impl<UR_DEVICE_INFO_PLATFORM>(), *getAdapter()));
727+
get_info_impl<UR_DEVICE_INFO_PLATFORM>(), getAdapter()));
728728
}
729729

730730
CASE(info::device::profile) {
@@ -940,7 +940,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
940940

941941
// TODO: std::array<size_t, 3> ?
942942
size_t result[3];
943-
getAdapter()->call<UrApiKind::urDeviceGetInfo>(
943+
getAdapter().call<UrApiKind::urDeviceGetInfo>(
944944
getHandleRef(), UR_DEVICE_INFO_MAX_WORK_GROUPS_3D, sizeof(result),
945945
&result, nullptr);
946946
return id<3>(std::min(Limit, result[2]), std::min(Limit, result[1]),
@@ -1011,7 +1011,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
10111011
ur_result_t Err = Devs.error();
10121012
if (Err == UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION)
10131013
return std::vector<sycl::device>{};
1014-
getAdapter()->checkUrResult(Err);
1014+
getAdapter().checkUrResult(Err);
10151015
}
10161016

10171017
std::vector<sycl::device> Result;
@@ -1488,7 +1488,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
14881488
CASE(ext_oneapi_graph) {
14891489
ur_device_command_buffer_update_capability_flags_t UpdateCapabilities;
14901490
bool CallSuccessful =
1491-
getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
1491+
getAdapter().call_nocheck<UrApiKind::urDeviceGetInfo>(
14921492
MDevice, UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_CAPABILITIES_EXP,
14931493
sizeof(UpdateCapabilities), &UpdateCapabilities,
14941494
nullptr) == UR_RESULT_SUCCESS;
@@ -1510,7 +1510,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
15101510
CASE(ext_oneapi_limited_graph) {
15111511
bool SupportsCommandBuffers = false;
15121512
bool CallSuccessful =
1513-
getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
1513+
getAdapter().call_nocheck<UrApiKind::urDeviceGetInfo>(
15141514
MDevice, UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP,
15151515
sizeof(SupportsCommandBuffers), &SupportsCommandBuffers,
15161516
nullptr) == UR_RESULT_SUCCESS;
@@ -1875,7 +1875,7 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
18751875
// Not all devices support this device info query
18761876
return std::nullopt;
18771877
}
1878-
getAdapter()->checkUrResult(Err);
1878+
getAdapter().checkUrResult(Err);
18791879
}
18801880

18811881
auto Val = static_cast<int>(DeviceIp.value());

sycl/source/detail/error_handling/error_handling.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ void handleOutOfResources(const device_impl &DeviceImpl,
3737
const size_t TotalNumberOfWIs =
3838
NDRDesc.LocalSize[0] * NDRDesc.LocalSize[1] * NDRDesc.LocalSize[2];
3939

40-
const AdapterPtr &Adapter = DeviceImpl.getAdapter();
40+
adapter_impl &Adapter = DeviceImpl.getAdapter();
4141
uint32_t NumRegisters = 0;
42-
Adapter->call<UrApiKind::urKernelGetInfo>(Kernel, UR_KERNEL_INFO_NUM_REGS,
42+
Adapter.call<UrApiKind::urKernelGetInfo>(Kernel, UR_KERNEL_INFO_NUM_REGS,
4343
sizeof(NumRegisters),
4444
&NumRegisters, nullptr);
4545

@@ -96,32 +96,32 @@ void handleInvalidWorkGroupSize(const device_impl &DeviceImpl,
9696
IsLevelZero = true;
9797
}
9898

99-
const AdapterPtr &Adapter = DeviceImpl.getAdapter();
99+
adapter_impl &Adapter = DeviceImpl.getAdapter();
100100
ur_device_handle_t Device = DeviceImpl.getHandleRef();
101101

102102
size_t CompileWGSize[3] = {0};
103-
Adapter->call<UrApiKind::urKernelGetGroupInfo>(
103+
Adapter.call<UrApiKind::urKernelGetGroupInfo>(
104104
Kernel, Device, UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE,
105105
sizeof(size_t) * 3, CompileWGSize, nullptr);
106106

107107
size_t CompileMaxWGSize[3] = {0};
108-
ur_result_t URRes = Adapter->call_nocheck<UrApiKind::urKernelGetGroupInfo>(
108+
ur_result_t URRes = Adapter.call_nocheck<UrApiKind::urKernelGetGroupInfo>(
109109
Kernel, Device, UR_KERNEL_GROUP_INFO_COMPILE_MAX_WORK_GROUP_SIZE,
110110
sizeof(size_t) * 3, CompileMaxWGSize, nullptr);
111111
if (URRes != UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION) {
112-
Adapter->checkUrResult(URRes);
112+
Adapter.checkUrResult(URRes);
113113
}
114114

115115
size_t CompileMaxLinearWGSize = 0;
116-
URRes = Adapter->call_nocheck<UrApiKind::urKernelGetGroupInfo>(
116+
URRes = Adapter.call_nocheck<UrApiKind::urKernelGetGroupInfo>(
117117
Kernel, Device, UR_KERNEL_GROUP_INFO_COMPILE_MAX_LINEAR_WORK_GROUP_SIZE,
118118
sizeof(size_t), &CompileMaxLinearWGSize, nullptr);
119119
if (URRes != UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION) {
120-
Adapter->checkUrResult(URRes);
120+
Adapter.checkUrResult(URRes);
121121
}
122122

123123
size_t MaxWGSize = 0;
124-
Adapter->call<UrApiKind::urDeviceGetInfo>(
124+
Adapter.call<UrApiKind::urDeviceGetInfo>(
125125
Device, UR_DEVICE_INFO_MAX_WORK_GROUP_SIZE, sizeof(size_t), &MaxWGSize,
126126
nullptr);
127127

@@ -186,7 +186,7 @@ void handleInvalidWorkGroupSize(const device_impl &DeviceImpl,
186186
}
187187

188188
size_t MaxThreadsPerBlock[3] = {};
189-
Adapter->call<UrApiKind::urDeviceGetInfo>(
189+
Adapter.call<UrApiKind::urDeviceGetInfo>(
190190
Device, UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES, sizeof(MaxThreadsPerBlock),
191191
MaxThreadsPerBlock, nullptr);
192192

@@ -232,7 +232,7 @@ void handleInvalidWorkGroupSize(const device_impl &DeviceImpl,
232232
// than the value specified by UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE in
233233
// table 5.21.
234234
size_t KernelWGSize = 0;
235-
Adapter->call<UrApiKind::urKernelGetGroupInfo>(
235+
Adapter.call<UrApiKind::urKernelGetGroupInfo>(
236236
Kernel, Device, UR_KERNEL_GROUP_INFO_WORK_GROUP_SIZE, sizeof(size_t),
237237
&KernelWGSize, nullptr);
238238
if (TotalNumberOfWIs > KernelWGSize)
@@ -284,15 +284,15 @@ void handleInvalidWorkGroupSize(const device_impl &DeviceImpl,
284284
// work-group given by local_work_size
285285

286286
ur_program_handle_t Program = nullptr;
287-
Adapter->call<UrApiKind::urKernelGetInfo>(
287+
Adapter.call<UrApiKind::urKernelGetInfo>(
288288
Kernel, UR_KERNEL_INFO_PROGRAM, sizeof(ur_program_handle_t),
289289
&Program, nullptr);
290290
size_t OptsSize = 0;
291-
Adapter->call<UrApiKind::urProgramGetBuildInfo>(
291+
Adapter.call<UrApiKind::urProgramGetBuildInfo>(
292292
Program, Device, UR_PROGRAM_BUILD_INFO_OPTIONS, 0, nullptr,
293293
&OptsSize);
294294
std::string Opts(OptsSize, '\0');
295-
Adapter->call<UrApiKind::urProgramGetBuildInfo>(
295+
Adapter.call<UrApiKind::urProgramGetBuildInfo>(
296296
Program, Device, UR_PROGRAM_BUILD_INFO_OPTIONS, OptsSize,
297297
&Opts.front(), nullptr);
298298
const bool HasStd20 = Opts.find("-cl-std=CL2.0") != std::string::npos;
@@ -351,12 +351,12 @@ void handleInvalidWorkGroupSize(const device_impl &DeviceImpl,
351351
void handleInvalidWorkItemSize(const device_impl &DeviceImpl,
352352
const NDRDescT &NDRDesc) {
353353

354-
const AdapterPtr &Adapter = DeviceImpl.getAdapter();
354+
adapter_impl &Adapter = DeviceImpl.getAdapter();
355355
ur_device_handle_t Device = DeviceImpl.getHandleRef();
356356

357357
size_t MaxWISize[] = {0, 0, 0};
358358

359-
Adapter->call<UrApiKind::urDeviceGetInfo>(
359+
Adapter.call<UrApiKind::urDeviceGetInfo>(
360360
Device, UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES, sizeof(MaxWISize), &MaxWISize,
361361
nullptr);
362362
for (unsigned I = 0; I < NDRDesc.Dims; I++) {
@@ -371,11 +371,11 @@ void handleInvalidWorkItemSize(const device_impl &DeviceImpl,
371371

372372
void handleInvalidValue(const device_impl &DeviceImpl,
373373
const NDRDescT &NDRDesc) {
374-
const AdapterPtr &Adapter = DeviceImpl.getAdapter();
374+
adapter_impl &Adapter = DeviceImpl.getAdapter();
375375
ur_device_handle_t Device = DeviceImpl.getHandleRef();
376376

377377
size_t MaxNWGs[] = {0, 0, 0};
378-
Adapter->call<UrApiKind::urDeviceGetInfo>(Device,
378+
Adapter.call<UrApiKind::urDeviceGetInfo>(Device,
379379
UR_DEVICE_INFO_MAX_WORK_GROUPS_3D,
380380
sizeof(MaxNWGs), &MaxNWGs, nullptr);
381381
for (unsigned int I = 0; I < NDRDesc.Dims; I++) {
@@ -452,7 +452,7 @@ void handleErrorOrWarning(ur_result_t Error, const device_impl &DeviceImpl,
452452
// an error or a warning. It also ensures that the contents of the error
453453
// message buffer (used only by UR_RESULT_ERROR_ADAPTER_SPECIFIC_ERROR) get
454454
// handled correctly.
455-
return DeviceImpl.getAdapter()->checkUrResult(Error);
455+
return DeviceImpl.getAdapter().checkUrResult(Error);
456456

457457
// TODO: Handle other error codes
458458

@@ -469,7 +469,7 @@ void handleErrorOrWarning(ur_result_t Error, const device_impl &DeviceImpl,
469469

470470
namespace detail::kernel_get_group_info {
471471
void handleErrorOrWarning(ur_result_t Error, ur_kernel_group_info_t Descriptor,
472-
const AdapterPtr &Adapter) {
472+
adapter_impl &Adapter) {
473473
assert(Error != UR_RESULT_SUCCESS &&
474474
"Success is expected to be handled on caller side");
475475
switch (Error) {
@@ -483,7 +483,7 @@ void handleErrorOrWarning(ur_result_t Error, ur_kernel_group_info_t Descriptor,
483483
break;
484484
// TODO: Handle other error codes
485485
default:
486-
Adapter->checkUrResult(Error);
486+
Adapter.checkUrResult(Error);
487487
break;
488488
}
489489
}

sycl/source/detail/error_handling/error_handling.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ void handleErrorOrWarning(ur_result_t, const device_impl &, ur_kernel_handle_t,
3131

3232
namespace kernel_get_group_info {
3333
/// Analyzes error code of urKernelGetGroupInfo.
34-
void handleErrorOrWarning(ur_result_t, ur_kernel_group_info_t,
35-
const AdapterPtr &);
34+
void handleErrorOrWarning(ur_result_t, ur_kernel_group_info_t, adapter_impl &);
3635
} // namespace kernel_get_group_info
3736

3837
} // namespace detail

0 commit comments

Comments
 (0)