From 865e68e7983e1ee7dc2bf2728c31ff5a574d0d64 Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Mon, 7 Jul 2025 12:05:24 -0700 Subject: [PATCH] [NFC][SYCL][Graph] Add `successors`/`predecessors` views + cleanup Part of refactoring to get rid of most (all?) `std::weak_ptr` and some of `std::shared_ptr` started in https://github.com/intel/llvm/pull/19295. Use `nodes_range` from that PR to implement `successors`/`predecessors` views and update read-only accesses to the successors/predecessors to go through them. I'm not changing the data members `MSuccessors`/`MPredecessors` yet because it would affect unittests. I'd prefer to refactor most of the code in future PRs before making that change and updating unittests in one go. I'm updating some APIs to accept `node_impl &` instead of `std::shared_ptr` where the change is mostly localized to the callers iterating over successors/predecessors and doesn't spoil into other code too much. For those that weren't updated here we (temporarily) use `shared_from_this()` but I expect to eliminate those unnecessary copies when those interfaces will be updated in the subsequent PRs. --- sycl/source/detail/graph/graph_impl.cpp | 101 +++++++++++------------ sycl/source/detail/graph/graph_impl.hpp | 17 ++-- sycl/source/detail/graph/memory_pool.cpp | 29 +++---- sycl/source/detail/graph/memory_pool.hpp | 2 +- sycl/source/detail/graph/node_impl.hpp | 9 ++ 5 files changed, 78 insertions(+), 80 deletions(-) diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 1f2d60c0dc9d..a6fb80afcd87 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -100,17 +100,16 @@ void sortTopological(std::set, Source.pop(); SortedNodes.push_back(Node); - for (auto &SuccWP : Node->MSuccessors) { - auto Succ = SuccWP.lock(); + for (node_impl &Succ : Node->successors()) { - if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum)) { + if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) { continue; } - auto &TotalVisitedEdges = Succ->MTotalVisitedEdges; + auto &TotalVisitedEdges = Succ.MTotalVisitedEdges; ++TotalVisitedEdges; - if (TotalVisitedEdges == Succ->MPredecessors.size()) { - Source.push(Succ); + if (TotalVisitedEdges == Succ.MPredecessors.size()) { + Source.push(Succ.weak_from_this()); } } } @@ -127,14 +126,14 @@ void sortTopological(std::set, /// a node with a smaller partition number. /// @param Node Node to assign to the partition. /// @param PartitionNum Number to propagate. -void propagatePartitionUp(std::shared_ptr Node, int PartitionNum) { - if (((Node->MPartitionNum != -1) && (Node->MPartitionNum <= PartitionNum)) || - (Node->MCGType == sycl::detail::CGType::CodeplayHostTask)) { +void propagatePartitionUp(node_impl &Node, int PartitionNum) { + if (((Node.MPartitionNum != -1) && (Node.MPartitionNum <= PartitionNum)) || + (Node.MCGType == sycl::detail::CGType::CodeplayHostTask)) { return; } - Node->MPartitionNum = PartitionNum; - for (auto &Predecessor : Node->MPredecessors) { - propagatePartitionUp(Predecessor.lock(), PartitionNum); + Node.MPartitionNum = PartitionNum; + for (node_impl &Predecessor : Node.predecessors()) { + propagatePartitionUp(Predecessor, PartitionNum); } } @@ -146,17 +145,17 @@ void propagatePartitionUp(std::shared_ptr Node, int PartitionNum) { /// @param HostTaskList List of host tasks that have already been processed and /// are encountered as successors to the node Node. void propagatePartitionDown( - const std::shared_ptr &Node, int PartitionNum, + node_impl &Node, int PartitionNum, std::list> &HostTaskList) { - if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) { - if (Node->MPartitionNum != -1) { - HostTaskList.push_front(Node); + if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) { + if (Node.MPartitionNum != -1) { + HostTaskList.push_front(Node.shared_from_this()); } return; } - Node->MPartitionNum = PartitionNum; - for (auto &Successor : Node->MSuccessors) { - propagatePartitionDown(Successor.lock(), PartitionNum, HostTaskList); + Node.MPartitionNum = PartitionNum; + for (node_impl &Successor : Node.successors()) { + propagatePartitionDown(Successor, PartitionNum, HostTaskList); } } @@ -165,8 +164,8 @@ void propagatePartitionDown( /// @param Node node to test /// @return True is `Node` is a root of its partition bool isPartitionRoot(std::shared_ptr Node) { - for (auto &Predecessor : Node->MPredecessors) { - if (Predecessor.lock()->MPartitionNum == Node->MPartitionNum) { + for (node_impl &Predecessor : Node->predecessors()) { + if (Predecessor.MPartitionNum == Node->MPartitionNum) { return false; } } @@ -221,15 +220,15 @@ void exec_graph_impl::makePartitions() { auto Node = HostTaskList.front(); HostTaskList.pop_front(); CurrentPartition++; - for (auto &Predecessor : Node->MPredecessors) { - propagatePartitionUp(Predecessor.lock(), CurrentPartition); + for (node_impl &Predecessor : Node->predecessors()) { + propagatePartitionUp(Predecessor, CurrentPartition); } CurrentPartition++; Node->MPartitionNum = CurrentPartition; CurrentPartition++; auto TmpSize = HostTaskList.size(); - for (auto &Successor : Node->MSuccessors) { - propagatePartitionDown(Successor.lock(), CurrentPartition, HostTaskList); + for (node_impl &Successor : Node->successors()) { + propagatePartitionDown(Successor, CurrentPartition, HostTaskList); } if (HostTaskList.size() > TmpSize) { // At least one HostTask has been re-numbered so group merge opportunities @@ -290,9 +289,9 @@ void exec_graph_impl::makePartitions() { for (const auto &Partition : MPartitions) { for (auto const &Root : Partition->MRoots) { auto RootNode = Root.lock(); - for (const auto &Dep : RootNode->MPredecessors) { - auto NodeDep = Dep.lock(); - auto &Predecessor = MPartitions[MPartitionNodes[NodeDep]]; + for (node_impl &NodeDep : RootNode->predecessors()) { + auto &Predecessor = + MPartitions[MPartitionNodes[NodeDep.shared_from_this()]]; Partition->MPredecessors.push_back(Predecessor.get()); Predecessor->MSuccessors.push_back(Partition.get()); } @@ -424,8 +423,8 @@ std::set> graph_impl::getCGEdges( bool ShouldAddDep = true; // If any of this node's successors have this requirement then we skip // adding the current node as a dependency. - for (auto &Succ : Node->MSuccessors) { - if (Succ.lock()->hasRequirementDependency(Req)) { + for (node_impl &Succ : Node->successors()) { + if (Succ.hasRequirementDependency(Req)) { ShouldAddDep = false; break; } @@ -774,17 +773,17 @@ void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) { // predecessors until we find the real dependency. void exec_graph_impl::findRealDeps( std::vector &Deps, - std::shared_ptr CurrentNode, int ReferencePartitionNum) { - if (!CurrentNode->requiresEnqueue()) { - for (auto &N : CurrentNode->MPredecessors) { - auto NodeImpl = N.lock(); + node_impl &CurrentNode, int ReferencePartitionNum) { + if (!CurrentNode.requiresEnqueue()) { + for (node_impl &NodeImpl : CurrentNode.predecessors()) { findRealDeps(Deps, NodeImpl, ReferencePartitionNum); } } else { + auto CurrentNodePtr = CurrentNode.shared_from_this(); // Verify if CurrentNode belong the the same partition - if (MPartitionNodes[CurrentNode] == ReferencePartitionNum) { + if (MPartitionNodes[CurrentNodePtr] == ReferencePartitionNum) { // Verify that the sync point has actually been set for this node. - auto SyncPoint = MSyncPoints.find(CurrentNode); + auto SyncPoint = MSyncPoints.find(CurrentNodePtr); assert(SyncPoint != MSyncPoints.end() && "No sync point has been set for node dependency."); // Check if the dependency has already been added. @@ -802,8 +801,8 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, ur_exp_command_buffer_handle_t CommandBuffer, std::shared_ptr Node) { std::vector Deps; - for (auto &N : Node->MPredecessors) { - findRealDeps(Deps, N.lock(), MPartitionNodes[Node]); + for (node_impl &N : Node->predecessors()) { + findRealDeps(Deps, N, MPartitionNodes[Node]); } ur_exp_command_buffer_sync_point_t NewSyncPoint; ur_exp_command_buffer_command_handle_t NewCommand = 0; @@ -858,8 +857,8 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, std::shared_ptr Node) { std::vector Deps; - for (auto &N : Node->MPredecessors) { - findRealDeps(Deps, N.lock(), MPartitionNodes[Node]); + for (node_impl &N : Node->predecessors()) { + findRealDeps(Deps, N, MPartitionNodes[Node]); } sycl::detail::EventImplPtr Event = @@ -1328,8 +1327,8 @@ void exec_graph_impl::duplicateNodes() { auto NodeCopy = NewNodes[i]; // Look through all the original node successors, find their copies and // register those as successors with the current copied node - for (auto &NextNode : OriginalNode->MSuccessors) { - auto Successor = NodesMap.at(NextNode.lock()); + for (node_impl &NextNode : OriginalNode->successors()) { + auto Successor = NodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1370,8 +1369,8 @@ void exec_graph_impl::duplicateNodes() { auto SubgraphNode = SubgraphNodes[i]; auto NodeCopy = NewSubgraphNodes[i]; - for (auto &NextNode : SubgraphNode->MSuccessors) { - auto Successor = SubgraphNodesMap.at(NextNode.lock()); + for (node_impl &NextNode : SubgraphNode->successors()) { + auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1392,9 +1391,8 @@ void exec_graph_impl::duplicateNodes() { // original subgraph node // Predecessors - for (auto &PredNodeWeak : NewNode->MPredecessors) { - auto PredNode = PredNodeWeak.lock(); - auto &Successors = PredNode->MSuccessors; + for (node_impl &PredNode : NewNode->predecessors()) { + auto &Successors = PredNode.MSuccessors; // Remove the subgraph node from this nodes successors Successors.erase(std::remove_if(Successors.begin(), Successors.end(), @@ -1406,14 +1404,13 @@ void exec_graph_impl::duplicateNodes() { // Add all input nodes from the subgraph as successors for this node // instead for (auto &Input : Inputs) { - PredNode->registerSuccessor(Input); + PredNode.registerSuccessor(Input); } } // Successors - for (auto &SuccNodeWeak : NewNode->MSuccessors) { - auto SuccNode = SuccNodeWeak.lock(); - auto &Predecessors = SuccNode->MPredecessors; + for (node_impl &SuccNode : NewNode->successors()) { + auto &Predecessors = SuccNode.MPredecessors; // Remove the subgraph node from this nodes successors Predecessors.erase(std::remove_if(Predecessors.begin(), @@ -1426,7 +1423,7 @@ void exec_graph_impl::duplicateNodes() { // Add all Output nodes from the subgraph as predecessors for this node // instead for (auto &Output : Outputs) { - Output->registerSuccessor(SuccNode); + Output->registerSuccessor(SuccNode.shared_from_this()); } } diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index 8c22c6fbb754..8a82c258d604 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -352,19 +352,17 @@ class graph_impl : public std::enable_shared_from_this { /// @param NodeA pointer to the first node for comparison /// @param NodeB pointer to the second node for comparison /// @return true is same structure found, false otherwise - static bool checkNodeRecursive(const std::shared_ptr &NodeA, - const std::shared_ptr &NodeB) { + static bool checkNodeRecursive(node_impl &NodeA, node_impl &NodeB) { size_t FoundCnt = 0; - for (std::weak_ptr &SuccA : NodeA->MSuccessors) { - for (std::weak_ptr &SuccB : NodeB->MSuccessors) { - if (NodeA->isSimilar(*NodeB) && - checkNodeRecursive(SuccA.lock(), SuccB.lock())) { + for (node_impl &SuccA : NodeA.successors()) { + for (node_impl &SuccB : NodeB.successors()) { + if (NodeA.isSimilar(NodeB) && checkNodeRecursive(SuccA, SuccB)) { FoundCnt++; break; } } } - if (FoundCnt != NodeA->MSuccessors.size()) { + if (FoundCnt != NodeA.MSuccessors.size()) { return false; } @@ -434,7 +432,7 @@ class graph_impl : public std::enable_shared_from_this { auto NodeBLocked = NodeB.lock(); if (NodeALocked->isSimilar(*NodeBLocked)) { - if (checkNodeRecursive(NodeALocked, NodeBLocked)) { + if (checkNodeRecursive(*NodeALocked, *NodeBLocked)) { RootsFound++; break; } @@ -829,8 +827,7 @@ class exec_graph_impl { /// SyncPoint for CurrentNode, otherwise we need to /// synchronize on the host with the completion of previous partitions. void findRealDeps(std::vector &Deps, - std::shared_ptr CurrentNode, - int ReferencePartitionNum); + node_impl &CurrentNode, int ReferencePartitionNum); /// Duplicate nodes from the modifiable graph associated with this executable /// graph and store them locally. Any subgraph nodes in the modifiable graph diff --git a/sycl/source/detail/graph/memory_pool.cpp b/sycl/source/detail/graph/memory_pool.cpp index fdbf90df56be..c36ee1f21669 100644 --- a/sycl/source/detail/graph/memory_pool.cpp +++ b/sycl/source/detail/graph/memory_pool.cpp @@ -116,25 +116,20 @@ graph_mem_pool::tryReuseExistingAllocation( // free nodes. We do this in a breadth-first approach because we want to find // the shortest path to a reusable allocation. - std::queue> NodesToCheck; + std::queue NodesToCheck; // Add all the dependent nodes to the queue, they will be popped first for (auto &Dep : DepNodes) { - NodesToCheck.push(Dep); + NodesToCheck.push(&*Dep); } // Called when traversing over nodes to check if the current node is a free // node for one of the available allocations. If it is we populate AllocInfo // with the allocation to be reused. auto CheckNodeEqual = - [&CompatibleAllocs](const std::shared_ptr &CurrentNode) - -> std::optional { + [&CompatibleAllocs](node_impl &CurrentNode) -> std::optional { for (auto &Alloc : CompatibleAllocs) { - const auto &AllocFreeNode = Alloc.LastFreeNode; - // Compare control blocks without having to lock AllocFreeNode to check - // for node equality - if (!CurrentNode.owner_before(AllocFreeNode) && - !AllocFreeNode.owner_before(CurrentNode)) { + if (&CurrentNode == Alloc.LastFreeNode) { return Alloc; } } @@ -142,9 +137,9 @@ graph_mem_pool::tryReuseExistingAllocation( }; while (!NodesToCheck.empty()) { - auto CurrentNode = NodesToCheck.front().lock(); + node_impl &CurrentNode = *NodesToCheck.front(); - if (CurrentNode->MTotalVisitedEdges > 0) { + if (CurrentNode.MTotalVisitedEdges > 0) { continue; } @@ -152,13 +147,13 @@ graph_mem_pool::tryReuseExistingAllocation( // for any of the allocations which are free for reuse. We should not bother // checking nodes that are not free nodes, so we continue and check their // predecessors. - if (CurrentNode->MNodeType == node_type::async_free) { + if (CurrentNode.MNodeType == node_type::async_free) { std::optional AllocFound = CheckNodeEqual(CurrentNode); if (AllocFound) { // Reset visited nodes tracking MGraph.resetNodeVisitedEdges(); // Reset last free node for allocation - MAllocations.at(AllocFound.value().Ptr).LastFreeNode.reset(); + MAllocations.at(AllocFound.value().Ptr).LastFreeNode = nullptr; // Remove found allocation from the free list MFreeAllocations.erase(std::find(MFreeAllocations.begin(), MFreeAllocations.end(), @@ -168,12 +163,12 @@ graph_mem_pool::tryReuseExistingAllocation( } // Add CurrentNode predecessors to queue - for (auto &Pred : CurrentNode->MPredecessors) { - NodesToCheck.push(Pred); + for (node_impl &Pred : CurrentNode.predecessors()) { + NodesToCheck.push(&Pred); } // Mark node as visited - CurrentNode->MTotalVisitedEdges = 1; + CurrentNode.MTotalVisitedEdges = 1; NodesToCheck.pop(); } @@ -183,7 +178,7 @@ graph_mem_pool::tryReuseExistingAllocation( void graph_mem_pool::markAllocationAsAvailable( void *Ptr, const std::shared_ptr &FreeNode) { MFreeAllocations.push_back(Ptr); - MAllocations.at(Ptr).LastFreeNode = FreeNode; + MAllocations.at(Ptr).LastFreeNode = FreeNode.get(); } } // namespace detail diff --git a/sycl/source/detail/graph/memory_pool.hpp b/sycl/source/detail/graph/memory_pool.hpp index aa4a2d1fb011..a6e023a6c4db 100644 --- a/sycl/source/detail/graph/memory_pool.hpp +++ b/sycl/source/detail/graph/memory_pool.hpp @@ -44,7 +44,7 @@ class graph_mem_pool { // Should the allocation be zero initialized during initial allocation bool ZeroInit = false; // Last free node for this allocation in the graph - std::weak_ptr LastFreeNode = {}; + node_impl *LastFreeNode = nullptr; }; public: diff --git a/sycl/source/detail/graph/node_impl.hpp b/sycl/source/detail/graph/node_impl.hpp index 455d3b921fd2..11166e1eba89 100644 --- a/sycl/source/detail/graph/node_impl.hpp +++ b/sycl/source/detail/graph/node_impl.hpp @@ -31,6 +31,7 @@ class node; namespace detail { // Forward declarations class node_impl; +class nodes_range; class exec_graph_impl; /// Takes a vector of weak_ptrs to node_impls and returns a vector of node @@ -116,6 +117,10 @@ class node_impl : public std::enable_shared_from_this { /// cannot be used to find out the partion of a node outside of this process. int MPartitionNum = -1; + // Out-of-class as need "complete" `nodes_range`: + inline nodes_range successors() const; + inline nodes_range predecessors() const; + /// Add successor to the node. /// @param Node Node to add as a successor. void registerSuccessor(const std::shared_ptr &Node) { @@ -830,6 +835,10 @@ class nodes_range { size_t size() const { return Size; } bool empty() const { return Size == 0; } }; + +inline nodes_range node_impl::successors() const { return MSuccessors; } +inline nodes_range node_impl::predecessors() const { return MPredecessors; } + } // namespace detail } // namespace experimental } // namespace oneapi