Skip to content

[NFC][SYCL][Graph] Switch more sets/maps to raw node_impl * #19371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 21 additions & 24 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); }

void graph_impl::removeRoot(node_impl &Root) { MRoots.erase(&Root); }

std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
std::set<node_impl *> graph_impl::getCGEdges(
const std::shared_ptr<sycl::detail::CG> &CommandGroup) const {
const auto &Requirements = CommandGroup->getRequirements();
if (!MAllowBuffers && Requirements.size()) {
Expand All @@ -362,14 +362,14 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
}

// Add any nodes specified by event dependencies into the dependency list
std::set<std::shared_ptr<node_impl>> UniqueDeps;
std::set<node_impl *> UniqueDeps;
for (auto &Dep : CommandGroup->getEvents()) {
if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl == MEventsMap.end()) {
throw sycl::exception(sycl::make_error_code(errc::invalid),
"Event dependency from handler::depends_on does "
"not correspond to a node within the graph");
} else {
UniqueDeps.insert(NodeImpl->second);
UniqueDeps.insert(NodeImpl->second.get());
}
}

Expand All @@ -388,7 +388,7 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
}
}
if (ShouldAddDep) {
UniqueDeps.insert(Node);
UniqueDeps.insert(Node.get());
}
}
}
Expand Down Expand Up @@ -501,7 +501,7 @@ graph_impl::add(node_type NodeType,
nodes_range Deps) {

// A unique set of dependencies obtained by checking requirements and events
std::set<std::shared_ptr<node_impl>> UniqueDeps = getCGEdges(CommandGroup);
std::set<node_impl *> UniqueDeps = getCGEdges(CommandGroup);

// Track and mark the memory objects being used by the graph.
markCGMemObjs(CommandGroup);
Expand Down Expand Up @@ -530,8 +530,7 @@ std::shared_ptr<node_impl>
graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
nodes_range Deps) {
// Set of Dependent nodes based on CG event and accessor dependencies.
std::set<std::shared_ptr<node_impl>> DynCGDeps =
getCGEdges(DynCGImpl->MCommandGroups[0]);
std::set<node_impl *> DynCGDeps = getCGEdges(DynCGImpl->MCommandGroups[0]);
for (unsigned i = 1; i < DynCGImpl->getNumCGs(); i++) {
auto &CG = DynCGImpl->MCommandGroups[i];
auto CGEdges = getCGEdges(CG);
Expand Down Expand Up @@ -1559,7 +1558,7 @@ bool exec_graph_impl::needsScheduledUpdate(
}

void exec_graph_impl::populateURKernelUpdateStructs(
const std::shared_ptr<node_impl> &Node, FastKernelCacheValPtr &BundleObjs,
node_impl &Node, FastKernelCacheValPtr &BundleObjs,
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
Expand All @@ -1574,7 +1573,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(

// Gather arg information from Node
auto &ExecCG =
*(static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get()));
*(static_cast<sycl::detail::CGExecKernel *>(Node.MCommandGroup.get()));
// Copy args because we may modify them
std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments();
// Copy NDR desc since we need to modify it
Expand Down Expand Up @@ -1713,7 +1712,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
// TODO: Handle subgraphs or any other cases where multiple nodes may be
// associated with a single key, once those node types are supported for
// update.
auto ExecNode = MIDCache.find(Node->MID);
auto ExecNode = MIDCache.find(Node.MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

auto Command = MCommandMap.find(ExecNode->second.get());
Expand All @@ -1725,30 +1724,29 @@ void exec_graph_impl::populateURKernelUpdateStructs(
ExecNode->second->updateFromOtherNode(Node);
}

std::map<int, std::vector<std::shared_ptr<node_impl>>>
exec_graph_impl::getURUpdatableNodes(
const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
std::map<int, std::vector<node_impl *>>
exec_graph_impl::getURUpdatableNodes(nodes_range Nodes) const {
// Iterate over the list of nodes, and for every node that can
// be updated through UR, add it to the list of nodes for
// that can be updated for the UR command-buffer partition.
std::map<int, std::vector<std::shared_ptr<node_impl>>> PartitionedNodes;
std::map<int, std::vector<node_impl *>> PartitionedNodes;

// Initialize vector for each partition
for (size_t i = 0; i < MPartitions.size(); i++) {
PartitionedNodes[i] = {};
}

for (auto &Node : Nodes) {
for (node_impl &Node : Nodes) {
// Kernel node update is the only command type supported in UR for update.
if (Node->MCGType != sycl::detail::CGType::Kernel) {
if (Node.MCGType != sycl::detail::CGType::Kernel) {
continue;
}

auto ExecNode = MIDCache.find(Node->MID);
auto ExecNode = MIDCache.find(Node.MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
auto PartitionIndex = MPartitionNodes.find(ExecNode->second.get());
assert(PartitionIndex != MPartitionNodes.end());
PartitionedNodes[PartitionIndex->second].push_back(Node);
PartitionedNodes[PartitionIndex->second].push_back(&Node);
}

return PartitionedNodes;
Expand All @@ -1765,13 +1763,12 @@ void exec_graph_impl::updateHostTasksImpl(
auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

ExecNode->second->updateFromOtherNode(Node);
ExecNode->second->updateFromOtherNode(*Node);
}
}

void exec_graph_impl::updateURImpl(
ur_exp_command_buffer_handle_t CommandBuffer,
const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
void exec_graph_impl::updateURImpl(ur_exp_command_buffer_handle_t CommandBuffer,
nodes_range Nodes) const {
const size_t NumUpdatableNodes = Nodes.size();
if (NumUpdatableNodes == 0) {
return;
Expand All @@ -1797,10 +1794,10 @@ void exec_graph_impl::updateURImpl(
std::vector<FastKernelCacheValPtr> KernelBundleObjList(NumUpdatableNodes);

size_t StructListIndex = 0;
for (auto &Node : Nodes) {
for (node_impl &Node : Nodes) {
// This should be the case when getURUpdatableNodes() is used to
// create the list of nodes.
assert(Node->MCGType == sycl::detail::CGType::Kernel);
assert(Node.MCGType == sycl::detail::CGType::Kernel);

auto &MemobjDescs = MemobjDescsList[StructListIndex];
auto &MemobjProps = MemobjPropsList[StructListIndex];
Expand Down
10 changes: 5 additions & 5 deletions sycl/source/detail/graph/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// dependent nodes if so.
/// @param CommandGroup The command group to verify and retrieve edges for.
/// @return Set of dependent nodes in the graph.
std::set<std::shared_ptr<node_impl>>
std::set<node_impl *>
getCGEdges(const std::shared_ptr<sycl::detail::CG> &CommandGroup) const;

/// Identifies the sycl buffers used in the command-group and marks them
Expand Down Expand Up @@ -692,7 +692,7 @@ class exec_graph_impl {
/// through UR should be included in this list, currently this is only
/// nodes of kernel type.
void updateURImpl(ur_exp_command_buffer_handle_t CommandBuffer,
const std::vector<std::shared_ptr<node_impl>> &Nodes) const;
nodes_range Nodes) const;

/// Update host-task nodes
/// @param Nodes List of nodes to update, any node that is not a host-task
Expand All @@ -708,8 +708,8 @@ class exec_graph_impl {
///
/// @param Nodes List of nodes to split
/// @return Map of partition indexes to nodes
std::map<int, std::vector<std::shared_ptr<node_impl>>> getURUpdatableNodes(
const std::vector<std::shared_ptr<node_impl>> &Nodes) const;
std::map<int, std::vector<node_impl *>>
getURUpdatableNodes(nodes_range Nodes) const;

unsigned long long getID() const { return MID; }

Expand Down Expand Up @@ -859,7 +859,7 @@ class exec_graph_impl {
/// @param[out] NDRDesc ND-Range to update.
/// @param[out] UpdateDesc Base struct in the pointer chain.
void populateURKernelUpdateStructs(
const std::shared_ptr<node_impl> &Node, FastKernelCacheValPtr &BundleObjs,
node_impl &Node, FastKernelCacheValPtr &BundleObjs,
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/graph/node_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,9 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
}
/// Update this node with the command-group from another node.
/// @param Other The other node to update, must be of the same node type.
void updateFromOtherNode(const std::shared_ptr<node_impl> &Other) {
assert(MNodeType == Other->MNodeType);
MCommandGroup = Other->getCGCopy();
void updateFromOtherNode(node_impl &Other) {
assert(MNodeType == Other.MNodeType);
MCommandGroup = Other.getCGCopy();
}

id_type getID() const { return MID; }
Expand Down
Loading