Skip to content

Commit 395f476

Browse files
[NFCI][SYCL][Graph] Refactor graph_impl::add
Part of the refactoring to eliminate `std::weak_ptr<node_impl>` and reduce usage of `std::shared_ptr<node_impl>` by preferring raw ptr/ref. Previous PRs in the series: intel#19295 intel#19332 intel#19334 intel#19350 * Accept `Deps` as `nodes_range` in `graph_impl::add` * Return `node_impl &` from `graph_impl::add` * Add `node` support in `nodes_range` and use that together with modified `graph_impl::add` when created new `node_impl`s based on `std::vector<node> Deps` to avoid creation of temporary `DepImpls` storage. * Also updated `registerSuccessor/registerPredecessor` and `addEventForNode/addDepsToNode` to accept raw `node_impl &` as the changes above resulted in having raw reference at the call sites.
1 parent 99e9fa1 commit 395f476

File tree

4 files changed

+79
-90
lines changed

4 files changed

+79
-90
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 32 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,8 @@ void graph_impl::markCGMemObjs(
409409
}
410410
}
411411

412-
std::shared_ptr<node_impl> graph_impl::add(nodes_range Deps) {
413-
const std::shared_ptr<node_impl> &NodeImpl = std::make_shared<node_impl>();
414-
415-
MNodeStorage.push_back(NodeImpl);
412+
node_impl &graph_impl::add(nodes_range Deps) {
413+
node_impl &NodeImpl = createNode();
416414

417415
addDepsToNode(NodeImpl, Deps);
418416
// Add an event associated with this explicit node for mixed usage
@@ -421,10 +419,9 @@ std::shared_ptr<node_impl> graph_impl::add(nodes_range Deps) {
421419
return NodeImpl;
422420
}
423421

424-
std::shared_ptr<node_impl>
425-
graph_impl::add(std::function<void(handler &)> CGF,
426-
const std::vector<sycl::detail::ArgDesc> &Args,
427-
std::vector<std::shared_ptr<node_impl>> &Deps) {
422+
node_impl &graph_impl::add(std::function<void(handler &)> CGF,
423+
const std::vector<sycl::detail::ArgDesc> &Args,
424+
nodes_range Deps) {
428425
(void)Args;
429426
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
430427
detail::handler_impl HandlerImpl{*this};
@@ -435,7 +432,9 @@ graph_impl::add(std::function<void(handler &)> CGF,
435432

436433
// Pass the node deps to the handler so they are available when processing the
437434
// CGF, need for async_malloc nodes.
438-
Handler.impl->MNodeDeps = Deps;
435+
Handler.impl->MNodeDeps.clear();
436+
for (node_impl &N : Deps)
437+
Handler.impl->MNodeDeps.push_back(N.shared_from_this());
439438

440439
#if XPTI_ENABLE_INSTRUMENTATION
441440
// Save code location if one was set in TLS.
@@ -471,7 +470,7 @@ graph_impl::add(std::function<void(handler &)> CGF,
471470
: ext::oneapi::experimental::detail::getNodeTypeFromCG(
472471
Handler.getType());
473472

474-
auto NodeImpl =
473+
node_impl &NodeImpl =
475474
this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Deps);
476475

477476
// Add an event associated with this explicit node for mixed usage
@@ -489,44 +488,41 @@ graph_impl::add(std::function<void(handler &)> CGF,
489488
}
490489

491490
for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
492-
DynamicParam->registerNode(NodeImpl, ArgIndex);
491+
DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex);
493492
}
494493

495494
return NodeImpl;
496495
}
497496

498-
std::shared_ptr<node_impl>
499-
graph_impl::add(node_type NodeType,
500-
std::shared_ptr<sycl::detail::CG> CommandGroup,
501-
nodes_range Deps) {
497+
node_impl &graph_impl::add(node_type NodeType,
498+
std::shared_ptr<sycl::detail::CG> CommandGroup,
499+
nodes_range Deps) {
502500

503501
// A unique set of dependencies obtained by checking requirements and events
504502
std::set<std::shared_ptr<node_impl>> UniqueDeps = getCGEdges(CommandGroup);
505503

506504
// Track and mark the memory objects being used by the graph.
507505
markCGMemObjs(CommandGroup);
508506

509-
const std::shared_ptr<node_impl> &NodeImpl =
510-
std::make_shared<node_impl>(NodeType, std::move(CommandGroup));
511-
MNodeStorage.push_back(NodeImpl);
507+
node_impl &NodeImpl = createNode(NodeType, std::move(CommandGroup));
512508

513509
// Add any deps determined from requirements and events into the dependency
514510
// list
515511
addDepsToNode(NodeImpl, Deps);
516512
addDepsToNode(NodeImpl, UniqueDeps);
517513

518514
if (NodeType == node_type::async_free) {
519-
auto AsyncFreeCG =
520-
static_cast<CGAsyncFree *>(NodeImpl->MCommandGroup.get());
515+
auto AsyncFreeCG = static_cast<CGAsyncFree *>(NodeImpl.MCommandGroup.get());
521516
// If this is an async free node mark that it is now available for reuse,
522517
// and pass the async free node for tracking.
523-
MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), NodeImpl);
518+
MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(),
519+
NodeImpl.shared_from_this());
524520
}
525521

526522
return NodeImpl;
527523
}
528524

529-
std::shared_ptr<node_impl>
525+
node_impl &
530526
graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
531527
nodes_range Deps) {
532528
// Set of Dependent nodes based on CG event and accessor dependencies.
@@ -551,15 +547,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
551547
const auto &ActiveKernel = DynCGImpl->getActiveCG();
552548
node_type NodeType =
553549
ext::oneapi::experimental::detail::getNodeTypeFromCG(DynCGImpl->MCGType);
554-
std::shared_ptr<detail::node_impl> NodeImpl =
555-
add(NodeType, ActiveKernel, Deps);
550+
detail::node_impl &NodeImpl = add(NodeType, ActiveKernel, Deps);
556551

557552
// Add an event associated with this explicit node for mixed usage
558553
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
559554
NodeImpl);
560555

561556
// Track the dynamic command-group used inside the node object
562-
DynCGImpl->MNodes.push_back(NodeImpl);
557+
DynCGImpl->MNodes.push_back(NodeImpl.shared_from_this());
563558

564559
return NodeImpl;
565560
}
@@ -652,7 +647,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
652647
bool DestWasGraphRoot = Dest->MPredecessors.size() == 0;
653648

654649
// We need to add the edges first before checking for cycles
655-
Src->registerSuccessor(Dest);
650+
Src->registerSuccessor(*Dest);
656651

657652
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
658653
if (DestLostRootStatus) {
@@ -1265,7 +1260,7 @@ void exec_graph_impl::duplicateNodes() {
12651260
// Look through all the original node successors, find their copies and
12661261
// register those as successors with the current copied node
12671262
for (node_impl &NextNode : OriginalNode->successors()) {
1268-
auto Successor = NodesMap.at(NextNode.shared_from_this());
1263+
node_impl &Successor = *NodesMap.at(NextNode.shared_from_this());
12691264
NodeCopy->registerSuccessor(Successor);
12701265
}
12711266
}
@@ -1307,7 +1302,8 @@ void exec_graph_impl::duplicateNodes() {
13071302
auto NodeCopy = NewSubgraphNodes[i];
13081303

13091304
for (node_impl &NextNode : SubgraphNode->successors()) {
1310-
auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this());
1305+
node_impl &Successor =
1306+
*SubgraphNodesMap.at(NextNode.shared_from_this());
13111307
NodeCopy->registerSuccessor(Successor);
13121308
}
13131309
}
@@ -1341,7 +1337,7 @@ void exec_graph_impl::duplicateNodes() {
13411337
// Add all input nodes from the subgraph as successors for this node
13421338
// instead
13431339
for (auto &Input : Inputs) {
1344-
PredNode.registerSuccessor(Input);
1340+
PredNode.registerSuccessor(*Input);
13451341
}
13461342
}
13471343

@@ -1360,7 +1356,7 @@ void exec_graph_impl::duplicateNodes() {
13601356
// Add all Output nodes from the subgraph as predecessors for this node
13611357
// instead
13621358
for (auto &Output : Outputs) {
1363-
Output->registerSuccessor(SuccNode.shared_from_this());
1359+
Output->registerSuccessor(SuccNode);
13641360
}
13651361
}
13661362

@@ -1843,38 +1839,25 @@ node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF,
18431839
"dynamic command-group.");
18441840
}
18451841

1846-
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1847-
for (auto &D : Deps) {
1848-
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
1849-
}
1850-
18511842
graph_impl::WriteLock Lock(impl->MMutex);
1852-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DynCGFImpl, DepImpls);
1853-
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
1843+
detail::node_impl &NodeImpl = impl->add(DynCGFImpl, Deps);
1844+
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
18541845
}
18551846

18561847
node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
18571848
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
1858-
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1859-
for (auto &D : Deps) {
1860-
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
1861-
}
18621849

18631850
graph_impl::WriteLock Lock(impl->MMutex);
1864-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
1865-
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
1851+
detail::node_impl &NodeImpl = impl->add(Deps);
1852+
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
18661853
}
18671854

18681855
node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
18691856
const std::vector<node> &Deps) {
18701857
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
1871-
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1872-
for (auto &D : Deps) {
1873-
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
1874-
}
18751858

1876-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);
1877-
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
1859+
detail::node_impl &NodeImpl = impl->add(CGF, {}, Deps);
1860+
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
18781861
}
18791862

18801863
void modifiable_command_graph::addGraphLeafDependencies(node Node) {

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
147147
/// @param CommandGroup The CG which stores all information for this node.
148148
/// @param Deps Dependencies of the created node.
149149
/// @return Created node in the graph.
150-
std::shared_ptr<node_impl> add(node_type NodeType,
151-
std::shared_ptr<sycl::detail::CG> CommandGroup,
152-
nodes_range Deps);
150+
node_impl &add(node_type NodeType,
151+
std::shared_ptr<sycl::detail::CG> CommandGroup,
152+
nodes_range Deps);
153153

154154
/// Create a CGF node in the graph.
155155
/// @param CGF Command-group function to create node with.
156156
/// @param Args Node arguments.
157157
/// @param Deps Dependencies of the created node.
158158
/// @return Created node in the graph.
159-
std::shared_ptr<node_impl> add(std::function<void(handler &)> CGF,
160-
const std::vector<sycl::detail::ArgDesc> &Args,
161-
std::vector<std::shared_ptr<node_impl>> &Deps);
159+
node_impl &add(std::function<void(handler &)> CGF,
160+
const std::vector<sycl::detail::ArgDesc> &Args,
161+
nodes_range Deps);
162162

163163
/// Create an empty node in the graph.
164164
/// @param Deps List of predecessor nodes.
165165
/// @return Created node in the graph.
166-
std::shared_ptr<node_impl> add(nodes_range Deps);
166+
node_impl &add(nodes_range Deps);
167167

168168
/// Create a dynamic command-group node in the graph.
169169
/// @param DynCGImpl Dynamic command-group used to create node.
170170
/// @param Deps List of predecessor nodes.
171171
/// @return Created node in the graph.
172-
std::shared_ptr<node_impl>
173-
add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl, nodes_range Deps);
172+
node_impl &add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
173+
nodes_range Deps);
174174

175175
/// Add a queue to the set of queues which are currently recording to this
176176
/// graph.
@@ -192,10 +192,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
192192
/// @param EventImpl Event to associate with a node in map.
193193
/// @param NodeImpl Node to associate with event in map.
194194
void addEventForNode(std::shared_ptr<sycl::detail::event_impl> EventImpl,
195-
const std::shared_ptr<node_impl> &NodeImpl) {
195+
node_impl &NodeImpl) {
196196
if (!(EventImpl->hasCommandGraph()))
197197
EventImpl->setCommandGraph(shared_from_this());
198-
MEventsMap[EventImpl] = NodeImpl;
198+
MEventsMap[EventImpl] = NodeImpl.shared_from_this();
199199
}
200200

201201
/// Find the sycl event associated with a node.
@@ -510,6 +510,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
510510
}
511511

512512
private:
513+
template <typename... Ts> node_impl &createNode(Ts &&...Args) {
514+
MNodeStorage.push_back(
515+
std::make_shared<node_impl>(std::forward<Ts>(Args)...));
516+
return *MNodeStorage.back();
517+
}
518+
513519
/// Check the graph for cycles by performing a depth-first search of the
514520
/// graph. If a node is visited more than once in a given path through the
515521
/// graph, a cycle is present and the search ends immediately.
@@ -524,13 +530,13 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
524530
/// added as a root node.
525531
/// @param Node The node to add deps for
526532
/// @param Deps List of dependent nodes
527-
void addDepsToNode(const std::shared_ptr<node_impl> &Node, nodes_range Deps) {
533+
void addDepsToNode(node_impl &Node, nodes_range Deps) {
528534
for (node_impl &N : Deps) {
529535
N.registerSuccessor(Node);
530-
this->removeRoot(*Node);
536+
this->removeRoot(Node);
531537
}
532-
if (Node->MPredecessors.empty()) {
533-
this->addRoot(*Node);
538+
if (Node.MPredecessors.empty()) {
539+
this->addRoot(Node);
534540
}
535541
}
536542

sycl/source/detail/graph/node_impl.hpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <sycl/detail/cg_types.hpp> // for CGType
1515
#include <sycl/detail/kernel_desc.hpp> // for kernel_param_kind_t
1616

17+
#include <sycl/ext/oneapi/experimental/graph/node.hpp> // for node
18+
1719
#include <cstring>
1820
#include <fstream>
1921
#include <iomanip>
@@ -26,8 +28,6 @@ inline namespace _V1 {
2628
namespace ext {
2729
namespace oneapi {
2830
namespace experimental {
29-
// Forward declarations
30-
class node;
3131

3232
namespace detail {
3333
// Forward declarations
@@ -121,27 +121,27 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
121121

122122
/// Add successor to the node.
123123
/// @param Node Node to add as a successor.
124-
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
124+
void registerSuccessor(node_impl &Node) {
125125
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
126-
[Node](const std::weak_ptr<node_impl> &Ptr) {
127-
return Ptr.lock() == Node;
126+
[&Node](const std::weak_ptr<node_impl> &Ptr) {
127+
return Ptr.lock().get() == &Node;
128128
}) != MSuccessors.end()) {
129129
return;
130130
}
131-
MSuccessors.push_back(Node);
132-
Node->registerPredecessor(shared_from_this());
131+
MSuccessors.push_back(Node.weak_from_this());
132+
Node.registerPredecessor(*this);
133133
}
134134

135135
/// Add predecessor to the node.
136136
/// @param Node Node to add as a predecessor.
137-
void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
137+
void registerPredecessor(node_impl &Node) {
138138
if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
139139
[&Node](const std::weak_ptr<node_impl> &Ptr) {
140-
return Ptr.lock() == Node;
140+
return Ptr.lock().get() == &Node;
141141
}) != MPredecessors.end()) {
142142
return;
143143
}
144-
MPredecessors.push_back(Node);
144+
MPredecessors.push_back(Node.weak_from_this());
145145
}
146146

147147
/// Construct an empty node.
@@ -774,7 +774,7 @@ class nodes_range {
774774
//
775775
std::set<std::shared_ptr<node_impl>>, std::set<node_impl *>,
776776
//
777-
std::list<node_impl *>>;
777+
std::list<node_impl *>, std::vector<node>>;
778778

779779
storage_iter Begin;
780780
storage_iter End;
@@ -783,10 +783,8 @@ class nodes_range {
783783
public:
784784
nodes_range(const nodes_range &Other) = default;
785785

786-
template <
787-
typename ContainerTy,
788-
typename = std::enable_if_t<!std::is_same_v<nodes_range, ContainerTy>>>
789-
nodes_range(ContainerTy &Container)
786+
template <typename ContainerTy>
787+
nodes_range(const ContainerTy &Container)
790788
: Begin{Container.begin()}, End{Container.end()}, Size{Container.size()} {
791789
}
792790

@@ -812,12 +810,14 @@ class nodes_range {
812810
return std::visit(
813811
[](auto &&It) -> node_impl & {
814812
auto &Elem = *It;
815-
if constexpr (std::is_same_v<std::decay_t<decltype(Elem)>,
816-
std::weak_ptr<node_impl>>) {
813+
using Ty = std::decay_t<decltype(Elem)>;
814+
if constexpr (std::is_same_v<Ty, std::weak_ptr<node_impl>>) {
817815
// This assumes that weak_ptr doesn't actually manage lifetime and
818816
// the object is guaranteed to be alive (which seems to be the
819817
// assumption across all graph code).
820818
return *Elem.lock();
819+
} else if constexpr (std::is_same_v<Ty, node>) {
820+
return *getSyclObjImpl(Elem);
821821
} else {
822822
return *Elem;
823823
}

0 commit comments

Comments
 (0)