From d2ff9fe4175681e0ae20365c357dd902f83660ad Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Mon, 7 Jul 2025 12:05:24 -0700 Subject: [PATCH 1/2] [NFC][SYCL][Graph] Add `successors`/`predecessors` views + cleanup Part of refactoring to get rid of most (all?) `std::weak_ptr` and some of `std::shared_ptr` started in https://github.com/intel/llvm/pull/19295. Use `nodes_range` from that PR to implement `successors`/`predecessors` views and update read-only accesses to the successors/predecessors to go through them. I'm not changing the data members `MSuccessors`/`MPredecessors` yet because it would affect unittests. I'd prefer to refactor most of the code in future PRs before making that change and updating unittests in one go. I'm updating some APIs to accept `node_impl &` instead of `std::shared_ptr` where the change is mostly localized to the callers iterating over successors/predecessors and doesn't spoil into other code too much. For those that weren't updated here we (temporarily) use `shared_from_this()` but I expect to eliminate those unnecessary copies when those interfaces will be updated in the subsequent PRs. --- sycl/source/detail/graph/graph_impl.cpp | 101 +++++++++++------------ sycl/source/detail/graph/graph_impl.hpp | 17 ++-- sycl/source/detail/graph/memory_pool.cpp | 29 +++---- sycl/source/detail/graph/memory_pool.hpp | 2 +- sycl/source/detail/graph/node_impl.hpp | 9 ++ 5 files changed, 78 insertions(+), 80 deletions(-) diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index ee54916b1201..24d208a833ab 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -100,17 +100,16 @@ void sortTopological(std::set, Source.pop(); SortedNodes.push_back(Node); - for (auto &SuccWP : Node->MSuccessors) { - auto Succ = SuccWP.lock(); + for (node_impl &Succ : Node->successors()) { - if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum)) { + if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) { continue; } - auto &TotalVisitedEdges = Succ->MTotalVisitedEdges; + auto &TotalVisitedEdges = Succ.MTotalVisitedEdges; ++TotalVisitedEdges; - if (TotalVisitedEdges == Succ->MPredecessors.size()) { - Source.push(Succ); + if (TotalVisitedEdges == Succ.MPredecessors.size()) { + Source.push(Succ.weak_from_this()); } } } @@ -127,14 +126,14 @@ void sortTopological(std::set, /// a node with a smaller partition number. /// @param Node Node to assign to the partition. /// @param PartitionNum Number to propagate. -void propagatePartitionUp(std::shared_ptr Node, int PartitionNum) { - if (((Node->MPartitionNum != -1) && (Node->MPartitionNum <= PartitionNum)) || - (Node->MCGType == sycl::detail::CGType::CodeplayHostTask)) { +void propagatePartitionUp(node_impl &Node, int PartitionNum) { + if (((Node.MPartitionNum != -1) && (Node.MPartitionNum <= PartitionNum)) || + (Node.MCGType == sycl::detail::CGType::CodeplayHostTask)) { return; } - Node->MPartitionNum = PartitionNum; - for (auto &Predecessor : Node->MPredecessors) { - propagatePartitionUp(Predecessor.lock(), PartitionNum); + Node.MPartitionNum = PartitionNum; + for (node_impl &Predecessor : Node.predecessors()) { + propagatePartitionUp(Predecessor, PartitionNum); } } @@ -146,17 +145,17 @@ void propagatePartitionUp(std::shared_ptr Node, int PartitionNum) { /// @param HostTaskList List of host tasks that have already been processed and /// are encountered as successors to the node Node. void propagatePartitionDown( - const std::shared_ptr &Node, int PartitionNum, + node_impl &Node, int PartitionNum, std::list> &HostTaskList) { - if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) { - if (Node->MPartitionNum != -1) { - HostTaskList.push_front(Node); + if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) { + if (Node.MPartitionNum != -1) { + HostTaskList.push_front(Node.shared_from_this()); } return; } - Node->MPartitionNum = PartitionNum; - for (auto &Successor : Node->MSuccessors) { - propagatePartitionDown(Successor.lock(), PartitionNum, HostTaskList); + Node.MPartitionNum = PartitionNum; + for (node_impl &Successor : Node.successors()) { + propagatePartitionDown(Successor, PartitionNum, HostTaskList); } } @@ -165,8 +164,8 @@ void propagatePartitionDown( /// @param Node node to test /// @return True is `Node` is a root of its partition bool isPartitionRoot(std::shared_ptr Node) { - for (auto &Predecessor : Node->MPredecessors) { - if (Predecessor.lock()->MPartitionNum == Node->MPartitionNum) { + for (node_impl &Predecessor : Node->predecessors()) { + if (Predecessor.MPartitionNum == Node->MPartitionNum) { return false; } } @@ -221,15 +220,15 @@ void exec_graph_impl::makePartitions() { auto Node = HostTaskList.front(); HostTaskList.pop_front(); CurrentPartition++; - for (auto &Predecessor : Node->MPredecessors) { - propagatePartitionUp(Predecessor.lock(), CurrentPartition); + for (node_impl &Predecessor : Node->predecessors()) { + propagatePartitionUp(Predecessor, CurrentPartition); } CurrentPartition++; Node->MPartitionNum = CurrentPartition; CurrentPartition++; auto TmpSize = HostTaskList.size(); - for (auto &Successor : Node->MSuccessors) { - propagatePartitionDown(Successor.lock(), CurrentPartition, HostTaskList); + for (node_impl &Successor : Node->successors()) { + propagatePartitionDown(Successor, CurrentPartition, HostTaskList); } if (HostTaskList.size() > TmpSize) { // At least one HostTask has been re-numbered so group merge opportunities @@ -290,9 +289,9 @@ void exec_graph_impl::makePartitions() { for (const auto &Partition : MPartitions) { for (auto const &Root : Partition->MRoots) { auto RootNode = Root.lock(); - for (const auto &Dep : RootNode->MPredecessors) { - auto NodeDep = Dep.lock(); - auto &Predecessor = MPartitions[MPartitionNodes[NodeDep]]; + for (node_impl &NodeDep : RootNode->predecessors()) { + auto &Predecessor = + MPartitions[MPartitionNodes[NodeDep.shared_from_this()]]; Partition->MPredecessors.push_back(Predecessor.get()); Predecessor->MSuccessors.push_back(Partition.get()); } @@ -390,8 +389,8 @@ std::set> graph_impl::getCGEdges( bool ShouldAddDep = true; // If any of this node's successors have this requirement then we skip // adding the current node as a dependency. - for (auto &Succ : Node->MSuccessors) { - if (Succ.lock()->hasRequirementDependency(Req)) { + for (node_impl &Succ : Node->successors()) { + if (Succ.hasRequirementDependency(Req)) { ShouldAddDep = false; break; } @@ -721,17 +720,17 @@ void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) { // predecessors until we find the real dependency. void exec_graph_impl::findRealDeps( std::vector &Deps, - std::shared_ptr CurrentNode, int ReferencePartitionNum) { - if (!CurrentNode->requiresEnqueue()) { - for (auto &N : CurrentNode->MPredecessors) { - auto NodeImpl = N.lock(); + node_impl &CurrentNode, int ReferencePartitionNum) { + if (!CurrentNode.requiresEnqueue()) { + for (node_impl &NodeImpl : CurrentNode.predecessors()) { findRealDeps(Deps, NodeImpl, ReferencePartitionNum); } } else { + auto CurrentNodePtr = CurrentNode.shared_from_this(); // Verify if CurrentNode belong the the same partition - if (MPartitionNodes[CurrentNode] == ReferencePartitionNum) { + if (MPartitionNodes[CurrentNodePtr] == ReferencePartitionNum) { // Verify that the sync point has actually been set for this node. - auto SyncPoint = MSyncPoints.find(CurrentNode); + auto SyncPoint = MSyncPoints.find(CurrentNodePtr); assert(SyncPoint != MSyncPoints.end() && "No sync point has been set for node dependency."); // Check if the dependency has already been added. @@ -749,8 +748,8 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, ur_exp_command_buffer_handle_t CommandBuffer, std::shared_ptr Node) { std::vector Deps; - for (auto &N : Node->MPredecessors) { - findRealDeps(Deps, N.lock(), MPartitionNodes[Node]); + for (node_impl &N : Node->predecessors()) { + findRealDeps(Deps, N, MPartitionNodes[Node]); } ur_exp_command_buffer_sync_point_t NewSyncPoint; ur_exp_command_buffer_command_handle_t NewCommand = 0; @@ -805,8 +804,8 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, std::shared_ptr Node) { std::vector Deps; - for (auto &N : Node->MPredecessors) { - findRealDeps(Deps, N.lock(), MPartitionNodes[Node]); + for (node_impl &N : Node->predecessors()) { + findRealDeps(Deps, N, MPartitionNodes[Node]); } sycl::detail::EventImplPtr Event = @@ -1275,8 +1274,8 @@ void exec_graph_impl::duplicateNodes() { auto NodeCopy = NewNodes[i]; // Look through all the original node successors, find their copies and // register those as successors with the current copied node - for (auto &NextNode : OriginalNode->MSuccessors) { - auto Successor = NodesMap.at(NextNode.lock()); + for (node_impl &NextNode : OriginalNode->successors()) { + auto Successor = NodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1317,8 +1316,8 @@ void exec_graph_impl::duplicateNodes() { auto SubgraphNode = SubgraphNodes[i]; auto NodeCopy = NewSubgraphNodes[i]; - for (auto &NextNode : SubgraphNode->MSuccessors) { - auto Successor = SubgraphNodesMap.at(NextNode.lock()); + for (node_impl &NextNode : SubgraphNode->successors()) { + auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this()); NodeCopy->registerSuccessor(Successor); } } @@ -1339,9 +1338,8 @@ void exec_graph_impl::duplicateNodes() { // original subgraph node // Predecessors - for (auto &PredNodeWeak : NewNode->MPredecessors) { - auto PredNode = PredNodeWeak.lock(); - auto &Successors = PredNode->MSuccessors; + for (node_impl &PredNode : NewNode->predecessors()) { + auto &Successors = PredNode.MSuccessors; // Remove the subgraph node from this nodes successors Successors.erase(std::remove_if(Successors.begin(), Successors.end(), @@ -1353,14 +1351,13 @@ void exec_graph_impl::duplicateNodes() { // Add all input nodes from the subgraph as successors for this node // instead for (auto &Input : Inputs) { - PredNode->registerSuccessor(Input); + PredNode.registerSuccessor(Input); } } // Successors - for (auto &SuccNodeWeak : NewNode->MSuccessors) { - auto SuccNode = SuccNodeWeak.lock(); - auto &Predecessors = SuccNode->MPredecessors; + for (node_impl &SuccNode : NewNode->successors()) { + auto &Predecessors = SuccNode.MPredecessors; // Remove the subgraph node from this nodes successors Predecessors.erase(std::remove_if(Predecessors.begin(), @@ -1373,7 +1370,7 @@ void exec_graph_impl::duplicateNodes() { // Add all Output nodes from the subgraph as predecessors for this node // instead for (auto &Output : Outputs) { - Output->registerSuccessor(SuccNode); + Output->registerSuccessor(SuccNode.shared_from_this()); } } diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index d0944af0eb5c..d181156ce097 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -346,19 +346,17 @@ class graph_impl : public std::enable_shared_from_this { /// @param NodeA pointer to the first node for comparison /// @param NodeB pointer to the second node for comparison /// @return true is same structure found, false otherwise - static bool checkNodeRecursive(const std::shared_ptr &NodeA, - const std::shared_ptr &NodeB) { + static bool checkNodeRecursive(node_impl &NodeA, node_impl &NodeB) { size_t FoundCnt = 0; - for (std::weak_ptr &SuccA : NodeA->MSuccessors) { - for (std::weak_ptr &SuccB : NodeB->MSuccessors) { - if (NodeA->isSimilar(*NodeB) && - checkNodeRecursive(SuccA.lock(), SuccB.lock())) { + for (node_impl &SuccA : NodeA.successors()) { + for (node_impl &SuccB : NodeB.successors()) { + if (NodeA.isSimilar(NodeB) && checkNodeRecursive(SuccA, SuccB)) { FoundCnt++; break; } } } - if (FoundCnt != NodeA->MSuccessors.size()) { + if (FoundCnt != NodeA.MSuccessors.size()) { return false; } @@ -428,7 +426,7 @@ class graph_impl : public std::enable_shared_from_this { auto NodeBLocked = NodeB.lock(); if (NodeALocked->isSimilar(*NodeBLocked)) { - if (checkNodeRecursive(NodeALocked, NodeBLocked)) { + if (checkNodeRecursive(*NodeALocked, *NodeBLocked)) { RootsFound++; break; } @@ -817,8 +815,7 @@ class exec_graph_impl { /// SyncPoint for CurrentNode, otherwise we need to /// synchronize on the host with the completion of previous partitions. void findRealDeps(std::vector &Deps, - std::shared_ptr CurrentNode, - int ReferencePartitionNum); + node_impl &CurrentNode, int ReferencePartitionNum); /// Duplicate nodes from the modifiable graph associated with this executable /// graph and store them locally. Any subgraph nodes in the modifiable graph diff --git a/sycl/source/detail/graph/memory_pool.cpp b/sycl/source/detail/graph/memory_pool.cpp index fdbf90df56be..c36ee1f21669 100644 --- a/sycl/source/detail/graph/memory_pool.cpp +++ b/sycl/source/detail/graph/memory_pool.cpp @@ -116,25 +116,20 @@ graph_mem_pool::tryReuseExistingAllocation( // free nodes. We do this in a breadth-first approach because we want to find // the shortest path to a reusable allocation. - std::queue> NodesToCheck; + std::queue NodesToCheck; // Add all the dependent nodes to the queue, they will be popped first for (auto &Dep : DepNodes) { - NodesToCheck.push(Dep); + NodesToCheck.push(&*Dep); } // Called when traversing over nodes to check if the current node is a free // node for one of the available allocations. If it is we populate AllocInfo // with the allocation to be reused. auto CheckNodeEqual = - [&CompatibleAllocs](const std::shared_ptr &CurrentNode) - -> std::optional { + [&CompatibleAllocs](node_impl &CurrentNode) -> std::optional { for (auto &Alloc : CompatibleAllocs) { - const auto &AllocFreeNode = Alloc.LastFreeNode; - // Compare control blocks without having to lock AllocFreeNode to check - // for node equality - if (!CurrentNode.owner_before(AllocFreeNode) && - !AllocFreeNode.owner_before(CurrentNode)) { + if (&CurrentNode == Alloc.LastFreeNode) { return Alloc; } } @@ -142,9 +137,9 @@ graph_mem_pool::tryReuseExistingAllocation( }; while (!NodesToCheck.empty()) { - auto CurrentNode = NodesToCheck.front().lock(); + node_impl &CurrentNode = *NodesToCheck.front(); - if (CurrentNode->MTotalVisitedEdges > 0) { + if (CurrentNode.MTotalVisitedEdges > 0) { continue; } @@ -152,13 +147,13 @@ graph_mem_pool::tryReuseExistingAllocation( // for any of the allocations which are free for reuse. We should not bother // checking nodes that are not free nodes, so we continue and check their // predecessors. - if (CurrentNode->MNodeType == node_type::async_free) { + if (CurrentNode.MNodeType == node_type::async_free) { std::optional AllocFound = CheckNodeEqual(CurrentNode); if (AllocFound) { // Reset visited nodes tracking MGraph.resetNodeVisitedEdges(); // Reset last free node for allocation - MAllocations.at(AllocFound.value().Ptr).LastFreeNode.reset(); + MAllocations.at(AllocFound.value().Ptr).LastFreeNode = nullptr; // Remove found allocation from the free list MFreeAllocations.erase(std::find(MFreeAllocations.begin(), MFreeAllocations.end(), @@ -168,12 +163,12 @@ graph_mem_pool::tryReuseExistingAllocation( } // Add CurrentNode predecessors to queue - for (auto &Pred : CurrentNode->MPredecessors) { - NodesToCheck.push(Pred); + for (node_impl &Pred : CurrentNode.predecessors()) { + NodesToCheck.push(&Pred); } // Mark node as visited - CurrentNode->MTotalVisitedEdges = 1; + CurrentNode.MTotalVisitedEdges = 1; NodesToCheck.pop(); } @@ -183,7 +178,7 @@ graph_mem_pool::tryReuseExistingAllocation( void graph_mem_pool::markAllocationAsAvailable( void *Ptr, const std::shared_ptr &FreeNode) { MFreeAllocations.push_back(Ptr); - MAllocations.at(Ptr).LastFreeNode = FreeNode; + MAllocations.at(Ptr).LastFreeNode = FreeNode.get(); } } // namespace detail diff --git a/sycl/source/detail/graph/memory_pool.hpp b/sycl/source/detail/graph/memory_pool.hpp index aa4a2d1fb011..a6e023a6c4db 100644 --- a/sycl/source/detail/graph/memory_pool.hpp +++ b/sycl/source/detail/graph/memory_pool.hpp @@ -44,7 +44,7 @@ class graph_mem_pool { // Should the allocation be zero initialized during initial allocation bool ZeroInit = false; // Last free node for this allocation in the graph - std::weak_ptr LastFreeNode = {}; + node_impl *LastFreeNode = nullptr; }; public: diff --git a/sycl/source/detail/graph/node_impl.hpp b/sycl/source/detail/graph/node_impl.hpp index 455d3b921fd2..11166e1eba89 100644 --- a/sycl/source/detail/graph/node_impl.hpp +++ b/sycl/source/detail/graph/node_impl.hpp @@ -31,6 +31,7 @@ class node; namespace detail { // Forward declarations class node_impl; +class nodes_range; class exec_graph_impl; /// Takes a vector of weak_ptrs to node_impls and returns a vector of node @@ -116,6 +117,10 @@ class node_impl : public std::enable_shared_from_this { /// cannot be used to find out the partion of a node outside of this process. int MPartitionNum = -1; + // Out-of-class as need "complete" `nodes_range`: + inline nodes_range successors() const; + inline nodes_range predecessors() const; + /// Add successor to the node. /// @param Node Node to add as a successor. void registerSuccessor(const std::shared_ptr &Node) { @@ -830,6 +835,10 @@ class nodes_range { size_t size() const { return Size; } bool empty() const { return Size == 0; } }; + +inline nodes_range node_impl::successors() const { return MSuccessors; } +inline nodes_range node_impl::predecessors() const { return MPredecessors; } + } // namespace detail } // namespace experimental } // namespace oneapi From 45c2e5269b289c178a9378a962ce6548ebe0cbe8 Mon Sep 17 00:00:00 2001 From: Andrei Elovikov Date: Mon, 7 Jul 2025 14:30:50 -0700 Subject: [PATCH 2/2] [NFC][SYCL][Graph] Update some maps to use raw `node_impl *` Continuation of the refactoring in https://github.com/intel/llvm/pull/19295 https://github.com/intel/llvm/pull/19332 --- sycl/source/detail/async_alloc.cpp | 4 +- sycl/source/detail/graph/graph_impl.cpp | 36 +++++------ sycl/source/detail/graph/graph_impl.hpp | 25 +++----- sycl/source/handler.cpp | 16 +++-- .../Extensions/CommandGraph/InOrderQueue.cpp | 64 +++++++++---------- 5 files changed, 70 insertions(+), 75 deletions(-) diff --git a/sycl/source/detail/async_alloc.cpp b/sycl/source/detail/async_alloc.cpp index 96861fa8a587..8142109fbd18 100644 --- a/sycl/source/detail/async_alloc.cpp +++ b/sycl/source/detail/async_alloc.cpp @@ -47,9 +47,9 @@ std::vector> getDepGraphNodes( // If this is being recorded from an in-order queue we need to get the last // in-order node if any, since this will later become a dependency of the // node being processed here. - if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue); + if (detail::node_impl *LastInOrderNode = Graph->getLastInorderNode(Queue); LastInOrderNode) { - DepNodes.push_back(LastInOrderNode); + DepNodes.push_back(LastInOrderNode->shared_from_this()); } return DepNodes; } diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index 24d208a833ab..bd71eb7a68ca 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -255,7 +255,7 @@ void exec_graph_impl::makePartitions() { const std::shared_ptr &Partition = std::make_shared(); for (auto &Node : MNodeStorage) { if (Node->MPartitionNum == i) { - MPartitionNodes[Node] = PartitionFinalNum; + MPartitionNodes[Node.get()] = PartitionFinalNum; if (isPartitionRoot(Node)) { Partition->MRoots.insert(Node); if (Node->MCGType == CGType::CodeplayHostTask) { @@ -290,8 +290,7 @@ void exec_graph_impl::makePartitions() { for (auto const &Root : Partition->MRoots) { auto RootNode = Root.lock(); for (node_impl &NodeDep : RootNode->predecessors()) { - auto &Predecessor = - MPartitions[MPartitionNodes[NodeDep.shared_from_this()]]; + auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]]; Partition->MPredecessors.push_back(Predecessor.get()); Predecessor->MSuccessors.push_back(Partition.get()); } @@ -610,8 +609,7 @@ bool graph_impl::checkForCycles() { return CycleFound; } -std::shared_ptr -graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) { +node_impl *graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) { if (!Queue) { assert(0 == MInorderQueueMap.count(std::weak_ptr{})); @@ -624,8 +622,8 @@ graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) { } void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue, - std::shared_ptr Node) { - MInorderQueueMap[Queue.weak_from_this()] = std::move(Node); + node_impl &Node) { + MInorderQueueMap[Queue.weak_from_this()] = &Node; } void graph_impl::makeEdge(std::shared_ptr Src, @@ -726,11 +724,10 @@ void exec_graph_impl::findRealDeps( findRealDeps(Deps, NodeImpl, ReferencePartitionNum); } } else { - auto CurrentNodePtr = CurrentNode.shared_from_this(); // Verify if CurrentNode belong the the same partition - if (MPartitionNodes[CurrentNodePtr] == ReferencePartitionNum) { + if (MPartitionNodes[&CurrentNode] == ReferencePartitionNum) { // Verify that the sync point has actually been set for this node. - auto SyncPoint = MSyncPoints.find(CurrentNodePtr); + auto SyncPoint = MSyncPoints.find(&CurrentNode); assert(SyncPoint != MSyncPoints.end() && "No sync point has been set for node dependency."); // Check if the dependency has already been added. @@ -749,7 +746,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, std::shared_ptr Node) { std::vector Deps; for (node_impl &N : Node->predecessors()) { - findRealDeps(Deps, N, MPartitionNodes[Node]); + findRealDeps(Deps, N, MPartitionNodes[Node.get()]); } ur_exp_command_buffer_sync_point_t NewSyncPoint; ur_exp_command_buffer_command_handle_t NewCommand = 0; @@ -782,7 +779,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx, Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr); if (MIsUpdatable) { - MCommandMap[Node] = NewCommand; + MCommandMap[Node.get()] = NewCommand; } if (Res != UR_RESULT_SUCCESS) { @@ -805,7 +802,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, std::vector Deps; for (node_impl &N : Node->predecessors()) { - findRealDeps(Deps, N, MPartitionNodes[Node]); + findRealDeps(Deps, N, MPartitionNodes[Node.get()]); } sycl::detail::EventImplPtr Event = @@ -814,7 +811,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer, /*EventNeeded=*/true, CommandBuffer, Deps); if (MIsUpdatable) { - MCommandMap[Node] = Event->getCommandBufferCommand(); + MCommandMap[Node.get()] = Event->getCommandBufferCommand(); } return Event->getSyncPoint(); @@ -830,7 +827,8 @@ void exec_graph_impl::buildRequirements() { Node->MCommandGroup->getRequirements().begin(), Node->MCommandGroup->getRequirements().end()); - std::shared_ptr &Partition = MPartitions[MPartitionNodes[Node]]; + std::shared_ptr &Partition = + MPartitions[MPartitionNodes[Node.get()]]; Partition->MRequirements.insert( Partition->MRequirements.end(), @@ -877,10 +875,10 @@ void exec_graph_impl::createCommandBuffers( Node->MCommandGroup.get()) ->MStreams.size() == 0) { - MSyncPoints[Node] = + MSyncPoints[Node.get()] = enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node); } else { - MSyncPoints[Node] = enqueueNode(OutCommandBuffer, Node); + MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node); } } @@ -1726,7 +1724,7 @@ void exec_graph_impl::populateURKernelUpdateStructs( auto ExecNode = MIDCache.find(Node->MID); assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); - auto Command = MCommandMap.find(ExecNode->second); + auto Command = MCommandMap.find(ExecNode->second.get()); assert(Command != MCommandMap.end()); UpdateDesc.hCommand = Command->second; @@ -1756,7 +1754,7 @@ exec_graph_impl::getURUpdatableNodes( auto ExecNode = MIDCache.find(Node->MID); assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); - auto PartitionIndex = MPartitionNodes.find(ExecNode->second); + auto PartitionIndex = MPartitionNodes.find(ExecNode->second.get()); assert(PartitionIndex != MPartitionNodes.end()); PartitionedNodes[PartitionIndex->second].push_back(Node); } diff --git a/sycl/source/detail/graph/graph_impl.hpp b/sycl/source/detail/graph/graph_impl.hpp index d181156ce097..0e257a77e5ef 100644 --- a/sycl/source/detail/graph/graph_impl.hpp +++ b/sycl/source/detail/graph/graph_impl.hpp @@ -294,14 +294,12 @@ class graph_impl : public std::enable_shared_from_this { /// @param Queue In-order queue to find the last node added to the graph from. /// @return Last node in this graph added from \p Queue recording, or empty /// shared pointer if none. - std::shared_ptr - getLastInorderNode(sycl::detail::queue_impl *Queue); + node_impl *getLastInorderNode(sycl::detail::queue_impl *Queue); /// Track the last node added to this graph from an in-order queue. /// @param Queue In-order queue to register \p Node for. /// @param Node Last node that was added to this graph from \p Queue. - void setLastInorderNode(sycl::detail::queue_impl &Queue, - std::shared_ptr Node); + void setLastInorderNode(sycl::detail::queue_impl &Queue, node_impl &Node); /// Prints the contents of the graph to a text file in DOT format. /// @param FilePath Path to the output file. @@ -465,15 +463,14 @@ class graph_impl : public std::enable_shared_from_this { /// @param[in] Queue The queue the barrier was recorded from. /// @param[in] BarrierNodeImpl The created barrier node. void setBarrierDep(std::weak_ptr Queue, - std::shared_ptr BarrierNodeImpl) { - MBarrierDependencyMap[Queue] = BarrierNodeImpl; + node_impl &BarrierNodeImpl) { + MBarrierDependencyMap[Queue] = &BarrierNodeImpl; } /// Get the last barrier node that was submitted to the queue. /// @param[in] Queue The queue to find the last barrier node of. An empty /// shared_ptr is returned if no barrier node has been recorded to the queue. - std::shared_ptr - getBarrierDep(std::weak_ptr Queue) { + node_impl *getBarrierDep(std::weak_ptr Queue) { return MBarrierDependencyMap[Queue]; } @@ -553,7 +550,7 @@ class graph_impl : public std::enable_shared_from_this { /// Map for every in-order queue thats recorded a node to the graph, what /// the last node added was. We can use this to create new edges on the last /// node if any more nodes are added to the graph from the queue. - std::map, std::shared_ptr, + std::map, node_impl *, std::owner_less>> MInorderQueueMap; /// Controls whether we skip the cycle checks in makeEdge, set by the presence @@ -568,7 +565,7 @@ class graph_impl : public std::enable_shared_from_this { /// Mapping from queues to barrier nodes. For each queue the last barrier /// node recorded to the graph from the queue is stored. - std::map, std::shared_ptr, + std::map, node_impl *, std::owner_less>> MBarrierDependencyMap; /// Graph memory pool for handling graph-owned memory allocations for this @@ -886,14 +883,13 @@ class exec_graph_impl { std::shared_ptr MGraphImpl; /// Map of nodes in the exec graph to the sync point representing their /// execution in the command graph. - std::unordered_map, - ur_exp_command_buffer_sync_point_t> + std::unordered_map MSyncPoints; /// Sycl queue impl ptr associated with this graph. std::shared_ptr MQueueImpl; /// Map of nodes in the exec graph to the partition number to which they /// belong. - std::unordered_map, int> MPartitionNodes; + std::unordered_map MPartitionNodes; /// Device associated with this executable graph. sycl::device MDevice; /// Context associated with this executable graph. @@ -909,8 +905,7 @@ class exec_graph_impl { /// Storage for copies of nodes from the original modifiable graph. std::vector> MNodeStorage; /// Map of nodes to their associated UR command handles. - std::unordered_map, - ur_exp_command_buffer_command_handle_t> + std::unordered_map MCommandMap; /// List of partition without any predecessors in this exec graph. std::vector> MRootPartitions; diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 9c1f9068096b..f575885b6a24 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -888,28 +888,30 @@ event handler::finalize() { // node can set it as a predecessor. std::vector> Deps; - if (auto DependentNode = GraphImpl->getLastInorderNode(Queue)) { - Deps.push_back(std::move(DependentNode)); + if (ext::oneapi::experimental::detail::node_impl *DependentNode = + GraphImpl->getLastInorderNode(Queue)) { + Deps.push_back(DependentNode->shared_from_this()); } NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps); // If we are recording an in-order queue remember the new node, so it // can be used as a dependency for any more nodes recorded from this // queue. - GraphImpl->setLastInorderNode(*Queue, NodeImpl); + GraphImpl->setLastInorderNode(*Queue, *NodeImpl); } else { - auto LastBarrierRecordedFromQueue = - GraphImpl->getBarrierDep(Queue->weak_from_this()); + ext::oneapi::experimental::detail::node_impl + *LastBarrierRecordedFromQueue = + GraphImpl->getBarrierDep(Queue->weak_from_this()); std::vector> Deps; if (LastBarrierRecordedFromQueue) { - Deps.push_back(LastBarrierRecordedFromQueue); + Deps.push_back(LastBarrierRecordedFromQueue->shared_from_this()); } NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps); if (NodeImpl->MCGType == sycl::detail::CGType::Barrier) { - GraphImpl->setBarrierDep(Queue->weak_from_this(), NodeImpl); + GraphImpl->setBarrierDep(Queue->weak_from_this(), *NodeImpl); } } diff --git a/sycl/unittests/Extensions/CommandGraph/InOrderQueue.cpp b/sycl/unittests/Extensions/CommandGraph/InOrderQueue.cpp index a2c2e3b8df2e..17703be2de59 100644 --- a/sycl/unittests/Extensions/CommandGraph/InOrderQueue.cpp +++ b/sycl/unittests/Extensions/CommandGraph/InOrderQueue.cpp @@ -35,9 +35,9 @@ TEST_F(CommandGraphTest, InOrderQueue) { ASSERT_NE(PtrNode2, nullptr); ASSERT_NE(PtrNode2, PtrNode1); ASSERT_EQ(PtrNode1->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode1->MSuccessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode1->MSuccessors.front().lock().get(), PtrNode2); ASSERT_EQ(PtrNode2->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MPredecessors.front().lock(), PtrNode1); + ASSERT_EQ(PtrNode2->MPredecessors.front().lock().get(), PtrNode1); auto Node3Graph = InOrderQueue.submit( [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); @@ -47,9 +47,9 @@ TEST_F(CommandGraphTest, InOrderQueue) { ASSERT_NE(PtrNode3, nullptr); ASSERT_NE(PtrNode3, PtrNode2); ASSERT_EQ(PtrNode2->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MSuccessors.front().lock(), PtrNode3); + ASSERT_EQ(PtrNode2->MSuccessors.front().lock().get(), PtrNode3); ASSERT_EQ(PtrNode3->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode3->MPredecessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode3->MPredecessors.front().lock().get(), PtrNode2); InOrderGraph.end_recording(InOrderQueue); @@ -92,9 +92,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithEmpty) { ASSERT_NE(PtrNode2, nullptr); ASSERT_NE(PtrNode2, PtrNode1); ASSERT_EQ(PtrNode1->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode1->MSuccessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode1->MSuccessors.front().lock().get(), PtrNode2); ASSERT_EQ(PtrNode2->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MPredecessors.front().lock(), PtrNode1); + ASSERT_EQ(PtrNode2->MPredecessors.front().lock().get(), PtrNode1); auto Node3Graph = InOrderQueue.submit( [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); @@ -104,9 +104,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithEmpty) { ASSERT_NE(PtrNode3, nullptr); ASSERT_NE(PtrNode3, PtrNode2); ASSERT_EQ(PtrNode2->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MSuccessors.front().lock(), PtrNode3); + ASSERT_EQ(PtrNode2->MSuccessors.front().lock().get(), PtrNode3); ASSERT_EQ(PtrNode3->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode3->MPredecessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode3->MPredecessors.front().lock().get(), PtrNode2); InOrderGraph.end_recording(InOrderQueue); @@ -150,9 +150,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithEmptyFirst) { ASSERT_NE(PtrNode2, nullptr); ASSERT_NE(PtrNode2, PtrNode1); ASSERT_EQ(PtrNode1->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode1->MSuccessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode1->MSuccessors.front().lock().get(), PtrNode2); ASSERT_EQ(PtrNode2->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MPredecessors.front().lock(), PtrNode1); + ASSERT_EQ(PtrNode2->MPredecessors.front().lock().get(), PtrNode1); auto Node3Graph = InOrderQueue.submit( [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); @@ -162,9 +162,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithEmptyFirst) { ASSERT_NE(PtrNode3, nullptr); ASSERT_NE(PtrNode3, PtrNode2); ASSERT_EQ(PtrNode2->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MSuccessors.front().lock(), PtrNode3); + ASSERT_EQ(PtrNode2->MSuccessors.front().lock().get(), PtrNode3); ASSERT_EQ(PtrNode3->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode3->MPredecessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode3->MPredecessors.front().lock().get(), PtrNode2); InOrderGraph.end_recording(InOrderQueue); @@ -209,9 +209,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithEmptyLast) { ASSERT_NE(PtrNode2, nullptr); ASSERT_NE(PtrNode2, PtrNode1); ASSERT_EQ(PtrNode1->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode1->MSuccessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode1->MSuccessors.front().lock().get(), PtrNode2); ASSERT_EQ(PtrNode2->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MPredecessors.front().lock(), PtrNode1); + ASSERT_EQ(PtrNode2->MPredecessors.front().lock().get(), PtrNode1); auto Node3Graph = InOrderQueue.submit([&](sycl::handler &cgh) {}); @@ -220,9 +220,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithEmptyLast) { ASSERT_NE(PtrNode3, nullptr); ASSERT_NE(PtrNode3, PtrNode2); ASSERT_EQ(PtrNode2->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MSuccessors.front().lock(), PtrNode3); + ASSERT_EQ(PtrNode2->MSuccessors.front().lock().get(), PtrNode3); ASSERT_EQ(PtrNode3->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode3->MPredecessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode3->MPredecessors.front().lock().get(), PtrNode2); InOrderGraph.end_recording(InOrderQueue); @@ -279,9 +279,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithPreviousHostTask) { ASSERT_NE(PtrNode2, nullptr); ASSERT_NE(PtrNode2, PtrNode1); ASSERT_EQ(PtrNode1->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode1->MSuccessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode1->MSuccessors.front().lock().get(), PtrNode2); ASSERT_EQ(PtrNode2->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MPredecessors.front().lock(), PtrNode1); + ASSERT_EQ(PtrNode2->MPredecessors.front().lock().get(), PtrNode1); auto Node3Graph = InOrderQueue.submit( [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); @@ -291,9 +291,9 @@ TEST_F(CommandGraphTest, InOrderQueueWithPreviousHostTask) { ASSERT_NE(PtrNode3, nullptr); ASSERT_NE(PtrNode3, PtrNode2); ASSERT_EQ(PtrNode2->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MSuccessors.front().lock(), PtrNode3); + ASSERT_EQ(PtrNode2->MSuccessors.front().lock().get(), PtrNode3); ASSERT_EQ(PtrNode3->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode3->MPredecessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode3->MPredecessors.front().lock().get(), PtrNode2); InOrderGraph.end_recording(InOrderQueue); @@ -346,9 +346,9 @@ TEST_F(CommandGraphTest, InOrderQueueHostTaskAndGraph) { ASSERT_NE(PtrNode2, nullptr); ASSERT_NE(PtrNode2, PtrNode1); ASSERT_EQ(PtrNode1->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode1->MSuccessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode1->MSuccessors.front().lock().get(), PtrNode2); ASSERT_EQ(PtrNode2->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MPredecessors.front().lock(), PtrNode1); + ASSERT_EQ(PtrNode2->MPredecessors.front().lock().get(), PtrNode1); auto Node3Graph = InOrderQueue.submit( [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); @@ -358,9 +358,9 @@ TEST_F(CommandGraphTest, InOrderQueueHostTaskAndGraph) { ASSERT_NE(PtrNode3, nullptr); ASSERT_NE(PtrNode3, PtrNode2); ASSERT_EQ(PtrNode2->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MSuccessors.front().lock(), PtrNode3); + ASSERT_EQ(PtrNode2->MSuccessors.front().lock().get(), PtrNode3); ASSERT_EQ(PtrNode3->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode3->MPredecessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode3->MPredecessors.front().lock().get(), PtrNode2); InOrderGraph.end_recording(InOrderQueue); @@ -423,9 +423,9 @@ TEST_F(CommandGraphTest, InOrderQueueMemsetAndGraph) { ASSERT_NE(PtrNode2, nullptr); ASSERT_NE(PtrNode2, PtrNode1); ASSERT_EQ(PtrNode1->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode1->MSuccessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode1->MSuccessors.front().lock().get(), PtrNode2); ASSERT_EQ(PtrNode2->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MPredecessors.front().lock(), PtrNode1); + ASSERT_EQ(PtrNode2->MPredecessors.front().lock().get(), PtrNode1); auto Node3Graph = InOrderQueue.submit( [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); @@ -435,9 +435,9 @@ TEST_F(CommandGraphTest, InOrderQueueMemsetAndGraph) { ASSERT_NE(PtrNode3, nullptr); ASSERT_NE(PtrNode3, PtrNode2); ASSERT_EQ(PtrNode2->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MSuccessors.front().lock(), PtrNode3); + ASSERT_EQ(PtrNode2->MSuccessors.front().lock().get(), PtrNode3); ASSERT_EQ(PtrNode3->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode3->MPredecessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode3->MPredecessors.front().lock().get(), PtrNode2); InOrderGraph.end_recording(InOrderQueue); @@ -483,9 +483,9 @@ TEST_F(CommandGraphTest, InOrderQueueMemcpyAndGraph) { ASSERT_NE(PtrNode2, nullptr); ASSERT_NE(PtrNode2, PtrNode1); ASSERT_EQ(PtrNode1->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode1->MSuccessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode1->MSuccessors.front().lock().get(), PtrNode2); ASSERT_EQ(PtrNode2->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MPredecessors.front().lock(), PtrNode1); + ASSERT_EQ(PtrNode2->MPredecessors.front().lock().get(), PtrNode1); auto Node3Graph = InOrderQueue.submit( [&](sycl::handler &cgh) { cgh.single_task([]() {}); }); @@ -495,9 +495,9 @@ TEST_F(CommandGraphTest, InOrderQueueMemcpyAndGraph) { ASSERT_NE(PtrNode3, nullptr); ASSERT_NE(PtrNode3, PtrNode2); ASSERT_EQ(PtrNode2->MSuccessors.size(), 1lu); - ASSERT_EQ(PtrNode2->MSuccessors.front().lock(), PtrNode3); + ASSERT_EQ(PtrNode2->MSuccessors.front().lock().get(), PtrNode3); ASSERT_EQ(PtrNode3->MPredecessors.size(), 1lu); - ASSERT_EQ(PtrNode3->MPredecessors.front().lock(), PtrNode2); + ASSERT_EQ(PtrNode3->MPredecessors.front().lock().get(), PtrNode2); InOrderGraph.end_recording(InOrderQueue);