Skip to content

Commit 6bae737

Browse files
authored
[SYCL] optimize wait() for in-order queue (#18656)
This shortens the critical section for in-order and out-of-order cases and avoids taking the lock entirely if MNoLastEventMode is set.
1 parent aaab7e9 commit 6bae737

File tree

1 file changed

+46
-37
lines changed

1 file changed

+46
-37
lines changed

sycl/source/detail/queue_impl.cpp

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ queue_impl::submit_impl(const detail::type_erased_cgfo_ty &CGF,
394394
CGF, SecondaryQueue, /*CallerNeedsEvent*/ true, Loc, IsTopCodeLoc, {});
395395
if (EventImpl)
396396
EventImpl->attachEventToCompleteWeak(FlushEvent);
397-
registerStreamServiceEvent(FlushEvent);
397+
if (!isInOrder()) {
398+
// For in-order queue, the dependencies will be tracked by LastEvent
399+
registerStreamServiceEvent(FlushEvent);
400+
}
398401
}
399402

400403
return EventImpl;
@@ -615,51 +618,57 @@ void queue_impl::wait(const detail::code_location &CodeLoc) {
615618
}
616619
}
617620

618-
std::vector<std::weak_ptr<event_impl>> WeakEvents;
619-
EventImplPtr LastEvent;
620-
{
621-
std::lock_guard<std::mutex> Lock(MMutex);
622-
WeakEvents.swap(MEventsWeak);
623-
LastEvent = MDefaultGraphDeps.LastEventPtr;
621+
if (isInOrder() && !MNoLastEventMode.load(std::memory_order_relaxed)) {
622+
// if MLastEvent is not null, we need to wait for it
623+
EventImplPtr LastEvent;
624+
{
625+
std::lock_guard<std::mutex> Lock(MMutex);
626+
LastEvent = MDefaultGraphDeps.LastEventPtr;
627+
}
628+
if (LastEvent) {
629+
LastEvent->wait(LastEvent);
630+
}
631+
} else if (!isInOrder()) {
632+
std::vector<std::weak_ptr<event_impl>> WeakEvents;
633+
{
634+
std::lock_guard<std::mutex> Lock(MMutex);
635+
WeakEvents.swap(MEventsWeak);
636+
MMissedCleanupRequests.unset(
637+
[&](MissedCleanupRequestsType &MissedCleanupRequests) {
638+
for (auto &UpdatedGraph : MissedCleanupRequests)
639+
doUnenqueuedCommandCleanup(UpdatedGraph);
640+
MissedCleanupRequests.clear();
641+
});
642+
}
624643

625-
MMissedCleanupRequests.unset(
626-
[&](MissedCleanupRequestsType &MissedCleanupRequests) {
627-
for (auto &UpdatedGraph : MissedCleanupRequests)
628-
doUnenqueuedCommandCleanup(UpdatedGraph);
629-
MissedCleanupRequests.clear();
630-
});
631-
}
632-
// If the queue is either a host one or does not support OOO (and we use
633-
// multiple in-order queues as a result of that), wait for each event
634-
// directly. Otherwise, only wait for unenqueued or host task events, starting
635-
// from the latest submitted task in order to minimize total amount of calls,
636-
// then handle the rest with urQueueFinish.
637-
for (auto EventImplWeakPtrIt = WeakEvents.rbegin();
638-
EventImplWeakPtrIt != WeakEvents.rend(); ++EventImplWeakPtrIt) {
639-
if (std::shared_ptr<event_impl> EventImplSharedPtr =
640-
EventImplWeakPtrIt->lock()) {
641-
// A nullptr UR event indicates that urQueueFinish will not cover it,
642-
// either because it's a host task event or an unenqueued one.
643-
if (nullptr == EventImplSharedPtr->getHandle()) {
644-
EventImplSharedPtr->wait(EventImplSharedPtr);
644+
// Wait for unenqueued or host task events, starting
645+
// from the latest submitted task in order to minimize total amount of
646+
// calls, then handle the rest with urQueueFinish.
647+
for (auto EventImplWeakPtrIt = WeakEvents.rbegin();
648+
EventImplWeakPtrIt != WeakEvents.rend(); ++EventImplWeakPtrIt) {
649+
if (std::shared_ptr<event_impl> EventImplSharedPtr =
650+
EventImplWeakPtrIt->lock()) {
651+
// A nullptr UR event indicates that urQueueFinish will not cover it,
652+
// either because it's a host task event or an unenqueued one.
653+
if (nullptr == EventImplSharedPtr->getHandle()) {
654+
EventImplSharedPtr->wait(EventImplSharedPtr);
655+
}
645656
}
646657
}
647658
}
648659

649-
if (LastEvent) {
650-
LastEvent->wait(LastEvent);
651-
}
652-
653660
const AdapterPtr &Adapter = getAdapter();
654661
Adapter->call<UrApiKind::urQueueFinish>(getHandleRef());
655662

656-
std::vector<EventImplPtr> StreamsServiceEvents;
657-
{
658-
std::lock_guard<std::mutex> Lock(MStreamsServiceEventsMutex);
659-
StreamsServiceEvents.swap(MStreamsServiceEvents);
663+
if (!isInOrder()) {
664+
std::vector<EventImplPtr> StreamsServiceEvents;
665+
{
666+
std::lock_guard<std::mutex> Lock(MStreamsServiceEventsMutex);
667+
StreamsServiceEvents.swap(MStreamsServiceEvents);
668+
}
669+
for (const EventImplPtr &Event : StreamsServiceEvents)
670+
Event->wait(Event);
660671
}
661-
for (const EventImplPtr &Event : StreamsServiceEvents)
662-
Event->wait(Event);
663672

664673
#ifdef XPTI_ENABLE_INSTRUMENTATION
665674
if (xptiEnabled) {

0 commit comments

Comments
 (0)