Skip to content

Commit 45c2e52

Browse files
[NFC][SYCL][Graph] Update some maps to use raw node_impl *
Continuation of the refactoring in #19295 #19332
1 parent d2ff9fe commit 45c2e52

File tree

5 files changed

+70
-75
lines changed

5 files changed

+70
-75
lines changed

sycl/source/detail/async_alloc.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
4747
// If this is being recorded from an in-order queue we need to get the last
4848
// in-order node if any, since this will later become a dependency of the
4949
// node being processed here.
50-
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue);
50+
if (detail::node_impl *LastInOrderNode = Graph->getLastInorderNode(Queue);
5151
LastInOrderNode) {
52-
DepNodes.push_back(LastInOrderNode);
52+
DepNodes.push_back(LastInOrderNode->shared_from_this());
5353
}
5454
return DepNodes;
5555
}

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ void exec_graph_impl::makePartitions() {
255255
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
256256
for (auto &Node : MNodeStorage) {
257257
if (Node->MPartitionNum == i) {
258-
MPartitionNodes[Node] = PartitionFinalNum;
258+
MPartitionNodes[Node.get()] = PartitionFinalNum;
259259
if (isPartitionRoot(Node)) {
260260
Partition->MRoots.insert(Node);
261261
if (Node->MCGType == CGType::CodeplayHostTask) {
@@ -290,8 +290,7 @@ void exec_graph_impl::makePartitions() {
290290
for (auto const &Root : Partition->MRoots) {
291291
auto RootNode = Root.lock();
292292
for (node_impl &NodeDep : RootNode->predecessors()) {
293-
auto &Predecessor =
294-
MPartitions[MPartitionNodes[NodeDep.shared_from_this()]];
293+
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
295294
Partition->MPredecessors.push_back(Predecessor.get());
296295
Predecessor->MSuccessors.push_back(Partition.get());
297296
}
@@ -610,8 +609,7 @@ bool graph_impl::checkForCycles() {
610609
return CycleFound;
611610
}
612611

613-
std::shared_ptr<node_impl>
614-
graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
612+
node_impl *graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
615613
if (!Queue) {
616614
assert(0 ==
617615
MInorderQueueMap.count(std::weak_ptr<sycl::detail::queue_impl>{}));
@@ -624,8 +622,8 @@ graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
624622
}
625623

626624
void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
627-
std::shared_ptr<node_impl> Node) {
628-
MInorderQueueMap[Queue.weak_from_this()] = std::move(Node);
625+
node_impl &Node) {
626+
MInorderQueueMap[Queue.weak_from_this()] = &Node;
629627
}
630628

631629
void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
@@ -726,11 +724,10 @@ void exec_graph_impl::findRealDeps(
726724
findRealDeps(Deps, NodeImpl, ReferencePartitionNum);
727725
}
728726
} else {
729-
auto CurrentNodePtr = CurrentNode.shared_from_this();
730727
// Verify if CurrentNode belong the the same partition
731-
if (MPartitionNodes[CurrentNodePtr] == ReferencePartitionNum) {
728+
if (MPartitionNodes[&CurrentNode] == ReferencePartitionNum) {
732729
// Verify that the sync point has actually been set for this node.
733-
auto SyncPoint = MSyncPoints.find(CurrentNodePtr);
730+
auto SyncPoint = MSyncPoints.find(&CurrentNode);
734731
assert(SyncPoint != MSyncPoints.end() &&
735732
"No sync point has been set for node dependency.");
736733
// Check if the dependency has already been added.
@@ -749,7 +746,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
749746
std::shared_ptr<node_impl> Node) {
750747
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
751748
for (node_impl &N : Node->predecessors()) {
752-
findRealDeps(Deps, N, MPartitionNodes[Node]);
749+
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
753750
}
754751
ur_exp_command_buffer_sync_point_t NewSyncPoint;
755752
ur_exp_command_buffer_command_handle_t NewCommand = 0;
@@ -782,7 +779,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
782779
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr);
783780

784781
if (MIsUpdatable) {
785-
MCommandMap[Node] = NewCommand;
782+
MCommandMap[Node.get()] = NewCommand;
786783
}
787784

788785
if (Res != UR_RESULT_SUCCESS) {
@@ -805,7 +802,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
805802

806803
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
807804
for (node_impl &N : Node->predecessors()) {
808-
findRealDeps(Deps, N, MPartitionNodes[Node]);
805+
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
809806
}
810807

811808
sycl::detail::EventImplPtr Event =
@@ -814,7 +811,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
814811
/*EventNeeded=*/true, CommandBuffer, Deps);
815812

816813
if (MIsUpdatable) {
817-
MCommandMap[Node] = Event->getCommandBufferCommand();
814+
MCommandMap[Node.get()] = Event->getCommandBufferCommand();
818815
}
819816

820817
return Event->getSyncPoint();
@@ -830,7 +827,8 @@ void exec_graph_impl::buildRequirements() {
830827
Node->MCommandGroup->getRequirements().begin(),
831828
Node->MCommandGroup->getRequirements().end());
832829

833-
std::shared_ptr<partition> &Partition = MPartitions[MPartitionNodes[Node]];
830+
std::shared_ptr<partition> &Partition =
831+
MPartitions[MPartitionNodes[Node.get()]];
834832

835833
Partition->MRequirements.insert(
836834
Partition->MRequirements.end(),
@@ -877,10 +875,10 @@ void exec_graph_impl::createCommandBuffers(
877875
Node->MCommandGroup.get())
878876
->MStreams.size() ==
879877
0) {
880-
MSyncPoints[Node] =
878+
MSyncPoints[Node.get()] =
881879
enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node);
882880
} else {
883-
MSyncPoints[Node] = enqueueNode(OutCommandBuffer, Node);
881+
MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node);
884882
}
885883
}
886884

@@ -1726,7 +1724,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
17261724
auto ExecNode = MIDCache.find(Node->MID);
17271725
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
17281726

1729-
auto Command = MCommandMap.find(ExecNode->second);
1727+
auto Command = MCommandMap.find(ExecNode->second.get());
17301728
assert(Command != MCommandMap.end());
17311729
UpdateDesc.hCommand = Command->second;
17321730

@@ -1756,7 +1754,7 @@ exec_graph_impl::getURUpdatableNodes(
17561754

17571755
auto ExecNode = MIDCache.find(Node->MID);
17581756
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1759-
auto PartitionIndex = MPartitionNodes.find(ExecNode->second);
1757+
auto PartitionIndex = MPartitionNodes.find(ExecNode->second.get());
17601758
assert(PartitionIndex != MPartitionNodes.end());
17611759
PartitionedNodes[PartitionIndex->second].push_back(Node);
17621760
}

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -294,14 +294,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
294294
/// @param Queue In-order queue to find the last node added to the graph from.
295295
/// @return Last node in this graph added from \p Queue recording, or empty
296296
/// shared pointer if none.
297-
std::shared_ptr<node_impl>
298-
getLastInorderNode(sycl::detail::queue_impl *Queue);
297+
node_impl *getLastInorderNode(sycl::detail::queue_impl *Queue);
299298

300299
/// Track the last node added to this graph from an in-order queue.
301300
/// @param Queue In-order queue to register \p Node for.
302301
/// @param Node Last node that was added to this graph from \p Queue.
303-
void setLastInorderNode(sycl::detail::queue_impl &Queue,
304-
std::shared_ptr<node_impl> Node);
302+
void setLastInorderNode(sycl::detail::queue_impl &Queue, node_impl &Node);
305303

306304
/// Prints the contents of the graph to a text file in DOT format.
307305
/// @param FilePath Path to the output file.
@@ -465,15 +463,14 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
465463
/// @param[in] Queue The queue the barrier was recorded from.
466464
/// @param[in] BarrierNodeImpl The created barrier node.
467465
void setBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue,
468-
std::shared_ptr<node_impl> BarrierNodeImpl) {
469-
MBarrierDependencyMap[Queue] = BarrierNodeImpl;
466+
node_impl &BarrierNodeImpl) {
467+
MBarrierDependencyMap[Queue] = &BarrierNodeImpl;
470468
}
471469

472470
/// Get the last barrier node that was submitted to the queue.
473471
/// @param[in] Queue The queue to find the last barrier node of. An empty
474472
/// shared_ptr is returned if no barrier node has been recorded to the queue.
475-
std::shared_ptr<node_impl>
476-
getBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue) {
473+
node_impl *getBarrierDep(std::weak_ptr<sycl::detail::queue_impl> Queue) {
477474
return MBarrierDependencyMap[Queue];
478475
}
479476

@@ -553,7 +550,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
553550
/// Map for every in-order queue thats recorded a node to the graph, what
554551
/// the last node added was. We can use this to create new edges on the last
555552
/// node if any more nodes are added to the graph from the queue.
556-
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
553+
std::map<std::weak_ptr<sycl::detail::queue_impl>, node_impl *,
557554
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
558555
MInorderQueueMap;
559556
/// Controls whether we skip the cycle checks in makeEdge, set by the presence
@@ -568,7 +565,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
568565

569566
/// Mapping from queues to barrier nodes. For each queue the last barrier
570567
/// node recorded to the graph from the queue is stored.
571-
std::map<std::weak_ptr<sycl::detail::queue_impl>, std::shared_ptr<node_impl>,
568+
std::map<std::weak_ptr<sycl::detail::queue_impl>, node_impl *,
572569
std::owner_less<std::weak_ptr<sycl::detail::queue_impl>>>
573570
MBarrierDependencyMap;
574571
/// Graph memory pool for handling graph-owned memory allocations for this
@@ -886,14 +883,13 @@ class exec_graph_impl {
886883
std::shared_ptr<graph_impl> MGraphImpl;
887884
/// Map of nodes in the exec graph to the sync point representing their
888885
/// execution in the command graph.
889-
std::unordered_map<std::shared_ptr<node_impl>,
890-
ur_exp_command_buffer_sync_point_t>
886+
std::unordered_map<node_impl *, ur_exp_command_buffer_sync_point_t>
891887
MSyncPoints;
892888
/// Sycl queue impl ptr associated with this graph.
893889
std::shared_ptr<sycl::detail::queue_impl> MQueueImpl;
894890
/// Map of nodes in the exec graph to the partition number to which they
895891
/// belong.
896-
std::unordered_map<std::shared_ptr<node_impl>, int> MPartitionNodes;
892+
std::unordered_map<node_impl *, int> MPartitionNodes;
897893
/// Device associated with this executable graph.
898894
sycl::device MDevice;
899895
/// Context associated with this executable graph.
@@ -909,8 +905,7 @@ class exec_graph_impl {
909905
/// Storage for copies of nodes from the original modifiable graph.
910906
std::vector<std::shared_ptr<node_impl>> MNodeStorage;
911907
/// Map of nodes to their associated UR command handles.
912-
std::unordered_map<std::shared_ptr<node_impl>,
913-
ur_exp_command_buffer_command_handle_t>
908+
std::unordered_map<node_impl *, ur_exp_command_buffer_command_handle_t>
914909
MCommandMap;
915910
/// List of partition without any predecessors in this exec graph.
916911
std::vector<std::weak_ptr<partition>> MRootPartitions;

sycl/source/handler.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -888,28 +888,30 @@ event handler::finalize() {
888888
// node can set it as a predecessor.
889889
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
890890
Deps;
891-
if (auto DependentNode = GraphImpl->getLastInorderNode(Queue)) {
892-
Deps.push_back(std::move(DependentNode));
891+
if (ext::oneapi::experimental::detail::node_impl *DependentNode =
892+
GraphImpl->getLastInorderNode(Queue)) {
893+
Deps.push_back(DependentNode->shared_from_this());
893894
}
894895
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);
895896

896897
// If we are recording an in-order queue remember the new node, so it
897898
// can be used as a dependency for any more nodes recorded from this
898899
// queue.
899-
GraphImpl->setLastInorderNode(*Queue, NodeImpl);
900+
GraphImpl->setLastInorderNode(*Queue, *NodeImpl);
900901
} else {
901-
auto LastBarrierRecordedFromQueue =
902-
GraphImpl->getBarrierDep(Queue->weak_from_this());
902+
ext::oneapi::experimental::detail::node_impl
903+
*LastBarrierRecordedFromQueue =
904+
GraphImpl->getBarrierDep(Queue->weak_from_this());
903905
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
904906
Deps;
905907

906908
if (LastBarrierRecordedFromQueue) {
907-
Deps.push_back(LastBarrierRecordedFromQueue);
909+
Deps.push_back(LastBarrierRecordedFromQueue->shared_from_this());
908910
}
909911
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);
910912

911913
if (NodeImpl->MCGType == sycl::detail::CGType::Barrier) {
912-
GraphImpl->setBarrierDep(Queue->weak_from_this(), NodeImpl);
914+
GraphImpl->setBarrierDep(Queue->weak_from_this(), *NodeImpl);
913915
}
914916
}
915917

0 commit comments

Comments
 (0)