Skip to content

[NFC][SYCL][Graph] Add successors/predecessors views + cleanup #19332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 49 additions & 52 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,16 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
Source.pop();
SortedNodes.push_back(Node);

for (auto &SuccWP : Node->MSuccessors) {
auto Succ = SuccWP.lock();
for (node_impl &Succ : Node->successors()) {

if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum)) {
if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) {
continue;
}

auto &TotalVisitedEdges = Succ->MTotalVisitedEdges;
auto &TotalVisitedEdges = Succ.MTotalVisitedEdges;
++TotalVisitedEdges;
if (TotalVisitedEdges == Succ->MPredecessors.size()) {
Source.push(Succ);
if (TotalVisitedEdges == Succ.MPredecessors.size()) {
Source.push(Succ.weak_from_this());
}
}
}
Expand All @@ -127,14 +126,14 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
/// a node with a smaller partition number.
/// @param Node Node to assign to the partition.
/// @param PartitionNum Number to propagate.
void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
if (((Node->MPartitionNum != -1) && (Node->MPartitionNum <= PartitionNum)) ||
(Node->MCGType == sycl::detail::CGType::CodeplayHostTask)) {
void propagatePartitionUp(node_impl &Node, int PartitionNum) {
if (((Node.MPartitionNum != -1) && (Node.MPartitionNum <= PartitionNum)) ||
(Node.MCGType == sycl::detail::CGType::CodeplayHostTask)) {
return;
}
Node->MPartitionNum = PartitionNum;
for (auto &Predecessor : Node->MPredecessors) {
propagatePartitionUp(Predecessor.lock(), PartitionNum);
Node.MPartitionNum = PartitionNum;
for (node_impl &Predecessor : Node.predecessors()) {
propagatePartitionUp(Predecessor, PartitionNum);
}
}

Expand All @@ -146,17 +145,17 @@ void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
/// @param HostTaskList List of host tasks that have already been processed and
/// are encountered as successors to the node Node.
void propagatePartitionDown(
const std::shared_ptr<node_impl> &Node, int PartitionNum,
node_impl &Node, int PartitionNum,
std::list<std::shared_ptr<node_impl>> &HostTaskList) {
if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) {
if (Node->MPartitionNum != -1) {
HostTaskList.push_front(Node);
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
if (Node.MPartitionNum != -1) {
HostTaskList.push_front(Node.shared_from_this());
}
return;
}
Node->MPartitionNum = PartitionNum;
for (auto &Successor : Node->MSuccessors) {
propagatePartitionDown(Successor.lock(), PartitionNum, HostTaskList);
Node.MPartitionNum = PartitionNum;
for (node_impl &Successor : Node.successors()) {
propagatePartitionDown(Successor, PartitionNum, HostTaskList);
}
}

Expand All @@ -165,8 +164,8 @@ void propagatePartitionDown(
/// @param Node node to test
/// @return True is `Node` is a root of its partition
bool isPartitionRoot(std::shared_ptr<node_impl> Node) {
for (auto &Predecessor : Node->MPredecessors) {
if (Predecessor.lock()->MPartitionNum == Node->MPartitionNum) {
for (node_impl &Predecessor : Node->predecessors()) {
if (Predecessor.MPartitionNum == Node->MPartitionNum) {
return false;
}
}
Expand Down Expand Up @@ -221,15 +220,15 @@ void exec_graph_impl::makePartitions() {
auto Node = HostTaskList.front();
HostTaskList.pop_front();
CurrentPartition++;
for (auto &Predecessor : Node->MPredecessors) {
propagatePartitionUp(Predecessor.lock(), CurrentPartition);
for (node_impl &Predecessor : Node->predecessors()) {
propagatePartitionUp(Predecessor, CurrentPartition);
}
CurrentPartition++;
Node->MPartitionNum = CurrentPartition;
CurrentPartition++;
auto TmpSize = HostTaskList.size();
for (auto &Successor : Node->MSuccessors) {
propagatePartitionDown(Successor.lock(), CurrentPartition, HostTaskList);
for (node_impl &Successor : Node->successors()) {
propagatePartitionDown(Successor, CurrentPartition, HostTaskList);
}
if (HostTaskList.size() > TmpSize) {
// At least one HostTask has been re-numbered so group merge opportunities
Expand Down Expand Up @@ -290,9 +289,9 @@ void exec_graph_impl::makePartitions() {
for (const auto &Partition : MPartitions) {
for (auto const &Root : Partition->MRoots) {
auto RootNode = Root.lock();
for (const auto &Dep : RootNode->MPredecessors) {
auto NodeDep = Dep.lock();
auto &Predecessor = MPartitions[MPartitionNodes[NodeDep]];
for (node_impl &NodeDep : RootNode->predecessors()) {
auto &Predecessor =
MPartitions[MPartitionNodes[NodeDep.shared_from_this()]];
Partition->MPredecessors.push_back(Predecessor.get());
Predecessor->MSuccessors.push_back(Partition.get());
}
Expand Down Expand Up @@ -424,8 +423,8 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
bool ShouldAddDep = true;
// If any of this node's successors have this requirement then we skip
// adding the current node as a dependency.
for (auto &Succ : Node->MSuccessors) {
if (Succ.lock()->hasRequirementDependency(Req)) {
for (node_impl &Succ : Node->successors()) {
if (Succ.hasRequirementDependency(Req)) {
ShouldAddDep = false;
break;
}
Expand Down Expand Up @@ -774,17 +773,17 @@ void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
// predecessors until we find the real dependency.
void exec_graph_impl::findRealDeps(
std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
std::shared_ptr<node_impl> CurrentNode, int ReferencePartitionNum) {
if (!CurrentNode->requiresEnqueue()) {
for (auto &N : CurrentNode->MPredecessors) {
auto NodeImpl = N.lock();
node_impl &CurrentNode, int ReferencePartitionNum) {
if (!CurrentNode.requiresEnqueue()) {
for (node_impl &NodeImpl : CurrentNode.predecessors()) {
findRealDeps(Deps, NodeImpl, ReferencePartitionNum);
}
} else {
auto CurrentNodePtr = CurrentNode.shared_from_this();
// Verify if CurrentNode belong the the same partition
if (MPartitionNodes[CurrentNode] == ReferencePartitionNum) {
if (MPartitionNodes[CurrentNodePtr] == ReferencePartitionNum) {
// Verify that the sync point has actually been set for this node.
auto SyncPoint = MSyncPoints.find(CurrentNode);
auto SyncPoint = MSyncPoints.find(CurrentNodePtr);
assert(SyncPoint != MSyncPoints.end() &&
"No sync point has been set for node dependency.");
// Check if the dependency has already been added.
Expand All @@ -802,8 +801,8 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
ur_exp_command_buffer_handle_t CommandBuffer,
std::shared_ptr<node_impl> Node) {
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
for (auto &N : Node->MPredecessors) {
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
for (node_impl &N : Node->predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[Node]);
}
ur_exp_command_buffer_sync_point_t NewSyncPoint;
ur_exp_command_buffer_command_handle_t NewCommand = 0;
Expand Down Expand Up @@ -858,8 +857,8 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
std::shared_ptr<node_impl> Node) {

std::vector<ur_exp_command_buffer_sync_point_t> Deps;
for (auto &N : Node->MPredecessors) {
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
for (node_impl &N : Node->predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[Node]);
}

sycl::detail::EventImplPtr Event =
Expand Down Expand Up @@ -1328,8 +1327,8 @@ void exec_graph_impl::duplicateNodes() {
auto NodeCopy = NewNodes[i];
// Look through all the original node successors, find their copies and
// register those as successors with the current copied node
for (auto &NextNode : OriginalNode->MSuccessors) {
auto Successor = NodesMap.at(NextNode.lock());
for (node_impl &NextNode : OriginalNode->successors()) {
auto Successor = NodesMap.at(NextNode.shared_from_this());
NodeCopy->registerSuccessor(Successor);
}
}
Expand Down Expand Up @@ -1370,8 +1369,8 @@ void exec_graph_impl::duplicateNodes() {
auto SubgraphNode = SubgraphNodes[i];
auto NodeCopy = NewSubgraphNodes[i];

for (auto &NextNode : SubgraphNode->MSuccessors) {
auto Successor = SubgraphNodesMap.at(NextNode.lock());
for (node_impl &NextNode : SubgraphNode->successors()) {
auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this());
NodeCopy->registerSuccessor(Successor);
}
}
Expand All @@ -1392,9 +1391,8 @@ void exec_graph_impl::duplicateNodes() {
// original subgraph node

// Predecessors
for (auto &PredNodeWeak : NewNode->MPredecessors) {
auto PredNode = PredNodeWeak.lock();
auto &Successors = PredNode->MSuccessors;
for (node_impl &PredNode : NewNode->predecessors()) {
auto &Successors = PredNode.MSuccessors;

// Remove the subgraph node from this nodes successors
Successors.erase(std::remove_if(Successors.begin(), Successors.end(),
Expand All @@ -1406,14 +1404,13 @@ void exec_graph_impl::duplicateNodes() {
// Add all input nodes from the subgraph as successors for this node
// instead
for (auto &Input : Inputs) {
PredNode->registerSuccessor(Input);
PredNode.registerSuccessor(Input);
}
}

// Successors
for (auto &SuccNodeWeak : NewNode->MSuccessors) {
auto SuccNode = SuccNodeWeak.lock();
auto &Predecessors = SuccNode->MPredecessors;
for (node_impl &SuccNode : NewNode->successors()) {
auto &Predecessors = SuccNode.MPredecessors;

// Remove the subgraph node from this nodes successors
Predecessors.erase(std::remove_if(Predecessors.begin(),
Expand All @@ -1426,7 +1423,7 @@ void exec_graph_impl::duplicateNodes() {
// Add all Output nodes from the subgraph as predecessors for this node
// instead
for (auto &Output : Outputs) {
Output->registerSuccessor(SuccNode);
Output->registerSuccessor(SuccNode.shared_from_this());
}
}

Expand Down
17 changes: 7 additions & 10 deletions sycl/source/detail/graph/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,19 +352,17 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @param NodeA pointer to the first node for comparison
/// @param NodeB pointer to the second node for comparison
/// @return true is same structure found, false otherwise
static bool checkNodeRecursive(const std::shared_ptr<node_impl> &NodeA,
const std::shared_ptr<node_impl> &NodeB) {
static bool checkNodeRecursive(node_impl &NodeA, node_impl &NodeB) {
size_t FoundCnt = 0;
for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
if (NodeA->isSimilar(*NodeB) &&
checkNodeRecursive(SuccA.lock(), SuccB.lock())) {
for (node_impl &SuccA : NodeA.successors()) {
for (node_impl &SuccB : NodeB.successors()) {
if (NodeA.isSimilar(NodeB) && checkNodeRecursive(SuccA, SuccB)) {
FoundCnt++;
break;
}
}
}
if (FoundCnt != NodeA->MSuccessors.size()) {
if (FoundCnt != NodeA.MSuccessors.size()) {
return false;
}

Expand Down Expand Up @@ -434,7 +432,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
auto NodeBLocked = NodeB.lock();

if (NodeALocked->isSimilar(*NodeBLocked)) {
if (checkNodeRecursive(NodeALocked, NodeBLocked)) {
if (checkNodeRecursive(*NodeALocked, *NodeBLocked)) {
RootsFound++;
break;
}
Expand Down Expand Up @@ -829,8 +827,7 @@ class exec_graph_impl {
/// SyncPoint for CurrentNode, otherwise we need to
/// synchronize on the host with the completion of previous partitions.
void findRealDeps(std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
std::shared_ptr<node_impl> CurrentNode,
int ReferencePartitionNum);
node_impl &CurrentNode, int ReferencePartitionNum);

/// Duplicate nodes from the modifiable graph associated with this executable
/// graph and store them locally. Any subgraph nodes in the modifiable graph
Expand Down
29 changes: 12 additions & 17 deletions sycl/source/detail/graph/memory_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,49 +116,44 @@ graph_mem_pool::tryReuseExistingAllocation(
// free nodes. We do this in a breadth-first approach because we want to find
// the shortest path to a reusable allocation.

std::queue<std::weak_ptr<node_impl>> NodesToCheck;
std::queue<node_impl *> NodesToCheck;

// Add all the dependent nodes to the queue, they will be popped first
for (auto &Dep : DepNodes) {
NodesToCheck.push(Dep);
NodesToCheck.push(&*Dep);
}

// Called when traversing over nodes to check if the current node is a free
// node for one of the available allocations. If it is we populate AllocInfo
// with the allocation to be reused.
auto CheckNodeEqual =
[&CompatibleAllocs](const std::shared_ptr<node_impl> &CurrentNode)
-> std::optional<alloc_info> {
[&CompatibleAllocs](node_impl &CurrentNode) -> std::optional<alloc_info> {
for (auto &Alloc : CompatibleAllocs) {
const auto &AllocFreeNode = Alloc.LastFreeNode;
// Compare control blocks without having to lock AllocFreeNode to check
// for node equality
if (!CurrentNode.owner_before(AllocFreeNode) &&
!AllocFreeNode.owner_before(CurrentNode)) {
if (&CurrentNode == Alloc.LastFreeNode) {
return Alloc;
}
}
return std::nullopt;
};

while (!NodesToCheck.empty()) {
auto CurrentNode = NodesToCheck.front().lock();
node_impl &CurrentNode = *NodesToCheck.front();

if (CurrentNode->MTotalVisitedEdges > 0) {
if (CurrentNode.MTotalVisitedEdges > 0) {
continue;
}

// Check if the node is a free node and, if so, check if it is a free node
// for any of the allocations which are free for reuse. We should not bother
// checking nodes that are not free nodes, so we continue and check their
// predecessors.
if (CurrentNode->MNodeType == node_type::async_free) {
if (CurrentNode.MNodeType == node_type::async_free) {
std::optional<alloc_info> AllocFound = CheckNodeEqual(CurrentNode);
if (AllocFound) {
// Reset visited nodes tracking
MGraph.resetNodeVisitedEdges();
// Reset last free node for allocation
MAllocations.at(AllocFound.value().Ptr).LastFreeNode.reset();
MAllocations.at(AllocFound.value().Ptr).LastFreeNode = nullptr;
// Remove found allocation from the free list
MFreeAllocations.erase(std::find(MFreeAllocations.begin(),
MFreeAllocations.end(),
Expand All @@ -168,12 +163,12 @@ graph_mem_pool::tryReuseExistingAllocation(
}

// Add CurrentNode predecessors to queue
for (auto &Pred : CurrentNode->MPredecessors) {
NodesToCheck.push(Pred);
for (node_impl &Pred : CurrentNode.predecessors()) {
NodesToCheck.push(&Pred);
}

// Mark node as visited
CurrentNode->MTotalVisitedEdges = 1;
CurrentNode.MTotalVisitedEdges = 1;
NodesToCheck.pop();
}

Expand All @@ -183,7 +178,7 @@ graph_mem_pool::tryReuseExistingAllocation(
void graph_mem_pool::markAllocationAsAvailable(
void *Ptr, const std::shared_ptr<node_impl> &FreeNode) {
MFreeAllocations.push_back(Ptr);
MAllocations.at(Ptr).LastFreeNode = FreeNode;
MAllocations.at(Ptr).LastFreeNode = FreeNode.get();
}

} // namespace detail
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/graph/memory_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class graph_mem_pool {
// Should the allocation be zero initialized during initial allocation
bool ZeroInit = false;
// Last free node for this allocation in the graph
std::weak_ptr<node_impl> LastFreeNode = {};
node_impl *LastFreeNode = nullptr;
};

public:
Expand Down
9 changes: 9 additions & 0 deletions sycl/source/detail/graph/node_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class node;
namespace detail {
// Forward declarations
class node_impl;
class nodes_range;
class exec_graph_impl;

/// Takes a vector of weak_ptrs to node_impls and returns a vector of node
Expand Down Expand Up @@ -116,6 +117,10 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
/// cannot be used to find out the partion of a node outside of this process.
int MPartitionNum = -1;

// Out-of-class as need "complete" `nodes_range`:
inline nodes_range successors() const;
inline nodes_range predecessors() const;

/// Add successor to the node.
/// @param Node Node to add as a successor.
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
Expand Down Expand Up @@ -830,6 +835,10 @@ class nodes_range {
size_t size() const { return Size; }
bool empty() const { return Size == 0; }
};

inline nodes_range node_impl::successors() const { return MSuccessors; }
inline nodes_range node_impl::predecessors() const { return MPredecessors; }

} // namespace detail
} // namespace experimental
} // namespace oneapi
Expand Down
Loading