Skip to content

Commit 75bbf97

Browse files
[NFC][SYCL][Graph] Switch more sets/maps to raw node_impl *
Continuation of intel#19295 intel#19332 intel#19334 intel#19350 intel#19352
1 parent 19d83d5 commit 75bbf97

File tree

3 files changed

+29
-32
lines changed

3 files changed

+29
-32
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); }
341341

342342
void graph_impl::removeRoot(node_impl &Root) { MRoots.erase(&Root); }
343343

344-
std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
344+
std::set<node_impl *> graph_impl::getCGEdges(
345345
const std::shared_ptr<sycl::detail::CG> &CommandGroup) const {
346346
const auto &Requirements = CommandGroup->getRequirements();
347347
if (!MAllowBuffers && Requirements.size()) {
@@ -362,14 +362,14 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
362362
}
363363

364364
// Add any nodes specified by event dependencies into the dependency list
365-
std::set<std::shared_ptr<node_impl>> UniqueDeps;
365+
std::set<node_impl *> UniqueDeps;
366366
for (auto &Dep : CommandGroup->getEvents()) {
367367
if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl == MEventsMap.end()) {
368368
throw sycl::exception(sycl::make_error_code(errc::invalid),
369369
"Event dependency from handler::depends_on does "
370370
"not correspond to a node within the graph");
371371
} else {
372-
UniqueDeps.insert(NodeImpl->second);
372+
UniqueDeps.insert(NodeImpl->second.get());
373373
}
374374
}
375375

@@ -388,7 +388,7 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
388388
}
389389
}
390390
if (ShouldAddDep) {
391-
UniqueDeps.insert(Node);
391+
UniqueDeps.insert(Node.get());
392392
}
393393
}
394394
}
@@ -501,7 +501,7 @@ graph_impl::add(node_type NodeType,
501501
nodes_range Deps) {
502502

503503
// A unique set of dependencies obtained by checking requirements and events
504-
std::set<std::shared_ptr<node_impl>> UniqueDeps = getCGEdges(CommandGroup);
504+
std::set<node_impl *> UniqueDeps = getCGEdges(CommandGroup);
505505

506506
// Track and mark the memory objects being used by the graph.
507507
markCGMemObjs(CommandGroup);
@@ -530,8 +530,7 @@ std::shared_ptr<node_impl>
530530
graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
531531
nodes_range Deps) {
532532
// Set of Dependent nodes based on CG event and accessor dependencies.
533-
std::set<std::shared_ptr<node_impl>> DynCGDeps =
534-
getCGEdges(DynCGImpl->MCommandGroups[0]);
533+
std::set<node_impl *> DynCGDeps = getCGEdges(DynCGImpl->MCommandGroups[0]);
535534
for (unsigned i = 1; i < DynCGImpl->getNumCGs(); i++) {
536535
auto &CG = DynCGImpl->MCommandGroups[i];
537536
auto CGEdges = getCGEdges(CG);
@@ -1559,7 +1558,7 @@ bool exec_graph_impl::needsScheduledUpdate(
15591558
}
15601559

15611560
void exec_graph_impl::populateURKernelUpdateStructs(
1562-
const std::shared_ptr<node_impl> &Node, FastKernelCacheValPtr &BundleObjs,
1561+
node_impl &Node, FastKernelCacheValPtr &BundleObjs,
15631562
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
15641563
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
15651564
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
@@ -1574,7 +1573,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
15741573

15751574
// Gather arg information from Node
15761575
auto &ExecCG =
1577-
*(static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get()));
1576+
*(static_cast<sycl::detail::CGExecKernel *>(Node.MCommandGroup.get()));
15781577
// Copy args because we may modify them
15791578
std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments();
15801579
// Copy NDR desc since we need to modify it
@@ -1713,7 +1712,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
17131712
// TODO: Handle subgraphs or any other cases where multiple nodes may be
17141713
// associated with a single key, once those node types are supported for
17151714
// update.
1716-
auto ExecNode = MIDCache.find(Node->MID);
1715+
auto ExecNode = MIDCache.find(Node.MID);
17171716
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
17181717

17191718
auto Command = MCommandMap.find(ExecNode->second.get());
@@ -1725,30 +1724,29 @@ void exec_graph_impl::populateURKernelUpdateStructs(
17251724
ExecNode->second->updateFromOtherNode(Node);
17261725
}
17271726

1728-
std::map<int, std::vector<std::shared_ptr<node_impl>>>
1729-
exec_graph_impl::getURUpdatableNodes(
1730-
const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1727+
std::map<int, std::vector<node_impl *>>
1728+
exec_graph_impl::getURUpdatableNodes(nodes_range Nodes) const {
17311729
// Iterate over the list of nodes, and for every node that can
17321730
// be updated through UR, add it to the list of nodes for
17331731
// that can be updated for the UR command-buffer partition.
1734-
std::map<int, std::vector<std::shared_ptr<node_impl>>> PartitionedNodes;
1732+
std::map<int, std::vector<node_impl *>> PartitionedNodes;
17351733

17361734
// Initialize vector for each partition
17371735
for (size_t i = 0; i < MPartitions.size(); i++) {
17381736
PartitionedNodes[i] = {};
17391737
}
17401738

1741-
for (auto &Node : Nodes) {
1739+
for (node_impl &Node : Nodes) {
17421740
// Kernel node update is the only command type supported in UR for update.
1743-
if (Node->MCGType != sycl::detail::CGType::Kernel) {
1741+
if (Node.MCGType != sycl::detail::CGType::Kernel) {
17441742
continue;
17451743
}
17461744

1747-
auto ExecNode = MIDCache.find(Node->MID);
1745+
auto ExecNode = MIDCache.find(Node.MID);
17481746
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
17491747
auto PartitionIndex = MPartitionNodes.find(ExecNode->second.get());
17501748
assert(PartitionIndex != MPartitionNodes.end());
1751-
PartitionedNodes[PartitionIndex->second].push_back(Node);
1749+
PartitionedNodes[PartitionIndex->second].push_back(&Node);
17521750
}
17531751

17541752
return PartitionedNodes;
@@ -1765,13 +1763,12 @@ void exec_graph_impl::updateHostTasksImpl(
17651763
auto ExecNode = MIDCache.find(Node->MID);
17661764
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
17671765

1768-
ExecNode->second->updateFromOtherNode(Node);
1766+
ExecNode->second->updateFromOtherNode(*Node);
17691767
}
17701768
}
17711769

1772-
void exec_graph_impl::updateURImpl(
1773-
ur_exp_command_buffer_handle_t CommandBuffer,
1774-
const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1770+
void exec_graph_impl::updateURImpl(ur_exp_command_buffer_handle_t CommandBuffer,
1771+
nodes_range Nodes) const {
17751772
const size_t NumUpdatableNodes = Nodes.size();
17761773
if (NumUpdatableNodes == 0) {
17771774
return;
@@ -1797,10 +1794,10 @@ void exec_graph_impl::updateURImpl(
17971794
std::vector<FastKernelCacheValPtr> KernelBundleObjList(NumUpdatableNodes);
17981795

17991796
size_t StructListIndex = 0;
1800-
for (auto &Node : Nodes) {
1797+
for (node_impl &Node : Nodes) {
18011798
// This should be the case when getURUpdatableNodes() is used to
18021799
// create the list of nodes.
1803-
assert(Node->MCGType == sycl::detail::CGType::Kernel);
1800+
assert(Node.MCGType == sycl::detail::CGType::Kernel);
18041801

18051802
auto &MemobjDescs = MemobjDescsList[StructListIndex];
18061803
auto &MemobjProps = MemobjPropsList[StructListIndex];

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
134134
/// dependent nodes if so.
135135
/// @param CommandGroup The command group to verify and retrieve edges for.
136136
/// @return Set of dependent nodes in the graph.
137-
std::set<std::shared_ptr<node_impl>>
137+
std::set<node_impl *>
138138
getCGEdges(const std::shared_ptr<sycl::detail::CG> &CommandGroup) const;
139139

140140
/// Identifies the sycl buffers used in the command-group and marks them
@@ -692,7 +692,7 @@ class exec_graph_impl {
692692
/// through UR should be included in this list, currently this is only
693693
/// nodes of kernel type.
694694
void updateURImpl(ur_exp_command_buffer_handle_t CommandBuffer,
695-
const std::vector<std::shared_ptr<node_impl>> &Nodes) const;
695+
nodes_range Nodes) const;
696696

697697
/// Update host-task nodes
698698
/// @param Nodes List of nodes to update, any node that is not a host-task
@@ -708,8 +708,8 @@ class exec_graph_impl {
708708
///
709709
/// @param Nodes List of nodes to split
710710
/// @return Map of partition indexes to nodes
711-
std::map<int, std::vector<std::shared_ptr<node_impl>>> getURUpdatableNodes(
712-
const std::vector<std::shared_ptr<node_impl>> &Nodes) const;
711+
std::map<int, std::vector<node_impl *>>
712+
getURUpdatableNodes(nodes_range Nodes) const;
713713

714714
unsigned long long getID() const { return MID; }
715715

@@ -859,7 +859,7 @@ class exec_graph_impl {
859859
/// @param[out] NDRDesc ND-Range to update.
860860
/// @param[out] UpdateDesc Base struct in the pointer chain.
861861
void populateURKernelUpdateStructs(
862-
const std::shared_ptr<node_impl> &Node, FastKernelCacheValPtr &BundleObjs,
862+
node_impl &Node, FastKernelCacheValPtr &BundleObjs,
863863
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
864864
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
865865
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,

sycl/source/detail/graph/node_impl.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,9 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
460460
}
461461
/// Update this node with the command-group from another node.
462462
/// @param Other The other node to update, must be of the same node type.
463-
void updateFromOtherNode(const std::shared_ptr<node_impl> &Other) {
464-
assert(MNodeType == Other->MNodeType);
465-
MCommandGroup = Other->getCGCopy();
463+
void updateFromOtherNode(node_impl &Other) {
464+
assert(MNodeType == Other.MNodeType);
465+
MCommandGroup = Other.getCGCopy();
466466
}
467467

468468
id_type getID() const { return MID; }

0 commit comments

Comments
 (0)