Skip to content

Commit 028a3e0

Browse files
[NFC][SYCL][Graph] Use raw node_impl ptr/ref in node<->event mapping
Continuation of #19295 #19332 #19334 #19350 #19352
1 parent 19d83d5 commit 028a3e0

File tree

5 files changed

+67
-72
lines changed

5 files changed

+67
-72
lines changed

sycl/source/detail/async_alloc.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ getUrEvents(const std::vector<std::shared_ptr<detail::event_impl>> &DepEvents) {
3333
return RetUrEvents;
3434
}
3535

36-
std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
36+
std::vector<detail::node_impl *> getDepGraphNodes(
3737
sycl::handler &Handler, detail::queue_impl *Queue,
3838
const std::shared_ptr<detail::graph_impl> &Graph,
3939
const std::vector<std::shared_ptr<detail::event_impl>> &DepEvents) {
@@ -42,14 +42,14 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
4242
auto DepNodes = Graph->getNodesForEvents(DepEvents);
4343
// If this node was added explicitly we may have node deps in the handler as
4444
// well, so add them to the list
45-
DepNodes.insert(DepNodes.end(), HandlerImpl.MNodeDeps.begin(),
46-
HandlerImpl.MNodeDeps.end());
45+
for (auto &N : HandlerImpl.MNodeDeps)
46+
DepNodes.push_back(N.get());
4747
// If this is being recorded from an in-order queue we need to get the last
4848
// in-order node if any, since this will later become a dependency of the
4949
// node being processed here.
5050
if (detail::node_impl *LastInOrderNode = Graph->getLastInorderNode(Queue);
5151
LastInOrderNode) {
52-
DepNodes.push_back(LastInOrderNode->shared_from_this());
52+
DepNodes.push_back(LastInOrderNode);
5353
}
5454
return DepNodes;
5555
}

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
369369
"Event dependency from handler::depends_on does "
370370
"not correspond to a node within the graph");
371371
} else {
372-
UniqueDeps.insert(NodeImpl->second);
372+
UniqueDeps.insert(NodeImpl->second->shared_from_this());
373373
}
374374
}
375375

@@ -417,7 +417,7 @@ std::shared_ptr<node_impl> graph_impl::add(nodes_range Deps) {
417417
addDepsToNode(NodeImpl, Deps);
418418
// Add an event associated with this explicit node for mixed usage
419419
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
420-
NodeImpl);
420+
*NodeImpl);
421421
return NodeImpl;
422422
}
423423

@@ -476,7 +476,7 @@ graph_impl::add(std::function<void(handler &)> CGF,
476476

477477
// Add an event associated with this explicit node for mixed usage
478478
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
479-
NodeImpl);
479+
*NodeImpl);
480480

481481
// Retrieve any dynamic parameters which have been registered in the CGF and
482482
// register the actual nodes with them.
@@ -556,7 +556,7 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
556556

557557
// Add an event associated with this explicit node for mixed usage
558558
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
559-
NodeImpl);
559+
*NodeImpl);
560560

561561
// Track the dynamic command-group used inside the node object
562562
DynCGImpl->MNodes.push_back(NodeImpl);
@@ -689,9 +689,9 @@ std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
689689
auto RecordedQueueSP = RecordedQueue.lock();
690690
for (auto &Node : MNodeStorage) {
691691
if (Node->MSuccessors.empty()) {
692-
auto EventForNode = getEventForNode(Node);
692+
auto EventForNode = getEventForNode(*Node);
693693
if (EventForNode->getSubmittedQueue() == RecordedQueueSP) {
694-
Events.push_back(getEventForNode(Node));
694+
Events.push_back(getEventForNode(*Node));
695695
}
696696
}
697697
}

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -192,21 +192,21 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
192192
/// @param EventImpl Event to associate with a node in map.
193193
/// @param NodeImpl Node to associate with event in map.
194194
void addEventForNode(std::shared_ptr<sycl::detail::event_impl> EventImpl,
195-
const std::shared_ptr<node_impl> &NodeImpl) {
195+
node_impl &NodeImpl) {
196196
if (!(EventImpl->hasCommandGraph()))
197197
EventImpl->setCommandGraph(shared_from_this());
198-
MEventsMap[EventImpl] = NodeImpl;
198+
MEventsMap[EventImpl] = &NodeImpl;
199199
}
200200

201201
/// Find the sycl event associated with a node.
202202
/// @param NodeImpl Node to find event for.
203203
/// @return Event associated with node.
204204
std::shared_ptr<sycl::detail::event_impl>
205-
getEventForNode(std::shared_ptr<node_impl> NodeImpl) const {
205+
getEventForNode(node_impl &NodeImpl) const {
206206
ReadLock Lock(MMutex);
207207
if (auto EventImpl = std::find_if(
208208
MEventsMap.begin(), MEventsMap.end(),
209-
[NodeImpl](auto &it) { return it.second == NodeImpl; });
209+
[&NodeImpl](auto &it) { return it.second == &NodeImpl; });
210210
EventImpl != MEventsMap.end()) {
211211
return EventImpl->first;
212212
}
@@ -220,13 +220,14 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
220220
/// the given event.
221221
/// @param EventImpl Event to find the node for.
222222
/// @return Node associated with the event.
223-
std::shared_ptr<node_impl>
223+
node_impl &
224224
getNodeForEvent(std::shared_ptr<sycl::detail::event_impl> EventImpl) {
225225
ReadLock Lock(MMutex);
226226

227227
if (auto NodeFound = MEventsMap.find(EventImpl);
228228
NodeFound != std::end(MEventsMap)) {
229-
return NodeFound->second;
229+
// TODO: Is it guaranteed to be non-null?
230+
return *NodeFound->second;
230231
}
231232

232233
throw sycl::exception(
@@ -238,9 +239,9 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
238239
/// found for a given event.
239240
/// @param Events Events to find nodes for.
240241
/// @return A list of node counterparts for each event, in the same order.
241-
std::vector<std::shared_ptr<node_impl>> getNodesForEvents(
242+
std::vector<node_impl *> getNodesForEvents(
242243
const std::vector<std::shared_ptr<sycl::detail::event_impl>> &Events) {
243-
std::vector<std::shared_ptr<node_impl>> NodeList{};
244+
std::vector<node_impl *> NodeList{};
244245
NodeList.reserve(Events.size());
245246

246247
ReadLock Lock(MMutex);
@@ -544,8 +545,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
544545
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
545546
MRecordingQueues;
546547
/// Map of events to their associated recorded nodes.
547-
std::unordered_map<std::shared_ptr<sycl::detail::event_impl>,
548-
std::shared_ptr<node_impl>>
548+
std::unordered_map<std::shared_ptr<sycl::detail::event_impl>, node_impl *>
549549
MEventsMap;
550550
/// Map for every in-order queue thats recorded a node to the graph, what
551551
/// the last node added was. We can use this to create new edges on the last

sycl/source/handler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ event handler::finalize() {
916916
}
917917

918918
// Associate an event with this new node and return the event.
919-
GraphImpl->addEventForNode(EventImpl, std::move(NodeImpl));
919+
GraphImpl->addEventForNode(EventImpl, *NodeImpl);
920920

921921
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
922922
return EventImpl;

0 commit comments

Comments
 (0)