@@ -251,6 +251,13 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
251
251
uint32_t NumEntries,
252
252
ur_device_handle_t *phDevices,
253
253
uint32_t *pNumDevices) {
254
+ constexpr std::pair<const ur_platform_backend_t , const char *> adapters[6 ] = {
255
+ {UR_PLATFORM_BACKEND_UNKNOWN, " *" },
256
+ {UR_PLATFORM_BACKEND_LEVEL_ZERO, " level_zero" },
257
+ {UR_PLATFORM_BACKEND_OPENCL, " opencl" },
258
+ {UR_PLATFORM_BACKEND_CUDA, " cuda" },
259
+ {UR_PLATFORM_BACKEND_HIP, " hip" },
260
+ {UR_PLATFORM_BACKEND_NATIVE_CPU, " native_cpu" }};
254
261
255
262
if (!hPlatform) {
256
263
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
@@ -323,7 +330,9 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
323
330
// (If we wished to preserve the ordering of terms, we could replace
324
331
// `std::map` with `std::queue<std::pair<key_type_t, value_type_t>>` or
325
332
// something similar.)
326
- auto maybeEnvVarMap = getenv_to_map (" ONEAPI_DEVICE_SELECTOR" , false );
333
+ auto maybeEnvVarMap =
334
+ getenv_to_map (" ONEAPI_DEVICE_SELECTOR" , /* reject_empty= */ false ,
335
+ /* allow_duplicate= */ false , /* lower= */ true );
327
336
logger::debug (
328
337
" getenv_to_map parsed env var and {} a map" ,
329
338
(maybeEnvVarMap.has_value () ? " produced" : " failed to produce" ));
@@ -380,35 +389,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
380
389
sizeof (ur_platform_backend_t ), &platformBackend, 0 )) {
381
390
return UR_RESULT_ERROR_INVALID_PLATFORM;
382
391
}
383
- const std::string platformBackendName = // hPlatform->get_backend_name();
384
- [&platformBackend]() constexpr {
385
- switch (platformBackend) {
386
- case UR_PLATFORM_BACKEND_UNKNOWN:
387
- return " *" ; // the only ODS string that matches
388
- break ;
389
- case UR_PLATFORM_BACKEND_LEVEL_ZERO:
390
- return " level_zero" ;
391
- break ;
392
- case UR_PLATFORM_BACKEND_OPENCL:
393
- return " opencl" ;
394
- break ;
395
- case UR_PLATFORM_BACKEND_CUDA:
396
- return " cuda" ;
397
- break ;
398
- case UR_PLATFORM_BACKEND_HIP:
399
- return " hip" ;
400
- break ;
401
- case UR_PLATFORM_BACKEND_NATIVE_CPU:
402
- return " *" ; // the only ODS string that matches
403
- break ;
404
- case UR_PLATFORM_BACKEND_FORCE_UINT32:
405
- return " " ; // no ODS string matches this
406
- break ;
407
- default :
408
- return " " ; // no ODS string matches this
409
- break ;
410
- }
411
- }();
412
392
413
393
using DeviceHardwareType = ur_device_type_t ;
414
394
@@ -483,18 +463,18 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
483
463
// Note the hPlatform -> platformBackend -> platformBackendName conversion
484
464
// above guarantees minimal sanity for the comparison with backend from the
485
465
// ODS string
486
- if (backend.front () != ' *' &&
487
- ! std::equal (platformBackendName. cbegin (), platformBackendName. cend (),
488
- backend. cbegin (), backend. cend () ,
489
- []( const auto &a, const auto &b ) {
490
- // case-insensitive comparison by converting both tolower
491
- return std::tolower ( static_cast < unsigned char >(a)) ==
492
- std::tolower ( static_cast < unsigned char >(b) );
493
- })) {
494
- // irrelevant term for current request: different backend -- silently
495
- // ignore
496
- logger::error ( " unrecognised backend '{}' " , backend) ;
497
- return UR_RESULT_ERROR_INVALID_VALUE;
466
+ if (backend.front () != ' *' ) {
467
+ auto cend = &adapters[ sizeof (adapters) / sizeof (adapters[ 0 ])];
468
+ auto found = std::find_if (adapters, cend,
469
+ [&]( auto &p ) { return p. second == backend; });
470
+ if (found == cend) {
471
+ // It's not a legal backend
472
+ logger::error ( " unrecognised backend '{}' " , backend );
473
+ return UR_RESULT_ERROR_INVALID_VALUE;
474
+ } else if (found-> first != platformBackend) {
475
+ // If it's a rule for a different backend, ignore it
476
+ continue ;
477
+ }
498
478
}
499
479
if (termPair.second .size () == 0 ) {
500
480
// malformed term: missing filterStrings -- output ERROR
0 commit comments