From 99e9fa1be37358333a72029da9a62b352ce45547 Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Mon, 7 Jul 2025 18:01:04 -0700 Subject: [PATCH 1/2] [NFC][SYCL][Graph] Use raw `node_impl *` in `MRoots`/`MSchedule` ... and update the code surrounding their uses in the same spirit. Continuation of https://github.com/intel/llvm/pull/19295 https://github.com/intel/llvm/pull/19332 https://github.com/intel/llvm/pull/19334 --- sycl/source/detail/graph/graph_impl.cpp | 99 +++++------ sycl/source/detail/graph/graph_impl.hpp | 59 +++---- sycl/source/detail/graph/node_impl.cpp | 7 +- sycl/source/detail/graph/node_impl.hpp | 10 +- .../Extensions/CommandGraph/Barrier.cpp | 163 ++++++++---------- .../Extensions/CommandGraph/CommandGraph.cpp | 2 +- .../Extensions/CommandGraph/Exceptions.cpp | 2 +- .../Extensions/CommandGraph/MultiThreaded.cpp | 14 +- 8 files changed, 160 insertions(+), 196 deletions(-) diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index d5555ef688767..0220b0a5e46c8 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -85,31 +85,29 @@ inline const char *nodeTypeToString(node_type NodeType) { /// @param[in] PartitionBounded If set to true, the topological sort is stopped /// at partition borders. Hence, nodes belonging to a partition different from /// the NodeImpl partition are not processed. -void sortTopological(std::set, - std::owner_less>> &Roots, - std::list> &SortedNodes, +void sortTopological(nodes_range Roots, std::list &SortedNodes, bool PartitionBounded) { - std::stack> Source; + std::stack Source; - for (auto &Node : Roots) { - Source.push(Node); + for (node_impl &Node : Roots) { + Source.push(&Node); } while (!Source.empty()) { - auto Node = Source.top().lock(); + node_impl &Node = *Source.top(); Source.pop(); - SortedNodes.push_back(Node); + SortedNodes.push_back(&Node); - for (node_impl &Succ : Node->successors()) { + for (node_impl &Succ : Node.successors()) { - if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) { + if (PartitionBounded && (Succ.MPartitionNum != Node.MPartitionNum)) { continue; } auto &TotalVisitedEdges = Succ.MTotalVisitedEdges; ++TotalVisitedEdges; if (TotalVisitedEdges == Succ.MPredecessors.size()) { - Source.push(Succ.weak_from_this()); + Source.push(&Succ); } } } @@ -163,9 +161,9 @@ void propagatePartitionDown( /// belong to the same partition) /// @param Node node to test /// @return True is `Node` is a root of its partition -bool isPartitionRoot(std::shared_ptr Node) { - for (node_impl &Predecessor : Node->predecessors()) { - if (Predecessor.MPartitionNum == Node->MPartitionNum) { +bool isPartitionRoot(node_impl &Node) { + for (node_impl &Predecessor : Node.predecessors()) { + if (Predecessor.MPartitionNum == Node.MPartitionNum) { return false; } } @@ -173,7 +171,7 @@ bool isPartitionRoot(std::shared_ptr Node) { } } // anonymous namespace -void partition::schedule() { +void partition::updateSchedule() { if (MSchedule.empty()) { // There is no need to reset MTotalVisitedEdges before calling // sortTopological because this function is only called once per partition. @@ -256,8 +254,8 @@ void exec_graph_impl::makePartitions() { for (auto &Node : MNodeStorage) { if (Node->MPartitionNum == i) { MPartitionNodes[Node.get()] = PartitionFinalNum; - if (isPartitionRoot(Node)) { - Partition->MRoots.insert(Node); + if (isPartitionRoot(*Node)) { + Partition->MRoots.insert(Node.get()); if (Node->MCGType == CGType::CodeplayHostTask) { Partition->MIsHostTask = true; } @@ -265,7 +263,7 @@ void exec_graph_impl::makePartitions() { } } if (Partition->MRoots.size() > 0) { - Partition->schedule(); + Partition->updateSchedule(); Partition->MIsInOrderGraph = Partition->checkIfGraphIsSinglePath(); MPartitions.push_back(Partition); MRootPartitions.push_back(Partition); @@ -287,9 +285,8 @@ void exec_graph_impl::makePartitions() { // Compute partition dependencies for (const auto &Partition : MPartitions) { - for (auto const &Root : Partition->MRoots) { - auto RootNode = Root.lock(); - for (node_impl &NodeDep : RootNode->predecessors()) { + for (node_impl &Root : Partition->roots()) { + for (node_impl &NodeDep : Root.predecessors()) { auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]]; Partition->MPredecessors.push_back(Predecessor.get()); Predecessor->MSuccessors.push_back(Partition.get()); @@ -340,13 +337,9 @@ graph_impl::~graph_impl() { } } -void graph_impl::addRoot(const std::shared_ptr &Root) { - MRoots.insert(Root); -} +void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); } -void graph_impl::removeRoot(const std::shared_ptr &Root) { - MRoots.erase(Root); -} +void graph_impl::removeRoot(node_impl &Root) { MRoots.erase(&Root); } std::set> graph_impl::getCGEdges( const std::shared_ptr &CommandGroup) const { @@ -593,7 +586,7 @@ bool graph_impl::clearQueues() { } bool graph_impl::checkForCycles() { - std::list> SortedNodes; + std::list SortedNodes; sortTopological(MRoots, SortedNodes, false); // If after a topological sort, not all the nodes in the graph are sorted, @@ -664,7 +657,7 @@ void graph_impl::makeEdge(std::shared_ptr Src, bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1; if (DestLostRootStatus) { // Dest is no longer a Root node, so we need to remove it from MRoots. - MRoots.erase(Dest); + MRoots.erase(Dest.get()); } // We can skip cycle checks if either Dest has no successors (cycle not @@ -679,14 +672,14 @@ void graph_impl::makeEdge(std::shared_ptr Src, Dest->MPredecessors.pop_back(); if (DestLostRootStatus) { // Add Dest back into MRoots. - MRoots.insert(Dest); + MRoots.insert(Dest.get()); } throw sycl::exception(make_error_code(sycl::errc::invalid), "Command graphs cannot contain cycles."); } } - removeRoot(Dest); // remove receiver from root node list + removeRoot(*Dest); // remove receiver from root node list } std::vector graph_impl::getExitNodesEvents( @@ -740,14 +733,12 @@ void exec_graph_impl::findRealDeps( } } -ur_exp_command_buffer_sync_point_t -exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, - sycl::detail::device_impl &DeviceImpl, - ur_exp_command_buffer_handle_t CommandBuffer, - std::shared_ptr Node) { +ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNodeDirect( + const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl, + ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node) { std::vector Deps; - for (node_impl &N : Node->predecessors()) { - findRealDeps(Deps, N, MPartitionNodes[Node.get()]); + for (node_impl &N : Node.predecessors()) { + findRealDeps(Deps, N, MPartitionNodes[&Node]); } ur_exp_command_buffer_sync_point_t NewSyncPoint; ur_exp_command_buffer_command_handle_t NewCommand = 0; @@ -760,7 +751,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, if (xptiEnabled) { StreamID = xptiRegisterStream(sycl::detail::SYCL_STREAM_NAME); sycl::detail::CGExecKernel *CGExec = - static_cast(Node->MCommandGroup.get()); + static_cast(Node.MCommandGroup.get()); sycl::detail::code_location CodeLoc(CGExec->MFileName.c_str(), CGExec->MFunctionName.c_str(), CGExec->MLine, CGExec->MColumn); @@ -776,11 +767,11 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, ur_result_t Res = sycl::detail::enqueueImpCommandBufferKernel( Ctx, DeviceImpl, CommandBuffer, - *static_cast((Node->MCommandGroup.get())), + *static_cast((Node.MCommandGroup.get())), Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr); if (MIsUpdatable) { - MCommandMap[Node.get()] = NewCommand; + MCommandMap[&Node] = NewCommand; } if (Res != UR_RESULT_SUCCESS) { @@ -799,20 +790,20 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, - std::shared_ptr Node) { + node_impl &Node) { std::vector Deps; - for (node_impl &N : Node->predecessors()) { - findRealDeps(Deps, N, MPartitionNodes[Node.get()]); + for (node_impl &N : Node.predecessors()) { + findRealDeps(Deps, N, MPartitionNodes[&Node]); } sycl::detail::EventImplPtr Event = sycl::detail::Scheduler::getInstance().addCG( - Node->getCGCopy(), *MQueueImpl, + Node.getCGCopy(), *MQueueImpl, /*EventNeeded=*/true, CommandBuffer, Deps); if (MIsUpdatable) { - MCommandMap[Node.get()] = Event->getCommandBufferCommand(); + MCommandMap[&Node] = Event->getCommandBufferCommand(); } return Event->getSyncPoint(); @@ -861,25 +852,25 @@ void exec_graph_impl::createCommandBuffers( Partition->MCommandBuffers[Device] = OutCommandBuffer; - for (const auto &Node : Partition->MSchedule) { + for (node_impl &Node : Partition->schedule()) { // Some nodes are not scheduled like other nodes, and only their // dependencies are propagated in findRealDeps - if (!Node->requiresEnqueue()) + if (!Node.requiresEnqueue()) continue; - sycl::detail::CGType type = Node->MCGType; + sycl::detail::CGType type = Node.MCGType; // If the node is a kernel with no special requirements we can enqueue it // directly. if (type == sycl::detail::CGType::Kernel && - Node->MCommandGroup->getRequirements().size() + + Node.MCommandGroup->getRequirements().size() + static_cast( - Node->MCommandGroup.get()) + Node.MCommandGroup.get()) ->MStreams.size() == 0) { - MSyncPoints[Node.get()] = + MSyncPoints[&Node] = enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node); } else { - MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node); + MSyncPoints[&Node] = enqueueNode(OutCommandBuffer, Node); } } @@ -2007,7 +1998,7 @@ std::vector modifiable_command_graph::get_nodes() const { std::vector modifiable_command_graph::get_root_nodes() const { graph_impl::ReadLock Lock(impl->MMutex); auto &Roots = impl->MRoots; - std::vector> Impls{}; + std::vector Impls{}; std::copy(Roots.begin(), Roots.end(), std::back_inserter(Impls)); return createNodesFromImpls(Impls); diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 0e257a77e5ef1..08255c5639627 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -53,10 +53,9 @@ class partition { partition() : MSchedule(), MCommandBuffers() {} /// List of root nodes. - std::set, std::owner_less>> - MRoots; + std::set MRoots; /// Execution schedule of nodes in the graph. - std::list> MSchedule; + std::list MSchedule; /// Map of devices to command buffers. std::unordered_map MCommandBuffers; @@ -84,17 +83,20 @@ class partition { // replaced every time the partition is executed. EventImplPtr MEvent; + nodes_range roots() const { return MRoots; } + nodes_range schedule() const { return MSchedule; } + /// Checks if the graph is single path, i.e. each node has a single successor. /// @return True if the graph is a single path bool checkIfGraphIsSinglePath() { if (MRoots.size() > 1) { return false; } - for (const auto &Node : MSchedule) { + for (node_impl &Node : schedule()) { // In version 1.3.28454 of the L0 driver, 2D Copy ops cannot not // be enqueued in an in-order cmd-list (causing execution to stall). // The 2D Copy test should be removed from here when the bug is fixed. - if ((Node->MSuccessors.size() > 1) || (Node->isNDCopyNode())) { + if ((Node.MSuccessors.size() > 1) || (Node.isNDCopyNode())) { return false; } } @@ -103,7 +105,7 @@ class partition { } /// Add nodes to MSchedule. - void schedule(); + void updateSchedule(); }; /// Implementation details of command_graph. @@ -126,7 +128,7 @@ class graph_impl : public std::enable_shared_from_this { /// Remove node from list of root nodes. /// @param Root Node to remove from list of root nodes. - void removeRoot(const std::shared_ptr &Root); + void removeRoot(node_impl &Root); /// Verifies the CG is valid to add to the graph and returns set of /// dependent nodes if so. @@ -281,8 +283,7 @@ class graph_impl : public std::enable_shared_from_this { sycl::device getDevice() const { return MDevice; } /// List of root nodes. - std::set, std::owner_less>> - MRoots; + std::set MRoots; /// Storage for all nodes contained within a graph. Nodes are connected to /// each other via weak_ptrs and so do not extend each other's lifetimes. @@ -290,6 +291,8 @@ class graph_impl : public std::enable_shared_from_this { /// than needing an expensive depth first search. std::vector> MNodeStorage; + nodes_range roots() const { return MRoots; } + /// Find the last node added to this graph from an in-order queue. /// @param Queue In-order queue to find the last node added to the graph from. /// @return Last node in this graph added from \p Queue recording, or empty @@ -312,8 +315,8 @@ class graph_impl : public std::enable_shared_from_this { std::fstream Stream(FilePath, std::ios::out); Stream << "digraph dot {" << std::endl; - for (std::weak_ptr Node : MRoots) - Node.lock()->printDotRecursive(Stream, VisitedNodes, Verbose); + for (node_impl &Node : roots()) + Node.printDotRecursive(Stream, VisitedNodes, Verbose); Stream << "}" << std::endl; @@ -418,13 +421,10 @@ class graph_impl : public std::enable_shared_from_this { } size_t RootsFound = 0; - for (std::weak_ptr NodeA : MRoots) { - for (std::weak_ptr NodeB : Graph.MRoots) { - auto NodeALocked = NodeA.lock(); - auto NodeBLocked = NodeB.lock(); - - if (NodeALocked->isSimilar(*NodeBLocked)) { - if (checkNodeRecursive(*NodeALocked, *NodeBLocked)) { + for (node_impl &NodeA : roots()) { + for (node_impl &NodeB : Graph.roots()) { + if (NodeA.isSimilar(NodeB)) { + if (checkNodeRecursive(NodeA, NodeB)) { RootsFound++; break; } @@ -518,7 +518,7 @@ class graph_impl : public std::enable_shared_from_this { /// Insert node into list of root nodes. /// @param Root Node to add to list of root nodes. - void addRoot(const std::shared_ptr &Root); + void addRoot(node_impl &Root); /// Adds dependencies for a new node, if it has no deps it will be /// added as a root node. @@ -527,10 +527,10 @@ class graph_impl : public std::enable_shared_from_this { void addDepsToNode(const std::shared_ptr &Node, nodes_range Deps) { for (node_impl &N : Deps) { N.registerSuccessor(Node); - this->removeRoot(Node); + this->removeRoot(*Node); } if (Node->MPredecessors.empty()) { - this->addRoot(Node); + this->addRoot(*Node); } } @@ -647,9 +647,7 @@ class exec_graph_impl { /// Query the scheduling of node execution. /// @return List of nodes in execution order. - const std::list> &getSchedule() const { - return MSchedule; - } + const std::list &getSchedule() const { return MSchedule; } /// Query the graph_impl. /// @return pointer to the graph_impl MGraphImpl @@ -730,8 +728,7 @@ class exec_graph_impl { /// @param Node The node being enqueued. /// @return UR sync point created for this node in the command-buffer. ur_exp_command_buffer_sync_point_t - enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, - std::shared_ptr Node); + enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node); /// Enqueue a node directly to the command-buffer without going through the /// scheduler. @@ -740,11 +737,9 @@ class exec_graph_impl { /// @param CommandBuffer Command-buffer to add node to as a command. /// @param Node The node being enqueued. /// @return UR sync point created for this node in the command-buffer. - ur_exp_command_buffer_sync_point_t - enqueueNodeDirect(const sycl::context &Ctx, - sycl::detail::device_impl &DeviceImpl, - ur_exp_command_buffer_handle_t CommandBuffer, - std::shared_ptr Node); + ur_exp_command_buffer_sync_point_t enqueueNodeDirect( + const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl, + ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node); /// Enqueues a host-task partition (i.e. a partition that contains only a /// single node and that node is a host-task). @@ -873,7 +868,7 @@ class exec_graph_impl { ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const; /// Execution schedule of nodes in the graph. - std::list> MSchedule; + std::list MSchedule; /// Pointer to the modifiable graph impl associated with this executable /// graph. /// Thread-safe implementation note: in the current implementation diff --git a/sycl/source/detail/graph/node_impl.cpp b/sycl/source/detail/graph/node_impl.cpp index 12c76c5522a97..36d2f589f0391 100644 --- a/sycl/source/detail/graph/node_impl.cpp +++ b/sycl/source/detail/graph/node_impl.cpp @@ -31,14 +31,11 @@ std::vector createNodesFromImpls( return Nodes; } -/// Takes a vector of shared_ptrs to node_impls and returns a vector of node -/// objects created from those impls, in the same order. -std::vector createNodesFromImpls( - const std::vector> &Impls) { +std::vector createNodesFromImpls(nodes_range Impls) { std::vector Nodes{}; Nodes.reserve(Impls.size()); - for (std::shared_ptr Impl : Impls) { + for (detail::node_impl &Impl : Impls) { Nodes.push_back(sycl::detail::createSyclObjFromImpl(Impl)); } diff --git a/sycl/source/detail/graph/node_impl.hpp b/sycl/source/detail/graph/node_impl.hpp index 11166e1eba897..578b62e0ab614 100644 --- a/sycl/source/detail/graph/node_impl.hpp +++ b/sycl/source/detail/graph/node_impl.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -39,10 +40,7 @@ class exec_graph_impl; std::vector createNodesFromImpls(const std::vector> &Impls); -/// Takes a vector of shared_ptrs to node_impls and returns a vector of node -/// objects created from those impls, in the same order. -std::vector -createNodesFromImpls(const std::vector> &Impls); +std::vector createNodesFromImpls(nodes_range Impls); inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) { using sycl::detail::CG; @@ -774,7 +772,9 @@ class nodes_range { // from `weak_ptr`s this alternative should be removed too. std::vector>, // - std::set>>; + std::set>, std::set, + // + std::list>; storage_iter Begin; storage_iter End; diff --git a/sycl/unittests/Extensions/CommandGraph/Barrier.cpp b/sycl/unittests/Extensions/CommandGraph/Barrier.cpp index 1793588cef111..648c28ab0af7b 100644 --- a/sycl/unittests/Extensions/CommandGraph/Barrier.cpp +++ b/sycl/unittests/Extensions/CommandGraph/Barrier.cpp @@ -39,10 +39,9 @@ TEST_F(CommandGraphTest, EnqueueBarrier) { // / \ // (4) (5) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto Node = Root.lock(); - ASSERT_EQ(Node->MSuccessors.size(), 1lu); - auto BarrierNode = Node->MSuccessors.front().lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto BarrierNode = Root.MSuccessors.front().lock(); ASSERT_EQ(BarrierNode->MCGType, sycl::detail::CGType::Barrier); ASSERT_EQ(GraphImpl.getEventForNode(BarrierNode).get(), &*getSyclObjImpl(Barrier)); @@ -79,14 +78,12 @@ TEST_F(CommandGraphTest, EnqueueBarrierMultipleQueues) { // / \ // (4) (5) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - - if (GraphImpl.getEventForNode(RootNode).get() == + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node2Graph)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier)); @@ -110,7 +107,7 @@ TEST_F(CommandGraphTest, EnqueueBarrierMultipleQueues) { } } } else { - ASSERT_EQ(RootNode->MSuccessors.size(), 0lu); + ASSERT_EQ(Root.MSuccessors.size(), 0lu); } } } @@ -147,10 +144,9 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitList) { // / \ / // (4) (5) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto Node = Root.lock(); - ASSERT_EQ(Node->MSuccessors.size(), 1lu); - auto SuccNode = Node->MSuccessors.front().lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); if (SuccNode->MCGType == sycl::detail::CGType::Barrier) { ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier)); @@ -204,10 +200,9 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitListMultipleQueues) { // \|/ // (B2) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto Node = Root.lock(); - ASSERT_EQ(Node->MSuccessors.size(), 1lu); - auto SuccNode = Node->MSuccessors.front().lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); if (SuccNode->MCGType == sycl::detail::CGType::Barrier) { ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier)); @@ -267,10 +262,9 @@ TEST_F(CommandGraphTest, EnqueueMultipleBarrier) { // / | \ // (6) (7) (8) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto Node = Root.lock(); - ASSERT_EQ(Node->MSuccessors.size(), 1lu); - auto SuccNode = Node->MSuccessors.front().lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); if (SuccNode->MCGType == sycl::detail::CGType::Barrier) { ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier1)); @@ -335,10 +329,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithPreviousCommand) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 1lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - ASSERT_EQ(RootNode->MSuccessors.size(), 0lu); - ASSERT_TRUE(RootNode->MCGType == sycl::detail::CGType::Barrier); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 0lu); + ASSERT_TRUE(Root.MCGType == sycl::detail::CGType::Barrier); } } @@ -370,20 +363,19 @@ TEST_F(CommandGraphTest, InOrderQueuesWithBarrier) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - - if (GraphImpl.getEventForNode(RootNode).get() == &*getSyclObjImpl(Node1)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == + &*getSyclObjImpl(Node1)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_TRUE(SuccNode->MCGType == sycl::detail::CGType::Barrier); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); - } else if (GraphImpl.getEventForNode(RootNode).get() == + } else if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node2)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 0lu); + ASSERT_EQ(Root.MSuccessors.size(), 0lu); } else { ASSERT_TRUE(false && "Unexpected root node"); } @@ -417,12 +409,10 @@ TEST_F(CommandGraphTest, InOrderQueuesWithBarrierWaitList) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); @@ -461,24 +451,23 @@ TEST_F(CommandGraphTest, InOrderQueuesWithEmptyBarrierWaitList) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == + &*getSyclObjImpl(Node1)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); - if (GraphImpl.getEventForNode(RootNode).get() == &*getSyclObjImpl(Node1)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); - } else if (GraphImpl.getEventForNode(RootNode).get() == + } else if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node2)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); + ASSERT_EQ(Root.MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Node3)); @@ -525,19 +514,18 @@ TEST_F(CommandGraphTest, BarrierMixedQueueTypes) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - - if (GraphImpl.getEventForNode(RootNode).get() == &*getSyclObjImpl(Node1)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - } else if (GraphImpl.getEventForNode(RootNode).get() == + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == + &*getSyclObjImpl(Node1)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + } else if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node2)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 2lu); + ASSERT_EQ(Root.MSuccessors.size(), 2lu); } else { ASSERT_TRUE(false && "Unexpected root node"); } - for (auto Succ : RootNode->MSuccessors) { + for (auto Succ : Root.MSuccessors) { auto SuccNode = Succ.lock(); if (GraphImpl.getEventForNode(SuccNode).get() == @@ -580,15 +568,14 @@ TEST_F(CommandGraphTest, BarrierBetweenExplicitNodes) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { - if (GraphImpl.getEventForNode(RootNode).get() == + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(BarrierNode)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 0lu); - } else if (RootNode.get() == &*getSyclObjImpl(Node1)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + ASSERT_EQ(Root.MSuccessors.size(), 0lu); + } else if (&Root == &*getSyclObjImpl(Node1)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(SuccNode.get(), &*getSyclObjImpl(Node2)); } else { ASSERT_TRUE(false); @@ -636,13 +623,12 @@ TEST_F(CommandGraphTest, BarrierMultipleOOOQueue) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 4u); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - auto RootNodeEvent = GraphImpl.getEventForNode(RootNode); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + auto RootNodeEvent = GraphImpl.getEventForNode(Root.shared_from_this()); if ((RootNodeEvent.get() == &*getSyclObjImpl(Node1)) || (RootNodeEvent.get() == &*getSyclObjImpl(Node2))) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); @@ -659,7 +645,7 @@ TEST_F(CommandGraphTest, BarrierMultipleOOOQueue) { &*getSyclObjImpl(Node6)); } else if ((RootNodeEvent.get() == &*getSyclObjImpl(Node3)) || (RootNodeEvent.get() == &*getSyclObjImpl(Node4))) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Node5)); @@ -701,18 +687,17 @@ TEST_F(CommandGraphTest, BarrierMultipleInOrderQueue) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2u); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - auto RootNodeEvent = GraphImpl.getEventForNode(RootNode); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + auto RootNodeEvent = GraphImpl.getEventForNode(Root.shared_from_this()); if (RootNodeEvent.get() == &*getSyclObjImpl(Node1)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); } else if (RootNodeEvent.get() == &*getSyclObjImpl(Node2)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Node3)); @@ -752,18 +737,17 @@ TEST_F(CommandGraphTest, BarrierMultipleMixedOrderQueues) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2u); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - auto RootNodeEvent = GraphImpl.getEventForNode(RootNode); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + auto RootNodeEvent = GraphImpl.getEventForNode(Root.shared_from_this()); if (RootNodeEvent.get() == &*getSyclObjImpl(Node1)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); } else if (RootNodeEvent.get() == &*getSyclObjImpl(Node2)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Node3)); @@ -797,18 +781,17 @@ TEST_F(CommandGraphTest, BarrierMultipleQueuesMultipleBarriers) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2u); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - auto RootNodeEvent = GraphImpl.getEventForNode(RootNode); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + auto RootNodeEvent = GraphImpl.getEventForNode(Root.shared_from_this()); if (RootNodeEvent.get() == &*getSyclObjImpl(Barrier1)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier4)); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); } else if (RootNodeEvent.get() == &*getSyclObjImpl(Barrier2)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier3)); @@ -875,21 +858,21 @@ TEST_F(CommandGraphTest, BarrierWithInOrderCommands) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { bool EvenPath; - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - if (GraphImpl.getEventForNode(RootNode).get() == &*getSyclObjImpl(Node2)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == + &*getSyclObjImpl(Node2)) { EvenPath = true; - } else if (GraphImpl.getEventForNode(RootNode).get() == + } else if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node1)) { EvenPath = false; } else { ASSERT_TRUE(false); } - auto Succ1Node = RootNode->MSuccessors.front().lock(); + auto Succ1Node = Root.MSuccessors.front().lock(); ASSERT_EQ(Succ1Node->MSuccessors.size(), 1lu); if (EvenPath) { ASSERT_EQ(GraphImpl.getEventForNode(Succ1Node).get(), diff --git a/sycl/unittests/Extensions/CommandGraph/CommandGraph.cpp b/sycl/unittests/Extensions/CommandGraph/CommandGraph.cpp index f9e9915ae9119..9799847d88d9f 100644 --- a/sycl/unittests/Extensions/CommandGraph/CommandGraph.cpp +++ b/sycl/unittests/Extensions/CommandGraph/CommandGraph.cpp @@ -59,7 +59,7 @@ TEST_F(CommandGraphTest, AddNode) { [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); ASSERT_FALSE(getSyclObjImpl(Node1)->isEmpty()); ASSERT_EQ(GraphImpl.MRoots.size(), 1lu); - ASSERT_EQ(GraphImpl.MRoots.begin()->lock().get(), &*getSyclObjImpl(Node1)); + ASSERT_EQ(*GraphImpl.MRoots.begin(), &*getSyclObjImpl(Node1)); ASSERT_TRUE(getSyclObjImpl(Node1)->MSuccessors.empty()); ASSERT_TRUE(getSyclObjImpl(Node1)->MPredecessors.empty()); diff --git a/sycl/unittests/Extensions/CommandGraph/Exceptions.cpp b/sycl/unittests/Extensions/CommandGraph/Exceptions.cpp index 589583e758fb9..ab651e140fe83 100644 --- a/sycl/unittests/Extensions/CommandGraph/Exceptions.cpp +++ b/sycl/unittests/Extensions/CommandGraph/Exceptions.cpp @@ -507,7 +507,7 @@ TEST_F(CommandGraphTest, MakeEdgeErrors) { experimental::detail::node_impl &NodeBImpl = *getSyclObjImpl(NodeB); ASSERT_EQ(GraphImpl.MRoots.size(), 1lu); - ASSERT_EQ(GraphImpl.MRoots.begin()->lock().get(), &NodeAImpl); + ASSERT_EQ(*GraphImpl.MRoots.begin(), &NodeAImpl); ASSERT_EQ(NodeAImpl.MSuccessors.size(), 1lu); ASSERT_EQ(NodeAImpl.MPredecessors.size(), 0lu); diff --git a/sycl/unittests/Extensions/CommandGraph/MultiThreaded.cpp b/sycl/unittests/Extensions/CommandGraph/MultiThreaded.cpp index f5242af68c535..fef1a947d0bc0 100644 --- a/sycl/unittests/Extensions/CommandGraph/MultiThreaded.cpp +++ b/sycl/unittests/Extensions/CommandGraph/MultiThreaded.cpp @@ -84,12 +84,10 @@ bool checkExecGraphSchedule( if (ScheduleA.size() != ScheduleB.size()) return false; - std::vector< - std::shared_ptr> - VScheduleA{std::begin(ScheduleA), std::end(ScheduleA)}; - std::vector< - std::shared_ptr> - VScheduleB{std::begin(ScheduleB), std::end(ScheduleB)}; + std::vector VScheduleA{ + std::begin(ScheduleA), std::end(ScheduleA)}; + std::vector VScheduleB{ + std::begin(ScheduleB), std::end(ScheduleB)}; for (size_t i = 0; i < VScheduleA.size(); i++) { if (!VScheduleA[i]->isSimilar(*VScheduleB[i])) @@ -244,7 +242,7 @@ TEST_F(MultiThreadGraphTest, RecordAddNodesInOrderQueue) { ASSERT_EQ(GraphImpl.MRoots.size(), 1lu); // Check structure graph - auto CurrentNode = GraphImpl.MRoots.begin()->lock(); + experimental::detail::node_impl *CurrentNode = *GraphImpl.MRoots.begin(); for (size_t i = 1; i <= GraphImpl.getNumberOfNodes(); i++) { EXPECT_LE(CurrentNode->MSuccessors.size(), 1lu); @@ -254,7 +252,7 @@ TEST_F(MultiThreadGraphTest, RecordAddNodesInOrderQueue) { } else { // Check other nodes have 1 successor EXPECT_EQ(CurrentNode->MSuccessors.size(), 1lu); - CurrentNode = CurrentNode->MSuccessors[0].lock(); + CurrentNode = CurrentNode->MSuccessors[0].lock().get(); } } } From 395f4763dad7b12abdcdfb8148486e6d3e63c71c Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Tue, 8 Jul 2025 08:25:35 -0700 Subject: [PATCH 2/2] [NFCI][SYCL][Graph] Refactor `graph_impl::add` Part of the refactoring to eliminate `std::weak_ptr` and reduce usage of `std::shared_ptr` by preferring raw ptr/ref. Previous PRs in the series: https://github.com/intel/llvm/pull/19295 https://github.com/intel/llvm/pull/19332 https://github.com/intel/llvm/pull/19334 https://github.com/intel/llvm/pull/19350 * Accept `Deps` as `nodes_range` in `graph_impl::add` * Return `node_impl &` from `graph_impl::add` * Add `node` support in `nodes_range` and use that together with modified `graph_impl::add` when created new `node_impl`s based on `std::vector Deps` to avoid creation of temporary `DepImpls` storage. * Also updated `registerSuccessor/registerPredecessor` and `addEventForNode/addDepsToNode` to accept raw `node_impl &` as the changes above resulted in having raw reference at the call sites. --- sycl/source/detail/graph/graph_impl.cpp | 81 ++++++++++--------------- sycl/source/detail/graph/graph_impl.hpp | 36 ++++++----- sycl/source/detail/graph/node_impl.hpp | 34 +++++------ sycl/source/handler.cpp | 18 +++--- 4 files changed, 79 insertions(+), 90 deletions(-) diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 0220b0a5e46c8..53730570a76c8 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -409,10 +409,8 @@ void graph_impl::markCGMemObjs( } } -std::shared_ptr graph_impl::add(nodes_range Deps) { - const std::shared_ptr &NodeImpl = std::make_shared(); - - MNodeStorage.push_back(NodeImpl); +node_impl &graph_impl::add(nodes_range Deps) { + node_impl &NodeImpl = createNode(); addDepsToNode(NodeImpl, Deps); // Add an event associated with this explicit node for mixed usage @@ -421,10 +419,9 @@ std::shared_ptr graph_impl::add(nodes_range Deps) { return NodeImpl; } -std::shared_ptr -graph_impl::add(std::function CGF, - const std::vector &Args, - std::vector> &Deps) { +node_impl &graph_impl::add(std::function CGF, + const std::vector &Args, + nodes_range Deps) { (void)Args; #ifdef __INTEL_PREVIEW_BREAKING_CHANGES detail::handler_impl HandlerImpl{*this}; @@ -435,7 +432,9 @@ graph_impl::add(std::function CGF, // Pass the node deps to the handler so they are available when processing the // CGF, need for async_malloc nodes. - Handler.impl->MNodeDeps = Deps; + Handler.impl->MNodeDeps.clear(); + for (node_impl &N : Deps) + Handler.impl->MNodeDeps.push_back(N.shared_from_this()); #if XPTI_ENABLE_INSTRUMENTATION // Save code location if one was set in TLS. @@ -471,7 +470,7 @@ graph_impl::add(std::function CGF, : ext::oneapi::experimental::detail::getNodeTypeFromCG( Handler.getType()); - auto NodeImpl = + node_impl &NodeImpl = this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Deps); // Add an event associated with this explicit node for mixed usage @@ -489,16 +488,15 @@ graph_impl::add(std::function CGF, } for (auto &[DynamicParam, ArgIndex] : DynamicParams) { - DynamicParam->registerNode(NodeImpl, ArgIndex); + DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex); } return NodeImpl; } -std::shared_ptr -graph_impl::add(node_type NodeType, - std::shared_ptr CommandGroup, - nodes_range Deps) { +node_impl &graph_impl::add(node_type NodeType, + std::shared_ptr CommandGroup, + nodes_range Deps) { // A unique set of dependencies obtained by checking requirements and events std::set> UniqueDeps = getCGEdges(CommandGroup); @@ -506,9 +504,7 @@ graph_impl::add(node_type NodeType, // Track and mark the memory objects being used by the graph. markCGMemObjs(CommandGroup); - const std::shared_ptr &NodeImpl = - std::make_shared(NodeType, std::move(CommandGroup)); - MNodeStorage.push_back(NodeImpl); + node_impl &NodeImpl = createNode(NodeType, std::move(CommandGroup)); // Add any deps determined from requirements and events into the dependency // list @@ -516,17 +512,17 @@ graph_impl::add(node_type NodeType, addDepsToNode(NodeImpl, UniqueDeps); if (NodeType == node_type::async_free) { - auto AsyncFreeCG = - static_cast(NodeImpl->MCommandGroup.get()); + auto AsyncFreeCG = static_cast(NodeImpl.MCommandGroup.get()); // If this is an async free node mark that it is now available for reuse, // and pass the async free node for tracking. - MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), NodeImpl); + MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), + NodeImpl.shared_from_this()); } return NodeImpl; } -std::shared_ptr +node_impl & graph_impl::add(std::shared_ptr &DynCGImpl, nodes_range Deps) { // Set of Dependent nodes based on CG event and accessor dependencies. @@ -551,15 +547,14 @@ graph_impl::add(std::shared_ptr &DynCGImpl, const auto &ActiveKernel = DynCGImpl->getActiveCG(); node_type NodeType = ext::oneapi::experimental::detail::getNodeTypeFromCG(DynCGImpl->MCGType); - std::shared_ptr NodeImpl = - add(NodeType, ActiveKernel, Deps); + detail::node_impl &NodeImpl = add(NodeType, ActiveKernel, Deps); // Add an event associated with this explicit node for mixed usage addEventForNode(sycl::detail::event_impl::create_completed_host_event(), NodeImpl); // Track the dynamic command-group used inside the node object - DynCGImpl->MNodes.push_back(NodeImpl); + DynCGImpl->MNodes.push_back(NodeImpl.shared_from_this()); return NodeImpl; } @@ -652,7 +647,7 @@ void graph_impl::makeEdge(std::shared_ptr Src, bool DestWasGraphRoot = Dest->MPredecessors.size() == 0; // We need to add the edges first before checking for cycles - Src->registerSuccessor(Dest); + Src->registerSuccessor(*Dest); bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1; if (DestLostRootStatus) { @@ -1265,7 +1260,7 @@ void exec_graph_impl::duplicateNodes() { // Look through all the original node successors, find their copies and // register those as successors with the current copied node for (node_impl &NextNode : OriginalNode->successors()) { - auto Successor = NodesMap.at(NextNode.shared_from_this()); + node_impl &Successor = *NodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1307,7 +1302,8 @@ void exec_graph_impl::duplicateNodes() { auto NodeCopy = NewSubgraphNodes[i]; for (node_impl &NextNode : SubgraphNode->successors()) { - auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this()); + node_impl &Successor = + *SubgraphNodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1341,7 +1337,7 @@ void exec_graph_impl::duplicateNodes() { // Add all input nodes from the subgraph as successors for this node // instead for (auto &Input : Inputs) { - PredNode.registerSuccessor(Input); + PredNode.registerSuccessor(*Input); } } @@ -1360,7 +1356,7 @@ void exec_graph_impl::duplicateNodes() { // Add all Output nodes from the subgraph as predecessors for this node // instead for (auto &Output : Outputs) { - Output->registerSuccessor(SuccNode.shared_from_this()); + Output->registerSuccessor(SuccNode); } } @@ -1843,38 +1839,25 @@ node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF, "dynamic command-group."); } - std::vector> DepImpls; - for (auto &D : Deps) { - DepImpls.push_back(sycl::detail::getSyclObjImpl(D)); - } - graph_impl::WriteLock Lock(impl->MMutex); - std::shared_ptr NodeImpl = impl->add(DynCGFImpl, DepImpls); - return sycl::detail::createSyclObjFromImpl(std::move(NodeImpl)); + detail::node_impl &NodeImpl = impl->add(DynCGFImpl, Deps); + return sycl::detail::createSyclObjFromImpl(NodeImpl); } node modifiable_command_graph::addImpl(const std::vector &Deps) { impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function"); - std::vector> DepImpls; - for (auto &D : Deps) { - DepImpls.push_back(sycl::detail::getSyclObjImpl(D)); - } graph_impl::WriteLock Lock(impl->MMutex); - std::shared_ptr NodeImpl = impl->add(DepImpls); - return sycl::detail::createSyclObjFromImpl(std::move(NodeImpl)); + detail::node_impl &NodeImpl = impl->add(Deps); + return sycl::detail::createSyclObjFromImpl(NodeImpl); } node modifiable_command_graph::addImpl(std::function CGF, const std::vector &Deps) { impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function"); - std::vector> DepImpls; - for (auto &D : Deps) { - DepImpls.push_back(sycl::detail::getSyclObjImpl(D)); - } - std::shared_ptr NodeImpl = impl->add(CGF, {}, DepImpls); - return sycl::detail::createSyclObjFromImpl(std::move(NodeImpl)); + detail::node_impl &NodeImpl = impl->add(CGF, {}, Deps); + return sycl::detail::createSyclObjFromImpl(NodeImpl); } void modifiable_command_graph::addGraphLeafDependencies(node Node) { diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 08255c5639627..d90e1f5132b89 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -147,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this { /// @param CommandGroup The CG which stores all information for this node. /// @param Deps Dependencies of the created node. /// @return Created node in the graph. - std::shared_ptr add(node_type NodeType, - std::shared_ptr CommandGroup, - nodes_range Deps); + node_impl &add(node_type NodeType, + std::shared_ptr CommandGroup, + nodes_range Deps); /// Create a CGF node in the graph. /// @param CGF Command-group function to create node with. /// @param Args Node arguments. /// @param Deps Dependencies of the created node. /// @return Created node in the graph. - std::shared_ptr add(std::function CGF, - const std::vector &Args, - std::vector> &Deps); + node_impl &add(std::function CGF, + const std::vector &Args, + nodes_range Deps); /// Create an empty node in the graph. /// @param Deps List of predecessor nodes. /// @return Created node in the graph. - std::shared_ptr add(nodes_range Deps); + node_impl &add(nodes_range Deps); /// Create a dynamic command-group node in the graph. /// @param DynCGImpl Dynamic command-group used to create node. /// @param Deps List of predecessor nodes. /// @return Created node in the graph. - std::shared_ptr - add(std::shared_ptr &DynCGImpl, nodes_range Deps); + node_impl &add(std::shared_ptr &DynCGImpl, + nodes_range Deps); /// Add a queue to the set of queues which are currently recording to this /// graph. @@ -192,10 +192,10 @@ class graph_impl : public std::enable_shared_from_this { /// @param EventImpl Event to associate with a node in map. /// @param NodeImpl Node to associate with event in map. void addEventForNode(std::shared_ptr EventImpl, - const std::shared_ptr &NodeImpl) { + node_impl &NodeImpl) { if (!(EventImpl->hasCommandGraph())) EventImpl->setCommandGraph(shared_from_this()); - MEventsMap[EventImpl] = NodeImpl; + MEventsMap[EventImpl] = NodeImpl.shared_from_this(); } /// Find the sycl event associated with a node. @@ -510,6 +510,12 @@ class graph_impl : public std::enable_shared_from_this { } private: + template node_impl &createNode(Ts &&...Args) { + MNodeStorage.push_back( + std::make_shared(std::forward(Args)...)); + return *MNodeStorage.back(); + } + /// Check the graph for cycles by performing a depth-first search of the /// graph. If a node is visited more than once in a given path through the /// graph, a cycle is present and the search ends immediately. @@ -524,13 +530,13 @@ class graph_impl : public std::enable_shared_from_this { /// added as a root node. /// @param Node The node to add deps for /// @param Deps List of dependent nodes - void addDepsToNode(const std::shared_ptr &Node, nodes_range Deps) { + void addDepsToNode(node_impl &Node, nodes_range Deps) { for (node_impl &N : Deps) { N.registerSuccessor(Node); - this->removeRoot(*Node); + this->removeRoot(Node); } - if (Node->MPredecessors.empty()) { - this->addRoot(*Node); + if (Node.MPredecessors.empty()) { + this->addRoot(Node); } } diff --git a/sycl/source/detail/graph/node_impl.hpp b/sycl/source/detail/graph/node_impl.hpp index 578b62e0ab614..3268b1ce1adf8 100644 --- a/sycl/source/detail/graph/node_impl.hpp +++ b/sycl/source/detail/graph/node_impl.hpp @@ -14,6 +14,8 @@ #include // for CGType #include // for kernel_param_kind_t +#include // for node + #include #include #include @@ -26,8 +28,6 @@ inline namespace _V1 { namespace ext { namespace oneapi { namespace experimental { -// Forward declarations -class node; namespace detail { // Forward declarations @@ -121,27 +121,27 @@ class node_impl : public std::enable_shared_from_this { /// Add successor to the node. /// @param Node Node to add as a successor. - void registerSuccessor(const std::shared_ptr &Node) { + void registerSuccessor(node_impl &Node) { if (std::find_if(MSuccessors.begin(), MSuccessors.end(), - [Node](const std::weak_ptr &Ptr) { - return Ptr.lock() == Node; + [&Node](const std::weak_ptr &Ptr) { + return Ptr.lock().get() == &Node; }) != MSuccessors.end()) { return; } - MSuccessors.push_back(Node); - Node->registerPredecessor(shared_from_this()); + MSuccessors.push_back(Node.weak_from_this()); + Node.registerPredecessor(*this); } /// Add predecessor to the node. /// @param Node Node to add as a predecessor. - void registerPredecessor(const std::shared_ptr &Node) { + void registerPredecessor(node_impl &Node) { if (std::find_if(MPredecessors.begin(), MPredecessors.end(), [&Node](const std::weak_ptr &Ptr) { - return Ptr.lock() == Node; + return Ptr.lock().get() == &Node; }) != MPredecessors.end()) { return; } - MPredecessors.push_back(Node); + MPredecessors.push_back(Node.weak_from_this()); } /// Construct an empty node. @@ -774,7 +774,7 @@ class nodes_range { // std::set>, std::set, // - std::list>; + std::list, std::vector>; storage_iter Begin; storage_iter End; @@ -783,10 +783,8 @@ class nodes_range { public: nodes_range(const nodes_range &Other) = default; - template < - typename ContainerTy, - typename = std::enable_if_t>> - nodes_range(ContainerTy &Container) + template + nodes_range(const ContainerTy &Container) : Begin{Container.begin()}, End{Container.end()}, Size{Container.size()} { } @@ -812,12 +810,14 @@ class nodes_range { return std::visit( [](auto &&It) -> node_impl & { auto &Elem = *It; - if constexpr (std::is_same_v, - std::weak_ptr>) { + using Ty = std::decay_t; + if constexpr (std::is_same_v>) { // This assumes that weak_ptr doesn't actually manage lifetime and // the object is guaranteed to be alive (which seems to be the // assumption across all graph code). return *Elem.lock(); + } else if constexpr (std::is_same_v) { + return *getSyclObjImpl(Elem); } else { return *Elem; } diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index f575885b6a24d..04026c1532290 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -886,13 +886,13 @@ event handler::finalize() { // In-order queues create implicit linear dependencies between nodes. // Find the last node added to the graph from this queue, so our new // node can set it as a predecessor. - std::vector> - Deps; + std::vector Deps; if (ext::oneapi::experimental::detail::node_impl *DependentNode = GraphImpl->getLastInorderNode(Queue)) { - Deps.push_back(DependentNode->shared_from_this()); + Deps.push_back(DependentNode); } - NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps); + NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps) + .shared_from_this(); // If we are recording an in-order queue remember the new node, so it // can be used as a dependency for any more nodes recorded from this @@ -902,13 +902,13 @@ event handler::finalize() { ext::oneapi::experimental::detail::node_impl *LastBarrierRecordedFromQueue = GraphImpl->getBarrierDep(Queue->weak_from_this()); - std::vector> - Deps; + std::vector Deps; if (LastBarrierRecordedFromQueue) { - Deps.push_back(LastBarrierRecordedFromQueue->shared_from_this()); + Deps.push_back(LastBarrierRecordedFromQueue); } - NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps); + NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps) + .shared_from_this(); if (NodeImpl->MCGType == sycl::detail::CGType::Barrier) { GraphImpl->setBarrierDep(Queue->weak_from_this(), *NodeImpl); @@ -916,7 +916,7 @@ event handler::finalize() { } // Associate an event with this new node and return the event. - GraphImpl->addEventForNode(EventImpl, std::move(NodeImpl)); + GraphImpl->addEventForNode(EventImpl, *NodeImpl); #ifdef __INTEL_PREVIEW_BREAKING_CHANGES return EventImpl;