Skip to content

Commit bdb235e

Browse files
committed
Fix MultiQueueMultiDevice param tests.
1 parent e3cddda commit bdb235e

File tree

6 files changed

+186
-89
lines changed

6 files changed

+186
-89
lines changed

test/conformance/enqueue/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ add_conformance_test_with_kernels_environment(enqueue
1010
urEnqueueEventsWaitMultiDevice.cpp
1111
urEnqueueEventsWaitWithBarrier.cpp
1212
urEnqueueKernelLaunch.cpp
13-
#urEnqueueKernelLaunchAndMemcpyInOrder.cpp
13+
urEnqueueKernelLaunchAndMemcpyInOrder.cpp
1414
urEnqueueMemBufferCopyRect.cpp
1515
urEnqueueMemBufferCopy.cpp
1616
urEnqueueMemBufferFill.cpp

test/conformance/enqueue/helpers.h

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,17 @@ printFillTestString(const testing::TestParamInfo<typename T::ParamType> &info) {
154154
return test_name.str();
155155
}
156156

157-
struct urMultiQueueMultiDeviceTest : uur::urMultiDeviceContextTestTemplate<1> {
158-
void initQueues(size_t minDevices) {
157+
// Similar to urMultiDeviceContextTestWithParam this fixture allows a min
158+
// device count to be specified, but in this case we duplicate existing
159+
// devices to reach the min device count rather than skipping if it isn't met.
160+
template <size_t minDevices>
161+
struct urMultiQueueMultiDeviceTest : uur::urAllDevicesTest {
162+
void SetUp() override {
163+
UUR_RETURN_ON_FATAL_FAILURE(uur::urAllDevicesTest::SetUp());
164+
165+
ASSERT_SUCCESS(
166+
urContextCreate(devices.size(), devices.data(), nullptr, &context));
167+
159168
// Duplicate our devices until we hit the minimum size specified.
160169
auto srcDevices = devices;
161170
while (devices.size() < minDevices) {
@@ -169,31 +178,49 @@ struct urMultiQueueMultiDeviceTest : uur::urMultiDeviceContextTestTemplate<1> {
169178
}
170179
}
171180

172-
// Default implementation that uses all available devices
173-
void SetUp() override {
174-
UUR_RETURN_ON_FATAL_FAILURE(
175-
uur::urMultiDeviceContextTestTemplate<1>::SetUp());
176-
initQueues(1);
181+
void TearDown() override {
182+
for (auto &queue : queues) {
183+
EXPECT_SUCCESS(urQueueRelease(queue));
184+
}
185+
UUR_RETURN_ON_FATAL_FAILURE(uur::urAllDevicesTest::TearDown());
177186
}
178187

179-
// Specialized implementation that duplicates all devices and queues
180-
void SetUp(size_t numDuplicate) {
181-
UUR_RETURN_ON_FATAL_FAILURE(
182-
uur::urMultiDeviceContextTestTemplate<1>::SetUp());
183-
initQueues(numDuplicate);
188+
ur_context_handle_t context;
189+
std::vector<ur_queue_handle_t> queues;
190+
};
191+
192+
template <size_t minDevices, class T>
193+
struct urMultiQueueMultiDeviceTestWithParam
194+
: uur::urAllDevicesTestWithParam<T> {
195+
using uur::urAllDevicesTestWithParam<T>::devices;
196+
void SetUp() override {
197+
UUR_RETURN_ON_FATAL_FAILURE(uur::urAllDevicesTestWithParam<T>::SetUp());
198+
199+
ASSERT_SUCCESS(
200+
urContextCreate(devices.size(), devices.data(), nullptr, &context));
201+
202+
// Duplicate our devices until we hit the minimum size specified.
203+
auto srcDevices = devices;
204+
while (devices.size() < minDevices) {
205+
devices.insert(devices.end(), srcDevices.begin(), srcDevices.end());
206+
}
207+
208+
for (auto &device : devices) {
209+
ur_queue_handle_t queue = nullptr;
210+
ASSERT_SUCCESS(urQueueCreate(context, device, nullptr, &queue));
211+
queues.push_back(queue);
212+
}
184213
}
185214

186215
void TearDown() override {
187216
for (auto &queue : queues) {
188217
EXPECT_SUCCESS(urQueueRelease(queue));
189218
}
190219
UUR_RETURN_ON_FATAL_FAILURE(
191-
uur::urMultiDeviceContextTestTemplate<1>::TearDown());
220+
uur::urAllDevicesTestWithParam<T>::TearDown());
192221
}
193-
std::function<std::tuple<std::vector<ur_device_handle_t>,
194-
std::vector<ur_queue_handle_t>>(void)>
195-
makeQueues;
196222

223+
ur_context_handle_t context;
197224
std::vector<ur_queue_handle_t> queues;
198225
};
199226

test/conformance/enqueue/urEnqueueEventsWaitMultiDevice.cpp

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,25 @@
1010
#include <uur/fixtures.h>
1111
#include <uur/raii.h>
1212

13-
struct urEnqueueEventsWaitMultiDeviceTest : uur::urMultiQueueMultiDeviceTest {
14-
void SetUp() override { SetUp(2); /* we need at least 2 devices */ }
13+
void checkDevicesSupportSharedUSM(
14+
const std::vector<ur_device_handle_t> &devices) {
15+
for (auto device : devices) {
16+
ur_device_usm_access_capability_flags_t shared_usm_single = 0;
17+
EXPECT_SUCCESS(
18+
uur::GetDeviceUSMSingleSharedSupport(device, shared_usm_single));
19+
if (!shared_usm_single) {
20+
GTEST_SKIP() << "Shared USM is not supported by the device.";
21+
}
22+
}
23+
}
1524

16-
void SetUp(size_t minDevices) {
25+
struct urEnqueueEventsWaitMultiDeviceTest
26+
: uur::urMultiQueueMultiDeviceTest<2> {
27+
void SetUp() override {
1728
UUR_RETURN_ON_FATAL_FAILURE(
18-
uur::urMultiQueueMultiDeviceTest::SetUp(minDevices));
19-
20-
for (auto device : devices) {
21-
ur_device_usm_access_capability_flags_t shared_usm_single = 0;
22-
EXPECT_SUCCESS(uur::GetDeviceUSMSingleSharedSupport(
23-
device, shared_usm_single));
24-
if (!shared_usm_single) {
25-
GTEST_SKIP() << "Shared USM is not supported by the device.";
26-
}
27-
}
29+
uur::urMultiQueueMultiDeviceTest<2>::SetUp());
30+
31+
checkDevicesSupportSharedUSM(devices);
2832

2933
ptrs.resize(devices.size());
3034
for (size_t i = 0; i < devices.size(); i++) {
@@ -40,7 +44,7 @@ struct urEnqueueEventsWaitMultiDeviceTest : uur::urMultiQueueMultiDeviceTest {
4044
}
4145
}
4246
UUR_RETURN_ON_FATAL_FAILURE(
43-
uur::urMultiQueueMultiDeviceTest::TearDown());
47+
uur::urMultiQueueMultiDeviceTest<2>::TearDown());
4448
}
4549

4650
void initData() {
@@ -98,12 +102,11 @@ TEST_P(urEnqueueEventsWaitMultiDeviceTest, EnqueueWaitOnADifferentQueue) {
98102

99103
verifyData(ptrs[1], pattern);
100104
}
101-
/*
105+
102106
struct urEnqueueEventsWaitMultiDeviceMTTest
103-
: urEnqueueEventsWaitMultiDeviceTest,
104-
testing::WithParamInterface<uur::BoolTestParam> {
107+
: uur::urMultiQueueMultiDeviceTestWithParam<8, uur::BoolTestParam> {
105108
void doComputation(std::function<void(size_t)> work) {
106-
auto multiThread = GetParam().value;
109+
auto multiThread = getParam().value;
107110
std::vector<std::thread> threads;
108111
for (size_t i = 0; i < devices.size(); i++) {
109112
if (multiThread) {
@@ -118,29 +121,50 @@ struct urEnqueueEventsWaitMultiDeviceMTTest
118121
}
119122

120123
void SetUp() override {
121-
const size_t minDevices = 8;
122-
UUR_RETURN_ON_FATAL_FAILURE(
123-
urEnqueueEventsWaitMultiDeviceTest::SetUp(minDevices));
124+
UUR_RETURN_ON_FATAL_FAILURE(uur::urMultiQueueMultiDeviceTestWithParam<
125+
8, uur::BoolTestParam>::SetUp());
126+
checkDevicesSupportSharedUSM(devices);
127+
128+
ptrs.resize(devices.size());
129+
for (size_t i = 0; i < devices.size(); i++) {
130+
EXPECT_SUCCESS(urUSMSharedAlloc(context, devices[i], nullptr,
131+
nullptr, size, &ptrs[i]));
132+
}
124133
}
125134

126-
void TearDown() override { urEnqueueEventsWaitMultiDeviceTest::TearDown(); }
127-
};
135+
void TearDown() override {
136+
for (auto ptr : ptrs) {
137+
if (ptr) {
138+
EXPECT_SUCCESS(urUSMFree(context, ptr));
139+
}
140+
}
141+
UUR_RETURN_ON_FATAL_FAILURE(uur::urMultiQueueMultiDeviceTestWithParam<
142+
8, uur::BoolTestParam>::TearDown());
143+
}
144+
145+
void initData() {
146+
EXPECT_SUCCESS(urEnqueueUSMFill(queues[0], ptrs[0], sizeof(pattern),
147+
&pattern, size, 0, nullptr, nullptr));
148+
EXPECT_SUCCESS(urQueueFinish(queues[0]));
149+
}
128150

129-
template <typename T>
130-
inline std::string
131-
printParams(const testing::TestParamInfo<typename T::ParamType> &info) {
132-
std::stringstream ss;
151+
void verifyData(void *ptr, uint32_t pattern) {
152+
for (size_t i = 0; i < count; i++) {
153+
ASSERT_EQ(reinterpret_cast<uint32_t *>(ptr)[i], pattern);
154+
}
155+
}
133156

134-
auto param1 = info.param;
135-
ss << (param1.value ? "" : "No") << param1.name;
157+
uint32_t pattern = 42;
158+
const size_t count = 1024;
159+
const size_t size = sizeof(uint32_t) * count;
136160

137-
return ss.str();
138-
}
161+
std::vector<void *> ptrs;
162+
};
139163

140-
INSTANTIATE_TEST_SUITE_P(
141-
, urEnqueueEventsWaitMultiDeviceMTTest,
164+
UUR_PLATFORM_TEST_SUITE_P(
165+
urEnqueueEventsWaitMultiDeviceMTTest,
142166
testing::ValuesIn(uur::BoolTestParam::makeBoolParam("MultiThread")),
143-
printParams<urEnqueueEventsWaitMultiDeviceMTTest>);
167+
uur::platformTestWithParamPrinter<uur::BoolTestParam>);
144168

145169
TEST_P(urEnqueueEventsWaitMultiDeviceMTTest, EnqueueWaitSingleQueueMultiOps) {
146170
std::vector<uint32_t> data(count, pattern);
@@ -216,4 +240,4 @@ TEST_P(urEnqueueEventsWaitMultiDeviceMTTest,
216240
}
217241

218242
verifyData(ptrs[0], pattern);
219-
}*/
243+
}

0 commit comments

Comments
 (0)