Skip to content

Commit a44f811

Browse files
Dan Holmeskbenzie
authored andcommitted
Fix-ups for first batch of unit tests
1 parent b517076 commit a44f811

File tree

2 files changed

+71
-48
lines changed

2 files changed

+71
-48
lines changed

source/loader/ur_lib.cpp

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
219219
uint32_t *pNumDevices) {
220220

221221
if (!hPlatform) {
222-
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
222+
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
223223
}
224224
// NumEntries is max number of devices wanted by the caller (max usable length of phDevices)
225225
if (NumEntries < 0) {
@@ -230,9 +230,22 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
230230
}
231231
// pNumDevices is the actual number of device handles added to phDevices by this function
232232
if (NumEntries == 0 && !pNumDevices) {
233-
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
233+
return UR_RESULT_ERROR_INVALID_SIZE;
234234
}
235235

236+
switch (DeviceType) {
237+
case UR_DEVICE_TYPE_ALL:
238+
case UR_DEVICE_TYPE_GPU:
239+
case UR_DEVICE_TYPE_DEFAULT:
240+
case UR_DEVICE_TYPE_CPU:
241+
case UR_DEVICE_TYPE_FPGA:
242+
case UR_DEVICE_TYPE_MCA:
243+
break;
244+
default:
245+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
246+
//urPrint("Unknown device type");
247+
break;
248+
}
236249
// plan:
237250
// 0. basic validation of argument values (see code above)
238251
// 1. conversion of argument values into useful data items
@@ -267,42 +280,6 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
267280
// however
268281
// discard "2,*" == "*,2"
269282

270-
ur_platform_backend_t platformBackend;
271-
if (UR_RESULT_SUCCESS !=
272-
urPlatformGetInfo(hPlatform, UR_PLATFORM_INFO_BACKEND,
273-
sizeof(ur_platform_backend_t), &platformBackend, 0)) {
274-
return UR_RESULT_ERROR_INVALID_PLATFORM;
275-
}
276-
const std::string platformBackendName = // hPlatform->get_backend_name();
277-
[&platformBackend]() constexpr {
278-
switch (platformBackend) {
279-
case UR_PLATFORM_BACKEND_UNKNOWN:
280-
return "*"; // the only ODS string that matches
281-
break;
282-
case UR_PLATFORM_BACKEND_LEVEL_ZERO:
283-
return "level_zero";
284-
break;
285-
case UR_PLATFORM_BACKEND_OPENCL:
286-
return "opencl";
287-
break;
288-
case UR_PLATFORM_BACKEND_CUDA:
289-
return "cuda";
290-
break;
291-
case UR_PLATFORM_BACKEND_HIP:
292-
return "hip";
293-
break;
294-
case UR_PLATFORM_BACKEND_NATIVE_CPU:
295-
return "*"; // the only ODS string that matches
296-
break;
297-
case UR_PLATFORM_BACKEND_FORCE_UINT32:
298-
return ""; // no ODS string matches this
299-
break;
300-
default:
301-
return ""; // no ODS string matches this
302-
break;
303-
}
304-
}();
305-
306283
// The std::map is sorted by its key, so this method of parsing the ODS env var
307284
// alters the ordering of the terms, which makes it impossible to check whether
308285
// all discard terms appear after all accept terms and to preserve the ordering
@@ -314,7 +291,7 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
314291
// discard term, for that backend.
315292
// (If we wished to preserve the ordering of terms, we could replace `std::map`
316293
// with `std::queue<std::pair<key_type_t, value_type_t>>` or something similar.)
317-
auto &maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", true);
294+
auto maybeEnvVarMap = getenv_to_map("ONEAPI_DEVICE_SELECTOR", true);
318295

319296
// if the ODS env var is not set at all, then pretend it was set to the default
320297
using EnvVarMap = std::map<std::string, std::vector<std::string>>;
@@ -361,6 +338,42 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
361338
")$",
362339
std::regex_constants::icase);
363340

341+
ur_platform_backend_t platformBackend;
342+
if (UR_RESULT_SUCCESS !=
343+
urPlatformGetInfo(hPlatform, UR_PLATFORM_INFO_BACKEND,
344+
sizeof(ur_platform_backend_t), &platformBackend, 0)) {
345+
return UR_RESULT_ERROR_INVALID_PLATFORM;
346+
}
347+
const std::string platformBackendName = // hPlatform->get_backend_name();
348+
[&platformBackend]() constexpr {
349+
switch (platformBackend) {
350+
case UR_PLATFORM_BACKEND_UNKNOWN:
351+
return "*"; // the only ODS string that matches
352+
break;
353+
case UR_PLATFORM_BACKEND_LEVEL_ZERO:
354+
return "level_zero";
355+
break;
356+
case UR_PLATFORM_BACKEND_OPENCL:
357+
return "opencl";
358+
break;
359+
case UR_PLATFORM_BACKEND_CUDA:
360+
return "cuda";
361+
break;
362+
case UR_PLATFORM_BACKEND_HIP:
363+
return "hip";
364+
break;
365+
case UR_PLATFORM_BACKEND_NATIVE_CPU:
366+
return "*"; // the only ODS string that matches
367+
break;
368+
case UR_PLATFORM_BACKEND_FORCE_UINT32:
369+
return ""; // no ODS string matches this
370+
break;
371+
default:
372+
return ""; // no ODS string matches this
373+
break;
374+
}
375+
}();
376+
364377
using DeviceHardwareType = ur_device_type_t;
365378

366379
enum class DevicePartLevel { ROOT, SUB, SUBSUB };
@@ -772,11 +785,15 @@ ur_result_t urDeviceGetSelected(ur_platform_handle_t hPlatform,
772785
if (NumEntries == 0) {
773786
*pNumDevices = static_cast<uint32_t>(selectedDevices.size());
774787
} else if (NumEntries > 0) {
775-
*pNumDevices = static_cast<uint32_t>(
776-
std::min((size_t)NumEntries, selectedDevices.size()));
777-
std::copy_n(selectedDevices.cbegin(), *pNumDevices, phDevices);
788+
size_t numToCopy = std::min((size_t)NumEntries, selectedDevices.size());
789+
std::copy_n(selectedDevices.cbegin(), numToCopy, phDevices);
790+
if (pNumDevices != nullptr) {
791+
*pNumDevices = static_cast<uint32_t>(numToCopy);
792+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
793+
}
778794
}
779795

796+
780797
return UR_RESULT_SUCCESS;
781798
}
782799
} // namespace ur_lib

test/conformance/device/urDeviceGetSelected.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@ using urDeviceGetSelectedTest = uur::urPlatformTest;
99

1010
TEST_F(urDeviceGetSelectedTest, Success) {
1111
uint32_t count = 0;
12-
ASSERT_SUCCESS(
13-
urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count));
12+
ur_result_t res1 =
13+
urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count);
14+
ASSERT_EQ_RESULT(res1, UR_RESULT_SUCCESS);
1415
ASSERT_NE(count, 0);
1516
std::vector<ur_device_handle_t> devices(count);
16-
ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count,
17-
devices.data(), nullptr));
18-
for (auto device : devices) {
17+
ASSERT_NE(devices.size(), 0);
18+
ASSERT_NE(devices.data(), nullptr);
19+
//FAIL() << "devices.size() = " << devices.size() << " whereas count = " << count;
20+
ur_result_t res2 = urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count,
21+
devices.data(), nullptr);
22+
ASSERT_EQ_RESULT(res2, UR_RESULT_SUCCESS);
23+
for (auto &device : devices) {
1924
ASSERT_NE(nullptr, device);
2025
}
2126
}
@@ -25,7 +30,8 @@ TEST_F(urDeviceGetSelectedTest, SuccessSubsetOfDevices) {
2530
ASSERT_SUCCESS(
2631
urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, 0, nullptr, &count));
2732
if (count < 2) {
28-
GTEST_SKIP();
33+
GTEST_SKIP() << "There are fewer than two devices in total for the "
34+
"platform so the subset test is impossible";
2935
}
3036
std::vector<ur_device_handle_t> devices(count - 1);
3137
ASSERT_SUCCESS(urDeviceGetSelected(platform, UR_DEVICE_TYPE_ALL, count - 1,

0 commit comments

Comments
 (0)