3
3
// See LICENSE.TXT
4
4
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5
5
6
+ #include " helpers.h"
7
+
6
8
#include < uur/fixtures.h>
7
9
#include < uur/raii.h>
8
10
13
15
std::tuple<size_t , size_t , size_t > minL0DriverVersion = {1 , 3 , 29534 };
14
16
15
17
template <typename T>
16
- struct urMultiQueueLaunchMemcpyTest : uur::urMultiDeviceContextTestTemplate< 1 > ,
18
+ struct urMultiQueueLaunchMemcpyTest : uur::urMultiQueueMultiDeviceTest ,
17
19
testing::WithParamInterface<T> {
18
20
std::string KernelName;
19
21
std::vector<ur_program_handle_t > programs;
20
22
std::vector<ur_kernel_handle_t > kernels;
21
23
std::vector<void *> SharedMem;
22
24
23
- std::vector<ur_queue_handle_t > queues;
24
- std::vector<ur_device_handle_t > devices;
25
-
26
- std::function<void (void )> createQueues;
27
-
28
25
static constexpr char ProgramName[] = " increment" ;
29
26
static constexpr size_t ArraySize = 100 ;
30
27
static constexpr size_t InitialValue = 1 ;
31
28
32
- void SetUp () override {
33
- UUR_RETURN_ON_FATAL_FAILURE (
34
- uur::urMultiDeviceContextTestTemplate<1 >::SetUp ());
29
+ void SetUp () override { throw std::runtime_error (" Not implemented" ); }
35
30
36
- createQueues ();
31
+ void SetUp (std::vector<ur_device_handle_t > srcDevices,
32
+ size_t duplicateDevices) {
33
+ UUR_RETURN_ON_FATAL_FAILURE (uur::urMultiQueueMultiDeviceTest::SetUp (
34
+ srcDevices, duplicateDevices));
37
35
38
36
for (auto &device : devices) {
39
37
SKIP_IF_DRIVER_TOO_OLD (" Level-Zero" , minL0DriverVersion, platform,
@@ -87,9 +85,6 @@ struct urMultiQueueLaunchMemcpyTest : uur::urMultiDeviceContextTestTemplate<1>,
87
85
for (auto &Ptr : SharedMem) {
88
86
urUSMFree (context, Ptr);
89
87
}
90
- for (const auto &queue : queues) {
91
- EXPECT_SUCCESS (urQueueRelease (queue));
92
- }
93
88
for (const auto &kernel : kernels) {
94
89
urKernelRelease (kernel);
95
90
}
@@ -136,23 +131,8 @@ struct urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam
136
131
using urMultiQueueLaunchMemcpyTest<Param>::SharedMem;
137
132
138
133
void SetUp () override {
139
- this ->createQueues = [&] {
140
- for (size_t i = 0 ; i < duplicateDevices; i++) {
141
- devices.insert (
142
- devices.end (),
143
- uur::KernelsEnvironment::instance->devices .begin (),
144
- uur::KernelsEnvironment::instance->devices .end ());
145
- }
146
-
147
- for (auto &device : devices) {
148
- ur_queue_handle_t queue = nullptr ;
149
- ASSERT_SUCCESS (urQueueCreate (context, device, 0 , &queue));
150
- queues.push_back (queue);
151
- }
152
- };
153
-
154
- UUR_RETURN_ON_FATAL_FAILURE (
155
- urMultiQueueLaunchMemcpyTest<Param>::SetUp ());
134
+ UUR_RETURN_ON_FATAL_FAILURE (urMultiQueueLaunchMemcpyTest<Param>::SetUp (
135
+ uur::KernelsEnvironment::instance->devices , duplicateDevices));
156
136
}
157
137
158
138
void TearDown () override {
@@ -166,8 +146,6 @@ struct urEnqueueKernelLaunchIncrementTest
166
146
std::tuple<ur_device_handle_t , uur::BoolTestParam>> {
167
147
static constexpr size_t numOps = 50 ;
168
148
169
- ur_queue_handle_t queue;
170
-
171
149
using Param = std::tuple<ur_device_handle_t , uur::BoolTestParam>;
172
150
using urMultiQueueLaunchMemcpyTest<Param>::context;
173
151
using urMultiQueueLaunchMemcpyTest<Param>::queues;
@@ -176,26 +154,12 @@ struct urEnqueueKernelLaunchIncrementTest
176
154
using urMultiQueueLaunchMemcpyTest<Param>::SharedMem;
177
155
178
156
void SetUp () override {
179
- auto device = std::get<0 >(GetParam ());
180
-
181
- this ->createQueues = [&] {
182
- ASSERT_SUCCESS (urQueueCreate (context, device, 0 , &queue));
183
-
184
- // use the same queue and device for all operations
185
- for (size_t i = 0 ; i < numOps; i++) {
186
- urQueueRetain (queue);
187
-
188
- queues.push_back (queue);
189
- devices.push_back (device);
190
- }
191
- };
192
-
193
- UUR_RETURN_ON_FATAL_FAILURE (
194
- urMultiQueueLaunchMemcpyTest<Param>::SetUp ());
157
+ UUR_RETURN_ON_FATAL_FAILURE (urMultiQueueLaunchMemcpyTest<Param>::SetUp (
158
+ std::vector<ur_device_handle_t >{std::get<0 >(GetParam ())},
159
+ numOps)); // Use single device, duplicated numOps times
195
160
}
196
161
197
162
void TearDown () override {
198
- urQueueRelease (queue);
199
163
UUR_RETURN_ON_FATAL_FAILURE (
200
164
urMultiQueueLaunchMemcpyTest<Param>::TearDown ());
201
165
}
@@ -219,6 +183,9 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
219
183
ur_event_handle_t *kernelEvent = nullptr ;
220
184
ur_event_handle_t *memcpyEvent = nullptr ;
221
185
186
+ // This is a single device test
187
+ auto queue = queues[0 ];
188
+
222
189
for (size_t i = 0 ; i < numOps; i++) {
223
190
if (useEvents) {
224
191
lastMemcpyEvent = memcpyEvent;
0 commit comments