Skip to content

Commit af818da

Browse files
[NFC][SYCL] Prepare unittests/scheduler for getSyclObjImpl to return raw ref (#19219)
I'm planning to change `getSyclObjImpl` to return a raw reference in a later patch, uploading a bunch of PRs in preparation to that to make the subsequent review easier.
1 parent e452bf0 commit af818da

9 files changed

+167
-170
lines changed

sycl/unittests/scheduler/BarrierDependencies.cpp

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,9 @@ TEST_F(SchedulerTest, BarrierWithDependsOn) {
6969

7070
auto EventA =
7171
QueueA.submit([&](sycl::handler &h) { h.ext_oneapi_barrier(); });
72-
std::shared_ptr<detail::event_impl> EventAImpl =
73-
detail::getSyclObjImpl(EventA);
72+
detail::event_impl &EventAImpl = *detail::getSyclObjImpl(EventA);
7473
// it means that command is enqueued
75-
ASSERT_NE(EventAImpl->getHandle(), nullptr);
74+
ASSERT_NE(EventAImpl.getHandle(), nullptr);
7675

7776
ASSERT_FALSE(EventsWaitVisited);
7877
ASSERT_TRUE(BarrierEventsWaitVisited);
@@ -83,14 +82,13 @@ TEST_F(SchedulerTest, BarrierWithDependsOn) {
8382
h.depends_on(EventA);
8483
h.ext_oneapi_barrier();
8584
});
86-
std::shared_ptr<detail::event_impl> EventBImpl =
87-
detail::getSyclObjImpl(EventB);
85+
detail::event_impl &EventBImpl = *detail::getSyclObjImpl(EventB);
8886
// it means that command is enqueued
89-
ASSERT_NE(EventBImpl->getHandle(), nullptr);
87+
ASSERT_NE(EventBImpl.getHandle(), nullptr);
9088

9189
ASSERT_TRUE(EventsWaitVisited);
9290
ASSERT_EQ(EventsInWaitList.size(), 1u);
93-
EXPECT_EQ(EventsInWaitList[0], EventAImpl->getHandle());
91+
EXPECT_EQ(EventsInWaitList[0], EventAImpl.getHandle());
9492

9593
ASSERT_TRUE(BarrierEventsWaitVisited);
9694
ASSERT_EQ(BarrierEventsInWaitList.size(), 0u);
@@ -118,13 +116,11 @@ TEST_F(SchedulerTest, BarrierWaitListWithDependsOn) {
118116
QueueA.submit([&](sycl::handler &h) { h.ext_oneapi_barrier(); });
119117
auto EventA2 =
120118
QueueA.submit([&](sycl::handler &h) { h.ext_oneapi_barrier(); });
121-
std::shared_ptr<detail::event_impl> EventAImpl =
122-
detail::getSyclObjImpl(EventA);
123-
std::shared_ptr<detail::event_impl> EventA2Impl =
124-
detail::getSyclObjImpl(EventA2);
119+
detail::event_impl &EventAImpl = *detail::getSyclObjImpl(EventA);
120+
detail::event_impl &EventA2Impl = *detail::getSyclObjImpl(EventA2);
125121
// it means that command is enqueued
126-
ASSERT_NE(EventAImpl->getHandle(), nullptr);
127-
ASSERT_NE(EventA2Impl->getHandle(), nullptr);
122+
ASSERT_NE(EventAImpl.getHandle(), nullptr);
123+
ASSERT_NE(EventA2Impl.getHandle(), nullptr);
128124

129125
ASSERT_FALSE(EventsWaitVisited);
130126
ASSERT_TRUE(BarrierEventsWaitVisited);
@@ -135,16 +131,15 @@ TEST_F(SchedulerTest, BarrierWaitListWithDependsOn) {
135131
h.depends_on(EventA);
136132
h.ext_oneapi_barrier({EventA2});
137133
});
138-
std::shared_ptr<detail::event_impl> EventBImpl =
139-
detail::getSyclObjImpl(EventB);
134+
detail::event_impl &EventBImpl = *detail::getSyclObjImpl(EventB);
140135
// it means that command is enqueued
141-
ASSERT_NE(EventBImpl->getHandle(), nullptr);
136+
ASSERT_NE(EventBImpl.getHandle(), nullptr);
142137

143138
ASSERT_FALSE(EventsWaitVisited);
144139
ASSERT_TRUE(BarrierEventsWaitVisited);
145140
ASSERT_EQ(BarrierEventsInWaitList.size(), 2u);
146-
EXPECT_EQ(BarrierEventsInWaitList[0], EventA2Impl->getHandle());
147-
EXPECT_EQ(BarrierEventsInWaitList[1], EventAImpl->getHandle());
141+
EXPECT_EQ(BarrierEventsInWaitList[0], EventA2Impl.getHandle());
142+
EXPECT_EQ(BarrierEventsInWaitList[1], EventAImpl.getHandle());
148143

149144
QueueA.wait();
150145
QueueB.wait();

sycl/unittests/scheduler/CommandsWaitForEvents.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ struct TestCtx {
3434
bool EventCtx2WasWaited = false;
3535

3636
TestCtx(queue &Queue1, queue &Queue2)
37-
: Q1(Queue1), Q2(Queue2),
38-
Ctx1(*detail::getSyclObjImpl(Q1.get_context()).get()),
39-
Ctx2(*detail::getSyclObjImpl(Q2.get_context()).get()) {
37+
: Q1(Queue1), Q2(Queue2), Ctx1(*detail::getSyclObjImpl(Q1.get_context())),
38+
Ctx2(*detail::getSyclObjImpl(Q2.get_context())) {
4039

4140
EventCtx1 = mock::createDummyHandle<ur_event_handle_t>();
4241
EventCtx2 = mock::createDummyHandle<ur_event_handle_t>();
@@ -152,12 +151,10 @@ TEST_F(SchedulerTest, StreamAUXCmdsWait) {
152151
ASSERT_TRUE(QueueImpl.MStreamsServiceEvents.size() == 1)
153152
<< "Expected 1 service stream event";
154153

155-
std::shared_ptr<sycl::detail::event_impl> EventImpl =
156-
detail::getSyclObjImpl(Event);
154+
auto &EventImplProxy =
155+
static_cast<EventImplProxyT &>(*detail::getSyclObjImpl(Event));
157156

158-
auto EventImplProxy = std::static_pointer_cast<EventImplProxyT>(EventImpl);
159-
160-
ASSERT_EQ(EventImplProxy->MWeakPostCompleteEvents.size(), 1u)
157+
ASSERT_EQ(EventImplProxy.MWeakPostCompleteEvents.size(), 1u)
161158
<< "Expected 1 post complete event";
162159

163160
Q.wait();

sycl/unittests/scheduler/EnqueueWithDependsOnDeps.cpp

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
namespace {
2121
using namespace sycl;
2222
using EventImplPtr = std::shared_ptr<detail::event_impl>;
23+
using sycl::detail::getSyclObjImpl;
2324

2425
constexpr auto DisableCleanupName = "SYCL_DISABLE_EXECUTION_GRAPH_CLEANUP";
2526

@@ -45,7 +46,7 @@ class DependsOnTests : public ::testing::Test {
4546
GTEST_SKIP();
4647

4748
queue QueueDev(context(Plt), default_selector_v);
48-
QueueDevImpl = detail::getSyclObjImpl(QueueDev);
49+
QueueDevImpl = getSyclObjImpl(QueueDev);
4950
}
5051

5152
void TearDown() {}
@@ -326,25 +327,23 @@ TEST_F(DependsOnTests, ShortcutFunctionWithWaitList) {
326327
// Mock up an incomplete host task
327328
auto HostTaskEvent =
328329
Queue.submit([&](sycl::handler &cgh) { cgh.host_task([=]() {}); });
329-
std::shared_ptr<detail::event_impl> HostTaskEventImpl =
330-
detail::getSyclObjImpl(HostTaskEvent);
330+
detail::event_impl &HostTaskEventImpl = *getSyclObjImpl(HostTaskEvent);
331331
HostTaskEvent.wait();
332-
auto *Cmd = static_cast<detail::Command *>(HostTaskEventImpl->getCommand());
332+
auto *Cmd = static_cast<detail::Command *>(HostTaskEventImpl.getCommand());
333333
ASSERT_NE(Cmd, nullptr);
334334
Cmd->MIsBlockable = true;
335335
Cmd->MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked;
336-
HostTaskEventImpl->setStateIncomplete();
336+
HostTaskEventImpl.setStateIncomplete();
337337

338338
auto SingleTaskEvent = Queue.submit([&](sycl::handler &cgh) {
339339
cgh.depends_on(HostTaskEvent);
340340
cgh.single_task<TestKernel<>>([] {});
341341
});
342-
std::shared_ptr<detail::event_impl> SingleTaskEventImpl =
343-
detail::getSyclObjImpl(SingleTaskEvent);
344-
EXPECT_EQ(SingleTaskEventImpl->getHandle(), nullptr);
342+
detail::event_impl &SingleTaskEventImpl = *getSyclObjImpl(SingleTaskEvent);
343+
EXPECT_EQ(SingleTaskEventImpl.getHandle(), nullptr);
345344

346345
// make HostTaskEvent completed, so SingleTaskEvent can be enqueued
347-
HostTaskEventImpl->setComplete();
346+
HostTaskEventImpl.setComplete();
348347
Cmd->MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueSuccess;
349348
EventsInWaitList.clear();
350349

@@ -356,9 +355,9 @@ TEST_F(DependsOnTests, ShortcutFunctionWithWaitList) {
356355
QueueDevImpl->get_context());
357356
auto ShortcutFuncEvent = Queue.memcpy(
358357
SecondBuf, FirstBuf, sizeof(int) * ArraySize, {SingleTaskEvent});
359-
EXPECT_NE(SingleTaskEventImpl->getHandle(), nullptr);
358+
EXPECT_NE(SingleTaskEventImpl.getHandle(), nullptr);
360359
ASSERT_EQ(EventsInWaitList.size(), 1u);
361-
EXPECT_EQ(EventsInWaitList[0], SingleTaskEventImpl->getHandle());
360+
EXPECT_EQ(EventsInWaitList[0], SingleTaskEventImpl.getHandle());
362361
Queue.wait();
363362
sycl::free(FirstBuf, Queue);
364363
sycl::free(SecondBuf, Queue);
@@ -372,31 +371,29 @@ TEST_F(DependsOnTests, BarrierWithWaitList) {
372371

373372
auto HostTaskEvent =
374373
Queue.submit([&](sycl::handler &cgh) { cgh.host_task([=]() {}); });
375-
std::shared_ptr<detail::event_impl> HostTaskEventImpl =
376-
detail::getSyclObjImpl(HostTaskEvent);
374+
detail::event_impl &HostTaskEventImpl = *getSyclObjImpl(HostTaskEvent);
377375
HostTaskEvent.wait();
378-
auto *Cmd = static_cast<detail::Command *>(HostTaskEventImpl->getCommand());
376+
auto *Cmd = static_cast<detail::Command *>(HostTaskEventImpl.getCommand());
379377
ASSERT_NE(Cmd, nullptr);
380378
Cmd->MIsBlockable = true;
381379
Cmd->MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueBlocked;
382-
HostTaskEventImpl->setStateIncomplete();
380+
HostTaskEventImpl.setStateIncomplete();
383381

384382
auto SingleTaskEvent = Queue.submit([&](sycl::handler &cgh) {
385383
cgh.depends_on(HostTaskEvent);
386384
cgh.single_task<TestKernel<>>([] {});
387385
});
388-
std::shared_ptr<detail::event_impl> SingleTaskEventImpl =
389-
detail::getSyclObjImpl(SingleTaskEvent);
390-
EXPECT_EQ(SingleTaskEventImpl->getHandle(), nullptr);
386+
detail::event_impl &SingleTaskEventImpl = *getSyclObjImpl(SingleTaskEvent);
387+
EXPECT_EQ(SingleTaskEventImpl.getHandle(), nullptr);
391388

392-
HostTaskEventImpl->setComplete();
389+
HostTaskEventImpl.setComplete();
393390
Cmd->MEnqueueStatus = detail::EnqueueResultT::SyclEnqueueSuccess;
394391
EventsInWaitList.clear();
395392

396393
Queue.ext_oneapi_submit_barrier(std::vector<sycl::event>{SingleTaskEvent});
397-
EXPECT_NE(SingleTaskEventImpl->getHandle(), nullptr);
394+
EXPECT_NE(SingleTaskEventImpl.getHandle(), nullptr);
398395
ASSERT_EQ(EventsInWaitList.size(), 1u);
399-
EXPECT_EQ(EventsInWaitList[0], SingleTaskEventImpl->getHandle());
396+
EXPECT_EQ(EventsInWaitList[0], SingleTaskEventImpl.getHandle());
400397
Queue.wait();
401398
}
402399
} // anonymous namespace

sycl/unittests/scheduler/GraphCleanup.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ static void checkCleanupOnEnqueue(MockScheduler &MS,
159159
MS.addCopyBack(&MockReq);
160160
verifyCleanup(Record, AllocaCmd, MockCmd, CommandDeleted);
161161

162-
MS.removeRecordForMemObj(detail::getSyclObjImpl(Buf).get());
162+
MS.removeRecordForMemObj(&*detail::getSyclObjImpl(Buf));
163163
}
164164

165165
static void checkCleanupOnLeafUpdate(
@@ -191,7 +191,7 @@ static void checkCleanupOnLeafUpdate(
191191
EXPECT_FALSE(CommandDeleted);
192192
SchedulerCall(Record);
193193
EXPECT_TRUE(CommandDeleted);
194-
MS.removeRecordForMemObj(detail::getSyclObjImpl(Buf).get());
194+
MS.removeRecordForMemObj(&*detail::getSyclObjImpl(Buf));
195195
}
196196

197197
TEST_F(SchedulerTest, PostEnqueueCleanup) {
@@ -214,10 +214,9 @@ TEST_F(SchedulerTest, PostEnqueueCleanup) {
214214
MockScheduler MS;
215215

216216
buffer<int, 1> Buf{range<1>(1)};
217-
std::shared_ptr<detail::buffer_impl> BufImpl = detail::getSyclObjImpl(Buf);
218217
detail::Requirement MockReq = getMockRequirement(Buf);
219218
MockReq.MDims = 1;
220-
MockReq.MSYCLMemObj = BufImpl.get();
219+
MockReq.MSYCLMemObj = &*detail::getSyclObjImpl(Buf);
221220

222221
checkCleanupOnEnqueue(MS, QueueImpl, Buf, MockReq);
223222
std::vector<detail::Command *> ToEnqueue;
@@ -279,11 +278,11 @@ TEST_F(SchedulerTest, HostTaskCleanup) {
279278
event Event = Queue.submit([&](sycl::handler &cgh) {
280279
cgh.host_task([&]() { std::unique_lock<std::mutex> Lock{Mutex}; });
281280
});
282-
detail::EventImplPtr EventImpl = detail::getSyclObjImpl(Event);
281+
detail::event_impl &EventImpl = *detail::getSyclObjImpl(Event);
283282

284283
// Unlike other commands, host task should be kept alive until its
285284
// completion.
286-
auto *Cmd = static_cast<detail::Command *>(EventImpl->getCommand());
285+
auto *Cmd = static_cast<detail::Command *>(EventImpl.getCommand());
287286
ASSERT_NE(Cmd, nullptr);
288287
EXPECT_TRUE(Cmd->isSuccessfullyEnqueued());
289288

@@ -293,7 +292,7 @@ TEST_F(SchedulerTest, HostTaskCleanup) {
293292
// submitted to the thread pool, shortly after the event is marked
294293
// as complete.
295294
detail::GlobalHandler::instance().drainThreadPool();
296-
ASSERT_EQ(EventImpl->getCommand(), nullptr);
295+
ASSERT_EQ(EventImpl.getCommand(), nullptr);
297296
}
298297

299298
struct AttachSchedulerWrapper {

0 commit comments

Comments
 (0)