Skip to content

Commit 8d8f695

Browse files
[SYCL] More type safety for device_impl::get_info_impl* (#18425)
Map `ur_device_info_t` value to UR query's return type. In the next PR I will remove the `ReturnT` template parameter from `device_impl::get_info_impl*` but I decided that an intermediate PR with `static_asert` instead of full transition would be beneficial to the reviewers as it would ensure that next PR would actually be a no-op. On the other hand, this one addresses some discrepances between what SYCL RT expected the UR''s return type to be vs. what UR actually claims to return.
1 parent 570f74d commit 8d8f695

File tree

2 files changed

+341
-48
lines changed

2 files changed

+341
-48
lines changed

sycl/source/detail/device_impl.hpp

Lines changed: 150 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,28 @@ class platform;
2929

3030
namespace detail {
3131

32+
// Note that UR's enums have weird *_FORCE_UINT32 values, we ignore them in the
33+
// callers. But we also can't write a fully-covered switch without mentioning it
34+
// there, which wouldn't make any sense. As such, ensure that "real" values
35+
// match and then just `static_cast` them (in the caller).
36+
template <typename T0, typename T1>
37+
constexpr bool enums_match(std::initializer_list<T0> l0,
38+
std::initializer_list<T1> l1) {
39+
using U0 = std::underlying_type_t<T0>;
40+
using U1 = std::underlying_type_t<T1>;
41+
using C = std::common_type_t<U0, U1>;
42+
// std::equal isn't constexpr until C++20.
43+
if (l0.size() != l1.size())
44+
return false;
45+
auto i0 = l0.begin();
46+
auto e = l0.end();
47+
auto i1 = l1.begin();
48+
for (; i0 != e; ++i0, ++i1)
49+
if (static_cast<C>(*i0) != static_cast<C>(*i1))
50+
return false;
51+
return true;
52+
}
53+
3254
// Forward declaration
3355
class platform_impl;
3456

@@ -208,6 +230,30 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
208230

209231
// device_traits.def
210232

233+
CASE(info::device::device_type) {
234+
using device_type = info::device_type;
235+
switch (get_info_impl<ur_device_type_t, UR_DEVICE_INFO_TYPE>()) {
236+
case UR_DEVICE_TYPE_DEFAULT:
237+
return device_type::automatic;
238+
case UR_DEVICE_TYPE_ALL:
239+
return device_type::all;
240+
case UR_DEVICE_TYPE_GPU:
241+
return device_type::gpu;
242+
case UR_DEVICE_TYPE_CPU:
243+
return device_type::cpu;
244+
case UR_DEVICE_TYPE_FPGA:
245+
return device_type::accelerator;
246+
case UR_DEVICE_TYPE_MCA:
247+
case UR_DEVICE_TYPE_VPU:
248+
return device_type::custom;
249+
default: {
250+
assert(false);
251+
// FIXME: what is that???
252+
return device_type::custom;
253+
}
254+
}
255+
}
256+
211257
CASE(info::device::max_work_item_sizes<3>) {
212258
auto result = get_info_impl<std::array<size_t, 3>,
213259
UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES>();
@@ -242,24 +288,46 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
242288
return get_fp_config<UR_DEVICE_INFO_DOUBLE_FP_CONFIG>();
243289
}
244290

291+
CASE(info::device::global_mem_cache_type) {
292+
using cache = info::global_mem_cache_type;
293+
static_assert(
294+
enums_match({UR_DEVICE_MEM_CACHE_TYPE_NONE,
295+
UR_DEVICE_MEM_CACHE_TYPE_READ_ONLY_CACHE,
296+
UR_DEVICE_MEM_CACHE_TYPE_READ_WRITE_CACHE},
297+
{cache::none, cache::read_only, cache::read_write}));
298+
return static_cast<cache>(
299+
get_info_impl<ur_device_mem_cache_type_t,
300+
UR_DEVICE_INFO_GLOBAL_MEM_CACHE_TYPE>());
301+
}
302+
303+
CASE(info::device::local_mem_type) {
304+
using mem = info::local_mem_type;
305+
static_assert(enums_match({UR_DEVICE_LOCAL_MEM_TYPE_NONE,
306+
UR_DEVICE_LOCAL_MEM_TYPE_LOCAL,
307+
UR_DEVICE_LOCAL_MEM_TYPE_GLOBAL},
308+
{mem::none, mem::local, mem::global}));
309+
return static_cast<mem>(get_info_impl<ur_device_local_mem_type_t,
310+
UR_DEVICE_INFO_LOCAL_MEM_TYPE>());
311+
}
312+
245313
CASE(info::device::atomic_memory_order_capabilities) {
246314
return readMemoryOrderBitfield(
247-
get_info_impl<ur_memory_order_capability_flag_t,
315+
get_info_impl<ur_memory_order_capability_flags_t,
248316
UR_DEVICE_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES>());
249317
}
250318
CASE(info::device::atomic_fence_order_capabilities) {
251319
return readMemoryOrderBitfield(
252-
get_info_impl<ur_memory_order_capability_flag_t,
320+
get_info_impl<ur_memory_order_capability_flags_t,
253321
UR_DEVICE_INFO_ATOMIC_FENCE_ORDER_CAPABILITIES>());
254322
}
255323
CASE(info::device::atomic_memory_scope_capabilities) {
256324
return readMemoryScopeBitfield(
257-
get_info_impl<size_t,
325+
get_info_impl<ur_memory_scope_capability_flags_t,
258326
UR_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES>());
259327
}
260328
CASE(info::device::atomic_fence_scope_capabilities) {
261329
return readMemoryScopeBitfield(
262-
get_info_impl<size_t,
330+
get_info_impl<ur_memory_scope_capability_flags_t,
263331
UR_DEVICE_INFO_ATOMIC_FENCE_SCOPE_CAPABILITIES>());
264332
}
265333

@@ -269,8 +337,8 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
269337
"info::device::execution_capabilities is available for "
270338
"backend::opencl only");
271339

272-
ur_device_exec_capability_flag_t bits =
273-
get_info_impl<ur_device_exec_capability_flag_t,
340+
ur_device_exec_capability_flags_t bits =
341+
get_info_impl<ur_device_exec_capability_flags_t,
274342
UR_DEVICE_INFO_EXECUTION_CAPABILITIES>();
275343
std::vector<info::execution_capability> result;
276344
if (bits & UR_DEVICE_EXEC_CAPABILITY_FLAG_KERNEL)
@@ -593,6 +661,12 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
593661
return get_matrix_combinations();
594662
}
595663

664+
CASE(ext::oneapi::experimental::info::device::mipmap_max_anisotropy) {
665+
// Implicit conversion:
666+
return get_info_impl<uint32_t,
667+
UR_DEVICE_INFO_MIPMAP_MAX_ANISOTROPY_EXP>();
668+
}
669+
596670
CASE(ext::oneapi::experimental::info::device::component_devices) {
597671
auto Devs = get_info_impl_nocheck<std::vector<ur_device_handle_t>,
598672
UR_DEVICE_INFO_COMPONENT_DEVICES>();
@@ -628,6 +702,10 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
628702
"A component with aspect::ext_oneapi_is_component "
629703
"must have a composite device.");
630704
}
705+
CASE(ext::oneapi::info::device::num_compute_units) {
706+
// uint32_t -> size_t
707+
return get_info_impl<uint32_t, UR_DEVICE_INFO_NUM_COMPUTE_UNITS>();
708+
}
631709

632710
// ext_intel_device_traits.def
633711

@@ -718,6 +796,11 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
718796
"The device does not have the ext_intel_memory_bus_width aspect");
719797
return get_info_impl<uint32_t, UR_DEVICE_INFO_MEMORY_BUS_WIDTH>();
720798
}
799+
CASE(ext::intel::info::device::max_compute_queue_indices) {
800+
// uint32_t->int implicit conversion.
801+
return get_info_impl<uint32_t,
802+
UR_DEVICE_INFO_MAX_COMPUTE_QUEUE_INDICES>();
803+
}
721804
CASE(ext::intel::esimd::info::device::has_2d_block_io_support) {
722805
if (!has(aspect::ext_intel_esimd))
723806
return false;
@@ -904,6 +987,17 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
904987
MDevice, Desc, 0, nullptr, &return_size) == UR_RESULT_SUCCESS;
905988
}
906989

990+
template <ur_device_info_t Desc> static constexpr auto ur_ret_type_impl() {
991+
if constexpr (false) {
992+
}
993+
#define MAP(VALUE, ...) else if constexpr (Desc == VALUE) return __VA_ARGS__{};
994+
#include "ur_device_info_ret_types.inc"
995+
#undef MAP
996+
}
997+
998+
template <ur_device_info_t Desc>
999+
using ur_ret_type = decltype(ur_ret_type_impl<Desc>());
1000+
9071001
// This should really be
9081002
// std::expected<ReturnT, ur_result_t>
9091003
// but we don't have C++23. Emulate close enough with as little code as
@@ -932,62 +1026,70 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
9321026

9331027
template <typename ReturnT, ur_device_info_t Desc>
9341028
expected<ReturnT, ur_result_t> get_info_impl_nocheck() const {
1029+
using ur_ret_t = ur_ret_type<Desc>;
9351030
static_assert(!std::is_same_v<ReturnT, std::string>,
9361031
"Wasn't needed before.");
9371032
if constexpr (std::is_same_v<ReturnT, bool>) {
9381033
return get_info_impl_nocheck<ur_bool_t, Desc>();
939-
} else if constexpr (is_std_vector_v<ReturnT>) {
940-
static_assert(
941-
!check_type_in_v<typename ReturnT::value_type, bool, std::string>);
942-
size_t ResultSize = 0;
943-
ur_result_t Error =
944-
getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
945-
getHandleRef(), Desc, 0, nullptr, &ResultSize);
946-
if (Error != UR_RESULT_SUCCESS)
947-
return {Error};
948-
if (ResultSize == 0)
949-
return {ReturnT{}};
950-
951-
ReturnT Result(ResultSize / sizeof(typename ReturnT::value_type));
952-
Error = getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
953-
getHandleRef(), Desc, ResultSize, Result.data(), nullptr);
954-
if (Error != UR_RESULT_SUCCESS)
955-
return {Error};
956-
return {Result};
9571034
} else {
958-
ReturnT Result;
959-
ur_result_t Error =
960-
getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
961-
getHandleRef(), Desc, sizeof(Result), &Result, nullptr);
962-
if (Error == UR_RESULT_SUCCESS)
1035+
static_assert(std::is_same_v<ur_ret_t, ReturnT>);
1036+
if constexpr (is_std_vector_v<ReturnT>) {
1037+
static_assert(
1038+
!check_type_in_v<typename ReturnT::value_type, bool, std::string>);
1039+
size_t ResultSize = 0;
1040+
ur_result_t Error =
1041+
getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
1042+
getHandleRef(), Desc, 0, nullptr, &ResultSize);
1043+
if (Error != UR_RESULT_SUCCESS)
1044+
return {Error};
1045+
if (ResultSize == 0)
1046+
return {ReturnT{}};
1047+
1048+
ReturnT Result(ResultSize / sizeof(typename ReturnT::value_type));
1049+
Error = getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
1050+
getHandleRef(), Desc, ResultSize, Result.data(), nullptr);
1051+
if (Error != UR_RESULT_SUCCESS)
1052+
return {Error};
9631053
return {Result};
964-
else
965-
return {Error};
1054+
} else {
1055+
ReturnT Result;
1056+
ur_result_t Error =
1057+
getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
1058+
getHandleRef(), Desc, sizeof(Result), &Result, nullptr);
1059+
if (Error == UR_RESULT_SUCCESS)
1060+
return {Result};
1061+
else
1062+
return {Error};
1063+
}
9661064
}
9671065
}
9681066

9691067
template <typename ReturnT, ur_device_info_t Desc>
9701068
ReturnT get_info_impl() const {
1069+
using ur_ret_t = ur_ret_type<Desc>;
9711070
if constexpr (std::is_same_v<ReturnT, bool>) {
9721071
return get_info_impl<ur_bool_t, Desc>();
973-
} else if constexpr (std::is_same_v<ReturnT, std::string>) {
974-
return urGetInfoString<UrApiKind::urDeviceGetInfo>(*this, Desc);
975-
} else if constexpr (is_std_vector_v<ReturnT>) {
976-
size_t ResultSize = 0;
977-
getAdapter()->call<UrApiKind::urDeviceGetInfo>(getHandleRef(), Desc, 0,
978-
nullptr, &ResultSize);
979-
if (ResultSize == 0)
980-
return {};
981-
982-
ReturnT Result(ResultSize / sizeof(typename ReturnT::value_type));
983-
getAdapter()->call<UrApiKind::urDeviceGetInfo>(
984-
getHandleRef(), Desc, ResultSize, Result.data(), nullptr);
985-
return Result;
9861072
} else {
987-
ReturnT Result;
988-
getAdapter()->call<UrApiKind::urDeviceGetInfo>(
989-
getHandleRef(), Desc, sizeof(Result), &Result, nullptr);
990-
return Result;
1073+
static_assert(std::is_same_v<ur_ret_t, ReturnT>);
1074+
if constexpr (std::is_same_v<ReturnT, std::string>) {
1075+
return urGetInfoString<UrApiKind::urDeviceGetInfo>(*this, Desc);
1076+
} else if constexpr (is_std_vector_v<ReturnT>) {
1077+
size_t ResultSize = 0;
1078+
getAdapter()->call<UrApiKind::urDeviceGetInfo>(getHandleRef(), Desc, 0,
1079+
nullptr, &ResultSize);
1080+
if (ResultSize == 0)
1081+
return {};
1082+
1083+
ReturnT Result(ResultSize / sizeof(typename ReturnT::value_type));
1084+
getAdapter()->call<UrApiKind::urDeviceGetInfo>(
1085+
getHandleRef(), Desc, ResultSize, Result.data(), nullptr);
1086+
return Result;
1087+
} else {
1088+
ReturnT Result;
1089+
getAdapter()->call<UrApiKind::urDeviceGetInfo>(
1090+
getHandleRef(), Desc, sizeof(Result), &Result, nullptr);
1091+
return Result;
1092+
}
9911093
}
9921094
}
9931095

0 commit comments

Comments
 (0)