Skip to content

Commit 4b7e279

Browse files
[NFC][SYCL][Graph] Update some maps to use raw node_impl *
Continuation of the refactoring in #19295 #19332
1 parent d2ff9fe commit 4b7e279

File tree

5 files changed

+65
-69
lines changed

5 files changed

+65
-69
lines changed

sycl/source/detail/async_alloc.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
4747
// If this is being recorded from an in-order queue we need to get the last
4848
// in-order node if any, since this will later become a dependency of the
4949
// node being processed here.
50-
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue);
50+
if (detail::node_impl *LastInOrderNode = Graph->getLastInorderNode(Queue);
5151
LastInOrderNode) {
52-
DepNodes.push_back(LastInOrderNode);
52+
DepNodes.push_back(LastInOrderNode->shared_from_this());
5353
}
5454
return DepNodes;
5555
}

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ void exec_graph_impl::makePartitions() {
255255
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
256256
for (auto &Node : MNodeStorage) {
257257
if (Node->MPartitionNum == i) {
258-
MPartitionNodes[Node] = PartitionFinalNum;
258+
MPartitionNodes[Node.get()] = PartitionFinalNum;
259259
if (isPartitionRoot(Node)) {
260260
Partition->MRoots.insert(Node);
261261
if (Node->MCGType == CGType::CodeplayHostTask) {
@@ -290,8 +290,7 @@ void exec_graph_impl::makePartitions() {
290290
for (auto const &Root : Partition->MRoots) {
291291
auto RootNode = Root.lock();
292292
for (node_impl &NodeDep : RootNode->predecessors()) {
293-
auto &Predecessor =
294-
MPartitions[MPartitionNodes[NodeDep.shared_from_this()]];
293+
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
295294
Partition->MPredecessors.push_back(Predecessor.get());
296295
Predecessor->MSuccessors.push_back(Partition.get());
297296
}
@@ -610,8 +609,7 @@ bool graph_impl::checkForCycles() {
610609
return CycleFound;
611610
}
612611

613-
std::shared_ptr<node_impl>
614-
graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
612+
node_impl *graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
615613
if (!Queue) {
616614
assert(0 ==
617615
MInorderQueueMap.count(std::weak_ptr<sycl::detail::queue_impl>{}));
@@ -625,7 +623,7 @@ graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
625623

626624
void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
627625
std::shared_ptr<node_impl> Node) {
628-
MInorderQueueMap[Queue.weak_from_this()] = std::move(Node);
626+
MInorderQueueMap[Queue.weak_from_this()] = &*Node;
629627
}
630628

631629
void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
@@ -726,11 +724,10 @@ void exec_graph_impl::findRealDeps(
726724
findRealDeps(Deps, NodeImpl, ReferencePartitionNum);
727725
}
728726
} else {
729-
auto CurrentNodePtr = CurrentNode.shared_from_this();
730727
// Verify if CurrentNode belong the the same partition
731-
if (MPartitionNodes[CurrentNodePtr] == ReferencePartitionNum) {
728+
if (MPartitionNodes[&CurrentNode] == ReferencePartitionNum) {
732729
// Verify that the sync point has actually been set for this node.
733-
auto SyncPoint = MSyncPoints.find(CurrentNodePtr);
730+
auto SyncPoint = MSyncPoints.find(&CurrentNode);
734731
assert(SyncPoint != MSyncPoints.end() &&
735732
"No sync point has been set for node dependency.");
736733
// Check if the dependency has already been added.
@@ -749,7 +746,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
749746
std::shared_ptr<node_impl> Node) {
750747
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
751748
for (node_impl &N : Node->predecessors()) {
752-
findRealDeps(Deps, N, MPartitionNodes[Node]);
749+
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
753750
}
754751
ur_exp_command_buffer_sync_point_t NewSyncPoint;
755752
ur_exp_command_buffer_command_handle_t NewCommand = 0;
@@ -782,7 +779,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
782779
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr);
783780

784781
if (MIsUpdatable) {
785-
MCommandMap[Node] = NewCommand;
782+
MCommandMap[Node.get()] = NewCommand;
786783
}
787784

788785
if (Res != UR_RESULT_SUCCESS) {
@@ -805,7 +802,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
805802

806803
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
807804
for (node_impl &N : Node->predecessors()) {
808-
findRealDeps(Deps, N, MPartitionNodes[Node]);
805+
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
809806
}
810807

811808
sycl::detail::EventImplPtr Event =
@@ -814,7 +811,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
814811
/*EventNeeded=*/true, CommandBuffer, Deps);
815812

816813
if (MIsUpdatable) {
817-
MCommandMap[Node] = Event->getCommandBufferCommand();
814+
MCommandMap[Node.get()] = Event->getCommandBufferCommand();
818815
}
819816

820817
return Event->getSyncPoint();
@@ -830,7 +827,8 @@ void exec_graph_impl::buildRequirements() {
830827
Node->MCommandGroup->getRequirements().begin(),
831828
Node->MCommandGroup->getRequirements().end());
832829

833-
std::shared_ptr<partition> &Partition = MPartitions[MPartitionNodes[Node]];
830+
std::shared_ptr<partition> &Partition =
831+
MPartitions[MPartitionNodes[Node.get()]];
834832

835833
Partition->MRequirements.insert(
836834
Partition->MRequirements.end(),
@@ -877,10 +875,10 @@ void exec_graph_impl::createCommandBuffers(
877875
Node->MCommandGroup.get())
878876
->MStreams.size() ==
879877
0) {
880-
MSyncPoints[Node] =
878+
MSyncPoints[Node.get()] =
881879
enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node);
882880
} else {
883-
MSyncPoints[Node] = enqueueNode(OutCommandBuffer, Node);
881+
MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node);
884882
}
885883
}
886884

@@ -1726,7 +1724,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
17261724
auto ExecNode = MIDCache.find(Node->MID);
17271725
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
17281726

1729-
auto Command = MCommandMap.find(ExecNode->second);
1727+
auto Command = MCommandMap.find(ExecNode->second.get());
17301728
assert(Command != MCommandMap.end());
17311729
UpdateDesc.hCommand = Command->second;
17321730

@@ -1756,7 +1754,7 @@ exec_graph_impl::getURUpdatableNodes(
17561754

17571755
auto ExecNode = MIDCache.find(Node->MID);
17581756
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1759-
auto PartitionIndex = MPartitionNodes.find(ExecNode->second);
1757+
auto PartitionIndex = MPartitionNodes.find(ExecNode->second.get());
17601758
assert(PartitionIndex != MPartitionNodes.end());
17611759
PartitionedNodes[PartitionIndex->second].push_back(Node);
17621760
}

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
294294
/// @param Queue In-order queue to find the last node added to the graph from.
295295
/// @return Last node in this graph added from \p Queue recording, or empty
296296
/// shared pointer if none.
297-
std::shared_ptr<node_impl>
298-
getLastInorderNode(sycl::detail::queue_impl *Queue);
297+
node_impl *getLastInorderNode(sycl::detail::queue_impl *Queue);
299298

300299
/// Track the last node added to this graph from an in-order queue.
301300
/// @param Queue In-order queue to register \p Node for.
@@ -466,14 +465,13 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
466465
/// @param[in] BarrierNodeImpl The created barrier node.
467466
void setBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue,
468467
std::shared_ptr<node_impl> BarrierNodeImpl) {
469-
MBarrierDependencyMap[Queue] = BarrierNodeImpl;
468+
MBarrierDependencyMap[Queue] = &*BarrierNodeImpl;
470469
}
471470

472471
/// Get the last barrier node that was submitted to the queue.
473472
/// @param[in] Queue The queue to find the last barrier node of. An empty
474473
/// shared_ptr is returned if no barrier node has been recorded to the queue.
475-
std::shared_ptr<node_impl>
476-
getBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue) {
474+
node_impl *getBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue) {
477475
return MBarrierDependencyMap[Queue];
478476
}
479477

@@ -553,7 +551,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
553551
/// Map for every in-order queue thats recorded a node to the graph, what
554552
/// the last node added was. We can use this to create new edges on the last
555553
/// node if any more nodes are added to the graph from the queue.
556-
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
554+
std::map<std::weak_ptr<sycl::detail::queue_impl>, node_impl *,
557555
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
558556
MInorderQueueMap;
559557
/// Controls whether we skip the cycle checks in makeEdge, set by the presence
@@ -568,7 +566,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
568566

569567
/// Mapping from queues to barrier nodes. For each queue the last barrier
570568
/// node recorded to the graph from the queue is stored.
571-
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
569+
std::map<std::weak_ptr<sycl::detail::queue_impl>, node_impl *,
572570
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
573571
MBarrierDependencyMap;
574572
/// Graph memory pool for handling graph-owned memory allocations for this
@@ -886,14 +884,13 @@ class exec_graph_impl {
886884
std::shared_ptr<graph_impl> MGraphImpl;
887885
/// Map of nodes in the exec graph to the sync point representing their
888886
/// execution in the command graph.
889-
std::unordered_map<std::shared_ptr<node_impl>,
890-
ur_exp_command_buffer_sync_point_t>
887+
std::unordered_map<node_impl *, ur_exp_command_buffer_sync_point_t>
891888
MSyncPoints;
892889
/// Sycl queue impl ptr associated with this graph.
893890
std::shared_ptr<sycl::detail::queue_impl> MQueueImpl;
894891
/// Map of nodes in the exec graph to the partition number to which they
895892
/// belong.
896-
std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
893+
std::unordered_map<node_impl *, int> MPartitionNodes;
897894
/// Device associated with this executable graph.
898895
sycl::device MDevice;
899896
/// Context associated with this executable graph.
@@ -909,8 +906,7 @@ class exec_graph_impl {
909906
/// Storage for copies of nodes from the original modifiable graph.
910907
std::vector<std::shared_ptr<node_impl>> MNodeStorage;
911908
/// Map of nodes to their associated UR command handles.
912-
std::unordered_map<std::shared_ptr<node_impl>,
913-
ur_exp_command_buffer_command_handle_t>
909+
std::unordered_map<node_impl *, ur_exp_command_buffer_command_handle_t>
914910
MCommandMap;
915911
/// List of partition without any predecessors in this exec graph.
916912
std::vector<std::weak_ptr<partition>> MRootPartitions;

sycl/source/handler.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -888,8 +888,9 @@ event handler::finalize() {
888888
// node can set it as a predecessor.
889889
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
890890
Deps;
891-
if (auto DependentNode = GraphImpl->getLastInorderNode(Queue)) {
892-
Deps.push_back(std::move(DependentNode));
891+
if (ext::oneapi::experimental::detail::node_impl *DependentNode =
892+
GraphImpl->getLastInorderNode(Queue)) {
893+
Deps.push_back(DependentNode->shared_from_this());
893894
}
894895
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);
895896

@@ -898,13 +899,14 @@ event handler::finalize() {
898899
// queue.
899900
GraphImpl->setLastInorderNode(*Queue, NodeImpl);
900901
} else {
901-
auto LastBarrierRecordedFromQueue =
902-
GraphImpl->getBarrierDep(Queue->weak_from_this());
902+
ext::oneapi::experimental::detail::node_impl
903+
*LastBarrierRecordedFromQueue =
904+
GraphImpl->getBarrierDep(Queue->weak_from_this());
903905
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
904906
Deps;
905907

906908
if (LastBarrierRecordedFromQueue) {
907-
Deps.push_back(LastBarrierRecordedFromQueue);
909+
Deps.push_back(LastBarrierRecordedFromQueue->shared_from_this());
908910
}
909911
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);
910912

0 commit comments

Comments
 (0)