@@ -100,17 +100,16 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
100
100
Source.pop ();
101
101
SortedNodes.push_back (Node);
102
102
103
- for (auto &SuccWP : Node->MSuccessors ) {
104
- auto Succ = SuccWP.lock ();
103
+ for (node_impl &Succ : Node->successors ()) {
105
104
106
- if (PartitionBounded && (Succ-> MPartitionNum != Node->MPartitionNum )) {
105
+ if (PartitionBounded && (Succ. MPartitionNum != Node->MPartitionNum )) {
107
106
continue ;
108
107
}
109
108
110
- auto &TotalVisitedEdges = Succ-> MTotalVisitedEdges ;
109
+ auto &TotalVisitedEdges = Succ. MTotalVisitedEdges ;
111
110
++TotalVisitedEdges;
112
- if (TotalVisitedEdges == Succ-> MPredecessors .size ()) {
113
- Source.push (Succ);
111
+ if (TotalVisitedEdges == Succ. MPredecessors .size ()) {
112
+ Source.push (Succ. weak_from_this () );
114
113
}
115
114
}
116
115
}
@@ -127,14 +126,14 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
127
126
// / a node with a smaller partition number.
128
127
// / @param Node Node to assign to the partition.
129
128
// / @param PartitionNum Number to propagate.
130
- void propagatePartitionUp (std::shared_ptr< node_impl> Node, int PartitionNum) {
131
- if (((Node-> MPartitionNum != -1 ) && (Node-> MPartitionNum <= PartitionNum)) ||
132
- (Node-> MCGType == sycl::detail::CGType::CodeplayHostTask)) {
129
+ void propagatePartitionUp (node_impl & Node, int PartitionNum) {
130
+ if (((Node. MPartitionNum != -1 ) && (Node. MPartitionNum <= PartitionNum)) ||
131
+ (Node. MCGType == sycl::detail::CGType::CodeplayHostTask)) {
133
132
return ;
134
133
}
135
- Node-> MPartitionNum = PartitionNum;
136
- for (auto &Predecessor : Node-> MPredecessors ) {
137
- propagatePartitionUp (Predecessor. lock () , PartitionNum);
134
+ Node. MPartitionNum = PartitionNum;
135
+ for (node_impl &Predecessor : Node. predecessors () ) {
136
+ propagatePartitionUp (Predecessor, PartitionNum);
138
137
}
139
138
}
140
139
@@ -146,17 +145,17 @@ void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
146
145
// / @param HostTaskList List of host tasks that have already been processed and
147
146
// / are encountered as successors to the node Node.
148
147
void propagatePartitionDown (
149
- const std::shared_ptr< node_impl> &Node, int PartitionNum,
148
+ node_impl &Node, int PartitionNum,
150
149
std::list<std::shared_ptr<node_impl>> &HostTaskList) {
151
- if (Node-> MCGType == sycl::detail::CGType::CodeplayHostTask) {
152
- if (Node-> MPartitionNum != -1 ) {
153
- HostTaskList.push_front (Node);
150
+ if (Node. MCGType == sycl::detail::CGType::CodeplayHostTask) {
151
+ if (Node. MPartitionNum != -1 ) {
152
+ HostTaskList.push_front (Node. shared_from_this () );
154
153
}
155
154
return ;
156
155
}
157
- Node-> MPartitionNum = PartitionNum;
158
- for (auto &Successor : Node-> MSuccessors ) {
159
- propagatePartitionDown (Successor. lock () , PartitionNum, HostTaskList);
156
+ Node. MPartitionNum = PartitionNum;
157
+ for (node_impl &Successor : Node. successors () ) {
158
+ propagatePartitionDown (Successor, PartitionNum, HostTaskList);
160
159
}
161
160
}
162
161
@@ -165,8 +164,8 @@ void propagatePartitionDown(
165
164
// / @param Node node to test
166
165
// / @return True is `Node` is a root of its partition
167
166
bool isPartitionRoot (std::shared_ptr<node_impl> Node) {
168
- for (auto &Predecessor : Node->MPredecessors ) {
169
- if (Predecessor.lock ()-> MPartitionNum == Node->MPartitionNum ) {
167
+ for (node_impl &Predecessor : Node->predecessors () ) {
168
+ if (Predecessor.MPartitionNum == Node->MPartitionNum ) {
170
169
return false ;
171
170
}
172
171
}
@@ -221,15 +220,15 @@ void exec_graph_impl::makePartitions() {
221
220
auto Node = HostTaskList.front ();
222
221
HostTaskList.pop_front ();
223
222
CurrentPartition++;
224
- for (auto &Predecessor : Node->MPredecessors ) {
225
- propagatePartitionUp (Predecessor. lock () , CurrentPartition);
223
+ for (node_impl &Predecessor : Node->predecessors () ) {
224
+ propagatePartitionUp (Predecessor, CurrentPartition);
226
225
}
227
226
CurrentPartition++;
228
227
Node->MPartitionNum = CurrentPartition;
229
228
CurrentPartition++;
230
229
auto TmpSize = HostTaskList.size ();
231
- for (auto &Successor : Node->MSuccessors ) {
232
- propagatePartitionDown (Successor. lock () , CurrentPartition, HostTaskList);
230
+ for (node_impl &Successor : Node->successors () ) {
231
+ propagatePartitionDown (Successor, CurrentPartition, HostTaskList);
233
232
}
234
233
if (HostTaskList.size () > TmpSize) {
235
234
// At least one HostTask has been re-numbered so group merge opportunities
@@ -256,7 +255,7 @@ void exec_graph_impl::makePartitions() {
256
255
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
257
256
for (auto &Node : MNodeStorage) {
258
257
if (Node->MPartitionNum == i) {
259
- MPartitionNodes[Node] = PartitionFinalNum;
258
+ MPartitionNodes[Node. get () ] = PartitionFinalNum;
260
259
if (isPartitionRoot (Node)) {
261
260
Partition->MRoots .insert (Node);
262
261
if (Node->MCGType == CGType::CodeplayHostTask) {
@@ -290,9 +289,8 @@ void exec_graph_impl::makePartitions() {
290
289
for (const auto &Partition : MPartitions) {
291
290
for (auto const &Root : Partition->MRoots ) {
292
291
auto RootNode = Root.lock ();
293
- for (const auto &Dep : RootNode->MPredecessors ) {
294
- auto NodeDep = Dep.lock ();
295
- auto &Predecessor = MPartitions[MPartitionNodes[NodeDep]];
292
+ for (node_impl &NodeDep : RootNode->predecessors ()) {
293
+ auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
296
294
Partition->MPredecessors .push_back (Predecessor.get ());
297
295
Predecessor->MSuccessors .push_back (Partition.get ());
298
296
}
@@ -390,8 +388,8 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
390
388
bool ShouldAddDep = true ;
391
389
// If any of this node's successors have this requirement then we skip
392
390
// adding the current node as a dependency.
393
- for (auto &Succ : Node->MSuccessors ) {
394
- if (Succ.lock ()-> hasRequirementDependency (Req)) {
391
+ for (node_impl &Succ : Node->successors () ) {
392
+ if (Succ.hasRequirementDependency (Req)) {
395
393
ShouldAddDep = false ;
396
394
break ;
397
395
}
@@ -611,8 +609,7 @@ bool graph_impl::checkForCycles() {
611
609
return CycleFound;
612
610
}
613
611
614
- std::shared_ptr<node_impl>
615
- graph_impl::getLastInorderNode (sycl::detail::queue_impl *Queue) {
612
+ node_impl *graph_impl::getLastInorderNode (sycl::detail::queue_impl *Queue) {
616
613
if (!Queue) {
617
614
assert (0 ==
618
615
MInorderQueueMap.count (std::weak_ptr<sycl::detail::queue_impl>{}));
@@ -626,7 +623,7 @@ graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
626
623
627
624
void graph_impl::setLastInorderNode (sycl::detail::queue_impl &Queue,
628
625
std::shared_ptr<node_impl> Node) {
629
- MInorderQueueMap[Queue.weak_from_this ()] = std::move ( Node) ;
626
+ MInorderQueueMap[Queue.weak_from_this ()] = &* Node;
630
627
}
631
628
632
629
void graph_impl::makeEdge (std::shared_ptr<node_impl> Src,
@@ -721,17 +718,16 @@ void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
721
718
// predecessors until we find the real dependency.
722
719
void exec_graph_impl::findRealDeps (
723
720
std::vector<ur_exp_command_buffer_sync_point_t > &Deps,
724
- std::shared_ptr<node_impl> CurrentNode, int ReferencePartitionNum) {
725
- if (!CurrentNode->requiresEnqueue ()) {
726
- for (auto &N : CurrentNode->MPredecessors ) {
727
- auto NodeImpl = N.lock ();
721
+ node_impl &CurrentNode, int ReferencePartitionNum) {
722
+ if (!CurrentNode.requiresEnqueue ()) {
723
+ for (node_impl &NodeImpl : CurrentNode.predecessors ()) {
728
724
findRealDeps (Deps, NodeImpl, ReferencePartitionNum);
729
725
}
730
726
} else {
731
727
// Verify if CurrentNode belong the the same partition
732
- if (MPartitionNodes[CurrentNode] == ReferencePartitionNum) {
728
+ if (MPartitionNodes[& CurrentNode] == ReferencePartitionNum) {
733
729
// Verify that the sync point has actually been set for this node.
734
- auto SyncPoint = MSyncPoints.find (CurrentNode);
730
+ auto SyncPoint = MSyncPoints.find (& CurrentNode);
735
731
assert (SyncPoint != MSyncPoints.end () &&
736
732
" No sync point has been set for node dependency." );
737
733
// Check if the dependency has already been added.
@@ -749,8 +745,8 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
749
745
ur_exp_command_buffer_handle_t CommandBuffer,
750
746
std::shared_ptr<node_impl> Node) {
751
747
std::vector<ur_exp_command_buffer_sync_point_t > Deps;
752
- for (auto &N : Node->MPredecessors ) {
753
- findRealDeps (Deps, N. lock () , MPartitionNodes[Node]);
748
+ for (node_impl &N : Node->predecessors () ) {
749
+ findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
754
750
}
755
751
ur_exp_command_buffer_sync_point_t NewSyncPoint;
756
752
ur_exp_command_buffer_command_handle_t NewCommand = 0 ;
@@ -783,7 +779,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
783
779
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr , nullptr );
784
780
785
781
if (MIsUpdatable) {
786
- MCommandMap[Node] = NewCommand;
782
+ MCommandMap[Node. get () ] = NewCommand;
787
783
}
788
784
789
785
if (Res != UR_RESULT_SUCCESS) {
@@ -805,8 +801,8 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
805
801
std::shared_ptr<node_impl> Node) {
806
802
807
803
std::vector<ur_exp_command_buffer_sync_point_t > Deps;
808
- for (auto &N : Node->MPredecessors ) {
809
- findRealDeps (Deps, N. lock () , MPartitionNodes[Node]);
804
+ for (node_impl &N : Node->predecessors () ) {
805
+ findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
810
806
}
811
807
812
808
sycl::detail::EventImplPtr Event =
@@ -815,7 +811,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
815
811
/* EventNeeded=*/ true , CommandBuffer, Deps);
816
812
817
813
if (MIsUpdatable) {
818
- MCommandMap[Node] = Event->getCommandBufferCommand ();
814
+ MCommandMap[Node. get () ] = Event->getCommandBufferCommand ();
819
815
}
820
816
821
817
return Event->getSyncPoint ();
@@ -831,7 +827,8 @@ void exec_graph_impl::buildRequirements() {
831
827
Node->MCommandGroup ->getRequirements ().begin (),
832
828
Node->MCommandGroup ->getRequirements ().end ());
833
829
834
- std::shared_ptr<partition> &Partition = MPartitions[MPartitionNodes[Node]];
830
+ std::shared_ptr<partition> &Partition =
831
+ MPartitions[MPartitionNodes[Node.get ()]];
835
832
836
833
Partition->MRequirements .insert (
837
834
Partition->MRequirements .end (),
@@ -878,10 +875,10 @@ void exec_graph_impl::createCommandBuffers(
878
875
Node->MCommandGroup .get ())
879
876
->MStreams .size () ==
880
877
0 ) {
881
- MSyncPoints[Node] =
878
+ MSyncPoints[Node. get () ] =
882
879
enqueueNodeDirect (MContext, DeviceImpl, OutCommandBuffer, Node);
883
880
} else {
884
- MSyncPoints[Node] = enqueueNode (OutCommandBuffer, Node);
881
+ MSyncPoints[Node. get () ] = enqueueNode (OutCommandBuffer, Node);
885
882
}
886
883
}
887
884
@@ -1275,8 +1272,8 @@ void exec_graph_impl::duplicateNodes() {
1275
1272
auto NodeCopy = NewNodes[i];
1276
1273
// Look through all the original node successors, find their copies and
1277
1274
// register those as successors with the current copied node
1278
- for (auto &NextNode : OriginalNode->MSuccessors ) {
1279
- auto Successor = NodesMap.at (NextNode.lock ());
1275
+ for (node_impl &NextNode : OriginalNode->successors () ) {
1276
+ auto Successor = NodesMap.at (NextNode.shared_from_this ());
1280
1277
NodeCopy->registerSuccessor (Successor);
1281
1278
}
1282
1279
}
@@ -1317,8 +1314,8 @@ void exec_graph_impl::duplicateNodes() {
1317
1314
auto SubgraphNode = SubgraphNodes[i];
1318
1315
auto NodeCopy = NewSubgraphNodes[i];
1319
1316
1320
- for (auto &NextNode : SubgraphNode->MSuccessors ) {
1321
- auto Successor = SubgraphNodesMap.at (NextNode.lock ());
1317
+ for (node_impl &NextNode : SubgraphNode->successors () ) {
1318
+ auto Successor = SubgraphNodesMap.at (NextNode.shared_from_this ());
1322
1319
NodeCopy->registerSuccessor (Successor);
1323
1320
}
1324
1321
}
@@ -1339,9 +1336,8 @@ void exec_graph_impl::duplicateNodes() {
1339
1336
// original subgraph node
1340
1337
1341
1338
// Predecessors
1342
- for (auto &PredNodeWeak : NewNode->MPredecessors ) {
1343
- auto PredNode = PredNodeWeak.lock ();
1344
- auto &Successors = PredNode->MSuccessors ;
1339
+ for (node_impl &PredNode : NewNode->predecessors ()) {
1340
+ auto &Successors = PredNode.MSuccessors ;
1345
1341
1346
1342
// Remove the subgraph node from this nodes successors
1347
1343
Successors.erase (std::remove_if (Successors.begin (), Successors.end (),
@@ -1353,14 +1349,13 @@ void exec_graph_impl::duplicateNodes() {
1353
1349
// Add all input nodes from the subgraph as successors for this node
1354
1350
// instead
1355
1351
for (auto &Input : Inputs) {
1356
- PredNode-> registerSuccessor (Input);
1352
+ PredNode. registerSuccessor (Input);
1357
1353
}
1358
1354
}
1359
1355
1360
1356
// Successors
1361
- for (auto &SuccNodeWeak : NewNode->MSuccessors ) {
1362
- auto SuccNode = SuccNodeWeak.lock ();
1363
- auto &Predecessors = SuccNode->MPredecessors ;
1357
+ for (node_impl &SuccNode : NewNode->successors ()) {
1358
+ auto &Predecessors = SuccNode.MPredecessors ;
1364
1359
1365
1360
// Remove the subgraph node from this nodes successors
1366
1361
Predecessors.erase (std::remove_if (Predecessors.begin (),
@@ -1373,7 +1368,7 @@ void exec_graph_impl::duplicateNodes() {
1373
1368
// Add all Output nodes from the subgraph as predecessors for this node
1374
1369
// instead
1375
1370
for (auto &Output : Outputs) {
1376
- Output->registerSuccessor (SuccNode);
1371
+ Output->registerSuccessor (SuccNode. shared_from_this () );
1377
1372
}
1378
1373
}
1379
1374
@@ -1729,7 +1724,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
1729
1724
auto ExecNode = MIDCache.find (Node->MID );
1730
1725
assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1731
1726
1732
- auto Command = MCommandMap.find (ExecNode->second );
1727
+ auto Command = MCommandMap.find (ExecNode->second . get () );
1733
1728
assert (Command != MCommandMap.end ());
1734
1729
UpdateDesc.hCommand = Command->second ;
1735
1730
@@ -1759,7 +1754,7 @@ exec_graph_impl::getURUpdatableNodes(
1759
1754
1760
1755
auto ExecNode = MIDCache.find (Node->MID );
1761
1756
assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1762
- auto PartitionIndex = MPartitionNodes.find (ExecNode->second );
1757
+ auto PartitionIndex = MPartitionNodes.find (ExecNode->second . get () );
1763
1758
assert (PartitionIndex != MPartitionNodes.end ());
1764
1759
PartitionedNodes[PartitionIndex->second ].push_back (Node);
1765
1760
}
0 commit comments