10
10
#include < uur/fixtures.h>
11
11
#include < uur/raii.h>
12
12
13
- struct urEnqueueEventsWaitMultiDeviceTest : uur::urMultiQueueMultiDeviceTest {
14
- void SetUp () override { SetUp (2 ); /* we need at least 2 devices */ }
13
+ void checkDevicesSupportSharedUSM (
14
+ const std::vector<ur_device_handle_t > &devices) {
15
+ for (auto device : devices) {
16
+ ur_device_usm_access_capability_flags_t shared_usm_single = 0 ;
17
+ EXPECT_SUCCESS (
18
+ uur::GetDeviceUSMSingleSharedSupport (device, shared_usm_single));
19
+ if (!shared_usm_single) {
20
+ GTEST_SKIP () << " Shared USM is not supported by the device." ;
21
+ }
22
+ }
23
+ }
15
24
16
- void SetUp (size_t minDevices) {
25
+ struct urEnqueueEventsWaitMultiDeviceTest
26
+ : uur::urMultiQueueMultiDeviceTest<2 > {
27
+ void SetUp () override {
17
28
UUR_RETURN_ON_FATAL_FAILURE (
18
- uur::urMultiQueueMultiDeviceTest::SetUp (minDevices));
19
-
20
- for (auto device : devices) {
21
- ur_device_usm_access_capability_flags_t shared_usm_single = 0 ;
22
- EXPECT_SUCCESS (uur::GetDeviceUSMSingleSharedSupport (
23
- device, shared_usm_single));
24
- if (!shared_usm_single) {
25
- GTEST_SKIP () << " Shared USM is not supported by the device." ;
26
- }
27
- }
29
+ uur::urMultiQueueMultiDeviceTest<2 >::SetUp ());
30
+
31
+ checkDevicesSupportSharedUSM (devices);
28
32
29
33
ptrs.resize (devices.size ());
30
34
for (size_t i = 0 ; i < devices.size (); i++) {
@@ -40,7 +44,7 @@ struct urEnqueueEventsWaitMultiDeviceTest : uur::urMultiQueueMultiDeviceTest {
40
44
}
41
45
}
42
46
UUR_RETURN_ON_FATAL_FAILURE (
43
- uur::urMultiQueueMultiDeviceTest::TearDown ());
47
+ uur::urMultiQueueMultiDeviceTest< 2 > ::TearDown ());
44
48
}
45
49
46
50
void initData () {
@@ -98,12 +102,11 @@ TEST_P(urEnqueueEventsWaitMultiDeviceTest, EnqueueWaitOnADifferentQueue) {
98
102
99
103
verifyData (ptrs[1 ], pattern);
100
104
}
101
- /*
105
+
102
106
struct urEnqueueEventsWaitMultiDeviceMTTest
103
- : urEnqueueEventsWaitMultiDeviceTest,
104
- testing::WithParamInterface<uur::BoolTestParam> {
107
+ : uur::urMultiQueueMultiDeviceTestWithParam<8 , uur::BoolTestParam> {
105
108
void doComputation (std::function<void (size_t )> work) {
106
- auto multiThread = GetParam ().value;
109
+ auto multiThread = getParam ().value ;
107
110
std::vector<std::thread> threads;
108
111
for (size_t i = 0 ; i < devices.size (); i++) {
109
112
if (multiThread) {
@@ -118,29 +121,50 @@ struct urEnqueueEventsWaitMultiDeviceMTTest
118
121
}
119
122
120
123
void SetUp () override {
121
- const size_t minDevices = 8;
122
- UUR_RETURN_ON_FATAL_FAILURE(
123
- urEnqueueEventsWaitMultiDeviceTest::SetUp(minDevices));
124
+ UUR_RETURN_ON_FATAL_FAILURE (uur::urMultiQueueMultiDeviceTestWithParam<
125
+ 8 , uur::BoolTestParam>::SetUp ());
126
+ checkDevicesSupportSharedUSM (devices);
127
+
128
+ ptrs.resize (devices.size ());
129
+ for (size_t i = 0 ; i < devices.size (); i++) {
130
+ EXPECT_SUCCESS (urUSMSharedAlloc (context, devices[i], nullptr ,
131
+ nullptr , size, &ptrs[i]));
132
+ }
124
133
}
125
134
126
- void TearDown() override { urEnqueueEventsWaitMultiDeviceTest::TearDown(); }
127
- };
135
+ void TearDown () override {
136
+ for (auto ptr : ptrs) {
137
+ if (ptr) {
138
+ EXPECT_SUCCESS (urUSMFree (context, ptr));
139
+ }
140
+ }
141
+ UUR_RETURN_ON_FATAL_FAILURE (uur::urMultiQueueMultiDeviceTestWithParam<
142
+ 8 , uur::BoolTestParam>::TearDown ());
143
+ }
144
+
145
+ void initData () {
146
+ EXPECT_SUCCESS (urEnqueueUSMFill (queues[0 ], ptrs[0 ], sizeof (pattern),
147
+ &pattern, size, 0 , nullptr , nullptr ));
148
+ EXPECT_SUCCESS (urQueueFinish (queues[0 ]));
149
+ }
128
150
129
- template <typename T>
130
- inline std::string
131
- printParams(const testing::TestParamInfo<typename T::ParamType> &info) {
132
- std::stringstream ss;
151
+ void verifyData (void *ptr, uint32_t pattern) {
152
+ for (size_t i = 0 ; i < count; i++) {
153
+ ASSERT_EQ (reinterpret_cast <uint32_t *>(ptr)[i], pattern);
154
+ }
155
+ }
133
156
134
- auto param1 = info.param;
135
- ss << (param1.value ? "" : "No") << param1.name;
157
+ uint32_t pattern = 42 ;
158
+ const size_t count = 1024 ;
159
+ const size_t size = sizeof (uint32_t ) * count;
136
160
137
- return ss.str() ;
138
- }
161
+ std::vector< void *> ptrs ;
162
+ };
139
163
140
- INSTANTIATE_TEST_SUITE_P (
141
- , urEnqueueEventsWaitMultiDeviceMTTest,
164
+ UUR_PLATFORM_TEST_SUITE_P (
165
+ urEnqueueEventsWaitMultiDeviceMTTest,
142
166
testing::ValuesIn (uur::BoolTestParam::makeBoolParam(" MultiThread" )),
143
- printParams<urEnqueueEventsWaitMultiDeviceMTTest >);
167
+ uur::platformTestWithParamPrinter<uur::BoolTestParam >);
144
168
145
169
TEST_P (urEnqueueEventsWaitMultiDeviceMTTest, EnqueueWaitSingleQueueMultiOps) {
146
170
std::vector<uint32_t > data (count, pattern);
@@ -216,4 +240,4 @@ TEST_P(urEnqueueEventsWaitMultiDeviceMTTest,
216
240
}
217
241
218
242
verifyData (ptrs[0 ], pattern);
219
- }*/
243
+ }
0 commit comments