Skip to content

Commit 588966f

Browse files
Dan Holmeskbenzie
authored andcommitted
Adding tests for ODS function urDeviceGetSelected
1 parent 085d21b commit 588966f

File tree

2 files changed

+260
-38
lines changed

2 files changed

+260
-38
lines changed

source/loader/ur_lib.cpp

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,12 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
292292
// discard term, for that backend.
293293
// (If we wished to preserve the ordering of terms, we could replace `std::map`
294294
// with `std::queue<std::pair<key_type_t, value_type_t>>` or something similar.)
295-
auto maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", true);
295+
auto maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", false);
296+
std::cout
297+
<< "DEBUG: " << (maybeEnvVarMap.has_value()
298+
? "getenv_to_map parsed env var and produced a map"
299+
: "getenv_to_map parsed env var and failed to produce a map")
300+
<< std::endl;
296301

297302
// if the ODS env var is not set at all, then pretend it was set to the default
298303
using EnvVarMap = std::map<std::string, std::vector<std::string>>;
@@ -419,53 +424,63 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
419424
std::vector<DeviceSpec> acceptDeviceList;
420425
std::vector<DeviceSpec> discardDeviceList;
421426

422-
std::vector<std::string> acceptFilters;
423-
std::vector<std::string> discardFilters;
424427
for (auto &termPair : mapODS) {
425428
std::string backend = termPair.first;
426-
if (backend.empty()) {
429+
if (backend.empty()) { // FIXME: never true because getenv_to_map rejects this case
427430
// malformed term: missing backend -- output ERROR, then continue
428431
// TODO: replace std::cout with URT message output mechanism
429432
std::cout << "ERROR: missing backend, format of filter = "
430-
"'[!]backend:filterStrings'";
433+
"'[!]backend:filterStrings'"
434+
<< std::endl;
431435
continue;
432436
}
433437
enum FilterType {
434438
AcceptFilter,
435439
DiscardFilter,
436440
} termType = (backend.front() != '!') ? AcceptFilter : DiscardFilter;
437-
auto &deviceList = acceptDeviceList;
441+
std::cout << "DEBUG: termType is"
442+
<< (termType != AcceptFilter ? "DiscardFilter"
443+
: "AcceptFilter")
444+
<< std::endl;
445+
auto &deviceList =
446+
(termType != AcceptFilter) ? discardDeviceList : acceptDeviceList;
438447
if (termType != AcceptFilter) {
448+
std::cout << "DEBUG: backend was '" << backend << "'" << std::endl;
439449
backend.erase(backend.cbegin());
440-
deviceList = discardDeviceList;
450+
std::cout << "DEBUG: backend now '" << backend << "'" << std::endl;
441451
}
442452
// Note the hPlatform -> platformBackend -> platformBackendName conversion above
443453
// guarantees minimal sanity for the comparison with backend from the ODS string
444-
if (backend != "*" &&
445-
std::equal(platformBackendName.cbegin(), platformBackendName.cend(),
454+
if (backend.front() != '*' &&
455+
!std::equal(platformBackendName.cbegin(), platformBackendName.cend(),
446456
backend.cbegin(), backend.cend(),
447457
[](const auto &a, const auto &b) {
448458
// case-insensitive comparison by converting both tolower
449459
return std::tolower(static_cast<unsigned char>(a)) ==
450460
std::tolower(static_cast<unsigned char>(b));
451461
})) {
452462
// irrelevant term for current request: different backend -- silently ignore
463+
// TODO: replace std::cout with URT message output mechanism
464+
std::cout << "WARNING: ignoring term with irrelevant backend"
465+
<< std::endl;
453466
continue;
454467
}
455468
if (termPair.second.size() == 0) {
456469
// malformed term: missing filterStrings -- output ERROR, then continue
457470
// TODO: replace std::cout with URT message output mechanism
458471
std::cout << "ERROR missing filterStrings, format of filter = "
459-
"'[!]backend:filterStrings'";
472+
"'[!]backend:filterStrings'"
473+
<< std::endl;
460474
continue;
461475
}
462476
if (std::find_if(termPair.second.cbegin(), termPair.second.cend(),
463477
[](const auto &s) { return s.empty(); }) !=
464-
termPair.second.cend()) {
478+
termPair.second.cend()) { // FIXME: never true because getenv_to_map rejects this case
465479
// malformed term: missing filterString -- output warning, then continue
466480
// TODO: replace std::cout with URT message output mechanism
467481
std::cout << "WARNING: empty filterString, format of filterStrings "
468-
"= 'filterString[,filterString[,...]]'";
482+
"= 'filterString[,filterString[,...]]'"
483+
<< std::endl;
469484
continue;
470485
}
471486
if (std::find_if(termPair.second.cbegin(), termPair.second.cend(),
@@ -475,7 +490,8 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
475490
// malformed term: too many dots in filterString -- output warning, then continue
476491
// TODO: replace std::cout with URT message output mechanism
477492
std::cout << "WARNING: too many dots in filterString, format of "
478-
"filterString = 'root[.sub[.subsub]]'";
493+
"filterString = 'root[.sub[.subsub]]'"
494+
<< std::endl;
479495
continue;
480496
}
481497
if (std::find_if(
@@ -497,7 +513,8 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
497513
// malformed term: star dot no-star in filterString -- output warning, then continue
498514
// TODO: replace std::cout with URT message output mechanism
499515
std::cout
500-
<< "WARNING: invalid wildcard in filterString, '*.' => '*.*'";
516+
<< "WARNING: invalid wildcard in filterString, '*.' => '*.*'"
517+
<< std::endl;
501518
continue;
502519
}
503520

@@ -511,7 +528,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
511528
const auto firstDeviceId = getDeviceId(firstPart);
512529
// first dot found, look for another
513530
std::string::size_type locationDot2 =
514-
filterString.find('.', locationDot1);
531+
filterString.find('.', locationDot1+1);
515532
std::string secondPart = filterString.substr(
516533
locationDot1 + 1, locationDot2 == std::string::npos
517534
? std::string::npos
@@ -539,25 +556,26 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
539556
hardwareType, firstDeviceId});
540557
}
541558
}
542-
543-
if (termType != AcceptFilter) {
544-
discardFilters.insert(discardFilters.end(),
545-
termPair.second.cbegin(),
546-
termPair.second.cend());
547-
} else {
548-
acceptFilters.insert(acceptFilters.end(), termPair.second.cbegin(),
549-
termPair.second.cend());
550-
}
551559
}
552560

553-
// if no accept filters are specified by the user, we must add a default "all root devices"
554-
if (acceptFilters.size() == 0) {
555-
acceptFilters.insert(acceptFilters.end(), 1, "*");
556-
}
557-
if (acceptDeviceList.size() == 0) {
561+
if (acceptDeviceList.size() == 0 && discardDeviceList.size() == 0) {
562+
// nothing in env var was understood as a valid term
563+
return UR_RESULT_ERROR_INVALID_VALUE;
564+
} else if (acceptDeviceList.size() == 0) {
565+
// no accept terms were understood, but at least one discard term was
566+
// we are magnanimous to the user when there were bad/ignored accept terms
567+
// by pretending there were no bad/ignored accept terms in the env var
568+
// for example, we pretend that "garbage:0;!cuda:*" was just "!cuda:*"
569+
// so we add an implicit accept-all term (equivalent to prepending "*:*;")
570+
// as we would have done if the user had given us the corrected string
558571
acceptDeviceList.push_back(DeviceSpec{
559572
DevicePartLevel::ROOT, ::UR_DEVICE_TYPE_ALL, DeviceIdTypeALL});
560573
}
574+
575+
std::cout << "DEBUG: size of acceptDeviceList = " << acceptDeviceList.size()
576+
<< std::endl
577+
<< "DEBUG: size of discardDeviceList = "
578+
<< discardDeviceList.size() << std::endl;
561579

562580
std::vector<DeviceSpec> rootDevices;
563581
std::vector<DeviceSpec> subDevices;
@@ -688,30 +706,46 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
688706
// if this is a subsubdevice filter, then it must be '*.*.*'
689707
matches = (filter.hwType == device.hwType) ||
690708
(filter.hwType == DeviceHardwareType::UR_DEVICE_TYPE_ALL);
709+
std::cout << "DEBUG: In ApplyFilter, if block case 1, matches = "
710+
<< matches << std::endl;
691711
} else if (filter.rootId != device.rootId) {
692712
// root part in filter is a number but does not match the number in the root part of device
693713
matches = false;
714+
std::cout << "DEBUG: In ApplyFilter, if block case 2, matches = "
715+
<< matches << std::endl;
694716
} else if (filter.level == DevicePartLevel::ROOT) {
695717
// this is a root device filter with a number that matches
696718
matches = true;
719+
std::cout << "DEBUG: In ApplyFilter, if block case 3, matches = "
720+
<< matches << std::endl;
697721
} else if (filter.subId == DeviceIdTypeALL) {
698722
// sub type of star always matches (when root part matches, which we already know here)
699723
// if this is a subdevice filter, then it must be 'matches.*'
700724
// if this is a subsubdevice filter, then it must be 'matches.*.*'
701725
matches = true;
726+
std::cout << "DEBUG: In ApplyFilter, if block case 4, matches = "
727+
<< matches << std::endl;
702728
} else if (filter.subId != device.subId) {
703729
// sub part in filter is a number but does not match the number in the sub part of device
704730
matches = false;
731+
std::cout << "DEBUG: In ApplyFilter, if block case 5, matches = "
732+
<< matches << std::endl;
705733
} else if (filter.level == DevicePartLevel::SUB) {
706734
// this is a sub device number filter, numbers match in both parts
707735
matches = true;
736+
std::cout << "DEBUG: In ApplyFilter, if block case 6, matches = "
737+
<< matches << std::endl;
708738
} else if (filter.subsubId == DeviceIdTypeALL) {
709739
// subsub type of star always matches (when other parts match, which we already know here)
710740
// this is a subsub device filter, it must be 'matches.matches.*'
711741
matches = true;
742+
std::cout << "DEBUG: In ApplyFilter, if block case 7, matches = "
743+
<< matches << std::endl;
712744
} else {
713745
// this is a subsub device filter, numbers in all three parts match
714746
matches = (filter.subsubId == device.subsubId);
747+
std::cout << "DEBUG: In ApplyFilter, if block case 8, matches = "
748+
<< matches << std::endl;
715749
}
716750
return matches;
717751
};
@@ -759,6 +793,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
759793
}
760794
return matches;
761795
};
796+
auto numAlreadySelected = selectedDevices.size();
762797
if (accept.level == DevicePartLevel::ROOT) {
763798
rootDevices.erase(std::remove_if(rootDevices.begin(),
764799
rootDevices.end(),
@@ -777,6 +812,13 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
777812
ApplyAcceptFilter),
778813
subSubDevices.end());
779814
}
815+
if (numAlreadySelected == selectedDevices.size()) {
816+
std::cout << "WARNING: an accept term was ignored because it "
817+
"does not select any additional devices"
818+
"selectedDevices.size() = "
819+
<< selectedDevices.size()
820+
<< std::endl;
821+
}
780822
}
781823

782824
// selectedDevices is now a vector containing all the right device handles

0 commit comments

Comments
 (0)