@@ -219,7 +219,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
219
219
uint32_t *pNumDevices) {
220
220
221
221
if (!hPlatform) {
222
- return UR_RESULT_ERROR_INVALID_NULL_POINTER ;
222
+ return UR_RESULT_ERROR_INVALID_NULL_HANDLE ;
223
223
}
224
224
// NumEntries is max number of devices wanted by the caller (max usable length of phDevices)
225
225
if (NumEntries < 0 ) {
@@ -230,9 +230,22 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
230
230
}
231
231
// pNumDevices is the actual number of device handles added to phDevices by this function
232
232
if (NumEntries == 0 && !pNumDevices) {
233
- return UR_RESULT_ERROR_INVALID_NULL_POINTER ;
233
+ return UR_RESULT_ERROR_INVALID_SIZE ;
234
234
}
235
235
236
+ switch (DeviceType) {
237
+ case UR_DEVICE_TYPE_ALL:
238
+ case UR_DEVICE_TYPE_GPU:
239
+ case UR_DEVICE_TYPE_DEFAULT:
240
+ case UR_DEVICE_TYPE_CPU:
241
+ case UR_DEVICE_TYPE_FPGA:
242
+ case UR_DEVICE_TYPE_MCA:
243
+ break ;
244
+ default :
245
+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
246
+ // urPrint("Unknown device type");
247
+ break ;
248
+ }
236
249
// plan:
237
250
// 0. basic validation of argument values (see code above)
238
251
// 1. conversion of argument values into useful data items
@@ -267,42 +280,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
267
280
// however
268
281
// discard "2,*" == "*,2"
269
282
270
- ur_platform_backend_t platformBackend;
271
- if (UR_RESULT_SUCCESS !=
272
- urPlatformGetInfo (hPlatform, UR_PLATFORM_INFO_BACKEND,
273
- sizeof (ur_platform_backend_t ), &platformBackend, 0 )) {
274
- return UR_RESULT_ERROR_INVALID_PLATFORM;
275
- }
276
- const std::string platformBackendName = // hPlatform->get_backend_name();
277
- [&platformBackend]() constexpr {
278
- switch (platformBackend) {
279
- case UR_PLATFORM_BACKEND_UNKNOWN:
280
- return " *" ; // the only ODS string that matches
281
- break ;
282
- case UR_PLATFORM_BACKEND_LEVEL_ZERO:
283
- return " level_zero" ;
284
- break ;
285
- case UR_PLATFORM_BACKEND_OPENCL:
286
- return " opencl" ;
287
- break ;
288
- case UR_PLATFORM_BACKEND_CUDA:
289
- return " cuda" ;
290
- break ;
291
- case UR_PLATFORM_BACKEND_HIP:
292
- return " hip" ;
293
- break ;
294
- case UR_PLATFORM_BACKEND_NATIVE_CPU:
295
- return " *" ; // the only ODS string that matches
296
- break ;
297
- case UR_PLATFORM_BACKEND_FORCE_UINT32:
298
- return " " ; // no ODS string matches this
299
- break ;
300
- default :
301
- return " " ; // no ODS string matches this
302
- break ;
303
- }
304
- }();
305
-
306
283
// The std::map is sorted by its key, so this method of parsing the ODS env var
307
284
// alters the ordering of the terms, which makes it impossible to check whether
308
285
// all discard terms appear after all accept terms and to preserve the ordering
@@ -314,7 +291,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
314
291
// discard term, for that backend.
315
292
// (If we wished to preserve the ordering of terms, we could replace `std::map`
316
293
// with `std::queue<std::pair<key_type_t, value_type_t>>` or something similar.)
317
- auto & maybeEnvVarMap = getenv_to_map (" ONEAPI_DEVICE_SELECTOR" , true );
294
+ auto maybeEnvVarMap = getenv_to_map (" ONEAPI_DEVICE_SELECTOR" , true );
318
295
319
296
// if the ODS env var is not set at all, then pretend it was set to the default
320
297
using EnvVarMap = std::map<std::string, std::vector<std::string>>;
@@ -361,6 +338,42 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
361
338
" )$" ,
362
339
std::regex_constants::icase);
363
340
341
+ ur_platform_backend_t platformBackend;
342
+ if (UR_RESULT_SUCCESS !=
343
+ urPlatformGetInfo (hPlatform, UR_PLATFORM_INFO_BACKEND,
344
+ sizeof (ur_platform_backend_t ), &platformBackend, 0 )) {
345
+ return UR_RESULT_ERROR_INVALID_PLATFORM;
346
+ }
347
+ const std::string platformBackendName = // hPlatform->get_backend_name();
348
+ [&platformBackend]() constexpr {
349
+ switch (platformBackend) {
350
+ case UR_PLATFORM_BACKEND_UNKNOWN:
351
+ return " *" ; // the only ODS string that matches
352
+ break ;
353
+ case UR_PLATFORM_BACKEND_LEVEL_ZERO:
354
+ return " level_zero" ;
355
+ break ;
356
+ case UR_PLATFORM_BACKEND_OPENCL:
357
+ return " opencl" ;
358
+ break ;
359
+ case UR_PLATFORM_BACKEND_CUDA:
360
+ return " cuda" ;
361
+ break ;
362
+ case UR_PLATFORM_BACKEND_HIP:
363
+ return " hip" ;
364
+ break ;
365
+ case UR_PLATFORM_BACKEND_NATIVE_CPU:
366
+ return " *" ; // the only ODS string that matches
367
+ break ;
368
+ case UR_PLATFORM_BACKEND_FORCE_UINT32:
369
+ return " " ; // no ODS string matches this
370
+ break ;
371
+ default :
372
+ return " " ; // no ODS string matches this
373
+ break ;
374
+ }
375
+ }();
376
+
364
377
using DeviceHardwareType = ur_device_type_t ;
365
378
366
379
enum class DevicePartLevel { ROOT, SUB, SUBSUB };
@@ -772,11 +785,15 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
772
785
if (NumEntries == 0 ) {
773
786
*pNumDevices = static_cast <uint32_t >(selectedDevices.size ());
774
787
} else if (NumEntries > 0 ) {
775
- *pNumDevices = static_cast <uint32_t >(
776
- std::min ((size_t )NumEntries, selectedDevices.size ()));
777
- std::copy_n (selectedDevices.cbegin (), *pNumDevices, phDevices);
788
+ size_t numToCopy = std::min ((size_t )NumEntries, selectedDevices.size ());
789
+ std::copy_n (selectedDevices.cbegin (), numToCopy, phDevices);
790
+ if (pNumDevices != nullptr ) {
791
+ *pNumDevices = static_cast <uint32_t >(numToCopy);
792
+ return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
793
+ }
778
794
}
779
795
796
+
780
797
return UR_RESULT_SUCCESS;
781
798
}
782
799
} // namespace ur_lib
0 commit comments