@@ -85,31 +85,29 @@ inline const char *nodeTypeToString(node_type NodeType) {
85
85
// / @param[in] PartitionBounded If set to true, the topological sort is stopped
86
86
// / at partition borders. Hence, nodes belonging to a partition different from
87
87
// / the NodeImpl partition are not processed.
88
- void sortTopological (std::set<std::weak_ptr<node_impl>,
89
- std::owner_less<std::weak_ptr<node_impl>>> &Roots,
90
- std::list<std::shared_ptr<node_impl>> &SortedNodes,
88
+ void sortTopological (nodes_range Roots, std::list<node_impl *> &SortedNodes,
91
89
bool PartitionBounded) {
92
- std::stack<std::weak_ptr< node_impl> > Source;
90
+ std::stack<node_impl * > Source;
93
91
94
- for (auto &Node : Roots) {
95
- Source.push (Node);
92
+ for (node_impl &Node : Roots) {
93
+ Source.push (& Node);
96
94
}
97
95
98
96
while (!Source.empty ()) {
99
- auto Node = Source.top (). lock ();
97
+ node_impl & Node = * Source.top ();
100
98
Source.pop ();
101
- SortedNodes.push_back (Node);
99
+ SortedNodes.push_back (& Node);
102
100
103
- for (node_impl &Succ : Node-> successors ()) {
101
+ for (node_impl &Succ : Node. successors ()) {
104
102
105
- if (PartitionBounded && (Succ.MPartitionNum != Node-> MPartitionNum )) {
103
+ if (PartitionBounded && (Succ.MPartitionNum != Node. MPartitionNum )) {
106
104
continue ;
107
105
}
108
106
109
107
auto &TotalVisitedEdges = Succ.MTotalVisitedEdges ;
110
108
++TotalVisitedEdges;
111
109
if (TotalVisitedEdges == Succ.MPredecessors .size ()) {
112
- Source.push (Succ. weak_from_this () );
110
+ Source.push (& Succ);
113
111
}
114
112
}
115
113
}
@@ -173,7 +171,7 @@ bool isPartitionRoot(std::shared_ptr<node_impl> Node) {
173
171
}
174
172
} // anonymous namespace
175
173
176
- void partition::schedule () {
174
+ void partition::updateSchedule () {
177
175
if (MSchedule.empty ()) {
178
176
// There is no need to reset MTotalVisitedEdges before calling
179
177
// sortTopological because this function is only called once per partition.
@@ -257,15 +255,15 @@ void exec_graph_impl::makePartitions() {
257
255
if (Node->MPartitionNum == i) {
258
256
MPartitionNodes[Node.get ()] = PartitionFinalNum;
259
257
if (isPartitionRoot (Node)) {
260
- Partition->MRoots .insert (Node);
258
+ Partition->MRoots .insert (Node. get () );
261
259
if (Node->MCGType == CGType::CodeplayHostTask) {
262
260
Partition->MIsHostTask = true ;
263
261
}
264
262
}
265
263
}
266
264
}
267
265
if (Partition->MRoots .size () > 0 ) {
268
- Partition->schedule ();
266
+ Partition->updateSchedule ();
269
267
Partition->MIsInOrderGraph = Partition->checkIfGraphIsSinglePath ();
270
268
MPartitions.push_back (Partition);
271
269
MRootPartitions.push_back (Partition);
@@ -287,9 +285,8 @@ void exec_graph_impl::makePartitions() {
287
285
288
286
// Compute partition dependencies
289
287
for (const auto &Partition : MPartitions) {
290
- for (auto const &Root : Partition->MRoots ) {
291
- auto RootNode = Root.lock ();
292
- for (node_impl &NodeDep : RootNode->predecessors ()) {
288
+ for (node_impl &Root : Partition->roots ()) {
289
+ for (node_impl &NodeDep : Root.predecessors ()) {
293
290
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
294
291
Partition->MPredecessors .push_back (Predecessor.get ());
295
292
Predecessor->MSuccessors .push_back (Partition.get ());
@@ -340,13 +337,9 @@ graph_impl::~graph_impl() {
340
337
}
341
338
}
342
339
343
- void graph_impl::addRoot (const std::shared_ptr<node_impl> &Root) {
344
- MRoots.insert (Root);
345
- }
340
+ void graph_impl::addRoot (node_impl &Root) { MRoots.insert (&Root); }
346
341
347
- void graph_impl::removeRoot (const std::shared_ptr<node_impl> &Root) {
348
- MRoots.erase (Root);
349
- }
342
+ void graph_impl::removeRoot (node_impl &Root) { MRoots.erase (&Root); }
350
343
351
344
std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges (
352
345
const std::shared_ptr<sycl::detail::CG> &CommandGroup) const {
@@ -593,7 +586,7 @@ bool graph_impl::clearQueues() {
593
586
}
594
587
595
588
bool graph_impl::checkForCycles () {
596
- std::list<std::shared_ptr< node_impl> > SortedNodes;
589
+ std::list<node_impl * > SortedNodes;
597
590
sortTopological (MRoots, SortedNodes, false );
598
591
599
592
// If after a topological sort, not all the nodes in the graph are sorted,
@@ -664,7 +657,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
664
657
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors .size () == 1 ;
665
658
if (DestLostRootStatus) {
666
659
// Dest is no longer a Root node, so we need to remove it from MRoots.
667
- MRoots.erase (Dest);
660
+ MRoots.erase (Dest. get () );
668
661
}
669
662
670
663
// We can skip cycle checks if either Dest has no successors (cycle not
@@ -679,14 +672,14 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
679
672
Dest->MPredecessors .pop_back ();
680
673
if (DestLostRootStatus) {
681
674
// Add Dest back into MRoots.
682
- MRoots.insert (Dest);
675
+ MRoots.insert (Dest. get () );
683
676
}
684
677
685
678
throw sycl::exception (make_error_code (sycl::errc::invalid),
686
679
" Command graphs cannot contain cycles." );
687
680
}
688
681
}
689
- removeRoot (Dest); // remove receiver from root node list
682
+ removeRoot (* Dest); // remove receiver from root node list
690
683
}
691
684
692
685
std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents (
@@ -739,14 +732,12 @@ void exec_graph_impl::findRealDeps(
739
732
}
740
733
}
741
734
742
- ur_exp_command_buffer_sync_point_t
743
- exec_graph_impl::enqueueNodeDirect (const sycl::context &Ctx,
744
- sycl::detail::device_impl &DeviceImpl,
745
- ur_exp_command_buffer_handle_t CommandBuffer,
746
- std::shared_ptr<node_impl> Node) {
735
+ ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNodeDirect (
736
+ const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl,
737
+ ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node) {
747
738
std::vector<ur_exp_command_buffer_sync_point_t > Deps;
748
- for (node_impl &N : Node-> predecessors ()) {
749
- findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
739
+ for (node_impl &N : Node. predecessors ()) {
740
+ findRealDeps (Deps, N, MPartitionNodes[& Node]);
750
741
}
751
742
ur_exp_command_buffer_sync_point_t NewSyncPoint;
752
743
ur_exp_command_buffer_command_handle_t NewCommand = 0 ;
@@ -759,7 +750,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
759
750
if (xptiEnabled) {
760
751
StreamID = xptiRegisterStream (sycl::detail::SYCL_STREAM_NAME);
761
752
sycl::detail::CGExecKernel *CGExec =
762
- static_cast <sycl::detail::CGExecKernel *>(Node-> MCommandGroup .get ());
753
+ static_cast <sycl::detail::CGExecKernel *>(Node. MCommandGroup .get ());
763
754
sycl::detail::code_location CodeLoc (CGExec->MFileName .c_str (),
764
755
CGExec->MFunctionName .c_str (),
765
756
CGExec->MLine , CGExec->MColumn );
@@ -775,11 +766,11 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
775
766
776
767
ur_result_t Res = sycl::detail::enqueueImpCommandBufferKernel (
777
768
Ctx, DeviceImpl, CommandBuffer,
778
- *static_cast <sycl::detail::CGExecKernel *>((Node-> MCommandGroup .get ())),
769
+ *static_cast <sycl::detail::CGExecKernel *>((Node. MCommandGroup .get ())),
779
770
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr , nullptr );
780
771
781
772
if (MIsUpdatable) {
782
- MCommandMap[Node. get () ] = NewCommand;
773
+ MCommandMap[& Node] = NewCommand;
783
774
}
784
775
785
776
if (Res != UR_RESULT_SUCCESS) {
@@ -798,20 +789,20 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
798
789
799
790
ur_exp_command_buffer_sync_point_t
800
791
exec_graph_impl::enqueueNode (ur_exp_command_buffer_handle_t CommandBuffer,
801
- std::shared_ptr< node_impl> Node) {
792
+ node_impl & Node) {
802
793
803
794
std::vector<ur_exp_command_buffer_sync_point_t > Deps;
804
- for (node_impl &N : Node-> predecessors ()) {
805
- findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
795
+ for (node_impl &N : Node. predecessors ()) {
796
+ findRealDeps (Deps, N, MPartitionNodes[& Node]);
806
797
}
807
798
808
799
sycl::detail::EventImplPtr Event =
809
800
sycl::detail::Scheduler::getInstance ().addCG (
810
- Node-> getCGCopy (), *MQueueImpl,
801
+ Node. getCGCopy (), *MQueueImpl,
811
802
/* EventNeeded=*/ true , CommandBuffer, Deps);
812
803
813
804
if (MIsUpdatable) {
814
- MCommandMap[Node. get () ] = Event->getCommandBufferCommand ();
805
+ MCommandMap[& Node] = Event->getCommandBufferCommand ();
815
806
}
816
807
817
808
return Event->getSyncPoint ();
@@ -860,25 +851,25 @@ void exec_graph_impl::createCommandBuffers(
860
851
861
852
Partition->MCommandBuffers [Device] = OutCommandBuffer;
862
853
863
- for (const auto &Node : Partition->MSchedule ) {
854
+ for (node_impl &Node : Partition->schedule () ) {
864
855
// Some nodes are not scheduled like other nodes, and only their
865
856
// dependencies are propagated in findRealDeps
866
- if (!Node-> requiresEnqueue ())
857
+ if (!Node. requiresEnqueue ())
867
858
continue ;
868
859
869
- sycl::detail::CGType type = Node-> MCGType ;
860
+ sycl::detail::CGType type = Node. MCGType ;
870
861
// If the node is a kernel with no special requirements we can enqueue it
871
862
// directly.
872
863
if (type == sycl::detail::CGType::Kernel &&
873
- Node-> MCommandGroup ->getRequirements ().size () +
864
+ Node. MCommandGroup ->getRequirements ().size () +
874
865
static_cast <sycl::detail::CGExecKernel *>(
875
- Node-> MCommandGroup .get ())
866
+ Node. MCommandGroup .get ())
876
867
->MStreams .size () ==
877
868
0 ) {
878
- MSyncPoints[Node. get () ] =
869
+ MSyncPoints[& Node] =
879
870
enqueueNodeDirect (MContext, DeviceImpl, OutCommandBuffer, Node);
880
871
} else {
881
- MSyncPoints[Node. get () ] = enqueueNode (OutCommandBuffer, Node);
872
+ MSyncPoints[& Node] = enqueueNode (OutCommandBuffer, Node);
882
873
}
883
874
}
884
875
@@ -2006,7 +1997,7 @@ std::vector<node> modifiable_command_graph::get_nodes() const {
2006
1997
std::vector<node> modifiable_command_graph::get_root_nodes () const {
2007
1998
graph_impl::ReadLock Lock (impl->MMutex );
2008
1999
auto &Roots = impl->MRoots ;
2009
- std::vector<std::weak_ptr< node_impl> > Impls{};
2000
+ std::vector<node_impl * > Impls{};
2010
2001
2011
2002
std::copy (Roots.begin (), Roots.end (), std::back_inserter (Impls));
2012
2003
return createNodesFromImpls (Impls);
0 commit comments