diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 3acafa17d20f..62983da8a4d1 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -409,10 +409,8 @@ void graph_impl::markCGMemObjs( } } -std::shared_ptr graph_impl::add(nodes_range Deps) { - const std::shared_ptr &NodeImpl = std::make_shared(); - - MNodeStorage.push_back(NodeImpl); +node_impl &graph_impl::add(nodes_range Deps) { + node_impl &NodeImpl = createNode(); addDepsToNode(NodeImpl, Deps); // Add an event associated with this explicit node for mixed usage @@ -421,10 +419,9 @@ std::shared_ptr graph_impl::add(nodes_range Deps) { return NodeImpl; } -std::shared_ptr -graph_impl::add(std::function CGF, - const std::vector &Args, - std::vector> &Deps) { +node_impl &graph_impl::add(std::function CGF, + const std::vector &Args, + nodes_range Deps) { (void)Args; #ifdef __INTEL_PREVIEW_BREAKING_CHANGES detail::handler_impl HandlerImpl{*this}; @@ -435,7 +432,9 @@ graph_impl::add(std::function CGF, // Pass the node deps to the handler so they are available when processing the // CGF, need for async_malloc nodes. - Handler.impl->MNodeDeps = Deps; + Handler.impl->MNodeDeps.clear(); + for (node_impl &N : Deps) + Handler.impl->MNodeDeps.push_back(N.shared_from_this()); #if XPTI_ENABLE_INSTRUMENTATION // Save code location if one was set in TLS. @@ -471,7 +470,7 @@ graph_impl::add(std::function CGF, : ext::oneapi::experimental::detail::getNodeTypeFromCG( Handler.getType()); - auto NodeImpl = + node_impl &NodeImpl = this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Deps); // Add an event associated with this explicit node for mixed usage @@ -489,16 +488,15 @@ graph_impl::add(std::function CGF, } for (auto &[DynamicParam, ArgIndex] : DynamicParams) { - DynamicParam->registerNode(NodeImpl, ArgIndex); + DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex); } return NodeImpl; } -std::shared_ptr -graph_impl::add(node_type NodeType, - std::shared_ptr CommandGroup, - nodes_range Deps) { +node_impl &graph_impl::add(node_type NodeType, + std::shared_ptr CommandGroup, + nodes_range Deps) { // A unique set of dependencies obtained by checking requirements and events std::set> UniqueDeps = getCGEdges(CommandGroup); @@ -506,9 +504,7 @@ graph_impl::add(node_type NodeType, // Track and mark the memory objects being used by the graph. markCGMemObjs(CommandGroup); - const std::shared_ptr &NodeImpl = - std::make_shared(NodeType, std::move(CommandGroup)); - MNodeStorage.push_back(NodeImpl); + node_impl &NodeImpl = createNode(NodeType, std::move(CommandGroup)); // Add any deps determined from requirements and events into the dependency // list @@ -516,17 +512,16 @@ graph_impl::add(node_type NodeType, addDepsToNode(NodeImpl, UniqueDeps); if (NodeType == node_type::async_free) { - auto AsyncFreeCG = - static_cast(NodeImpl->MCommandGroup.get()); + auto AsyncFreeCG = static_cast(NodeImpl.MCommandGroup.get()); // If this is an async free node mark that it is now available for reuse, // and pass the async free node for tracking. - MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), *NodeImpl); + MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), NodeImpl); } return NodeImpl; } -std::shared_ptr +node_impl & graph_impl::add(std::shared_ptr &DynCGImpl, nodes_range Deps) { // Set of Dependent nodes based on CG event and accessor dependencies. @@ -551,15 +546,14 @@ graph_impl::add(std::shared_ptr &DynCGImpl, const auto &ActiveKernel = DynCGImpl->getActiveCG(); node_type NodeType = ext::oneapi::experimental::detail::getNodeTypeFromCG(DynCGImpl->MCGType); - std::shared_ptr NodeImpl = - add(NodeType, ActiveKernel, Deps); + detail::node_impl &NodeImpl = add(NodeType, ActiveKernel, Deps); // Add an event associated with this explicit node for mixed usage addEventForNode(sycl::detail::event_impl::create_completed_host_event(), NodeImpl); // Track the dynamic command-group used inside the node object - DynCGImpl->MNodes.push_back(NodeImpl); + DynCGImpl->MNodes.push_back(NodeImpl.shared_from_this()); return NodeImpl; } @@ -652,7 +646,7 @@ void graph_impl::makeEdge(std::shared_ptr Src, bool DestWasGraphRoot = Dest->MPredecessors.size() == 0; // We need to add the edges first before checking for cycles - Src->registerSuccessor(Dest); + Src->registerSuccessor(*Dest); bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1; if (DestLostRootStatus) { @@ -1265,7 +1259,7 @@ void exec_graph_impl::duplicateNodes() { // Look through all the original node successors, find their copies and // register those as successors with the current copied node for (node_impl &NextNode : OriginalNode->successors()) { - auto Successor = NodesMap.at(NextNode.shared_from_this()); + node_impl &Successor = *NodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1307,7 +1301,8 @@ void exec_graph_impl::duplicateNodes() { auto NodeCopy = NewSubgraphNodes[i]; for (node_impl &NextNode : SubgraphNode->successors()) { - auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this()); + node_impl &Successor = + *SubgraphNodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1341,7 +1336,7 @@ void exec_graph_impl::duplicateNodes() { // Add all input nodes from the subgraph as successors for this node // instead for (auto &Input : Inputs) { - PredNode.registerSuccessor(Input); + PredNode.registerSuccessor(*Input); } } @@ -1360,7 +1355,7 @@ void exec_graph_impl::duplicateNodes() { // Add all Output nodes from the subgraph as predecessors for this node // instead for (auto &Output : Outputs) { - Output->registerSuccessor(SuccNode.shared_from_this()); + Output->registerSuccessor(SuccNode); } } @@ -1843,38 +1838,25 @@ node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF, "dynamic command-group."); } - std::vector> DepImpls; - for (auto &D : Deps) { - DepImpls.push_back(sycl::detail::getSyclObjImpl(D)); - } - graph_impl::WriteLock Lock(impl->MMutex); - std::shared_ptr NodeImpl = impl->add(DynCGFImpl, DepImpls); - return sycl::detail::createSyclObjFromImpl(std::move(NodeImpl)); + detail::node_impl &NodeImpl = impl->add(DynCGFImpl, Deps); + return sycl::detail::createSyclObjFromImpl(NodeImpl); } node modifiable_command_graph::addImpl(const std::vector &Deps) { impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function"); - std::vector> DepImpls; - for (auto &D : Deps) { - DepImpls.push_back(sycl::detail::getSyclObjImpl(D)); - } graph_impl::WriteLock Lock(impl->MMutex); - std::shared_ptr NodeImpl = impl->add(DepImpls); - return sycl::detail::createSyclObjFromImpl(std::move(NodeImpl)); + detail::node_impl &NodeImpl = impl->add(Deps); + return sycl::detail::createSyclObjFromImpl(NodeImpl); } node modifiable_command_graph::addImpl(std::function CGF, const std::vector &Deps) { impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function"); - std::vector> DepImpls; - for (auto &D : Deps) { - DepImpls.push_back(sycl::detail::getSyclObjImpl(D)); - } - std::shared_ptr NodeImpl = impl->add(CGF, {}, DepImpls); - return sycl::detail::createSyclObjFromImpl(std::move(NodeImpl)); + detail::node_impl &NodeImpl = impl->add(CGF, {}, Deps); + return sycl::detail::createSyclObjFromImpl(NodeImpl); } void modifiable_command_graph::addGraphLeafDependencies(node Node) { diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 08255c563962..d90e1f5132b8 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -147,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this { /// @param CommandGroup The CG which stores all information for this node. /// @param Deps Dependencies of the created node. /// @return Created node in the graph. - std::shared_ptr add(node_type NodeType, - std::shared_ptr CommandGroup, - nodes_range Deps); + node_impl &add(node_type NodeType, + std::shared_ptr CommandGroup, + nodes_range Deps); /// Create a CGF node in the graph. /// @param CGF Command-group function to create node with. /// @param Args Node arguments. /// @param Deps Dependencies of the created node. /// @return Created node in the graph. - std::shared_ptr add(std::function CGF, - const std::vector &Args, - std::vector> &Deps); + node_impl &add(std::function CGF, + const std::vector &Args, + nodes_range Deps); /// Create an empty node in the graph. /// @param Deps List of predecessor nodes. /// @return Created node in the graph. - std::shared_ptr add(nodes_range Deps); + node_impl &add(nodes_range Deps); /// Create a dynamic command-group node in the graph. /// @param DynCGImpl Dynamic command-group used to create node. /// @param Deps List of predecessor nodes. /// @return Created node in the graph. - std::shared_ptr - add(std::shared_ptr &DynCGImpl, nodes_range Deps); + node_impl &add(std::shared_ptr &DynCGImpl, + nodes_range Deps); /// Add a queue to the set of queues which are currently recording to this /// graph. @@ -192,10 +192,10 @@ class graph_impl : public std::enable_shared_from_this { /// @param EventImpl Event to associate with a node in map. /// @param NodeImpl Node to associate with event in map. void addEventForNode(std::shared_ptr EventImpl, - const std::shared_ptr &NodeImpl) { + node_impl &NodeImpl) { if (!(EventImpl->hasCommandGraph())) EventImpl->setCommandGraph(shared_from_this()); - MEventsMap[EventImpl] = NodeImpl; + MEventsMap[EventImpl] = NodeImpl.shared_from_this(); } /// Find the sycl event associated with a node. @@ -510,6 +510,12 @@ class graph_impl : public std::enable_shared_from_this { } private: + template node_impl &createNode(Ts &&...Args) { + MNodeStorage.push_back( + std::make_shared(std::forward(Args)...)); + return *MNodeStorage.back(); + } + /// Check the graph for cycles by performing a depth-first search of the /// graph. If a node is visited more than once in a given path through the /// graph, a cycle is present and the search ends immediately. @@ -524,13 +530,13 @@ class graph_impl : public std::enable_shared_from_this { /// added as a root node. /// @param Node The node to add deps for /// @param Deps List of dependent nodes - void addDepsToNode(const std::shared_ptr &Node, nodes_range Deps) { + void addDepsToNode(node_impl &Node, nodes_range Deps) { for (node_impl &N : Deps) { N.registerSuccessor(Node); - this->removeRoot(*Node); + this->removeRoot(Node); } - if (Node->MPredecessors.empty()) { - this->addRoot(*Node); + if (Node.MPredecessors.empty()) { + this->addRoot(Node); } } diff --git a/sycl/source/detail/graph/node_impl.hpp b/sycl/source/detail/graph/node_impl.hpp index 578b62e0ab61..3268b1ce1adf 100644 --- a/sycl/source/detail/graph/node_impl.hpp +++ b/sycl/source/detail/graph/node_impl.hpp @@ -14,6 +14,8 @@ #include // for CGType #include // for kernel_param_kind_t +#include // for node + #include #include #include @@ -26,8 +28,6 @@ inline namespace _V1 { namespace ext { namespace oneapi { namespace experimental { -// Forward declarations -class node; namespace detail { // Forward declarations @@ -121,27 +121,27 @@ class node_impl : public std::enable_shared_from_this { /// Add successor to the node. /// @param Node Node to add as a successor. - void registerSuccessor(const std::shared_ptr &Node) { + void registerSuccessor(node_impl &Node) { if (std::find_if(MSuccessors.begin(), MSuccessors.end(), - [Node](const std::weak_ptr &Ptr) { - return Ptr.lock() == Node; + [&Node](const std::weak_ptr &Ptr) { + return Ptr.lock().get() == &Node; }) != MSuccessors.end()) { return; } - MSuccessors.push_back(Node); - Node->registerPredecessor(shared_from_this()); + MSuccessors.push_back(Node.weak_from_this()); + Node.registerPredecessor(*this); } /// Add predecessor to the node. /// @param Node Node to add as a predecessor. - void registerPredecessor(const std::shared_ptr &Node) { + void registerPredecessor(node_impl &Node) { if (std::find_if(MPredecessors.begin(), MPredecessors.end(), [&Node](const std::weak_ptr &Ptr) { - return Ptr.lock() == Node; + return Ptr.lock().get() == &Node; }) != MPredecessors.end()) { return; } - MPredecessors.push_back(Node); + MPredecessors.push_back(Node.weak_from_this()); } /// Construct an empty node. @@ -774,7 +774,7 @@ class nodes_range { // std::set>, std::set, // - std::list>; + std::list, std::vector>; storage_iter Begin; storage_iter End; @@ -783,10 +783,8 @@ class nodes_range { public: nodes_range(const nodes_range &Other) = default; - template < - typename ContainerTy, - typename = std::enable_if_t>> - nodes_range(ContainerTy &Container) + template + nodes_range(const ContainerTy &Container) : Begin{Container.begin()}, End{Container.end()}, Size{Container.size()} { } @@ -812,12 +810,14 @@ class nodes_range { return std::visit( [](auto &&It) -> node_impl & { auto &Elem = *It; - if constexpr (std::is_same_v, - std::weak_ptr>) { + using Ty = std::decay_t; + if constexpr (std::is_same_v>) { // This assumes that weak_ptr doesn't actually manage lifetime and // the object is guaranteed to be alive (which seems to be the // assumption across all graph code). return *Elem.lock(); + } else if constexpr (std::is_same_v) { + return *getSyclObjImpl(Elem); } else { return *Elem; } diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index abcdcbfd5577..af7390c6ca7e 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -886,13 +886,13 @@ event handler::finalize() { // In-order queues create implicit linear dependencies between nodes. // Find the last node added to the graph from this queue, so our new // node can set it as a predecessor. - std::vector> - Deps; + std::vector Deps; if (ext::oneapi::experimental::detail::node_impl *DependentNode = GraphImpl->getLastInorderNode(Queue)) { - Deps.push_back(DependentNode->shared_from_this()); + Deps.push_back(DependentNode); } - NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps); + NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps) + .shared_from_this(); // If we are recording an in-order queue remember the new node, so it // can be used as a dependency for any more nodes recorded from this @@ -902,13 +902,13 @@ event handler::finalize() { ext::oneapi::experimental::detail::node_impl *LastBarrierRecordedFromQueue = GraphImpl->getBarrierDep(Queue->weak_from_this()); - std::vector> - Deps; + std::vector Deps; if (LastBarrierRecordedFromQueue) { - Deps.push_back(LastBarrierRecordedFromQueue->shared_from_this()); + Deps.push_back(LastBarrierRecordedFromQueue); } - NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps); + NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps) + .shared_from_this(); if (NodeImpl->MCGType == sycl::detail::CGType::Barrier) { GraphImpl->setBarrierDep(Queue->weak_from_this(), *NodeImpl); @@ -916,7 +916,7 @@ event handler::finalize() { } // Associate an event with this new node and return the event. - GraphImpl->addEventForNode(EventImpl, std::move(NodeImpl)); + GraphImpl->addEventForNode(EventImpl, *NodeImpl); #ifdef __INTEL_PREVIEW_BREAKING_CHANGES return EventImpl;