Skip to content

Commit 64069af

Browse files
[NFCI][SYCL][Graph] Refactor graph_impl::add
Part of the refactoring to eliminate `std::weak_ptr<node_impl>` and reduce usage of `std::shared_ptr<node_impl>` by preferring raw ptr/ref. Previous PRs in the series: #19295 #19332 #19334 #19350 * Accept `Deps` as `nodes_range` in `graph_impl::add` * Return `node_impl &` from `graph_impl::add` * Add `node` support in `nodes_range` and use that together with modified `graph_impl::add` when created new `node_impl`s based on `std::vector<node> Deps` to avoid creation of temporary `DepImpls` storage. * Also updated `registerSuccessor/registerPredecessor` and `addEventForNode/addDepsToNode` to accept raw `node_impl &` as the changes above resulted in having raw reference at the call sites.
1 parent 0bba746 commit 64069af

File tree

9 files changed

+234
-281
lines changed

9 files changed

+234
-281
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 77 additions & 103 deletions
Large diffs are not rendered by default.

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ class partition {
5353
partition() : MSchedule(), MCommandBuffers() {}
5454

5555
/// List of root nodes.
56-
std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
57-
MRoots;
56+
std::set<node_impl *> MRoots;
5857
/// Execution schedule of nodes in the graph.
59-
std::list<std::shared_ptr<node_impl>> MSchedule;
58+
std::list<node_impl *> MSchedule;
6059
/// Map of devices to command buffers.
6160
std::unordered_map<sycl::device, ur_exp_command_buffer_handle_t>
6261
MCommandBuffers;
@@ -84,17 +83,20 @@ class partition {
8483
// replaced every time the partition is executed.
8584
EventImplPtr MEvent;
8685

86+
nodes_range roots() const { return MRoots; }
87+
nodes_range schedule() const { return MSchedule; }
88+
8789
/// Checks if the graph is single path, i.e. each node has a single successor.
8890
/// @return True if the graph is a single path
8991
bool checkIfGraphIsSinglePath() {
9092
if (MRoots.size() > 1) {
9193
return false;
9294
}
93-
for (const auto &Node : MSchedule) {
95+
for (node_impl &Node : schedule()) {
9496
// In version 1.3.28454 of the L0 driver, 2D Copy ops cannot not
9597
// be enqueued in an in-order cmd-list (causing execution to stall).
9698
// The 2D Copy test should be removed from here when the bug is fixed.
97-
if ((Node->MSuccessors.size() > 1) || (Node->isNDCopyNode())) {
99+
if ((Node.MSuccessors.size() > 1) || (Node.isNDCopyNode())) {
98100
return false;
99101
}
100102
}
@@ -103,7 +105,7 @@ class partition {
103105
}
104106

105107
/// Add nodes to MSchedule.
106-
void schedule();
108+
void updateSchedule();
107109
};
108110

109111
/// Implementation details of command_graph<modifiable>.
@@ -126,7 +128,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
126128

127129
/// Remove node from list of root nodes.
128130
/// @param Root Node to remove from list of root nodes.
129-
void removeRoot(const std::shared_ptr<node_impl> &Root);
131+
void removeRoot(node_impl &Root);
130132

131133
/// Verifies the CG is valid to add to the graph and returns set of
132134
/// dependent nodes if so.
@@ -145,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
145147
/// @param CommandGroup The CG which stores all information for this node.
146148
/// @param Deps Dependencies of the created node.
147149
/// @return Created node in the graph.
148-
std::shared_ptr<node_impl> add(node_type NodeType,
149-
std::shared_ptr<sycl::detail::CG> CommandGroup,
150-
nodes_range Deps);
150+
node_impl &add(node_type NodeType,
151+
std::shared_ptr<sycl::detail::CG> CommandGroup,
152+
nodes_range Deps);
151153

152154
/// Create a CGF node in the graph.
153155
/// @param CGF Command-group function to create node with.
154156
/// @param Args Node arguments.
155157
/// @param Deps Dependencies of the created node.
156158
/// @return Created node in the graph.
157-
std::shared_ptr<node_impl> add(std::function<void(handler &)> CGF,
158-
const std::vector<sycl::detail::ArgDesc> &Args,
159-
std::vector<std::shared_ptr<node_impl>> &Deps);
159+
node_impl &add(std::function<void(handler &)> CGF,
160+
const std::vector<sycl::detail::ArgDesc> &Args,
161+
nodes_range Deps);
160162

161163
/// Create an empty node in the graph.
162164
/// @param Deps List of predecessor nodes.
163165
/// @return Created node in the graph.
164-
std::shared_ptr<node_impl> add(nodes_range Deps);
166+
node_impl &add(nodes_range Deps);
165167

166168
/// Create a dynamic command-group node in the graph.
167169
/// @param DynCGImpl Dynamic command-group used to create node.
168170
/// @param Deps List of predecessor nodes.
169171
/// @return Created node in the graph.
170-
std::shared_ptr<node_impl>
171-
add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl, nodes_range Deps);
172+
node_impl &add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
173+
nodes_range Deps);
172174

173175
/// Add a queue to the set of queues which are currently recording to this
174176
/// graph.
@@ -190,10 +192,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
190192
/// @param EventImpl Event to associate with a node in map.
191193
/// @param NodeImpl Node to associate with event in map.
192194
void addEventForNode(std::shared_ptr<sycl::detail::event_impl> EventImpl,
193-
const std::shared_ptr<node_impl> &NodeImpl) {
195+
node_impl &NodeImpl) {
194196
if (!(EventImpl->hasCommandGraph()))
195197
EventImpl->setCommandGraph(shared_from_this());
196-
MEventsMap[EventImpl] = NodeImpl;
198+
MEventsMap[EventImpl] = NodeImpl.shared_from_this();
197199
}
198200

199201
/// Find the sycl event associated with a node.
@@ -281,15 +283,16 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
281283
sycl::device getDevice() const { return MDevice; }
282284

283285
/// List of root nodes.
284-
std::set<std::weak_ptr<node_impl>, std::owner_less<std::weak_ptr<node_impl>>>
285-
MRoots;
286+
std::set<node_impl *> MRoots;
286287

287288
/// Storage for all nodes contained within a graph. Nodes are connected to
288289
/// each other via weak_ptrs and so do not extend each other's lifetimes.
289290
/// This storage allows easy iteration over all nodes in the graph, rather
290291
/// than needing an expensive depth first search.
291292
std::vector<std::shared_ptr<node_impl>> MNodeStorage;
292293

294+
nodes_range roots() const { return MRoots; }
295+
293296
/// Find the last node added to this graph from an in-order queue.
294297
/// @param Queue In-order queue to find the last node added to the graph from.
295298
/// @return Last node in this graph added from \p Queue recording, or empty
@@ -312,8 +315,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
312315
std::fstream Stream(FilePath, std::ios::out);
313316
Stream << "digraph dot {" << std::endl;
314317

315-
for (std::weak_ptr<node_impl> Node : MRoots)
316-
Node.lock()->printDotRecursive(Stream, VisitedNodes, Verbose);
318+
for (node_impl &Node : roots())
319+
Node.printDotRecursive(Stream, VisitedNodes, Verbose);
317320

318321
Stream << "}" << std::endl;
319322

@@ -418,13 +421,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
418421
}
419422

420423
size_t RootsFound = 0;
421-
for (std::weak_ptr<node_impl> NodeA : MRoots) {
422-
for (std::weak_ptr<node_impl> NodeB : Graph.MRoots) {
423-
auto NodeALocked = NodeA.lock();
424-
auto NodeBLocked = NodeB.lock();
425-
426-
if (NodeALocked->isSimilar(*NodeBLocked)) {
427-
if (checkNodeRecursive(*NodeALocked, *NodeBLocked)) {
424+
for (node_impl &NodeA : roots()) {
425+
for (node_impl &NodeB : Graph.roots()) {
426+
if (NodeA.isSimilar(NodeB)) {
427+
if (checkNodeRecursive(NodeA, NodeB)) {
428428
RootsFound++;
429429
break;
430430
}
@@ -510,6 +510,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
510510
}
511511

512512
private:
513+
template <typename... Ts> node_impl &createNode(Ts &&...Args) {
514+
MNodeStorage.push_back(
515+
std::make_shared<node_impl>(std::forward<Ts>(Args)...));
516+
return *MNodeStorage.back();
517+
}
518+
513519
/// Check the graph for cycles by performing a depth-first search of the
514520
/// graph. If a node is visited more than once in a given path through the
515521
/// graph, a cycle is present and the search ends immediately.
@@ -518,18 +524,18 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
518524

519525
/// Insert node into list of root nodes.
520526
/// @param Root Node to add to list of root nodes.
521-
void addRoot(const std::shared_ptr<node_impl> &Root);
527+
void addRoot(node_impl &Root);
522528

523529
/// Adds dependencies for a new node, if it has no deps it will be
524530
/// added as a root node.
525531
/// @param Node The node to add deps for
526532
/// @param Deps List of dependent nodes
527-
void addDepsToNode(const std::shared_ptr<node_impl> &Node, nodes_range Deps) {
533+
void addDepsToNode(node_impl &Node, nodes_range Deps) {
528534
for (node_impl &N : Deps) {
529535
N.registerSuccessor(Node);
530536
this->removeRoot(Node);
531537
}
532-
if (Node->MPredecessors.empty()) {
538+
if (Node.MPredecessors.empty()) {
533539
this->addRoot(Node);
534540
}
535541
}
@@ -647,9 +653,7 @@ class exec_graph_impl {
647653

648654
/// Query the scheduling of node execution.
649655
/// @return List of nodes in execution order.
650-
const std::list<std::shared_ptr<node_impl>> &getSchedule() const {
651-
return MSchedule;
652-
}
656+
const std::list<node_impl *> &getSchedule() const { return MSchedule; }
653657

654658
/// Query the graph_impl.
655659
/// @return pointer to the graph_impl MGraphImpl
@@ -730,8 +734,7 @@ class exec_graph_impl {
730734
/// @param Node The node being enqueued.
731735
/// @return UR sync point created for this node in the command-buffer.
732736
ur_exp_command_buffer_sync_point_t
733-
enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
734-
std::shared_ptr<node_impl> Node);
737+
enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node);
735738

736739
/// Enqueue a node directly to the command-buffer without going through the
737740
/// scheduler.
@@ -740,11 +743,9 @@ class exec_graph_impl {
740743
/// @param CommandBuffer Command-buffer to add node to as a command.
741744
/// @param Node The node being enqueued.
742745
/// @return UR sync point created for this node in the command-buffer.
743-
ur_exp_command_buffer_sync_point_t
744-
enqueueNodeDirect(const sycl::context &Ctx,
745-
sycl::detail::device_impl &DeviceImpl,
746-
ur_exp_command_buffer_handle_t CommandBuffer,
747-
std::shared_ptr<node_impl> Node);
746+
ur_exp_command_buffer_sync_point_t enqueueNodeDirect(
747+
const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl,
748+
ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node);
748749

749750
/// Enqueues a host-task partition (i.e. a partition that contains only a
750751
/// single node and that node is a host-task).
@@ -873,7 +874,7 @@ class exec_graph_impl {
873874
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const;
874875

875876
/// Execution schedule of nodes in the graph.
876-
std::list<std::shared_ptr<node_impl>> MSchedule;
877+
std::list<node_impl *> MSchedule;
877878
/// Pointer to the modifiable graph impl associated with this executable
878879
/// graph.
879880
/// Thread-safe implementation note: in the current implementation

sycl/source/detail/graph/node_impl.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,11 @@ std::vector<node> createNodesFromImpls(
3131
return Nodes;
3232
}
3333

34-
/// Takes a vector of shared_ptrs to node_impls and returns a vector of node
35-
/// objects created from those impls, in the same order.
36-
std::vector<node> createNodesFromImpls(
37-
const std::vector<std::shared_ptr<detail::node_impl>> &Impls) {
34+
std::vector<node> createNodesFromImpls(nodes_range Impls) {
3835
std::vector<node> Nodes{};
3936
Nodes.reserve(Impls.size());
4037

41-
for (std::shared_ptr<detail::node_impl> Impl : Impls) {
38+
for (detail::node_impl &Impl : Impls) {
4239
Nodes.push_back(sycl::detail::createSyclObjFromImpl<node>(Impl));
4340
}
4441

sycl/source/detail/graph/node_impl.hpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
#include <sycl/detail/cg_types.hpp> // for CGType
1515
#include <sycl/detail/kernel_desc.hpp> // for kernel_param_kind_t
1616

17+
#include <sycl/ext/oneapi/experimental/graph/node.hpp> // for node
18+
1719
#include <cstring>
1820
#include <fstream>
1921
#include <iomanip>
22+
#include <list>
2023
#include <set>
2124
#include <vector>
2225

@@ -25,8 +28,6 @@ inline namespace _V1 {
2528
namespace ext {
2629
namespace oneapi {
2730
namespace experimental {
28-
// Forward declarations
29-
class node;
3031

3132
namespace detail {
3233
// Forward declarations
@@ -39,10 +40,7 @@ class exec_graph_impl;
3940
std::vector<node>
4041
createNodesFromImpls(const std::vector<std::weak_ptr<node_impl>> &Impls);
4142

42-
/// Takes a vector of shared_ptrs to node_impls and returns a vector of node
43-
/// objects created from those impls, in the same order.
44-
std::vector<node>
45-
createNodesFromImpls(const std::vector<std::shared_ptr<node_impl>> &Impls);
43+
std::vector<node> createNodesFromImpls(nodes_range Impls);
4644

4745
inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) {
4846
using sycl::detail::CG;
@@ -123,27 +121,27 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
123121

124122
/// Add successor to the node.
125123
/// @param Node Node to add as a successor.
126-
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
124+
void registerSuccessor(node_impl &Node) {
127125
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
128-
[Node](const std::weak_ptr<node_impl> &Ptr) {
129-
return Ptr.lock() == Node;
126+
[&Node](const std::weak_ptr<node_impl> &Ptr) {
127+
return Ptr.lock().get() == &Node;
130128
}) != MSuccessors.end()) {
131129
return;
132130
}
133-
MSuccessors.push_back(Node);
134-
Node->registerPredecessor(shared_from_this());
131+
MSuccessors.push_back(Node.weak_from_this());
132+
Node.registerPredecessor(*this);
135133
}
136134

137135
/// Add predecessor to the node.
138136
/// @param Node Node to add as a predecessor.
139-
void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
137+
void registerPredecessor(node_impl &Node) {
140138
if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
141139
[&Node](const std::weak_ptr<node_impl> &Ptr) {
142-
return Ptr.lock() == Node;
140+
return Ptr.lock().get() == &Node;
143141
}) != MPredecessors.end()) {
144142
return;
145143
}
146-
MPredecessors.push_back(Node);
144+
MPredecessors.push_back(Node.weak_from_this());
147145
}
148146

149147
/// Construct an empty node.
@@ -774,7 +772,9 @@ class nodes_range {
774772
// from `weak_ptr`s this alternative should be removed too.
775773
std::vector<std::weak_ptr<node_impl>>,
776774
//
777-
std::set<std::shared_ptr<node_impl>>>;
775+
std::set<std::shared_ptr<node_impl>>, std::set<node_impl *>,
776+
//
777+
std::list<node_impl *>, std::vector<node>>;
778778

779779
storage_iter Begin;
780780
storage_iter End;
@@ -783,10 +783,8 @@ class nodes_range {
783783
public:
784784
nodes_range(const nodes_range &Other) = default;
785785

786-
template <
787-
typename ContainerTy,
788-
typename = std::enable_if_t<!std::is_same_v<nodes_range, ContainerTy>>>
789-
nodes_range(ContainerTy &Container)
786+
template <typename ContainerTy>
787+
nodes_range(const ContainerTy &Container)
790788
: Begin{Container.begin()}, End{Container.end()}, Size{Container.size()} {
791789
}
792790

@@ -812,12 +810,14 @@ class nodes_range {
812810
return std::visit(
813811
[](auto &&It) -> node_impl & {
814812
auto &Elem = *It;
815-
if constexpr (std::is_same_v<std::decay_t<decltype(Elem)>,
816-
std::weak_ptr<node_impl>>) {
813+
using Ty = std::decay_t<decltype(Elem)>;
814+
if constexpr (std::is_same_v<Ty, std::weak_ptr<node_impl>>) {
817815
// This assumes that weak_ptr doesn't actually manage lifetime and
818816
// the object is guaranteed to be alive (which seems to be the
819817
// assumption across all graph code).
820818
return *Elem.lock();
819+
} else if constexpr (std::is_same_v<Ty, node>) {
820+
return *getSyclObjImpl(Elem);
821821
} else {
822822
return *Elem;
823823
}

0 commit comments

Comments
 (0)