@@ -53,10 +53,9 @@ class partition {
53
53
partition () : MSchedule(), MCommandBuffers() {}
54
54
55
55
// / List of root nodes.
56
- std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
57
- MRoots;
56
+ std::set<node_impl *> MRoots;
58
57
// / Execution schedule of nodes in the graph.
59
- std::list<std::shared_ptr< node_impl> > MSchedule;
58
+ std::list<node_impl * > MSchedule;
60
59
// / Map of devices to command buffers.
61
60
std::unordered_map<sycl::device, ur_exp_command_buffer_handle_t >
62
61
MCommandBuffers;
@@ -84,17 +83,20 @@ class partition {
84
83
// replaced every time the partition is executed.
85
84
EventImplPtr MEvent;
86
85
86
+ nodes_range roots () const { return MRoots; }
87
+ nodes_range schedule () const { return MSchedule; }
88
+
87
89
// / Checks if the graph is single path, i.e. each node has a single successor.
88
90
// / @return True if the graph is a single path
89
91
bool checkIfGraphIsSinglePath () {
90
92
if (MRoots.size () > 1 ) {
91
93
return false ;
92
94
}
93
- for (const auto &Node : MSchedule ) {
95
+ for (node_impl &Node : schedule () ) {
94
96
// In version 1.3.28454 of the L0 driver, 2D Copy ops cannot not
95
97
// be enqueued in an in-order cmd-list (causing execution to stall).
96
98
// The 2D Copy test should be removed from here when the bug is fixed.
97
- if ((Node-> MSuccessors .size () > 1 ) || (Node-> isNDCopyNode ())) {
99
+ if ((Node. MSuccessors .size () > 1 ) || (Node. isNDCopyNode ())) {
98
100
return false ;
99
101
}
100
102
}
@@ -103,7 +105,7 @@ class partition {
103
105
}
104
106
105
107
// / Add nodes to MSchedule.
106
- void schedule ();
108
+ void updateSchedule ();
107
109
};
108
110
109
111
// / Implementation details of command_graph<modifiable>.
@@ -126,7 +128,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
126
128
127
129
// / Remove node from list of root nodes.
128
130
// / @param Root Node to remove from list of root nodes.
129
- void removeRoot (const std::shared_ptr< node_impl> &Root);
131
+ void removeRoot (node_impl &Root);
130
132
131
133
// / Verifies the CG is valid to add to the graph and returns set of
132
134
// / dependent nodes if so.
@@ -145,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
145
147
// / @param CommandGroup The CG which stores all information for this node.
146
148
// / @param Deps Dependencies of the created node.
147
149
// / @return Created node in the graph.
148
- std::shared_ptr< node_impl> add (node_type NodeType,
149
- std::shared_ptr<sycl::detail::CG> CommandGroup,
150
- nodes_range Deps);
150
+ node_impl & add (node_type NodeType,
151
+ std::shared_ptr<sycl::detail::CG> CommandGroup,
152
+ nodes_range Deps);
151
153
152
154
// / Create a CGF node in the graph.
153
155
// / @param CGF Command-group function to create node with.
154
156
// / @param Args Node arguments.
155
157
// / @param Deps Dependencies of the created node.
156
158
// / @return Created node in the graph.
157
- std::shared_ptr< node_impl> add (std::function<void (handler &)> CGF,
158
- const std::vector<sycl::detail::ArgDesc> &Args,
159
- std::vector<std::shared_ptr<node_impl>> & Deps);
159
+ node_impl & add (std::function<void (handler &)> CGF,
160
+ const std::vector<sycl::detail::ArgDesc> &Args,
161
+ nodes_range Deps);
160
162
161
163
// / Create an empty node in the graph.
162
164
// / @param Deps List of predecessor nodes.
163
165
// / @return Created node in the graph.
164
- std::shared_ptr< node_impl> add (nodes_range Deps);
166
+ node_impl & add (nodes_range Deps);
165
167
166
168
// / Create a dynamic command-group node in the graph.
167
169
// / @param DynCGImpl Dynamic command-group used to create node.
168
170
// / @param Deps List of predecessor nodes.
169
171
// / @return Created node in the graph.
170
- std::shared_ptr<node_impl>
171
- add (std::shared_ptr<dynamic_command_group_impl> &DynCGImpl, nodes_range Deps);
172
+ node_impl & add ( std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
173
+ nodes_range Deps);
172
174
173
175
// / Add a queue to the set of queues which are currently recording to this
174
176
// / graph.
@@ -190,10 +192,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
190
192
// / @param EventImpl Event to associate with a node in map.
191
193
// / @param NodeImpl Node to associate with event in map.
192
194
void addEventForNode (std::shared_ptr<sycl::detail::event_impl> EventImpl,
193
- const std::shared_ptr< node_impl> &NodeImpl) {
195
+ node_impl &NodeImpl) {
194
196
if (!(EventImpl->hasCommandGraph ()))
195
197
EventImpl->setCommandGraph (shared_from_this ());
196
- MEventsMap[EventImpl] = NodeImpl;
198
+ MEventsMap[EventImpl] = NodeImpl. shared_from_this () ;
197
199
}
198
200
199
201
// / Find the sycl event associated with a node.
@@ -281,15 +283,16 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
281
283
sycl::device getDevice () const { return MDevice; }
282
284
283
285
// / List of root nodes.
284
- std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
285
- MRoots;
286
+ std::set<node_impl *> MRoots;
286
287
287
288
// / Storage for all nodes contained within a graph. Nodes are connected to
288
289
// / each other via weak_ptrs and so do not extend each other's lifetimes.
289
290
// / This storage allows easy iteration over all nodes in the graph, rather
290
291
// / than needing an expensive depth first search.
291
292
std::vector<std::shared_ptr<node_impl>> MNodeStorage;
292
293
294
+ nodes_range roots () const { return MRoots; }
295
+
293
296
// / Find the last node added to this graph from an in-order queue.
294
297
// / @param Queue In-order queue to find the last node added to the graph from.
295
298
// / @return Last node in this graph added from \p Queue recording, or empty
@@ -312,8 +315,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
312
315
std::fstream Stream (FilePath, std::ios::out);
313
316
Stream << " digraph dot {" << std::endl;
314
317
315
- for (std::weak_ptr< node_impl> Node : MRoots )
316
- Node.lock ()-> printDotRecursive (Stream, VisitedNodes, Verbose);
318
+ for (node_impl & Node : roots () )
319
+ Node.printDotRecursive (Stream, VisitedNodes, Verbose);
317
320
318
321
Stream << " }" << std::endl;
319
322
@@ -418,13 +421,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
418
421
}
419
422
420
423
size_t RootsFound = 0 ;
421
- for (std::weak_ptr<node_impl> NodeA : MRoots) {
422
- for (std::weak_ptr<node_impl> NodeB : Graph.MRoots ) {
423
- auto NodeALocked = NodeA.lock ();
424
- auto NodeBLocked = NodeB.lock ();
425
-
426
- if (NodeALocked->isSimilar (*NodeBLocked)) {
427
- if (checkNodeRecursive (*NodeALocked, *NodeBLocked)) {
424
+ for (node_impl &NodeA : roots ()) {
425
+ for (node_impl &NodeB : Graph.roots ()) {
426
+ if (NodeA.isSimilar (NodeB)) {
427
+ if (checkNodeRecursive (NodeA, NodeB)) {
428
428
RootsFound++;
429
429
break ;
430
430
}
@@ -510,6 +510,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
510
510
}
511
511
512
512
private:
513
+ template <typename ... Ts> node_impl &createNode (Ts &&...Args) {
514
+ MNodeStorage.push_back (
515
+ std::make_shared<node_impl>(std::forward<Ts>(Args)...));
516
+ return *MNodeStorage.back ();
517
+ }
518
+
513
519
// / Check the graph for cycles by performing a depth-first search of the
514
520
// / graph. If a node is visited more than once in a given path through the
515
521
// / graph, a cycle is present and the search ends immediately.
@@ -518,18 +524,18 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
518
524
519
525
// / Insert node into list of root nodes.
520
526
// / @param Root Node to add to list of root nodes.
521
- void addRoot (const std::shared_ptr< node_impl> &Root);
527
+ void addRoot (node_impl &Root);
522
528
523
529
// / Adds dependencies for a new node, if it has no deps it will be
524
530
// / added as a root node.
525
531
// / @param Node The node to add deps for
526
532
// / @param Deps List of dependent nodes
527
- void addDepsToNode (const std::shared_ptr< node_impl> &Node, nodes_range Deps) {
533
+ void addDepsToNode (node_impl &Node, nodes_range Deps) {
528
534
for (node_impl &N : Deps) {
529
535
N.registerSuccessor (Node);
530
536
this ->removeRoot (Node);
531
537
}
532
- if (Node-> MPredecessors .empty ()) {
538
+ if (Node. MPredecessors .empty ()) {
533
539
this ->addRoot (Node);
534
540
}
535
541
}
@@ -647,9 +653,7 @@ class exec_graph_impl {
647
653
648
654
// / Query the scheduling of node execution.
649
655
// / @return List of nodes in execution order.
650
- const std::list<std::shared_ptr<node_impl>> &getSchedule () const {
651
- return MSchedule;
652
- }
656
+ const std::list<node_impl *> &getSchedule () const { return MSchedule; }
653
657
654
658
// / Query the graph_impl.
655
659
// / @return pointer to the graph_impl MGraphImpl
@@ -730,8 +734,7 @@ class exec_graph_impl {
730
734
// / @param Node The node being enqueued.
731
735
// / @return UR sync point created for this node in the command-buffer.
732
736
ur_exp_command_buffer_sync_point_t
733
- enqueueNode (ur_exp_command_buffer_handle_t CommandBuffer,
734
- std::shared_ptr<node_impl> Node);
737
+ enqueueNode (ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node);
735
738
736
739
// / Enqueue a node directly to the command-buffer without going through the
737
740
// / scheduler.
@@ -740,11 +743,9 @@ class exec_graph_impl {
740
743
// / @param CommandBuffer Command-buffer to add node to as a command.
741
744
// / @param Node The node being enqueued.
742
745
// / @return UR sync point created for this node in the command-buffer.
743
- ur_exp_command_buffer_sync_point_t
744
- enqueueNodeDirect (const sycl::context &Ctx,
745
- sycl::detail::device_impl &DeviceImpl,
746
- ur_exp_command_buffer_handle_t CommandBuffer,
747
- std::shared_ptr<node_impl> Node);
746
+ ur_exp_command_buffer_sync_point_t enqueueNodeDirect (
747
+ const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl,
748
+ ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node);
748
749
749
750
// / Enqueues a host-task partition (i.e. a partition that contains only a
750
751
// / single node and that node is a host-task).
@@ -873,7 +874,7 @@ class exec_graph_impl {
873
874
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const ;
874
875
875
876
// / Execution schedule of nodes in the graph.
876
- std::list<std::shared_ptr< node_impl> > MSchedule;
877
+ std::list<node_impl * > MSchedule;
877
878
// / Pointer to the modifiable graph impl associated with this executable
878
879
// / graph.
879
880
// / Thread-safe implementation note: in the current implementation
0 commit comments