From 99e9fa1be37358333a72029da9a62b352ce45547 Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Mon, 7 Jul 2025 18:01:04 -0700 Subject: [PATCH] [NFC][SYCL][Graph] Use raw `node_impl *` in `MRoots`/`MSchedule` ... and update the code surrounding their uses in the same spirit. Continuation of https://github.com/intel/llvm/pull/19295 https://github.com/intel/llvm/pull/19332 https://github.com/intel/llvm/pull/19334 --- sycl/source/detail/graph/graph_impl.cpp | 99 +++++------ sycl/source/detail/graph/graph_impl.hpp | 59 +++---- sycl/source/detail/graph/node_impl.cpp | 7 +- sycl/source/detail/graph/node_impl.hpp | 10 +- .../Extensions/CommandGraph/Barrier.cpp | 163 ++++++++---------- .../Extensions/CommandGraph/CommandGraph.cpp | 2 +- .../Extensions/CommandGraph/Exceptions.cpp | 2 +- .../Extensions/CommandGraph/MultiThreaded.cpp | 14 +- 8 files changed, 160 insertions(+), 196 deletions(-) diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index d5555ef688767..0220b0a5e46c8 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -85,31 +85,29 @@ inline const char *nodeTypeToString(node_type NodeType) { /// @param[in] PartitionBounded If set to true, the topological sort is stopped /// at partition borders. Hence, nodes belonging to a partition different from /// the NodeImpl partition are not processed. -void sortTopological(std::set, - std::owner_less>> &Roots, - std::list> &SortedNodes, +void sortTopological(nodes_range Roots, std::list &SortedNodes, bool PartitionBounded) { - std::stack> Source; + std::stack Source; - for (auto &Node : Roots) { - Source.push(Node); + for (node_impl &Node : Roots) { + Source.push(&Node); } while (!Source.empty()) { - auto Node = Source.top().lock(); + node_impl &Node = *Source.top(); Source.pop(); - SortedNodes.push_back(Node); + SortedNodes.push_back(&Node); - for (node_impl &Succ : Node->successors()) { + for (node_impl &Succ : Node.successors()) { - if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) { + if (PartitionBounded && (Succ.MPartitionNum != Node.MPartitionNum)) { continue; } auto &TotalVisitedEdges = Succ.MTotalVisitedEdges; ++TotalVisitedEdges; if (TotalVisitedEdges == Succ.MPredecessors.size()) { - Source.push(Succ.weak_from_this()); + Source.push(&Succ); } } } @@ -163,9 +161,9 @@ void propagatePartitionDown( /// belong to the same partition) /// @param Node node to test /// @return True is `Node` is a root of its partition -bool isPartitionRoot(std::shared_ptr Node) { - for (node_impl &Predecessor : Node->predecessors()) { - if (Predecessor.MPartitionNum == Node->MPartitionNum) { +bool isPartitionRoot(node_impl &Node) { + for (node_impl &Predecessor : Node.predecessors()) { + if (Predecessor.MPartitionNum == Node.MPartitionNum) { return false; } } @@ -173,7 +171,7 @@ bool isPartitionRoot(std::shared_ptr Node) { } } // anonymous namespace -void partition::schedule() { +void partition::updateSchedule() { if (MSchedule.empty()) { // There is no need to reset MTotalVisitedEdges before calling // sortTopological because this function is only called once per partition. @@ -256,8 +254,8 @@ void exec_graph_impl::makePartitions() { for (auto &Node : MNodeStorage) { if (Node->MPartitionNum == i) { MPartitionNodes[Node.get()] = PartitionFinalNum; - if (isPartitionRoot(Node)) { - Partition->MRoots.insert(Node); + if (isPartitionRoot(*Node)) { + Partition->MRoots.insert(Node.get()); if (Node->MCGType == CGType::CodeplayHostTask) { Partition->MIsHostTask = true; } @@ -265,7 +263,7 @@ void exec_graph_impl::makePartitions() { } } if (Partition->MRoots.size() > 0) { - Partition->schedule(); + Partition->updateSchedule(); Partition->MIsInOrderGraph = Partition->checkIfGraphIsSinglePath(); MPartitions.push_back(Partition); MRootPartitions.push_back(Partition); @@ -287,9 +285,8 @@ void exec_graph_impl::makePartitions() { // Compute partition dependencies for (const auto &Partition : MPartitions) { - for (auto const &Root : Partition->MRoots) { - auto RootNode = Root.lock(); - for (node_impl &NodeDep : RootNode->predecessors()) { + for (node_impl &Root : Partition->roots()) { + for (node_impl &NodeDep : Root.predecessors()) { auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]]; Partition->MPredecessors.push_back(Predecessor.get()); Predecessor->MSuccessors.push_back(Partition.get()); @@ -340,13 +337,9 @@ graph_impl::~graph_impl() { } } -void graph_impl::addRoot(const std::shared_ptr &Root) { - MRoots.insert(Root); -} +void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); } -void graph_impl::removeRoot(const std::shared_ptr &Root) { - MRoots.erase(Root); -} +void graph_impl::removeRoot(node_impl &Root) { MRoots.erase(&Root); } std::set> graph_impl::getCGEdges( const std::shared_ptr &CommandGroup) const { @@ -593,7 +586,7 @@ bool graph_impl::clearQueues() { } bool graph_impl::checkForCycles() { - std::list> SortedNodes; + std::list SortedNodes; sortTopological(MRoots, SortedNodes, false); // If after a topological sort, not all the nodes in the graph are sorted, @@ -664,7 +657,7 @@ void graph_impl::makeEdge(std::shared_ptr Src, bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1; if (DestLostRootStatus) { // Dest is no longer a Root node, so we need to remove it from MRoots. - MRoots.erase(Dest); + MRoots.erase(Dest.get()); } // We can skip cycle checks if either Dest has no successors (cycle not @@ -679,14 +672,14 @@ void graph_impl::makeEdge(std::shared_ptr Src, Dest->MPredecessors.pop_back(); if (DestLostRootStatus) { // Add Dest back into MRoots. - MRoots.insert(Dest); + MRoots.insert(Dest.get()); } throw sycl::exception(make_error_code(sycl::errc::invalid), "Command graphs cannot contain cycles."); } } - removeRoot(Dest); // remove receiver from root node list + removeRoot(*Dest); // remove receiver from root node list } std::vector graph_impl::getExitNodesEvents( @@ -740,14 +733,12 @@ void exec_graph_impl::findRealDeps( } } -ur_exp_command_buffer_sync_point_t -exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, - sycl::detail::device_impl &DeviceImpl, - ur_exp_command_buffer_handle_t CommandBuffer, - std::shared_ptr Node) { +ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNodeDirect( + const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl, + ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node) { std::vector Deps; - for (node_impl &N : Node->predecessors()) { - findRealDeps(Deps, N, MPartitionNodes[Node.get()]); + for (node_impl &N : Node.predecessors()) { + findRealDeps(Deps, N, MPartitionNodes[&Node]); } ur_exp_command_buffer_sync_point_t NewSyncPoint; ur_exp_command_buffer_command_handle_t NewCommand = 0; @@ -760,7 +751,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, if (xptiEnabled) { StreamID = xptiRegisterStream(sycl::detail::SYCL_STREAM_NAME); sycl::detail::CGExecKernel *CGExec = - static_cast(Node->MCommandGroup.get()); + static_cast(Node.MCommandGroup.get()); sycl::detail::code_location CodeLoc(CGExec->MFileName.c_str(), CGExec->MFunctionName.c_str(), CGExec->MLine, CGExec->MColumn); @@ -776,11 +767,11 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, ur_result_t Res = sycl::detail::enqueueImpCommandBufferKernel( Ctx, DeviceImpl, CommandBuffer, - *static_cast((Node->MCommandGroup.get())), + *static_cast((Node.MCommandGroup.get())), Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr); if (MIsUpdatable) { - MCommandMap[Node.get()] = NewCommand; + MCommandMap[&Node] = NewCommand; } if (Res != UR_RESULT_SUCCESS) { @@ -799,20 +790,20 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, - std::shared_ptr Node) { + node_impl &Node) { std::vector Deps; - for (node_impl &N : Node->predecessors()) { - findRealDeps(Deps, N, MPartitionNodes[Node.get()]); + for (node_impl &N : Node.predecessors()) { + findRealDeps(Deps, N, MPartitionNodes[&Node]); } sycl::detail::EventImplPtr Event = sycl::detail::Scheduler::getInstance().addCG( - Node->getCGCopy(), *MQueueImpl, + Node.getCGCopy(), *MQueueImpl, /*EventNeeded=*/true, CommandBuffer, Deps); if (MIsUpdatable) { - MCommandMap[Node.get()] = Event->getCommandBufferCommand(); + MCommandMap[&Node] = Event->getCommandBufferCommand(); } return Event->getSyncPoint(); @@ -861,25 +852,25 @@ void exec_graph_impl::createCommandBuffers( Partition->MCommandBuffers[Device] = OutCommandBuffer; - for (const auto &Node : Partition->MSchedule) { + for (node_impl &Node : Partition->schedule()) { // Some nodes are not scheduled like other nodes, and only their // dependencies are propagated in findRealDeps - if (!Node->requiresEnqueue()) + if (!Node.requiresEnqueue()) continue; - sycl::detail::CGType type = Node->MCGType; + sycl::detail::CGType type = Node.MCGType; // If the node is a kernel with no special requirements we can enqueue it // directly. if (type == sycl::detail::CGType::Kernel && - Node->MCommandGroup->getRequirements().size() + + Node.MCommandGroup->getRequirements().size() + static_cast( - Node->MCommandGroup.get()) + Node.MCommandGroup.get()) ->MStreams.size() == 0) { - MSyncPoints[Node.get()] = + MSyncPoints[&Node] = enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node); } else { - MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node); + MSyncPoints[&Node] = enqueueNode(OutCommandBuffer, Node); } } @@ -2007,7 +1998,7 @@ std::vector modifiable_command_graph::get_nodes() const { std::vector modifiable_command_graph::get_root_nodes() const { graph_impl::ReadLock Lock(impl->MMutex); auto &Roots = impl->MRoots; - std::vector> Impls{}; + std::vector Impls{}; std::copy(Roots.begin(), Roots.end(), std::back_inserter(Impls)); return createNodesFromImpls(Impls); diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 0e257a77e5ef1..08255c5639627 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -53,10 +53,9 @@ class partition { partition() : MSchedule(), MCommandBuffers() {} /// List of root nodes. - std::set, std::owner_less>> - MRoots; + std::set MRoots; /// Execution schedule of nodes in the graph. - std::list> MSchedule; + std::list MSchedule; /// Map of devices to command buffers. std::unordered_map MCommandBuffers; @@ -84,17 +83,20 @@ class partition { // replaced every time the partition is executed. EventImplPtr MEvent; + nodes_range roots() const { return MRoots; } + nodes_range schedule() const { return MSchedule; } + /// Checks if the graph is single path, i.e. each node has a single successor. /// @return True if the graph is a single path bool checkIfGraphIsSinglePath() { if (MRoots.size() > 1) { return false; } - for (const auto &Node : MSchedule) { + for (node_impl &Node : schedule()) { // In version 1.3.28454 of the L0 driver, 2D Copy ops cannot not // be enqueued in an in-order cmd-list (causing execution to stall). // The 2D Copy test should be removed from here when the bug is fixed. - if ((Node->MSuccessors.size() > 1) || (Node->isNDCopyNode())) { + if ((Node.MSuccessors.size() > 1) || (Node.isNDCopyNode())) { return false; } } @@ -103,7 +105,7 @@ class partition { } /// Add nodes to MSchedule. - void schedule(); + void updateSchedule(); }; /// Implementation details of command_graph. @@ -126,7 +128,7 @@ class graph_impl : public std::enable_shared_from_this { /// Remove node from list of root nodes. /// @param Root Node to remove from list of root nodes. - void removeRoot(const std::shared_ptr &Root); + void removeRoot(node_impl &Root); /// Verifies the CG is valid to add to the graph and returns set of /// dependent nodes if so. @@ -281,8 +283,7 @@ class graph_impl : public std::enable_shared_from_this { sycl::device getDevice() const { return MDevice; } /// List of root nodes. - std::set, std::owner_less>> - MRoots; + std::set MRoots; /// Storage for all nodes contained within a graph. Nodes are connected to /// each other via weak_ptrs and so do not extend each other's lifetimes. @@ -290,6 +291,8 @@ class graph_impl : public std::enable_shared_from_this { /// than needing an expensive depth first search. std::vector> MNodeStorage; + nodes_range roots() const { return MRoots; } + /// Find the last node added to this graph from an in-order queue. /// @param Queue In-order queue to find the last node added to the graph from. /// @return Last node in this graph added from \p Queue recording, or empty @@ -312,8 +315,8 @@ class graph_impl : public std::enable_shared_from_this { std::fstream Stream(FilePath, std::ios::out); Stream << "digraph dot {" << std::endl; - for (std::weak_ptr Node : MRoots) - Node.lock()->printDotRecursive(Stream, VisitedNodes, Verbose); + for (node_impl &Node : roots()) + Node.printDotRecursive(Stream, VisitedNodes, Verbose); Stream << "}" << std::endl; @@ -418,13 +421,10 @@ class graph_impl : public std::enable_shared_from_this { } size_t RootsFound = 0; - for (std::weak_ptr NodeA : MRoots) { - for (std::weak_ptr NodeB : Graph.MRoots) { - auto NodeALocked = NodeA.lock(); - auto NodeBLocked = NodeB.lock(); - - if (NodeALocked->isSimilar(*NodeBLocked)) { - if (checkNodeRecursive(*NodeALocked, *NodeBLocked)) { + for (node_impl &NodeA : roots()) { + for (node_impl &NodeB : Graph.roots()) { + if (NodeA.isSimilar(NodeB)) { + if (checkNodeRecursive(NodeA, NodeB)) { RootsFound++; break; } @@ -518,7 +518,7 @@ class graph_impl : public std::enable_shared_from_this { /// Insert node into list of root nodes. /// @param Root Node to add to list of root nodes. - void addRoot(const std::shared_ptr &Root); + void addRoot(node_impl &Root); /// Adds dependencies for a new node, if it has no deps it will be /// added as a root node. @@ -527,10 +527,10 @@ class graph_impl : public std::enable_shared_from_this { void addDepsToNode(const std::shared_ptr &Node, nodes_range Deps) { for (node_impl &N : Deps) { N.registerSuccessor(Node); - this->removeRoot(Node); + this->removeRoot(*Node); } if (Node->MPredecessors.empty()) { - this->addRoot(Node); + this->addRoot(*Node); } } @@ -647,9 +647,7 @@ class exec_graph_impl { /// Query the scheduling of node execution. /// @return List of nodes in execution order. - const std::list> &getSchedule() const { - return MSchedule; - } + const std::list &getSchedule() const { return MSchedule; } /// Query the graph_impl. /// @return pointer to the graph_impl MGraphImpl @@ -730,8 +728,7 @@ class exec_graph_impl { /// @param Node The node being enqueued. /// @return UR sync point created for this node in the command-buffer. ur_exp_command_buffer_sync_point_t - enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, - std::shared_ptr Node); + enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node); /// Enqueue a node directly to the command-buffer without going through the /// scheduler. @@ -740,11 +737,9 @@ class exec_graph_impl { /// @param CommandBuffer Command-buffer to add node to as a command. /// @param Node The node being enqueued. /// @return UR sync point created for this node in the command-buffer. - ur_exp_command_buffer_sync_point_t - enqueueNodeDirect(const sycl::context &Ctx, - sycl::detail::device_impl &DeviceImpl, - ur_exp_command_buffer_handle_t CommandBuffer, - std::shared_ptr Node); + ur_exp_command_buffer_sync_point_t enqueueNodeDirect( + const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl, + ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node); /// Enqueues a host-task partition (i.e. a partition that contains only a /// single node and that node is a host-task). @@ -873,7 +868,7 @@ class exec_graph_impl { ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const; /// Execution schedule of nodes in the graph. - std::list> MSchedule; + std::list MSchedule; /// Pointer to the modifiable graph impl associated with this executable /// graph. /// Thread-safe implementation note: in the current implementation diff --git a/sycl/source/detail/graph/node_impl.cpp b/sycl/source/detail/graph/node_impl.cpp index 12c76c5522a97..36d2f589f0391 100644 --- a/sycl/source/detail/graph/node_impl.cpp +++ b/sycl/source/detail/graph/node_impl.cpp @@ -31,14 +31,11 @@ std::vector createNodesFromImpls( return Nodes; } -/// Takes a vector of shared_ptrs to node_impls and returns a vector of node -/// objects created from those impls, in the same order. -std::vector createNodesFromImpls( - const std::vector> &Impls) { +std::vector createNodesFromImpls(nodes_range Impls) { std::vector Nodes{}; Nodes.reserve(Impls.size()); - for (std::shared_ptr Impl : Impls) { + for (detail::node_impl &Impl : Impls) { Nodes.push_back(sycl::detail::createSyclObjFromImpl(Impl)); } diff --git a/sycl/source/detail/graph/node_impl.hpp b/sycl/source/detail/graph/node_impl.hpp index 11166e1eba897..578b62e0ab614 100644 --- a/sycl/source/detail/graph/node_impl.hpp +++ b/sycl/source/detail/graph/node_impl.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -39,10 +40,7 @@ class exec_graph_impl; std::vector createNodesFromImpls(const std::vector> &Impls); -/// Takes a vector of shared_ptrs to node_impls and returns a vector of node -/// objects created from those impls, in the same order. -std::vector -createNodesFromImpls(const std::vector> &Impls); +std::vector createNodesFromImpls(nodes_range Impls); inline node_type getNodeTypeFromCG(sycl::detail::CGType CGType) { using sycl::detail::CG; @@ -774,7 +772,9 @@ class nodes_range { // from `weak_ptr`s this alternative should be removed too. std::vector>, // - std::set>>; + std::set>, std::set, + // + std::list>; storage_iter Begin; storage_iter End; diff --git a/sycl/unittests/Extensions/CommandGraph/Barrier.cpp b/sycl/unittests/Extensions/CommandGraph/Barrier.cpp index 1793588cef111..648c28ab0af7b 100644 --- a/sycl/unittests/Extensions/CommandGraph/Barrier.cpp +++ b/sycl/unittests/Extensions/CommandGraph/Barrier.cpp @@ -39,10 +39,9 @@ TEST_F(CommandGraphTest, EnqueueBarrier) { // / \ // (4) (5) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto Node = Root.lock(); - ASSERT_EQ(Node->MSuccessors.size(), 1lu); - auto BarrierNode = Node->MSuccessors.front().lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto BarrierNode = Root.MSuccessors.front().lock(); ASSERT_EQ(BarrierNode->MCGType, sycl::detail::CGType::Barrier); ASSERT_EQ(GraphImpl.getEventForNode(BarrierNode).get(), &*getSyclObjImpl(Barrier)); @@ -79,14 +78,12 @@ TEST_F(CommandGraphTest, EnqueueBarrierMultipleQueues) { // / \ // (4) (5) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - - if (GraphImpl.getEventForNode(RootNode).get() == + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node2Graph)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier)); @@ -110,7 +107,7 @@ TEST_F(CommandGraphTest, EnqueueBarrierMultipleQueues) { } } } else { - ASSERT_EQ(RootNode->MSuccessors.size(), 0lu); + ASSERT_EQ(Root.MSuccessors.size(), 0lu); } } } @@ -147,10 +144,9 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitList) { // / \ / // (4) (5) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto Node = Root.lock(); - ASSERT_EQ(Node->MSuccessors.size(), 1lu); - auto SuccNode = Node->MSuccessors.front().lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); if (SuccNode->MCGType == sycl::detail::CGType::Barrier) { ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier)); @@ -204,10 +200,9 @@ TEST_F(CommandGraphTest, EnqueueBarrierWaitListMultipleQueues) { // \|/ // (B2) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto Node = Root.lock(); - ASSERT_EQ(Node->MSuccessors.size(), 1lu); - auto SuccNode = Node->MSuccessors.front().lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); if (SuccNode->MCGType == sycl::detail::CGType::Barrier) { ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier)); @@ -267,10 +262,9 @@ TEST_F(CommandGraphTest, EnqueueMultipleBarrier) { // / | \ // (6) (7) (8) ASSERT_EQ(GraphImpl.MRoots.size(), 3lu); - for (auto Root : GraphImpl.MRoots) { - auto Node = Root.lock(); - ASSERT_EQ(Node->MSuccessors.size(), 1lu); - auto SuccNode = Node->MSuccessors.front().lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); if (SuccNode->MCGType == sycl::detail::CGType::Barrier) { ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier1)); @@ -335,10 +329,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithPreviousCommand) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 1lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - ASSERT_EQ(RootNode->MSuccessors.size(), 0lu); - ASSERT_TRUE(RootNode->MCGType == sycl::detail::CGType::Barrier); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 0lu); + ASSERT_TRUE(Root.MCGType == sycl::detail::CGType::Barrier); } } @@ -370,20 +363,19 @@ TEST_F(CommandGraphTest, InOrderQueuesWithBarrier) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - - if (GraphImpl.getEventForNode(RootNode).get() == &*getSyclObjImpl(Node1)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == + &*getSyclObjImpl(Node1)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_TRUE(SuccNode->MCGType == sycl::detail::CGType::Barrier); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); - } else if (GraphImpl.getEventForNode(RootNode).get() == + } else if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node2)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 0lu); + ASSERT_EQ(Root.MSuccessors.size(), 0lu); } else { ASSERT_TRUE(false && "Unexpected root node"); } @@ -417,12 +409,10 @@ TEST_F(CommandGraphTest, InOrderQueuesWithBarrierWaitList) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); @@ -461,24 +451,23 @@ TEST_F(CommandGraphTest, InOrderQueuesWithEmptyBarrierWaitList) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == + &*getSyclObjImpl(Node1)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); - if (GraphImpl.getEventForNode(RootNode).get() == &*getSyclObjImpl(Node1)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); - } else if (GraphImpl.getEventForNode(RootNode).get() == + } else if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node2)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); + ASSERT_EQ(Root.MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Node3)); @@ -525,19 +514,18 @@ TEST_F(CommandGraphTest, BarrierMixedQueueTypes) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - - if (GraphImpl.getEventForNode(RootNode).get() == &*getSyclObjImpl(Node1)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - } else if (GraphImpl.getEventForNode(RootNode).get() == + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == + &*getSyclObjImpl(Node1)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + } else if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node2)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 2lu); + ASSERT_EQ(Root.MSuccessors.size(), 2lu); } else { ASSERT_TRUE(false && "Unexpected root node"); } - for (auto Succ : RootNode->MSuccessors) { + for (auto Succ : Root.MSuccessors) { auto SuccNode = Succ.lock(); if (GraphImpl.getEventForNode(SuccNode).get() == @@ -580,15 +568,14 @@ TEST_F(CommandGraphTest, BarrierBetweenExplicitNodes) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { - if (GraphImpl.getEventForNode(RootNode).get() == + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(BarrierNode)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 0lu); - } else if (RootNode.get() == &*getSyclObjImpl(Node1)) { - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - auto SuccNode = RootNode->MSuccessors.front().lock(); + ASSERT_EQ(Root.MSuccessors.size(), 0lu); + } else if (&Root == &*getSyclObjImpl(Node1)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(SuccNode.get(), &*getSyclObjImpl(Node2)); } else { ASSERT_TRUE(false); @@ -636,13 +623,12 @@ TEST_F(CommandGraphTest, BarrierMultipleOOOQueue) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 4u); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - auto RootNodeEvent = GraphImpl.getEventForNode(RootNode); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + auto RootNodeEvent = GraphImpl.getEventForNode(Root.shared_from_this()); if ((RootNodeEvent.get() == &*getSyclObjImpl(Node1)) || (RootNodeEvent.get() == &*getSyclObjImpl(Node2))) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); @@ -659,7 +645,7 @@ TEST_F(CommandGraphTest, BarrierMultipleOOOQueue) { &*getSyclObjImpl(Node6)); } else if ((RootNodeEvent.get() == &*getSyclObjImpl(Node3)) || (RootNodeEvent.get() == &*getSyclObjImpl(Node4))) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Node5)); @@ -701,18 +687,17 @@ TEST_F(CommandGraphTest, BarrierMultipleInOrderQueue) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2u); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - auto RootNodeEvent = GraphImpl.getEventForNode(RootNode); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + auto RootNodeEvent = GraphImpl.getEventForNode(Root.shared_from_this()); if (RootNodeEvent.get() == &*getSyclObjImpl(Node1)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); } else if (RootNodeEvent.get() == &*getSyclObjImpl(Node2)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Node3)); @@ -752,18 +737,17 @@ TEST_F(CommandGraphTest, BarrierMultipleMixedOrderQueues) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2u); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - auto RootNodeEvent = GraphImpl.getEventForNode(RootNode); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + auto RootNodeEvent = GraphImpl.getEventForNode(Root.shared_from_this()); if (RootNodeEvent.get() == &*getSyclObjImpl(Node1)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(BarrierNode)); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); } else if (RootNodeEvent.get() == &*getSyclObjImpl(Node2)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Node3)); @@ -797,18 +781,17 @@ TEST_F(CommandGraphTest, BarrierMultipleQueuesMultipleBarriers) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2u); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); - auto RootNodeEvent = GraphImpl.getEventForNode(RootNode); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { + auto RootNodeEvent = GraphImpl.getEventForNode(Root.shared_from_this()); if (RootNodeEvent.get() == &*getSyclObjImpl(Barrier1)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier4)); ASSERT_EQ(SuccNode->MPredecessors.size(), 1lu); ASSERT_EQ(SuccNode->MSuccessors.size(), 0lu); } else if (RootNodeEvent.get() == &*getSyclObjImpl(Barrier2)) { - auto SuccNode = RootNode->MSuccessors.front().lock(); + auto SuccNode = Root.MSuccessors.front().lock(); ASSERT_EQ(GraphImpl.getEventForNode(SuccNode).get(), &*getSyclObjImpl(Barrier3)); @@ -875,21 +858,21 @@ TEST_F(CommandGraphTest, BarrierWithInOrderCommands) { experimental::detail::graph_impl &GraphImpl = *getSyclObjImpl(Graph); ASSERT_EQ(GraphImpl.MRoots.size(), 2lu); - for (auto Root : GraphImpl.MRoots) { - auto RootNode = Root.lock(); + for (experimental::detail::node_impl &Root : GraphImpl.roots()) { bool EvenPath; - ASSERT_EQ(RootNode->MSuccessors.size(), 1lu); - if (GraphImpl.getEventForNode(RootNode).get() == &*getSyclObjImpl(Node2)) { + ASSERT_EQ(Root.MSuccessors.size(), 1lu); + if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == + &*getSyclObjImpl(Node2)) { EvenPath = true; - } else if (GraphImpl.getEventForNode(RootNode).get() == + } else if (GraphImpl.getEventForNode(Root.shared_from_this()).get() == &*getSyclObjImpl(Node1)) { EvenPath = false; } else { ASSERT_TRUE(false); } - auto Succ1Node = RootNode->MSuccessors.front().lock(); + auto Succ1Node = Root.MSuccessors.front().lock(); ASSERT_EQ(Succ1Node->MSuccessors.size(), 1lu); if (EvenPath) { ASSERT_EQ(GraphImpl.getEventForNode(Succ1Node).get(), diff --git a/sycl/unittests/Extensions/CommandGraph/CommandGraph.cpp b/sycl/unittests/Extensions/CommandGraph/CommandGraph.cpp index f9e9915ae9119..9799847d88d9f 100644 --- a/sycl/unittests/Extensions/CommandGraph/CommandGraph.cpp +++ b/sycl/unittests/Extensions/CommandGraph/CommandGraph.cpp @@ -59,7 +59,7 @@ TEST_F(CommandGraphTest, AddNode) { [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); ASSERT_FALSE(getSyclObjImpl(Node1)->isEmpty()); ASSERT_EQ(GraphImpl.MRoots.size(), 1lu); - ASSERT_EQ(GraphImpl.MRoots.begin()->lock().get(), &*getSyclObjImpl(Node1)); + ASSERT_EQ(*GraphImpl.MRoots.begin(), &*getSyclObjImpl(Node1)); ASSERT_TRUE(getSyclObjImpl(Node1)->MSuccessors.empty()); ASSERT_TRUE(getSyclObjImpl(Node1)->MPredecessors.empty()); diff --git a/sycl/unittests/Extensions/CommandGraph/Exceptions.cpp b/sycl/unittests/Extensions/CommandGraph/Exceptions.cpp index 589583e758fb9..ab651e140fe83 100644 --- a/sycl/unittests/Extensions/CommandGraph/Exceptions.cpp +++ b/sycl/unittests/Extensions/CommandGraph/Exceptions.cpp @@ -507,7 +507,7 @@ TEST_F(CommandGraphTest, MakeEdgeErrors) { experimental::detail::node_impl &NodeBImpl = *getSyclObjImpl(NodeB); ASSERT_EQ(GraphImpl.MRoots.size(), 1lu); - ASSERT_EQ(GraphImpl.MRoots.begin()->lock().get(), &NodeAImpl); + ASSERT_EQ(*GraphImpl.MRoots.begin(), &NodeAImpl); ASSERT_EQ(NodeAImpl.MSuccessors.size(), 1lu); ASSERT_EQ(NodeAImpl.MPredecessors.size(), 0lu); diff --git a/sycl/unittests/Extensions/CommandGraph/MultiThreaded.cpp b/sycl/unittests/Extensions/CommandGraph/MultiThreaded.cpp index f5242af68c535..fef1a947d0bc0 100644 --- a/sycl/unittests/Extensions/CommandGraph/MultiThreaded.cpp +++ b/sycl/unittests/Extensions/CommandGraph/MultiThreaded.cpp @@ -84,12 +84,10 @@ bool checkExecGraphSchedule( if (ScheduleA.size() != ScheduleB.size()) return false; - std::vector< - std::shared_ptr> - VScheduleA{std::begin(ScheduleA), std::end(ScheduleA)}; - std::vector< - std::shared_ptr> - VScheduleB{std::begin(ScheduleB), std::end(ScheduleB)}; + std::vector VScheduleA{ + std::begin(ScheduleA), std::end(ScheduleA)}; + std::vector VScheduleB{ + std::begin(ScheduleB), std::end(ScheduleB)}; for (size_t i = 0; i < VScheduleA.size(); i++) { if (!VScheduleA[i]->isSimilar(*VScheduleB[i])) @@ -244,7 +242,7 @@ TEST_F(MultiThreadGraphTest, RecordAddNodesInOrderQueue) { ASSERT_EQ(GraphImpl.MRoots.size(), 1lu); // Check structure graph - auto CurrentNode = GraphImpl.MRoots.begin()->lock(); + experimental::detail::node_impl *CurrentNode = *GraphImpl.MRoots.begin(); for (size_t i = 1; i <= GraphImpl.getNumberOfNodes(); i++) { EXPECT_LE(CurrentNode->MSuccessors.size(), 1lu); @@ -254,7 +252,7 @@ TEST_F(MultiThreadGraphTest, RecordAddNodesInOrderQueue) { } else { // Check other nodes have 1 successor EXPECT_EQ(CurrentNode->MSuccessors.size(), 1lu); - CurrentNode = CurrentNode->MSuccessors[0].lock(); + CurrentNode = CurrentNode->MSuccessors[0].lock().get(); } } }