@@ -99,6 +99,29 @@ struct urMultiQueueLaunchMemcpyTest : uur::urMultiDeviceContextTestTemplate<1>,
99
99
UUR_RETURN_ON_FATAL_FAILURE (
100
100
uur::urMultiDeviceContextTestTemplate<1 >::TearDown ());
101
101
}
102
+
103
+ void runBackgroundCheck (std::vector<uur::raii::Event> &Events) {
104
+ std::vector<std::thread> threads;
105
+ for (size_t i = 0 ; i < Events.size (); i++) {
106
+ threads.emplace_back ([&, i] {
107
+ ur_event_status_t status;
108
+ do {
109
+ ASSERT_SUCCESS (urEventGetInfo (
110
+ Events[i].get (), UR_EVENT_INFO_COMMAND_EXECUTION_STATUS,
111
+ sizeof (ur_event_status_t ), &status, nullptr ));
112
+ } while (status != UR_EVENT_STATUS_COMPLETE);
113
+
114
+ auto ExpectedValue = InitialValue + i + 1 ;
115
+ for (uint32_t j = 0 ; j < ArraySize; ++j) {
116
+ ASSERT_EQ (reinterpret_cast <uint32_t *>(SharedMem[i])[j],
117
+ ExpectedValue);
118
+ }
119
+ });
120
+ }
121
+ for (auto &thread : threads) {
122
+ thread.join ();
123
+ }
124
+ }
102
125
};
103
126
104
127
template <typename Param>
@@ -189,26 +212,24 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
189
212
190
213
auto useEvents = std::get<1 >(GetParam ()).value ;
191
214
192
- std::vector<uur::raii::Event> Events (numOps * 2 );
193
- for (size_t i = 0 ; i < numOps; i++) {
194
- size_t waitNum = 0 ;
195
- ur_event_handle_t *lastEvent = nullptr ;
196
- ur_event_handle_t *kernelEvent = nullptr ;
197
- ur_event_handle_t *memcpyEvent = nullptr ;
215
+ std::vector<uur::raii::Event> kernelEvents (numOps);
216
+ std::vector<uur::raii::Event> memcpyEvents (numOps - 1 );
198
217
199
- if (useEvents) {
200
- // Events are: kernelEvent0, memcpyEvent0, kernelEvent1, ...
201
- waitNum = i > 0 ? 1 : 0 ;
202
- lastEvent = i > 0 ? Events[i * 2 - 1 ].ptr () : nullptr ;
218
+ ur_event_handle_t *lastMemcpyEvent = nullptr ;
219
+ ur_event_handle_t *kernelEvent = nullptr ;
220
+ ur_event_handle_t *memcpyEvent = nullptr ;
203
221
204
- kernelEvent = Events[i * 2 ].ptr ();
205
- memcpyEvent = Events[i * 2 + 1 ].ptr ();
222
+ for (size_t i = 0 ; i < numOps; i++) {
223
+ if (useEvents) {
224
+ lastMemcpyEvent = memcpyEvent;
225
+ kernelEvent = kernelEvents[i].ptr ();
226
+ memcpyEvent = i < numOps - 1 ? memcpyEvents[i].ptr () : nullptr ;
206
227
}
207
228
208
229
// execute kernel that increments each element by 1
209
230
ASSERT_SUCCESS (urEnqueueKernelLaunch (
210
231
queue, kernels[i], n_dimensions, &global_offset, &ArraySize,
211
- nullptr , waitNum, lastEvent , kernelEvent));
232
+ nullptr , bool (lastMemcpyEvent), lastMemcpyEvent , kernelEvent));
212
233
213
234
// copy the memory (input for the next kernel)
214
235
if (i < numOps - 1 ) {
@@ -220,11 +241,9 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
220
241
}
221
242
222
243
if (useEvents) {
223
- // TODO: just wait on the last event, once urEventWait is implemented
224
- // by V2 L0 adapter
225
- urQueueFinish (queue);
244
+ ASSERT_SUCCESS (urEventWait (1 , kernelEvents.back ().ptr ()));
226
245
} else {
227
- urQueueFinish (queue);
246
+ ASSERT_SUCCESS ( urQueueFinish (queue) );
228
247
}
229
248
230
249
size_t ExpectedValue = InitialValue;
@@ -237,12 +256,41 @@ TEST_P(urEnqueueKernelLaunchIncrementTest, Success) {
237
256
}
238
257
}
239
258
240
- struct VoidParam {};
259
+ template <typename T>
260
+ inline std::string
261
+ printParams (const testing::TestParamInfo<typename T::ParamType> &info) {
262
+ std::stringstream ss;
263
+
264
+ auto param1 = std::get<0 >(info.param );
265
+ ss << (param1.value ? " " : " No" ) << param1.name ;
266
+
267
+ auto param2 = std::get<1 >(info.param );
268
+ ss << (param2.value ? " " : " No" ) << param2.name ;
269
+
270
+ if constexpr (std::tuple_size_v < typename T::ParamType >> 2 ) {
271
+ auto param3 = std::get<2 >(info.param );
272
+ }
273
+
274
+ return ss.str ();
275
+ }
276
+
241
277
using urEnqueueKernelLaunchIncrementMultiDeviceTest =
242
- urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<VoidParam>;
278
+ urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<
279
+ std::tuple<uur::BoolTestParam, uur::BoolTestParam>>;
280
+
281
+ INSTANTIATE_TEST_SUITE_P (
282
+ , urEnqueueKernelLaunchIncrementMultiDeviceTest,
283
+ testing::Combine (
284
+ testing::ValuesIn (uur::BoolTestParam::makeBoolParam(" UseEventWait" )),
285
+ testing::ValuesIn(
286
+ uur::BoolTestParam::makeBoolParam (" RunBackgroundCheck" ))),
287
+ printParams<urEnqueueKernelLaunchIncrementMultiDeviceTest>);
243
288
244
289
// Do a chain of kernelLaunch(dev0) -> memcpy(dev0, dev1) -> kernelLaunch(dev1) ... ops
245
- TEST_F (urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
290
+ TEST_P (urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
291
+ auto waitOnEvent = std::get<0 >(GetParam ()).value ;
292
+ auto runBackgroundCheck = std::get<1 >(GetParam ()).value ;
293
+
246
294
size_t returned_size;
247
295
ASSERT_SUCCESS (urDeviceGetInfo (devices[0 ], UR_DEVICE_INFO_EXTENSIONS, 0 ,
248
296
nullptr , &returned_size));
@@ -265,19 +313,22 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
265
313
constexpr size_t global_offset = 0 ;
266
314
constexpr size_t n_dimensions = 1 ;
267
315
268
- std::vector<uur::raii::Event> Events (devices.size () * 2 );
316
+ std::vector<uur::raii::Event> kernelEvents (devices.size ());
317
+ std::vector<uur::raii::Event> memcpyEvents (devices.size () - 1 );
318
+
319
+ ur_event_handle_t *lastMemcpyEvent = nullptr ;
320
+ ur_event_handle_t *kernelEvent = nullptr ;
321
+ ur_event_handle_t *memcpyEvent = nullptr ;
322
+
269
323
for (size_t i = 0 ; i < devices.size (); i++) {
270
- // Events are: kernelEvent0, memcpyEvent0, kernelEvent1, ...
271
- size_t waitNum = i > 0 ? 1 : 0 ;
272
- ur_event_handle_t *lastEvent =
273
- i > 0 ? Events[i * 2 - 1 ].ptr () : nullptr ;
274
- ur_event_handle_t *kernelEvent = Events[i * 2 ].ptr ();
275
- ur_event_handle_t *memcpyEvent = Events[i * 2 + 1 ].ptr ();
324
+ lastMemcpyEvent = memcpyEvent;
325
+ kernelEvent = kernelEvents[i].ptr ();
326
+ memcpyEvent = i < devices.size () - 1 ? memcpyEvents[i].ptr () : nullptr ;
276
327
277
328
// execute kernel that increments each element by 1
278
329
ASSERT_SUCCESS (urEnqueueKernelLaunch (
279
330
queues[i], kernels[i], n_dimensions, &global_offset, &ArraySize,
280
- nullptr , waitNum, lastEvent , kernelEvent));
331
+ nullptr , bool (lastMemcpyEvent), lastMemcpyEvent , kernelEvent));
281
332
282
333
// copy the memory to next device
283
334
if (i < devices.size () - 1 ) {
@@ -287,9 +338,18 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
287
338
}
288
339
}
289
340
290
- // synchronize on the last queue only, this has to ensure all the operations
341
+ // While the device(s) execute, loop over the events and if completed, verify the results
342
+ if (runBackgroundCheck) {
343
+ this ->runBackgroundCheck (kernelEvents);
344
+ }
345
+
346
+ // synchronize on the last queue/event only, this has to ensure all the operations
291
347
// are completed
292
- urQueueFinish (queues.back ());
348
+ if (waitOnEvent) {
349
+ ASSERT_SUCCESS (urEventWait (1 , kernelEvents.back ().ptr ()));
350
+ } else {
351
+ ASSERT_SUCCESS (urQueueFinish (queues.back ()));
352
+ }
293
353
294
354
size_t ExpectedValue = InitialValue;
295
355
for (size_t i = 0 ; i < devices.size (); i++) {
@@ -301,20 +361,6 @@ TEST_F(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
301
361
}
302
362
}
303
363
304
- template <typename T>
305
- inline std::string
306
- printParams (const testing::TestParamInfo<typename T::ParamType> &info) {
307
- std::stringstream ss;
308
-
309
- auto param1 = std::get<0 >(info.param );
310
- auto param2 = std::get<1 >(info.param );
311
-
312
- ss << (param1.value ? " " : " No" ) << param1.name ;
313
- ss << (param2.value ? " " : " No" ) << param2.name ;
314
-
315
- return ss.str ();
316
- }
317
-
318
364
using urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest =
319
365
urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<
320
366
std::tuple<uur::BoolTestParam, uur::BoolTestParam>>;
@@ -374,9 +420,11 @@ TEST_P(urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest, Success) {
374
420
ArraySize * sizeof (uint32_t ), useEvents,
375
421
lastEvent, signalEvent));
376
422
377
- urQueueFinish (queue);
378
- // TODO: when useEvents is implemented for L0 v2 adapter
379
- // wait on event instead
423
+ if (useEvents) {
424
+ ASSERT_SUCCESS (urEventWait (1 , Events.back ().ptr ()));
425
+ } else {
426
+ ASSERT_SUCCESS (urQueueFinish (queue));
427
+ }
380
428
381
429
size_t ExpectedValue = InitialValue;
382
430
ExpectedValue += numOpsPerThread;
0 commit comments