Skip to content

Commit 25f9b07

Browse files
[NFCI][SYCL][Graph] Refactor graph_impl::add
Part of the refactoring to eliminate `std::weak_ptr<node_impl>` and reduce usage of `std::shared_ptr<node_impl>` by preferring raw ptr/ref. Previous PRs in the series: intel#19295 intel#19332 intel#19334 intel#19350 * Accept `Deps` as `nodes_range` in `graph_impl::add` * Return `node_impl &` from `graph_impl::add` * Add `node` support in `nodes_range` and use that together with modified `graph_impl::add` when created new `node_impl`s based on `std::vector<node> Deps` to avoid creation of temporary `DepImpls` storage. * Also updated `registerSuccessor/registerPredecessor` and `addEventForNode/addDepsToNode` to accept raw `node_impl &` as the changes above resulted in having raw reference at the call sites.
1 parent 99e9fa1 commit 25f9b07

File tree

4 files changed

+78
-89
lines changed

4 files changed

+78
-89
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 35 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,9 @@ void graph_impl::markCGMemObjs(
409409
}
410410
}
411411

412-
std::shared_ptr<node_impl> graph_impl::add(nodes_range Deps) {
413-
const std::shared_ptr<node_impl> &NodeImpl = std::make_shared<node_impl>();
414-
415-
MNodeStorage.push_back(NodeImpl);
412+
node_impl &graph_impl::add(nodes_range Deps) {
413+
MNodeStorage.push_back(std::make_shared<node_impl>());
414+
node_impl &NodeImpl = *MNodeStorage.back();
416415

417416
addDepsToNode(NodeImpl, Deps);
418417
// Add an event associated with this explicit node for mixed usage
@@ -421,10 +420,9 @@ std::shared_ptr<node_impl> graph_impl::add(nodes_range Deps) {
421420
return NodeImpl;
422421
}
423422

424-
std::shared_ptr<node_impl>
425-
graph_impl::add(std::function<void(handler &)> CGF,
426-
const std::vector<sycl::detail::ArgDesc> &Args,
427-
std::vector<std::shared_ptr<node_impl>> &Deps) {
423+
node_impl &graph_impl::add(std::function<void(handler &)> CGF,
424+
const std::vector<sycl::detail::ArgDesc> &Args,
425+
nodes_range Deps) {
428426
(void)Args;
429427
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
430428
detail::handler_impl HandlerImpl{*this};
@@ -435,7 +433,9 @@ graph_impl::add(std::function<void(handler &)> CGF,
435433

436434
// Pass the node deps to the handler so they are available when processing the
437435
// CGF, need for async_malloc nodes.
438-
Handler.impl->MNodeDeps = Deps;
436+
Handler.impl->MNodeDeps.clear(); // TODO: Is that right?
437+
for (node_impl &N : Deps)
438+
Handler.impl->MNodeDeps.push_back(N.shared_from_this());
439439

440440
#if XPTI_ENABLE_INSTRUMENTATION
441441
// Save code location if one was set in TLS.
@@ -471,7 +471,7 @@ graph_impl::add(std::function<void(handler &)> CGF,
471471
: ext::oneapi::experimental::detail::getNodeTypeFromCG(
472472
Handler.getType());
473473

474-
auto NodeImpl =
474+
node_impl &NodeImpl =
475475
this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Deps);
476476

477477
// Add an event associated with this explicit node for mixed usage
@@ -489,26 +489,25 @@ graph_impl::add(std::function<void(handler &)> CGF,
489489
}
490490

491491
for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
492-
DynamicParam->registerNode(NodeImpl, ArgIndex);
492+
DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex);
493493
}
494494

495495
return NodeImpl;
496496
}
497497

498-
std::shared_ptr<node_impl>
499-
graph_impl::add(node_type NodeType,
500-
std::shared_ptr<sycl::detail::CG> CommandGroup,
501-
nodes_range Deps) {
498+
node_impl &graph_impl::add(node_type NodeType,
499+
std::shared_ptr<sycl::detail::CG> CommandGroup,
500+
nodes_range Deps) {
502501

503502
// A unique set of dependencies obtained by checking requirements and events
504503
std::set<std::shared_ptr<node_impl>> UniqueDeps = getCGEdges(CommandGroup);
505504

506505
// Track and mark the memory objects being used by the graph.
507506
markCGMemObjs(CommandGroup);
508507

509-
const std::shared_ptr<node_impl> &NodeImpl =
510-
std::make_shared<node_impl>(NodeType, std::move(CommandGroup));
511-
MNodeStorage.push_back(NodeImpl);
508+
MNodeStorage.push_back(
509+
std::make_shared<node_impl>(NodeType, std::move(CommandGroup)));
510+
node_impl &NodeImpl = *MNodeStorage.back();
512511

513512
// Add any deps determined from requirements and events into the dependency
514513
// list
@@ -517,16 +516,18 @@ graph_impl::add(node_type NodeType,
517516

518517
if (NodeType == node_type::async_free) {
519518
auto AsyncFreeCG =
520-
static_cast<CGAsyncFree *>(NodeImpl->MCommandGroup.get());
519+
static_cast<CGAsyncFree *>(NodeImpl.MCommandGroup.get());
521520
// If this is an async free node mark that it is now available for reuse,
522521
// and pass the async free node for tracking.
523-
MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), NodeImpl);
522+
MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(),
523+
// TODO: use raw
524+
NodeImpl.shared_from_this());
524525
}
525526

526527
return NodeImpl;
527528
}
528529

529-
std::shared_ptr<node_impl>
530+
node_impl&
530531
graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
531532
nodes_range Deps) {
532533
// Set of Dependent nodes based on CG event and accessor dependencies.
@@ -551,15 +552,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
551552
const auto &ActiveKernel = DynCGImpl->getActiveCG();
552553
node_type NodeType =
553554
ext::oneapi::experimental::detail::getNodeTypeFromCG(DynCGImpl->MCGType);
554-
std::shared_ptr<detail::node_impl> NodeImpl =
555-
add(NodeType, ActiveKernel, Deps);
555+
detail::node_impl &NodeImpl = add(NodeType, ActiveKernel, Deps);
556556

557557
// Add an event associated with this explicit node for mixed usage
558558
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
559559
NodeImpl);
560560

561561
// Track the dynamic command-group used inside the node object
562-
DynCGImpl->MNodes.push_back(NodeImpl);
562+
DynCGImpl->MNodes.push_back(NodeImpl.shared_from_this());
563563

564564
return NodeImpl;
565565
}
@@ -652,7 +652,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
652652
bool DestWasGraphRoot = Dest->MPredecessors.size() == 0;
653653

654654
// We need to add the edges first before checking for cycles
655-
Src->registerSuccessor(Dest);
655+
Src->registerSuccessor(*Dest);
656656

657657
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
658658
if (DestLostRootStatus) {
@@ -1265,7 +1265,7 @@ void exec_graph_impl::duplicateNodes() {
12651265
// Look through all the original node successors, find their copies and
12661266
// register those as successors with the current copied node
12671267
for (node_impl &NextNode : OriginalNode->successors()) {
1268-
auto Successor = NodesMap.at(NextNode.shared_from_this());
1268+
node_impl &Successor = *NodesMap.at(NextNode.shared_from_this());
12691269
NodeCopy->registerSuccessor(Successor);
12701270
}
12711271
}
@@ -1307,7 +1307,7 @@ void exec_graph_impl::duplicateNodes() {
13071307
auto NodeCopy = NewSubgraphNodes[i];
13081308

13091309
for (node_impl &NextNode : SubgraphNode->successors()) {
1310-
auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this());
1310+
node_impl &Successor = *SubgraphNodesMap.at(NextNode.shared_from_this());
13111311
NodeCopy->registerSuccessor(Successor);
13121312
}
13131313
}
@@ -1341,7 +1341,7 @@ void exec_graph_impl::duplicateNodes() {
13411341
// Add all input nodes from the subgraph as successors for this node
13421342
// instead
13431343
for (auto &Input : Inputs) {
1344-
PredNode.registerSuccessor(Input);
1344+
PredNode.registerSuccessor(*Input);
13451345
}
13461346
}
13471347

@@ -1360,7 +1360,7 @@ void exec_graph_impl::duplicateNodes() {
13601360
// Add all Output nodes from the subgraph as predecessors for this node
13611361
// instead
13621362
for (auto &Output : Outputs) {
1363-
Output->registerSuccessor(SuccNode.shared_from_this());
1363+
Output->registerSuccessor(SuccNode);
13641364
}
13651365
}
13661366

@@ -1843,38 +1843,25 @@ node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF,
18431843
"dynamic command-group.");
18441844
}
18451845

1846-
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1847-
for (auto &D : Deps) {
1848-
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
1849-
}
1850-
18511846
graph_impl::WriteLock Lock(impl->MMutex);
1852-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DynCGFImpl, DepImpls);
1853-
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
1847+
detail::node_impl &NodeImpl = impl->add(DynCGFImpl, Deps);
1848+
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
18541849
}
18551850

18561851
node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
18571852
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
1858-
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1859-
for (auto &D : Deps) {
1860-
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
1861-
}
18621853

18631854
graph_impl::WriteLock Lock(impl->MMutex);
1864-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
1865-
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
1855+
detail::node_impl &NodeImpl = impl->add(Deps);
1856+
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
18661857
}
18671858

18681859
node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
18691860
const std::vector<node> &Deps) {
18701861
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
1871-
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1872-
for (auto &D : Deps) {
1873-
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
1874-
}
18751862

1876-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);
1877-
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
1863+
detail::node_impl &NodeImpl = impl->add(CGF, {}, Deps);
1864+
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
18781865
}
18791866

18801867
void modifiable_command_graph::addGraphLeafDependencies(node Node) {

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
147147
/// @param CommandGroup The CG which stores all information for this node.
148148
/// @param Deps Dependencies of the created node.
149149
/// @return Created node in the graph.
150-
std::shared_ptr<node_impl> add(node_type NodeType,
151-
std::shared_ptr<sycl::detail::CG> CommandGroup,
152-
nodes_range Deps);
150+
node_impl &add(node_type NodeType,
151+
std::shared_ptr<sycl::detail::CG> CommandGroup,
152+
nodes_range Deps);
153153

154154
/// Create a CGF node in the graph.
155155
/// @param CGF Command-group function to create node with.
156156
/// @param Args Node arguments.
157157
/// @param Deps Dependencies of the created node.
158158
/// @return Created node in the graph.
159-
std::shared_ptr<node_impl> add(std::function<void(handler &)> CGF,
160-
const std::vector<sycl::detail::ArgDesc> &Args,
161-
std::vector<std::shared_ptr<node_impl>> &Deps);
159+
node_impl &add(std::function<void(handler &)> CGF,
160+
const std::vector<sycl::detail::ArgDesc> &Args,
161+
nodes_range Deps);
162162

163163
/// Create an empty node in the graph.
164164
/// @param Deps List of predecessor nodes.
165165
/// @return Created node in the graph.
166-
std::shared_ptr<node_impl> add(nodes_range Deps);
166+
node_impl &add(nodes_range Deps);
167167

168168
/// Create a dynamic command-group node in the graph.
169169
/// @param DynCGImpl Dynamic command-group used to create node.
170170
/// @param Deps List of predecessor nodes.
171171
/// @return Created node in the graph.
172-
std::shared_ptr<node_impl>
173-
add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl, nodes_range Deps);
172+
node_impl &add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
173+
nodes_range Deps);
174174

175175
/// Add a queue to the set of queues which are currently recording to this
176176
/// graph.
@@ -192,10 +192,10 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
192192
/// @param EventImpl Event to associate with a node in map.
193193
/// @param NodeImpl Node to associate with event in map.
194194
void addEventForNode(std::shared_ptr<sycl::detail::event_impl> EventImpl,
195-
const std::shared_ptr<node_impl> &NodeImpl) {
195+
node_impl &NodeImpl) {
196196
if (!(EventImpl->hasCommandGraph()))
197197
EventImpl->setCommandGraph(shared_from_this());
198-
MEventsMap[EventImpl] = NodeImpl;
198+
MEventsMap[EventImpl] = NodeImpl.shared_from_this();
199199
}
200200

201201
/// Find the sycl event associated with a node.
@@ -289,6 +289,8 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
289289
/// each other via weak_ptrs and so do not extend each other's lifetimes.
290290
/// This storage allows easy iteration over all nodes in the graph, rather
291291
/// than needing an expensive depth first search.
292+
//
293+
// FIXME: consider `create` method that would push_back/back atomically.
292294
std::vector<std::shared_ptr<node_impl>> MNodeStorage;
293295

294296
nodes_range roots() const { return MRoots; }
@@ -524,13 +526,13 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
524526
/// added as a root node.
525527
/// @param Node The node to add deps for
526528
/// @param Deps List of dependent nodes
527-
void addDepsToNode(const std::shared_ptr<node_impl> &Node, nodes_range Deps) {
529+
void addDepsToNode(node_impl &Node, nodes_range Deps) {
528530
for (node_impl &N : Deps) {
529531
N.registerSuccessor(Node);
530-
this->removeRoot(*Node);
532+
this->removeRoot(Node);
531533
}
532-
if (Node->MPredecessors.empty()) {
533-
this->addRoot(*Node);
534+
if (Node.MPredecessors.empty()) {
535+
this->addRoot(Node);
534536
}
535537
}
536538

sycl/source/detail/graph/node_impl.hpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <sycl/detail/cg_types.hpp> // for CGType
1515
#include <sycl/detail/kernel_desc.hpp> // for kernel_param_kind_t
1616

17+
#include <sycl/ext/oneapi/experimental/graph/node.hpp> // for node
18+
1719
#include <cstring>
1820
#include <fstream>
1921
#include <iomanip>
@@ -26,8 +28,6 @@ inline namespace _V1 {
2628
namespace ext {
2729
namespace oneapi {
2830
namespace experimental {
29-
// Forward declarations
30-
class node;
3131

3232
namespace detail {
3333
// Forward declarations
@@ -121,27 +121,27 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
121121

122122
/// Add successor to the node.
123123
/// @param Node Node to add as a successor.
124-
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
124+
void registerSuccessor(node_impl &Node) {
125125
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
126-
[Node](const std::weak_ptr<node_impl> &Ptr) {
127-
return Ptr.lock() == Node;
126+
[&Node](const std::weak_ptr<node_impl> &Ptr) {
127+
return Ptr.lock().get() == &Node;
128128
}) != MSuccessors.end()) {
129129
return;
130130
}
131-
MSuccessors.push_back(Node);
132-
Node->registerPredecessor(shared_from_this());
131+
MSuccessors.push_back(Node.weak_from_this());
132+
Node.registerPredecessor(*this);
133133
}
134134

135135
/// Add predecessor to the node.
136136
/// @param Node Node to add as a predecessor.
137-
void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
137+
void registerPredecessor(node_impl &Node) {
138138
if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
139139
[&Node](const std::weak_ptr<node_impl> &Ptr) {
140-
return Ptr.lock() == Node;
140+
return Ptr.lock().get() == &Node;
141141
}) != MPredecessors.end()) {
142142
return;
143143
}
144-
MPredecessors.push_back(Node);
144+
MPredecessors.push_back(Node.weak_from_this());
145145
}
146146

147147
/// Construct an empty node.
@@ -774,7 +774,7 @@ class nodes_range {
774774
//
775775
std::set<std::shared_ptr<node_impl>>, std::set<node_impl *>,
776776
//
777-
std::list<node_impl *>>;
777+
std::list<node_impl *>, std::vector<node>>;
778778

779779
storage_iter Begin;
780780
storage_iter End;
@@ -783,10 +783,8 @@ class nodes_range {
783783
public:
784784
nodes_range(const nodes_range &Other) = default;
785785

786-
template <
787-
typename ContainerTy,
788-
typename = std::enable_if_t<!std::is_same_v<nodes_range, ContainerTy>>>
789-
nodes_range(ContainerTy &Container)
786+
template <typename ContainerTy>
787+
nodes_range(const ContainerTy &Container)
790788
: Begin{Container.begin()}, End{Container.end()}, Size{Container.size()} {
791789
}
792790

@@ -812,12 +810,14 @@ class nodes_range {
812810
return std::visit(
813811
[](auto &&It) -> node_impl & {
814812
auto &Elem = *It;
815-
if constexpr (std::is_same_v<std::decay_t<decltype(Elem)>,
816-
std::weak_ptr<node_impl>>) {
813+
using Ty = std::decay_t<decltype(Elem)>;
814+
if constexpr (std::is_same_v<Ty, std::weak_ptr<node_impl>>) {
817815
// This assumes that weak_ptr doesn't actually manage lifetime and
818816
// the object is guaranteed to be alive (which seems to be the
819817
// assumption across all graph code).
820818
return *Elem.lock();
819+
} else if constexpr (std::is_same_v<Ty, node>) {
820+
return *getSyclObjImpl(Elem);
821821
} else {
822822
return *Elem;
823823
}

0 commit comments

Comments
 (0)