diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 3acafa17d20f..75537d9c2f7f 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -341,7 +341,7 @@ void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); } void graph_impl::removeRoot(node_impl &Root) { MRoots.erase(&Root); } -std::set> graph_impl::getCGEdges( +std::set graph_impl::getCGEdges( const std::shared_ptr &CommandGroup) const { const auto &Requirements = CommandGroup->getRequirements(); if (!MAllowBuffers && Requirements.size()) { @@ -362,14 +362,14 @@ std::set> graph_impl::getCGEdges( } // Add any nodes specified by event dependencies into the dependency list - std::set> UniqueDeps; + std::set UniqueDeps; for (auto &Dep : CommandGroup->getEvents()) { if (auto NodeImpl = MEventsMap.find(Dep); NodeImpl == MEventsMap.end()) { throw sycl::exception(sycl::make_error_code(errc::invalid), "Event dependency from handler::depends_on does " "not correspond to a node within the graph"); } else { - UniqueDeps.insert(NodeImpl->second); + UniqueDeps.insert(NodeImpl->second.get()); } } @@ -388,7 +388,7 @@ std::set> graph_impl::getCGEdges( } } if (ShouldAddDep) { - UniqueDeps.insert(Node); + UniqueDeps.insert(Node.get()); } } } @@ -501,7 +501,7 @@ graph_impl::add(node_type NodeType, nodes_range Deps) { // A unique set of dependencies obtained by checking requirements and events - std::set> UniqueDeps = getCGEdges(CommandGroup); + std::set UniqueDeps = getCGEdges(CommandGroup); // Track and mark the memory objects being used by the graph. markCGMemObjs(CommandGroup); @@ -530,8 +530,7 @@ std::shared_ptr graph_impl::add(std::shared_ptr &DynCGImpl, nodes_range Deps) { // Set of Dependent nodes based on CG event and accessor dependencies. - std::set> DynCGDeps = - getCGEdges(DynCGImpl->MCommandGroups[0]); + std::set DynCGDeps = getCGEdges(DynCGImpl->MCommandGroups[0]); for (unsigned i = 1; i < DynCGImpl->getNumCGs(); i++) { auto &CG = DynCGImpl->MCommandGroups[i]; auto CGEdges = getCGEdges(CG); @@ -1559,7 +1558,7 @@ bool exec_graph_impl::needsScheduledUpdate( } void exec_graph_impl::populateURKernelUpdateStructs( - const std::shared_ptr &Node, FastKernelCacheValPtr &BundleObjs, + node_impl &Node, FastKernelCacheValPtr &BundleObjs, std::vector &MemobjDescs, std::vector &MemobjProps, std::vector &PtrDescs, @@ -1574,7 +1573,7 @@ void exec_graph_impl::populateURKernelUpdateStructs( // Gather arg information from Node auto &ExecCG = - *(static_cast(Node->MCommandGroup.get())); + *(static_cast(Node.MCommandGroup.get())); // Copy args because we may modify them std::vector NodeArgs = ExecCG.getArguments(); // Copy NDR desc since we need to modify it @@ -1713,7 +1712,7 @@ void exec_graph_impl::populateURKernelUpdateStructs( // TODO: Handle subgraphs or any other cases where multiple nodes may be // associated with a single key, once those node types are supported for // update. - auto ExecNode = MIDCache.find(Node->MID); + auto ExecNode = MIDCache.find(Node.MID); assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); auto Command = MCommandMap.find(ExecNode->second.get()); @@ -1725,30 +1724,29 @@ void exec_graph_impl::populateURKernelUpdateStructs( ExecNode->second->updateFromOtherNode(Node); } -std::map>> -exec_graph_impl::getURUpdatableNodes( - const std::vector> &Nodes) const { +std::map> +exec_graph_impl::getURUpdatableNodes(nodes_range Nodes) const { // Iterate over the list of nodes, and for every node that can // be updated through UR, add it to the list of nodes for // that can be updated for the UR command-buffer partition. - std::map>> PartitionedNodes; + std::map> PartitionedNodes; // Initialize vector for each partition for (size_t i = 0; i < MPartitions.size(); i++) { PartitionedNodes[i] = {}; } - for (auto &Node : Nodes) { + for (node_impl &Node : Nodes) { // Kernel node update is the only command type supported in UR for update. - if (Node->MCGType != sycl::detail::CGType::Kernel) { + if (Node.MCGType != sycl::detail::CGType::Kernel) { continue; } - auto ExecNode = MIDCache.find(Node->MID); + auto ExecNode = MIDCache.find(Node.MID); assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); auto PartitionIndex = MPartitionNodes.find(ExecNode->second.get()); assert(PartitionIndex != MPartitionNodes.end()); - PartitionedNodes[PartitionIndex->second].push_back(Node); + PartitionedNodes[PartitionIndex->second].push_back(&Node); } return PartitionedNodes; @@ -1765,13 +1763,12 @@ void exec_graph_impl::updateHostTasksImpl( auto ExecNode = MIDCache.find(Node->MID); assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); - ExecNode->second->updateFromOtherNode(Node); + ExecNode->second->updateFromOtherNode(*Node); } } -void exec_graph_impl::updateURImpl( - ur_exp_command_buffer_handle_t CommandBuffer, - const std::vector> &Nodes) const { +void exec_graph_impl::updateURImpl(ur_exp_command_buffer_handle_t CommandBuffer, + nodes_range Nodes) const { const size_t NumUpdatableNodes = Nodes.size(); if (NumUpdatableNodes == 0) { return; @@ -1797,10 +1794,10 @@ void exec_graph_impl::updateURImpl( std::vector KernelBundleObjList(NumUpdatableNodes); size_t StructListIndex = 0; - for (auto &Node : Nodes) { + for (node_impl &Node : Nodes) { // This should be the case when getURUpdatableNodes() is used to // create the list of nodes. - assert(Node->MCGType == sycl::detail::CGType::Kernel); + assert(Node.MCGType == sycl::detail::CGType::Kernel); auto &MemobjDescs = MemobjDescsList[StructListIndex]; auto &MemobjProps = MemobjPropsList[StructListIndex]; diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 08255c563962..fd7498500d84 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -134,7 +134,7 @@ class graph_impl : public std::enable_shared_from_this { /// dependent nodes if so. /// @param CommandGroup The command group to verify and retrieve edges for. /// @return Set of dependent nodes in the graph. - std::set> + std::set getCGEdges(const std::shared_ptr &CommandGroup) const; /// Identifies the sycl buffers used in the command-group and marks them @@ -692,7 +692,7 @@ class exec_graph_impl { /// through UR should be included in this list, currently this is only /// nodes of kernel type. void updateURImpl(ur_exp_command_buffer_handle_t CommandBuffer, - const std::vector> &Nodes) const; + nodes_range Nodes) const; /// Update host-task nodes /// @param Nodes List of nodes to update, any node that is not a host-task @@ -708,8 +708,8 @@ class exec_graph_impl { /// /// @param Nodes List of nodes to split /// @return Map of partition indexes to nodes - std::map>> getURUpdatableNodes( - const std::vector> &Nodes) const; + std::map> + getURUpdatableNodes(nodes_range Nodes) const; unsigned long long getID() const { return MID; } @@ -859,7 +859,7 @@ class exec_graph_impl { /// @param[out] NDRDesc ND-Range to update. /// @param[out] UpdateDesc Base struct in the pointer chain. void populateURKernelUpdateStructs( - const std::shared_ptr &Node, FastKernelCacheValPtr &BundleObjs, + node_impl &Node, FastKernelCacheValPtr &BundleObjs, std::vector &MemobjDescs, std::vector &MemobjProps, std::vector &PtrDescs, diff --git a/sycl/source/detail/graph/node_impl.hpp b/sycl/source/detail/graph/node_impl.hpp index 578b62e0ab61..52acd16ab699 100644 --- a/sycl/source/detail/graph/node_impl.hpp +++ b/sycl/source/detail/graph/node_impl.hpp @@ -460,9 +460,9 @@ class node_impl : public std::enable_shared_from_this { } /// Update this node with the command-group from another node. /// @param Other The other node to update, must be of the same node type. - void updateFromOtherNode(const std::shared_ptr &Other) { - assert(MNodeType == Other->MNodeType); - MCommandGroup = Other->getCGCopy(); + void updateFromOtherNode(node_impl &Other) { + assert(MNodeType == Other.MNodeType); + MCommandGroup = Other.getCGCopy(); } id_type getID() const { return MID; }