Skip to content

Commit 2cad128

Browse files
[NFC][SYCL][Graph] Update some maps to use raw node_impl *
Continuation of the refactoring in intel#19295 intel#19332
1 parent f409eb7 commit 2cad128

File tree

8 files changed

+136
-142
lines changed

8 files changed

+136
-142
lines changed

sycl/source/detail/async_alloc.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
4747
// If this is being recorded from an in-order queue we need to get the last
4848
// in-order node if any, since this will later become a dependency of the
4949
// node being processed here.
50-
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue);
50+
if (detail::node_impl *LastInOrderNode = Graph->getLastInorderNode(Queue);
5151
LastInOrderNode) {
52-
DepNodes.push_back(LastInOrderNode);
52+
DepNodes.push_back(LastInOrderNode->shared_from_this());
5353
}
5454
return DepNodes;
5555
}

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 58 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,16 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
100100
Source.pop();
101101
SortedNodes.push_back(Node);
102102

103-
for (auto &SuccWP : Node->MSuccessors) {
104-
auto Succ = SuccWP.lock();
103+
for (node_impl &Succ : Node->successors()) {
105104

106-
if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum)) {
105+
if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) {
107106
continue;
108107
}
109108

110-
auto &TotalVisitedEdges = Succ->MTotalVisitedEdges;
109+
auto &TotalVisitedEdges = Succ.MTotalVisitedEdges;
111110
++TotalVisitedEdges;
112-
if (TotalVisitedEdges == Succ->MPredecessors.size()) {
113-
Source.push(Succ);
111+
if (TotalVisitedEdges == Succ.MPredecessors.size()) {
112+
Source.push(Succ.weak_from_this());
114113
}
115114
}
116115
}
@@ -127,14 +126,14 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
127126
/// a node with a smaller partition number.
128127
/// @param Node Node to assign to the partition.
129128
/// @param PartitionNum Number to propagate.
130-
void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
131-
if (((Node->MPartitionNum != -1) && (Node->MPartitionNum <= PartitionNum)) ||
132-
(Node->MCGType == sycl::detail::CGType::CodeplayHostTask)) {
129+
void propagatePartitionUp(node_impl &Node, int PartitionNum) {
130+
if (((Node.MPartitionNum != -1) && (Node.MPartitionNum <= PartitionNum)) ||
131+
(Node.MCGType == sycl::detail::CGType::CodeplayHostTask)) {
133132
return;
134133
}
135-
Node->MPartitionNum = PartitionNum;
136-
for (auto &Predecessor : Node->MPredecessors) {
137-
propagatePartitionUp(Predecessor.lock(), PartitionNum);
134+
Node.MPartitionNum = PartitionNum;
135+
for (node_impl &Predecessor : Node.predecessors()) {
136+
propagatePartitionUp(Predecessor, PartitionNum);
138137
}
139138
}
140139

@@ -146,17 +145,17 @@ void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
146145
/// @param HostTaskList List of host tasks that have already been processed and
147146
/// are encountered as successors to the node Node.
148147
void propagatePartitionDown(
149-
const std::shared_ptr<node_impl> &Node, int PartitionNum,
148+
node_impl &Node, int PartitionNum,
150149
std::list<std::shared_ptr<node_impl>> &HostTaskList) {
151-
if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) {
152-
if (Node->MPartitionNum != -1) {
153-
HostTaskList.push_front(Node);
150+
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
151+
if (Node.MPartitionNum != -1) {
152+
HostTaskList.push_front(Node.shared_from_this());
154153
}
155154
return;
156155
}
157-
Node->MPartitionNum = PartitionNum;
158-
for (auto &Successor : Node->MSuccessors) {
159-
propagatePartitionDown(Successor.lock(), PartitionNum, HostTaskList);
156+
Node.MPartitionNum = PartitionNum;
157+
for (node_impl &Successor : Node.successors()) {
158+
propagatePartitionDown(Successor, PartitionNum, HostTaskList);
160159
}
161160
}
162161

@@ -165,8 +164,8 @@ void propagatePartitionDown(
165164
/// @param Node node to test
166165
/// @return True is `Node` is a root of its partition
167166
bool isPartitionRoot(std::shared_ptr<node_impl> Node) {
168-
for (auto &Predecessor : Node->MPredecessors) {
169-
if (Predecessor.lock()->MPartitionNum == Node->MPartitionNum) {
167+
for (node_impl &Predecessor : Node->predecessors()) {
168+
if (Predecessor.MPartitionNum == Node->MPartitionNum) {
170169
return false;
171170
}
172171
}
@@ -221,15 +220,15 @@ void exec_graph_impl::makePartitions() {
221220
auto Node = HostTaskList.front();
222221
HostTaskList.pop_front();
223222
CurrentPartition++;
224-
for (auto &Predecessor : Node->MPredecessors) {
225-
propagatePartitionUp(Predecessor.lock(), CurrentPartition);
223+
for (node_impl &Predecessor : Node->predecessors()) {
224+
propagatePartitionUp(Predecessor, CurrentPartition);
226225
}
227226
CurrentPartition++;
228227
Node->MPartitionNum = CurrentPartition;
229228
CurrentPartition++;
230229
auto TmpSize = HostTaskList.size();
231-
for (auto &Successor : Node->MSuccessors) {
232-
propagatePartitionDown(Successor.lock(), CurrentPartition, HostTaskList);
230+
for (node_impl &Successor : Node->successors()) {
231+
propagatePartitionDown(Successor, CurrentPartition, HostTaskList);
233232
}
234233
if (HostTaskList.size() > TmpSize) {
235234
// At least one HostTask has been re-numbered so group merge opportunities
@@ -256,7 +255,7 @@ void exec_graph_impl::makePartitions() {
256255
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
257256
for (auto &Node : MNodeStorage) {
258257
if (Node->MPartitionNum == i) {
259-
MPartitionNodes[Node] = PartitionFinalNum;
258+
MPartitionNodes[Node.get()] = PartitionFinalNum;
260259
if (isPartitionRoot(Node)) {
261260
Partition->MRoots.insert(Node);
262261
if (Node->MCGType == CGType::CodeplayHostTask) {
@@ -290,9 +289,8 @@ void exec_graph_impl::makePartitions() {
290289
for (const auto &Partition : MPartitions) {
291290
for (auto const &Root : Partition->MRoots) {
292291
auto RootNode = Root.lock();
293-
for (const auto &Dep : RootNode->MPredecessors) {
294-
auto NodeDep = Dep.lock();
295-
auto &Predecessor = MPartitions[MPartitionNodes[NodeDep]];
292+
for (node_impl &NodeDep : RootNode->predecessors()) {
293+
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
296294
Partition->MPredecessors.push_back(Predecessor.get());
297295
Predecessor->MSuccessors.push_back(Partition.get());
298296
}
@@ -390,8 +388,8 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
390388
bool ShouldAddDep = true;
391389
// If any of this node's successors have this requirement then we skip
392390
// adding the current node as a dependency.
393-
for (auto &Succ : Node->MSuccessors) {
394-
if (Succ.lock()->hasRequirementDependency(Req)) {
391+
for (node_impl &Succ : Node->successors()) {
392+
if (Succ.hasRequirementDependency(Req)) {
395393
ShouldAddDep = false;
396394
break;
397395
}
@@ -611,8 +609,7 @@ bool graph_impl::checkForCycles() {
611609
return CycleFound;
612610
}
613611

614-
std::shared_ptr<node_impl>
615-
graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
612+
node_impl *graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
616613
if (!Queue) {
617614
assert(0 ==
618615
MInorderQueueMap.count(std::weak_ptr<sycl::detail::queue_impl>{}));
@@ -626,7 +623,7 @@ graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
626623

627624
void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
628625
std::shared_ptr<node_impl> Node) {
629-
MInorderQueueMap[Queue.weak_from_this()] = std::move(Node);
626+
MInorderQueueMap[Queue.weak_from_this()] = &*Node;
630627
}
631628

632629
void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
@@ -721,17 +718,16 @@ void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
721718
// predecessors until we find the real dependency.
722719
void exec_graph_impl::findRealDeps(
723720
std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
724-
std::shared_ptr<node_impl> CurrentNode, int ReferencePartitionNum) {
725-
if (!CurrentNode->requiresEnqueue()) {
726-
for (auto &N : CurrentNode->MPredecessors) {
727-
auto NodeImpl = N.lock();
721+
node_impl &CurrentNode, int ReferencePartitionNum) {
722+
if (!CurrentNode.requiresEnqueue()) {
723+
for (node_impl &NodeImpl : CurrentNode.predecessors()) {
728724
findRealDeps(Deps, NodeImpl, ReferencePartitionNum);
729725
}
730726
} else {
731727
// Verify if CurrentNode belong the the same partition
732-
if (MPartitionNodes[CurrentNode] == ReferencePartitionNum) {
728+
if (MPartitionNodes[&CurrentNode] == ReferencePartitionNum) {
733729
// Verify that the sync point has actually been set for this node.
734-
auto SyncPoint = MSyncPoints.find(CurrentNode);
730+
auto SyncPoint = MSyncPoints.find(&CurrentNode);
735731
assert(SyncPoint != MSyncPoints.end() &&
736732
"No sync point has been set for node dependency.");
737733
// Check if the dependency has already been added.
@@ -749,8 +745,8 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
749745
ur_exp_command_buffer_handle_t CommandBuffer,
750746
std::shared_ptr<node_impl> Node) {
751747
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
752-
for (auto &N : Node->MPredecessors) {
753-
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
748+
for (node_impl &N : Node->predecessors()) {
749+
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
754750
}
755751
ur_exp_command_buffer_sync_point_t NewSyncPoint;
756752
ur_exp_command_buffer_command_handle_t NewCommand = 0;
@@ -783,7 +779,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
783779
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr);
784780

785781
if (MIsUpdatable) {
786-
MCommandMap[Node] = NewCommand;
782+
MCommandMap[Node.get()] = NewCommand;
787783
}
788784

789785
if (Res != UR_RESULT_SUCCESS) {
@@ -805,8 +801,8 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
805801
std::shared_ptr<node_impl> Node) {
806802

807803
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
808-
for (auto &N : Node->MPredecessors) {
809-
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
804+
for (node_impl &N : Node->predecessors()) {
805+
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
810806
}
811807

812808
sycl::detail::EventImplPtr Event =
@@ -815,7 +811,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
815811
/*EventNeeded=*/true, CommandBuffer, Deps);
816812

817813
if (MIsUpdatable) {
818-
MCommandMap[Node] = Event->getCommandBufferCommand();
814+
MCommandMap[Node.get()] = Event->getCommandBufferCommand();
819815
}
820816

821817
return Event->getSyncPoint();
@@ -831,7 +827,8 @@ void exec_graph_impl::buildRequirements() {
831827
Node->MCommandGroup->getRequirements().begin(),
832828
Node->MCommandGroup->getRequirements().end());
833829

834-
std::shared_ptr<partition> &Partition = MPartitions[MPartitionNodes[Node]];
830+
std::shared_ptr<partition> &Partition =
831+
MPartitions[MPartitionNodes[Node.get()]];
835832

836833
Partition->MRequirements.insert(
837834
Partition->MRequirements.end(),
@@ -878,10 +875,10 @@ void exec_graph_impl::createCommandBuffers(
878875
Node->MCommandGroup.get())
879876
->MStreams.size() ==
880877
0) {
881-
MSyncPoints[Node] =
878+
MSyncPoints[Node.get()] =
882879
enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node);
883880
} else {
884-
MSyncPoints[Node] = enqueueNode(OutCommandBuffer, Node);
881+
MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node);
885882
}
886883
}
887884

@@ -1275,8 +1272,8 @@ void exec_graph_impl::duplicateNodes() {
12751272
auto NodeCopy = NewNodes[i];
12761273
// Look through all the original node successors, find their copies and
12771274
// register those as successors with the current copied node
1278-
for (auto &NextNode : OriginalNode->MSuccessors) {
1279-
auto Successor = NodesMap.at(NextNode.lock());
1275+
for (node_impl &NextNode : OriginalNode->successors()) {
1276+
auto Successor = NodesMap.at(NextNode.shared_from_this());
12801277
NodeCopy->registerSuccessor(Successor);
12811278
}
12821279
}
@@ -1317,8 +1314,8 @@ void exec_graph_impl::duplicateNodes() {
13171314
auto SubgraphNode = SubgraphNodes[i];
13181315
auto NodeCopy = NewSubgraphNodes[i];
13191316

1320-
for (auto &NextNode : SubgraphNode->MSuccessors) {
1321-
auto Successor = SubgraphNodesMap.at(NextNode.lock());
1317+
for (node_impl &NextNode : SubgraphNode->successors()) {
1318+
auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this());
13221319
NodeCopy->registerSuccessor(Successor);
13231320
}
13241321
}
@@ -1339,9 +1336,8 @@ void exec_graph_impl::duplicateNodes() {
13391336
// original subgraph node
13401337

13411338
// Predecessors
1342-
for (auto &PredNodeWeak : NewNode->MPredecessors) {
1343-
auto PredNode = PredNodeWeak.lock();
1344-
auto &Successors = PredNode->MSuccessors;
1339+
for (node_impl &PredNode : NewNode->predecessors()) {
1340+
auto &Successors = PredNode.MSuccessors;
13451341

13461342
// Remove the subgraph node from this nodes successors
13471343
Successors.erase(std::remove_if(Successors.begin(), Successors.end(),
@@ -1353,14 +1349,13 @@ void exec_graph_impl::duplicateNodes() {
13531349
// Add all input nodes from the subgraph as successors for this node
13541350
// instead
13551351
for (auto &Input : Inputs) {
1356-
PredNode->registerSuccessor(Input);
1352+
PredNode.registerSuccessor(Input);
13571353
}
13581354
}
13591355

13601356
// Successors
1361-
for (auto &SuccNodeWeak : NewNode->MSuccessors) {
1362-
auto SuccNode = SuccNodeWeak.lock();
1363-
auto &Predecessors = SuccNode->MPredecessors;
1357+
for (node_impl &SuccNode : NewNode->successors()) {
1358+
auto &Predecessors = SuccNode.MPredecessors;
13641359

13651360
// Remove the subgraph node from this nodes successors
13661361
Predecessors.erase(std::remove_if(Predecessors.begin(),
@@ -1373,7 +1368,7 @@ void exec_graph_impl::duplicateNodes() {
13731368
// Add all Output nodes from the subgraph as predecessors for this node
13741369
// instead
13751370
for (auto &Output : Outputs) {
1376-
Output->registerSuccessor(SuccNode);
1371+
Output->registerSuccessor(SuccNode.shared_from_this());
13771372
}
13781373
}
13791374

@@ -1729,7 +1724,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
17291724
auto ExecNode = MIDCache.find(Node->MID);
17301725
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
17311726

1732-
auto Command = MCommandMap.find(ExecNode->second);
1727+
auto Command = MCommandMap.find(ExecNode->second.get());
17331728
assert(Command != MCommandMap.end());
17341729
UpdateDesc.hCommand = Command->second;
17351730

@@ -1759,7 +1754,7 @@ exec_graph_impl::getURUpdatableNodes(
17591754

17601755
auto ExecNode = MIDCache.find(Node->MID);
17611756
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
1762-
auto PartitionIndex = MPartitionNodes.find(ExecNode->second);
1757+
auto PartitionIndex = MPartitionNodes.find(ExecNode->second.get());
17631758
assert(PartitionIndex != MPartitionNodes.end());
17641759
PartitionedNodes[PartitionIndex->second].push_back(Node);
17651760
}

0 commit comments

Comments
 (0)