@@ -85,31 +85,29 @@ inline const char *nodeTypeToString(node_type NodeType) {
85
85
// / @param[in] PartitionBounded If set to true, the topological sort is stopped
86
86
// / at partition borders. Hence, nodes belonging to a partition different from
87
87
// / the NodeImpl partition are not processed.
88
- void sortTopological (std::set<std::weak_ptr<node_impl>,
89
- std::owner_less<std::weak_ptr<node_impl>>> &Roots,
90
- std::list<std::shared_ptr<node_impl>> &SortedNodes,
88
+ void sortTopological (nodes_range Roots, std::list<node_impl *> &SortedNodes,
91
89
bool PartitionBounded) {
92
- std::stack<std::weak_ptr< node_impl> > Source;
90
+ std::stack<node_impl * > Source;
93
91
94
- for (auto &Node : Roots) {
95
- Source.push (Node);
92
+ for (node_impl &Node : Roots) {
93
+ Source.push (& Node);
96
94
}
97
95
98
96
while (!Source.empty ()) {
99
- auto Node = Source.top (). lock ();
97
+ node_impl & Node = * Source.top ();
100
98
Source.pop ();
101
- SortedNodes.push_back (Node);
99
+ SortedNodes.push_back (& Node);
102
100
103
- for (node_impl &Succ : Node-> successors ()) {
101
+ for (node_impl &Succ : Node. successors ()) {
104
102
105
- if (PartitionBounded && (Succ.MPartitionNum != Node-> MPartitionNum )) {
103
+ if (PartitionBounded && (Succ.MPartitionNum != Node. MPartitionNum )) {
106
104
continue ;
107
105
}
108
106
109
107
auto &TotalVisitedEdges = Succ.MTotalVisitedEdges ;
110
108
++TotalVisitedEdges;
111
109
if (TotalVisitedEdges == Succ.MPredecessors .size ()) {
112
- Source.push (Succ. weak_from_this () );
110
+ Source.push (& Succ);
113
111
}
114
112
}
115
113
}
@@ -163,17 +161,17 @@ void propagatePartitionDown(
163
161
// / belong to the same partition)
164
162
// / @param Node node to test
165
163
// / @return True is `Node` is a root of its partition
166
- bool isPartitionRoot (std::shared_ptr< node_impl> Node) {
167
- for (node_impl &Predecessor : Node-> predecessors ()) {
168
- if (Predecessor.MPartitionNum == Node-> MPartitionNum ) {
164
+ bool isPartitionRoot (node_impl & Node) {
165
+ for (node_impl &Predecessor : Node. predecessors ()) {
166
+ if (Predecessor.MPartitionNum == Node. MPartitionNum ) {
169
167
return false ;
170
168
}
171
169
}
172
170
return true ;
173
171
}
174
172
} // anonymous namespace
175
173
176
- void partition::schedule () {
174
+ void partition::updateSchedule () {
177
175
if (MSchedule.empty ()) {
178
176
// There is no need to reset MTotalVisitedEdges before calling
179
177
// sortTopological because this function is only called once per partition.
@@ -256,16 +254,16 @@ void exec_graph_impl::makePartitions() {
256
254
for (auto &Node : MNodeStorage) {
257
255
if (Node->MPartitionNum == i) {
258
256
MPartitionNodes[Node.get ()] = PartitionFinalNum;
259
- if (isPartitionRoot (Node)) {
260
- Partition->MRoots .insert (Node);
257
+ if (isPartitionRoot (* Node)) {
258
+ Partition->MRoots .insert (Node. get () );
261
259
if (Node->MCGType == CGType::CodeplayHostTask) {
262
260
Partition->MIsHostTask = true ;
263
261
}
264
262
}
265
263
}
266
264
}
267
265
if (Partition->MRoots .size () > 0 ) {
268
- Partition->schedule ();
266
+ Partition->updateSchedule ();
269
267
Partition->MIsInOrderGraph = Partition->checkIfGraphIsSinglePath ();
270
268
MPartitions.push_back (Partition);
271
269
MRootPartitions.push_back (Partition);
@@ -287,9 +285,8 @@ void exec_graph_impl::makePartitions() {
287
285
288
286
// Compute partition dependencies
289
287
for (const auto &Partition : MPartitions) {
290
- for (auto const &Root : Partition->MRoots ) {
291
- auto RootNode = Root.lock ();
292
- for (node_impl &NodeDep : RootNode->predecessors ()) {
288
+ for (node_impl &Root : Partition->roots ()) {
289
+ for (node_impl &NodeDep : Root.predecessors ()) {
293
290
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
294
291
Partition->MPredecessors .push_back (Predecessor.get ());
295
292
Predecessor->MSuccessors .push_back (Partition.get ());
@@ -340,13 +337,9 @@ graph_impl::~graph_impl() {
340
337
}
341
338
}
342
339
343
- void graph_impl::addRoot (const std::shared_ptr<node_impl> &Root) {
344
- MRoots.insert (Root);
345
- }
340
+ void graph_impl::addRoot (node_impl &Root) { MRoots.insert (&Root); }
346
341
347
- void graph_impl::removeRoot (const std::shared_ptr<node_impl> &Root) {
348
- MRoots.erase (Root);
349
- }
342
+ void graph_impl::removeRoot (node_impl &Root) { MRoots.erase (&Root); }
350
343
351
344
std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges (
352
345
const std::shared_ptr<sycl::detail::CG> &CommandGroup) const {
@@ -593,7 +586,7 @@ bool graph_impl::clearQueues() {
593
586
}
594
587
595
588
bool graph_impl::checkForCycles () {
596
- std::list<std::shared_ptr< node_impl> > SortedNodes;
589
+ std::list<node_impl * > SortedNodes;
597
590
sortTopological (MRoots, SortedNodes, false );
598
591
599
592
// If after a topological sort, not all the nodes in the graph are sorted,
@@ -664,7 +657,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
664
657
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors .size () == 1 ;
665
658
if (DestLostRootStatus) {
666
659
// Dest is no longer a Root node, so we need to remove it from MRoots.
667
- MRoots.erase (Dest);
660
+ MRoots.erase (Dest. get () );
668
661
}
669
662
670
663
// We can skip cycle checks if either Dest has no successors (cycle not
@@ -679,14 +672,14 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
679
672
Dest->MPredecessors .pop_back ();
680
673
if (DestLostRootStatus) {
681
674
// Add Dest back into MRoots.
682
- MRoots.insert (Dest);
675
+ MRoots.insert (Dest. get () );
683
676
}
684
677
685
678
throw sycl::exception (make_error_code (sycl::errc::invalid),
686
679
" Command graphs cannot contain cycles." );
687
680
}
688
681
}
689
- removeRoot (Dest); // remove receiver from root node list
682
+ removeRoot (* Dest); // remove receiver from root node list
690
683
}
691
684
692
685
std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents (
@@ -740,14 +733,12 @@ void exec_graph_impl::findRealDeps(
740
733
}
741
734
}
742
735
743
- ur_exp_command_buffer_sync_point_t
744
- exec_graph_impl::enqueueNodeDirect (const sycl::context &Ctx,
745
- sycl::detail::device_impl &DeviceImpl,
746
- ur_exp_command_buffer_handle_t CommandBuffer,
747
- std::shared_ptr<node_impl> Node) {
736
+ ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNodeDirect (
737
+ const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl,
738
+ ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node) {
748
739
std::vector<ur_exp_command_buffer_sync_point_t > Deps;
749
- for (node_impl &N : Node-> predecessors ()) {
750
- findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
740
+ for (node_impl &N : Node. predecessors ()) {
741
+ findRealDeps (Deps, N, MPartitionNodes[& Node]);
751
742
}
752
743
ur_exp_command_buffer_sync_point_t NewSyncPoint;
753
744
ur_exp_command_buffer_command_handle_t NewCommand = 0 ;
@@ -760,7 +751,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
760
751
if (xptiEnabled) {
761
752
StreamID = xptiRegisterStream (sycl::detail::SYCL_STREAM_NAME);
762
753
sycl::detail::CGExecKernel *CGExec =
763
- static_cast <sycl::detail::CGExecKernel *>(Node-> MCommandGroup .get ());
754
+ static_cast <sycl::detail::CGExecKernel *>(Node. MCommandGroup .get ());
764
755
sycl::detail::code_location CodeLoc (CGExec->MFileName .c_str (),
765
756
CGExec->MFunctionName .c_str (),
766
757
CGExec->MLine , CGExec->MColumn );
@@ -776,11 +767,11 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
776
767
777
768
ur_result_t Res = sycl::detail::enqueueImpCommandBufferKernel (
778
769
Ctx, DeviceImpl, CommandBuffer,
779
- *static_cast <sycl::detail::CGExecKernel *>((Node-> MCommandGroup .get ())),
770
+ *static_cast <sycl::detail::CGExecKernel *>((Node. MCommandGroup .get ())),
780
771
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr , nullptr );
781
772
782
773
if (MIsUpdatable) {
783
- MCommandMap[Node. get () ] = NewCommand;
774
+ MCommandMap[& Node] = NewCommand;
784
775
}
785
776
786
777
if (Res != UR_RESULT_SUCCESS) {
@@ -799,20 +790,20 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
799
790
800
791
ur_exp_command_buffer_sync_point_t
801
792
exec_graph_impl::enqueueNode (ur_exp_command_buffer_handle_t CommandBuffer,
802
- std::shared_ptr< node_impl> Node) {
793
+ node_impl & Node) {
803
794
804
795
std::vector<ur_exp_command_buffer_sync_point_t > Deps;
805
- for (node_impl &N : Node-> predecessors ()) {
806
- findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
796
+ for (node_impl &N : Node. predecessors ()) {
797
+ findRealDeps (Deps, N, MPartitionNodes[& Node]);
807
798
}
808
799
809
800
sycl::detail::EventImplPtr Event =
810
801
sycl::detail::Scheduler::getInstance ().addCG (
811
- Node-> getCGCopy (), *MQueueImpl,
802
+ Node. getCGCopy (), *MQueueImpl,
812
803
/* EventNeeded=*/ true , CommandBuffer, Deps);
813
804
814
805
if (MIsUpdatable) {
815
- MCommandMap[Node. get () ] = Event->getCommandBufferCommand ();
806
+ MCommandMap[& Node] = Event->getCommandBufferCommand ();
816
807
}
817
808
818
809
return Event->getSyncPoint ();
@@ -861,25 +852,25 @@ void exec_graph_impl::createCommandBuffers(
861
852
862
853
Partition->MCommandBuffers [Device] = OutCommandBuffer;
863
854
864
- for (const auto &Node : Partition->MSchedule ) {
855
+ for (node_impl &Node : Partition->schedule () ) {
865
856
// Some nodes are not scheduled like other nodes, and only their
866
857
// dependencies are propagated in findRealDeps
867
- if (!Node-> requiresEnqueue ())
858
+ if (!Node. requiresEnqueue ())
868
859
continue ;
869
860
870
- sycl::detail::CGType type = Node-> MCGType ;
861
+ sycl::detail::CGType type = Node. MCGType ;
871
862
// If the node is a kernel with no special requirements we can enqueue it
872
863
// directly.
873
864
if (type == sycl::detail::CGType::Kernel &&
874
- Node-> MCommandGroup ->getRequirements ().size () +
865
+ Node. MCommandGroup ->getRequirements ().size () +
875
866
static_cast <sycl::detail::CGExecKernel *>(
876
- Node-> MCommandGroup .get ())
867
+ Node. MCommandGroup .get ())
877
868
->MStreams .size () ==
878
869
0 ) {
879
- MSyncPoints[Node. get () ] =
870
+ MSyncPoints[& Node] =
880
871
enqueueNodeDirect (MContext, DeviceImpl, OutCommandBuffer, Node);
881
872
} else {
882
- MSyncPoints[Node. get () ] = enqueueNode (OutCommandBuffer, Node);
873
+ MSyncPoints[& Node] = enqueueNode (OutCommandBuffer, Node);
883
874
}
884
875
}
885
876
@@ -2007,7 +1998,7 @@ std::vector<node> modifiable_command_graph::get_nodes() const {
2007
1998
std::vector<node> modifiable_command_graph::get_root_nodes () const {
2008
1999
graph_impl::ReadLock Lock (impl->MMutex );
2009
2000
auto &Roots = impl->MRoots ;
2010
- std::vector<std::weak_ptr< node_impl> > Impls{};
2001
+ std::vector<node_impl * > Impls{};
2011
2002
2012
2003
std::copy (Roots.begin (), Roots.end (), std::back_inserter (Impls));
2013
2004
return createNodesFromImpls (Impls);
0 commit comments