@@ -29,6 +29,28 @@ class platform;
29
29
30
30
namespace detail {
31
31
32
+ // Note that UR's enums have weird *_FORCE_UINT32 values, we ignore them in the
33
+ // callers. But we also can't write a fully-covered switch without mentioning it
34
+ // there, which wouldn't make any sense. As such, ensure that "real" values
35
+ // match and then just `static_cast` them (in the caller).
36
+ template <typename T0, typename T1>
37
+ constexpr bool enums_match (std::initializer_list<T0> l0,
38
+ std::initializer_list<T1> l1) {
39
+ using U0 = std::underlying_type_t <T0>;
40
+ using U1 = std::underlying_type_t <T1>;
41
+ using C = std::common_type_t <U0, U1>;
42
+ // std::equal isn't constexpr until C++20.
43
+ if (l0.size () != l1.size ())
44
+ return false ;
45
+ auto i0 = l0.begin ();
46
+ auto e = l0.end ();
47
+ auto i1 = l1.begin ();
48
+ for (; i0 != e; ++i0, ++i1)
49
+ if (static_cast <C>(*i0) != static_cast <C>(*i1))
50
+ return false ;
51
+ return true ;
52
+ }
53
+
32
54
// Forward declaration
33
55
class platform_impl ;
34
56
@@ -208,6 +230,30 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
208
230
209
231
// device_traits.def
210
232
233
+ CASE (info::device::device_type) {
234
+ using device_type = info::device_type;
235
+ switch (get_info_impl<ur_device_type_t , UR_DEVICE_INFO_TYPE>()) {
236
+ case UR_DEVICE_TYPE_DEFAULT:
237
+ return device_type::automatic;
238
+ case UR_DEVICE_TYPE_ALL:
239
+ return device_type::all;
240
+ case UR_DEVICE_TYPE_GPU:
241
+ return device_type::gpu;
242
+ case UR_DEVICE_TYPE_CPU:
243
+ return device_type::cpu;
244
+ case UR_DEVICE_TYPE_FPGA:
245
+ return device_type::accelerator;
246
+ case UR_DEVICE_TYPE_MCA:
247
+ case UR_DEVICE_TYPE_VPU:
248
+ return device_type::custom;
249
+ default : {
250
+ assert (false );
251
+ // FIXME: what is that???
252
+ return device_type::custom;
253
+ }
254
+ }
255
+ }
256
+
211
257
CASE (info::device::max_work_item_sizes<3 >) {
212
258
auto result = get_info_impl<std::array<size_t , 3 >,
213
259
UR_DEVICE_INFO_MAX_WORK_ITEM_SIZES>();
@@ -242,24 +288,46 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
242
288
return get_fp_config<UR_DEVICE_INFO_DOUBLE_FP_CONFIG>();
243
289
}
244
290
291
+ CASE (info::device::global_mem_cache_type) {
292
+ using cache = info::global_mem_cache_type;
293
+ static_assert (
294
+ enums_match ({UR_DEVICE_MEM_CACHE_TYPE_NONE,
295
+ UR_DEVICE_MEM_CACHE_TYPE_READ_ONLY_CACHE,
296
+ UR_DEVICE_MEM_CACHE_TYPE_READ_WRITE_CACHE},
297
+ {cache::none, cache::read_only, cache::read_write}));
298
+ return static_cast <cache>(
299
+ get_info_impl<ur_device_mem_cache_type_t ,
300
+ UR_DEVICE_INFO_GLOBAL_MEM_CACHE_TYPE>());
301
+ }
302
+
303
+ CASE (info::device::local_mem_type) {
304
+ using mem = info::local_mem_type;
305
+ static_assert (enums_match ({UR_DEVICE_LOCAL_MEM_TYPE_NONE,
306
+ UR_DEVICE_LOCAL_MEM_TYPE_LOCAL,
307
+ UR_DEVICE_LOCAL_MEM_TYPE_GLOBAL},
308
+ {mem::none, mem::local, mem::global}));
309
+ return static_cast <mem>(get_info_impl<ur_device_local_mem_type_t ,
310
+ UR_DEVICE_INFO_LOCAL_MEM_TYPE>());
311
+ }
312
+
245
313
CASE (info::device::atomic_memory_order_capabilities) {
246
314
return readMemoryOrderBitfield (
247
- get_info_impl<ur_memory_order_capability_flag_t ,
315
+ get_info_impl<ur_memory_order_capability_flags_t ,
248
316
UR_DEVICE_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES>());
249
317
}
250
318
CASE (info::device::atomic_fence_order_capabilities) {
251
319
return readMemoryOrderBitfield (
252
- get_info_impl<ur_memory_order_capability_flag_t ,
320
+ get_info_impl<ur_memory_order_capability_flags_t ,
253
321
UR_DEVICE_INFO_ATOMIC_FENCE_ORDER_CAPABILITIES>());
254
322
}
255
323
CASE (info::device::atomic_memory_scope_capabilities) {
256
324
return readMemoryScopeBitfield (
257
- get_info_impl<size_t ,
325
+ get_info_impl<ur_memory_scope_capability_flags_t ,
258
326
UR_DEVICE_INFO_ATOMIC_MEMORY_SCOPE_CAPABILITIES>());
259
327
}
260
328
CASE (info::device::atomic_fence_scope_capabilities) {
261
329
return readMemoryScopeBitfield (
262
- get_info_impl<size_t ,
330
+ get_info_impl<ur_memory_scope_capability_flags_t ,
263
331
UR_DEVICE_INFO_ATOMIC_FENCE_SCOPE_CAPABILITIES>());
264
332
}
265
333
@@ -269,8 +337,8 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
269
337
" info::device::execution_capabilities is available for "
270
338
" backend::opencl only" );
271
339
272
- ur_device_exec_capability_flag_t bits =
273
- get_info_impl<ur_device_exec_capability_flag_t ,
340
+ ur_device_exec_capability_flags_t bits =
341
+ get_info_impl<ur_device_exec_capability_flags_t ,
274
342
UR_DEVICE_INFO_EXECUTION_CAPABILITIES>();
275
343
std::vector<info::execution_capability> result;
276
344
if (bits & UR_DEVICE_EXEC_CAPABILITY_FLAG_KERNEL)
@@ -593,6 +661,12 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
593
661
return get_matrix_combinations ();
594
662
}
595
663
664
+ CASE (ext::oneapi::experimental::info::device::mipmap_max_anisotropy) {
665
+ // Implicit conversion:
666
+ return get_info_impl<uint32_t ,
667
+ UR_DEVICE_INFO_MIPMAP_MAX_ANISOTROPY_EXP>();
668
+ }
669
+
596
670
CASE (ext::oneapi::experimental::info::device::component_devices) {
597
671
auto Devs = get_info_impl_nocheck<std::vector<ur_device_handle_t >,
598
672
UR_DEVICE_INFO_COMPONENT_DEVICES>();
@@ -628,6 +702,10 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
628
702
" A component with aspect::ext_oneapi_is_component "
629
703
" must have a composite device." );
630
704
}
705
+ CASE (ext::oneapi::info::device::num_compute_units) {
706
+ // uint32_t -> size_t
707
+ return get_info_impl<uint32_t , UR_DEVICE_INFO_NUM_COMPUTE_UNITS>();
708
+ }
631
709
632
710
// ext_intel_device_traits.def
633
711
@@ -718,6 +796,11 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
718
796
" The device does not have the ext_intel_memory_bus_width aspect" );
719
797
return get_info_impl<uint32_t , UR_DEVICE_INFO_MEMORY_BUS_WIDTH>();
720
798
}
799
+ CASE (ext::intel::info::device::max_compute_queue_indices) {
800
+ // uint32_t->int implicit conversion.
801
+ return get_info_impl<uint32_t ,
802
+ UR_DEVICE_INFO_MAX_COMPUTE_QUEUE_INDICES>();
803
+ }
721
804
CASE (ext::intel::esimd::info::device::has_2d_block_io_support) {
722
805
if (!has (aspect::ext_intel_esimd))
723
806
return false ;
@@ -904,6 +987,17 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
904
987
MDevice, Desc, 0 , nullptr , &return_size) == UR_RESULT_SUCCESS;
905
988
}
906
989
990
+ template <ur_device_info_t Desc> static constexpr auto ur_ret_type_impl () {
991
+ if constexpr (false ) {
992
+ }
993
+ #define MAP (VALUE, ...) else if constexpr (Desc == VALUE) return __VA_ARGS__{};
994
+ #include " ur_device_info_ret_types.inc"
995
+ #undef MAP
996
+ }
997
+
998
+ template <ur_device_info_t Desc>
999
+ using ur_ret_type = decltype (ur_ret_type_impl<Desc>());
1000
+
907
1001
// This should really be
908
1002
// std::expected<ReturnT, ur_result_t>
909
1003
// but we don't have C++23. Emulate close enough with as little code as
@@ -932,62 +1026,70 @@ class device_impl : public std::enable_shared_from_this<device_impl> {
932
1026
933
1027
template <typename ReturnT, ur_device_info_t Desc>
934
1028
expected<ReturnT, ur_result_t > get_info_impl_nocheck () const {
1029
+ using ur_ret_t = ur_ret_type<Desc>;
935
1030
static_assert (!std::is_same_v<ReturnT, std::string>,
936
1031
" Wasn't needed before." );
937
1032
if constexpr (std::is_same_v<ReturnT, bool >) {
938
1033
return get_info_impl_nocheck<ur_bool_t , Desc>();
939
- } else if constexpr (is_std_vector_v<ReturnT>) {
940
- static_assert (
941
- !check_type_in_v<typename ReturnT::value_type, bool , std::string>);
942
- size_t ResultSize = 0 ;
943
- ur_result_t Error =
944
- getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
945
- getHandleRef (), Desc, 0 , nullptr , &ResultSize);
946
- if (Error != UR_RESULT_SUCCESS)
947
- return {Error};
948
- if (ResultSize == 0 )
949
- return {ReturnT{}};
950
-
951
- ReturnT Result (ResultSize / sizeof (typename ReturnT::value_type));
952
- Error = getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
953
- getHandleRef (), Desc, ResultSize, Result.data (), nullptr );
954
- if (Error != UR_RESULT_SUCCESS)
955
- return {Error};
956
- return {Result};
957
1034
} else {
958
- ReturnT Result;
959
- ur_result_t Error =
960
- getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
961
- getHandleRef (), Desc, sizeof (Result), &Result, nullptr );
962
- if (Error == UR_RESULT_SUCCESS)
1035
+ static_assert (std::is_same_v<ur_ret_t , ReturnT>);
1036
+ if constexpr (is_std_vector_v<ReturnT>) {
1037
+ static_assert (
1038
+ !check_type_in_v<typename ReturnT::value_type, bool , std::string>);
1039
+ size_t ResultSize = 0 ;
1040
+ ur_result_t Error =
1041
+ getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
1042
+ getHandleRef (), Desc, 0 , nullptr , &ResultSize);
1043
+ if (Error != UR_RESULT_SUCCESS)
1044
+ return {Error};
1045
+ if (ResultSize == 0 )
1046
+ return {ReturnT{}};
1047
+
1048
+ ReturnT Result (ResultSize / sizeof (typename ReturnT::value_type));
1049
+ Error = getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
1050
+ getHandleRef (), Desc, ResultSize, Result.data (), nullptr );
1051
+ if (Error != UR_RESULT_SUCCESS)
1052
+ return {Error};
963
1053
return {Result};
964
- else
965
- return {Error};
1054
+ } else {
1055
+ ReturnT Result;
1056
+ ur_result_t Error =
1057
+ getAdapter ()->call_nocheck <UrApiKind::urDeviceGetInfo>(
1058
+ getHandleRef (), Desc, sizeof (Result), &Result, nullptr );
1059
+ if (Error == UR_RESULT_SUCCESS)
1060
+ return {Result};
1061
+ else
1062
+ return {Error};
1063
+ }
966
1064
}
967
1065
}
968
1066
969
1067
template <typename ReturnT, ur_device_info_t Desc>
970
1068
ReturnT get_info_impl () const {
1069
+ using ur_ret_t = ur_ret_type<Desc>;
971
1070
if constexpr (std::is_same_v<ReturnT, bool >) {
972
1071
return get_info_impl<ur_bool_t , Desc>();
973
- } else if constexpr (std::is_same_v<ReturnT, std::string>) {
974
- return urGetInfoString<UrApiKind::urDeviceGetInfo>(*this , Desc);
975
- } else if constexpr (is_std_vector_v<ReturnT>) {
976
- size_t ResultSize = 0 ;
977
- getAdapter ()->call <UrApiKind::urDeviceGetInfo>(getHandleRef (), Desc, 0 ,
978
- nullptr , &ResultSize);
979
- if (ResultSize == 0 )
980
- return {};
981
-
982
- ReturnT Result (ResultSize / sizeof (typename ReturnT::value_type));
983
- getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
984
- getHandleRef (), Desc, ResultSize, Result.data (), nullptr );
985
- return Result;
986
1072
} else {
987
- ReturnT Result;
988
- getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
989
- getHandleRef (), Desc, sizeof (Result), &Result, nullptr );
990
- return Result;
1073
+ static_assert (std::is_same_v<ur_ret_t , ReturnT>);
1074
+ if constexpr (std::is_same_v<ReturnT, std::string>) {
1075
+ return urGetInfoString<UrApiKind::urDeviceGetInfo>(*this , Desc);
1076
+ } else if constexpr (is_std_vector_v<ReturnT>) {
1077
+ size_t ResultSize = 0 ;
1078
+ getAdapter ()->call <UrApiKind::urDeviceGetInfo>(getHandleRef (), Desc, 0 ,
1079
+ nullptr , &ResultSize);
1080
+ if (ResultSize == 0 )
1081
+ return {};
1082
+
1083
+ ReturnT Result (ResultSize / sizeof (typename ReturnT::value_type));
1084
+ getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
1085
+ getHandleRef (), Desc, ResultSize, Result.data (), nullptr );
1086
+ return Result;
1087
+ } else {
1088
+ ReturnT Result;
1089
+ getAdapter ()->call <UrApiKind::urDeviceGetInfo>(
1090
+ getHandleRef (), Desc, sizeof (Result), &Result, nullptr );
1091
+ return Result;
1092
+ }
991
1093
}
992
1094
}
993
1095
0 commit comments