Skip to content

[NFC][SYCL][Graph] Update some maps to use raw node_impl * #19334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sycl/source/detail/async_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
// If this is being recorded from an in-order queue we need to get the last
// in-order node if any, since this will later become a dependency of the
// node being processed here.
if (const auto &LastInOrderNode = Graph->getLastInorderNode(Queue);
if (detail::node_impl *LastInOrderNode = Graph->getLastInorderNode(Queue);
LastInOrderNode) {
DepNodes.push_back(LastInOrderNode);
DepNodes.push_back(LastInOrderNode->shared_from_this());
}
return DepNodes;
}
Expand Down
123 changes: 59 additions & 64 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,16 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
Source.pop();
SortedNodes.push_back(Node);

for (auto &SuccWP : Node->MSuccessors) {
auto Succ = SuccWP.lock();
for (node_impl &Succ : Node->successors()) {

if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum)) {
if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) {
continue;
}

auto &TotalVisitedEdges = Succ->MTotalVisitedEdges;
auto &TotalVisitedEdges = Succ.MTotalVisitedEdges;
++TotalVisitedEdges;
if (TotalVisitedEdges == Succ->MPredecessors.size()) {
Source.push(Succ);
if (TotalVisitedEdges == Succ.MPredecessors.size()) {
Source.push(Succ.weak_from_this());
}
}
}
Expand All @@ -127,14 +126,14 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
/// a node with a smaller partition number.
/// @param Node Node to assign to the partition.
/// @param PartitionNum Number to propagate.
void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
if (((Node->MPartitionNum != -1) && (Node->MPartitionNum <= PartitionNum)) ||
(Node->MCGType == sycl::detail::CGType::CodeplayHostTask)) {
void propagatePartitionUp(node_impl &Node, int PartitionNum) {
if (((Node.MPartitionNum != -1) && (Node.MPartitionNum <= PartitionNum)) ||
(Node.MCGType == sycl::detail::CGType::CodeplayHostTask)) {
return;
}
Node->MPartitionNum = PartitionNum;
for (auto &Predecessor : Node->MPredecessors) {
propagatePartitionUp(Predecessor.lock(), PartitionNum);
Node.MPartitionNum = PartitionNum;
for (node_impl &Predecessor : Node.predecessors()) {
propagatePartitionUp(Predecessor, PartitionNum);
}
}

Expand All @@ -146,17 +145,17 @@ void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
/// @param HostTaskList List of host tasks that have already been processed and
/// are encountered as successors to the node Node.
void propagatePartitionDown(
const std::shared_ptr<node_impl> &Node, int PartitionNum,
node_impl &Node, int PartitionNum,
std::list<std::shared_ptr<node_impl>> &HostTaskList) {
if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) {
if (Node->MPartitionNum != -1) {
HostTaskList.push_front(Node);
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
if (Node.MPartitionNum != -1) {
HostTaskList.push_front(Node.shared_from_this());
}
return;
}
Node->MPartitionNum = PartitionNum;
for (auto &Successor : Node->MSuccessors) {
propagatePartitionDown(Successor.lock(), PartitionNum, HostTaskList);
Node.MPartitionNum = PartitionNum;
for (node_impl &Successor : Node.successors()) {
propagatePartitionDown(Successor, PartitionNum, HostTaskList);
}
}

Expand All @@ -165,8 +164,8 @@ void propagatePartitionDown(
/// @param Node node to test
/// @return True is `Node` is a root of its partition
bool isPartitionRoot(std::shared_ptr<node_impl> Node) {
for (auto &Predecessor : Node->MPredecessors) {
if (Predecessor.lock()->MPartitionNum == Node->MPartitionNum) {
for (node_impl &Predecessor : Node->predecessors()) {
if (Predecessor.MPartitionNum == Node->MPartitionNum) {
return false;
}
}
Expand Down Expand Up @@ -221,15 +220,15 @@ void exec_graph_impl::makePartitions() {
auto Node = HostTaskList.front();
HostTaskList.pop_front();
CurrentPartition++;
for (auto &Predecessor : Node->MPredecessors) {
propagatePartitionUp(Predecessor.lock(), CurrentPartition);
for (node_impl &Predecessor : Node->predecessors()) {
propagatePartitionUp(Predecessor, CurrentPartition);
}
CurrentPartition++;
Node->MPartitionNum = CurrentPartition;
CurrentPartition++;
auto TmpSize = HostTaskList.size();
for (auto &Successor : Node->MSuccessors) {
propagatePartitionDown(Successor.lock(), CurrentPartition, HostTaskList);
for (node_impl &Successor : Node->successors()) {
propagatePartitionDown(Successor, CurrentPartition, HostTaskList);
}
if (HostTaskList.size() > TmpSize) {
// At least one HostTask has been re-numbered so group merge opportunities
Expand All @@ -256,7 +255,7 @@ void exec_graph_impl::makePartitions() {
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
for (auto &Node : MNodeStorage) {
if (Node->MPartitionNum == i) {
MPartitionNodes[Node] = PartitionFinalNum;
MPartitionNodes[Node.get()] = PartitionFinalNum;
if (isPartitionRoot(Node)) {
Partition->MRoots.insert(Node);
if (Node->MCGType == CGType::CodeplayHostTask) {
Expand Down Expand Up @@ -290,9 +289,8 @@ void exec_graph_impl::makePartitions() {
for (const auto &Partition : MPartitions) {
for (auto const &Root : Partition->MRoots) {
auto RootNode = Root.lock();
for (const auto &Dep : RootNode->MPredecessors) {
auto NodeDep = Dep.lock();
auto &Predecessor = MPartitions[MPartitionNodes[NodeDep]];
for (node_impl &NodeDep : RootNode->predecessors()) {
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
Partition->MPredecessors.push_back(Predecessor.get());
Predecessor->MSuccessors.push_back(Partition.get());
}
Expand Down Expand Up @@ -390,8 +388,8 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
bool ShouldAddDep = true;
// If any of this node's successors have this requirement then we skip
// adding the current node as a dependency.
for (auto &Succ : Node->MSuccessors) {
if (Succ.lock()->hasRequirementDependency(Req)) {
for (node_impl &Succ : Node->successors()) {
if (Succ.hasRequirementDependency(Req)) {
ShouldAddDep = false;
break;
}
Expand Down Expand Up @@ -611,8 +609,7 @@ bool graph_impl::checkForCycles() {
return CycleFound;
}

std::shared_ptr<node_impl>
graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
node_impl *graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
if (!Queue) {
assert(0 ==
MInorderQueueMap.count(std::weak_ptr<sycl::detail::queue_impl>{}));
Expand All @@ -625,8 +622,8 @@ graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
}

void graph_impl::setLastInorderNode(sycl::detail::queue_impl &Queue,
std::shared_ptr<node_impl> Node) {
MInorderQueueMap[Queue.weak_from_this()] = std::move(Node);
node_impl &Node) {
MInorderQueueMap[Queue.weak_from_this()] = &Node;
}

void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
Expand Down Expand Up @@ -721,17 +718,16 @@ void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
// predecessors until we find the real dependency.
void exec_graph_impl::findRealDeps(
std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
std::shared_ptr<node_impl> CurrentNode, int ReferencePartitionNum) {
if (!CurrentNode->requiresEnqueue()) {
for (auto &N : CurrentNode->MPredecessors) {
auto NodeImpl = N.lock();
node_impl &CurrentNode, int ReferencePartitionNum) {
if (!CurrentNode.requiresEnqueue()) {
for (node_impl &NodeImpl : CurrentNode.predecessors()) {
findRealDeps(Deps, NodeImpl, ReferencePartitionNum);
}
} else {
// Verify if CurrentNode belong the the same partition
if (MPartitionNodes[CurrentNode] == ReferencePartitionNum) {
if (MPartitionNodes[&CurrentNode] == ReferencePartitionNum) {
// Verify that the sync point has actually been set for this node.
auto SyncPoint = MSyncPoints.find(CurrentNode);
auto SyncPoint = MSyncPoints.find(&CurrentNode);
assert(SyncPoint != MSyncPoints.end() &&
"No sync point has been set for node dependency.");
// Check if the dependency has already been added.
Expand All @@ -749,8 +745,8 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
ur_exp_command_buffer_handle_t CommandBuffer,
std::shared_ptr<node_impl> Node) {
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
for (auto &N : Node->MPredecessors) {
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
for (node_impl &N : Node->predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
}
ur_exp_command_buffer_sync_point_t NewSyncPoint;
ur_exp_command_buffer_command_handle_t NewCommand = 0;
Expand Down Expand Up @@ -783,7 +779,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr);

if (MIsUpdatable) {
MCommandMap[Node] = NewCommand;
MCommandMap[Node.get()] = NewCommand;
}

if (Res != UR_RESULT_SUCCESS) {
Expand All @@ -805,8 +801,8 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
std::shared_ptr<node_impl> Node) {

std::vector<ur_exp_command_buffer_sync_point_t> Deps;
for (auto &N : Node->MPredecessors) {
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
for (node_impl &N : Node->predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
}

sycl::detail::EventImplPtr Event =
Expand All @@ -815,7 +811,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
/*EventNeeded=*/true, CommandBuffer, Deps);

if (MIsUpdatable) {
MCommandMap[Node] = Event->getCommandBufferCommand();
MCommandMap[Node.get()] = Event->getCommandBufferCommand();
}

return Event->getSyncPoint();
Expand All @@ -831,7 +827,8 @@ void exec_graph_impl::buildRequirements() {
Node->MCommandGroup->getRequirements().begin(),
Node->MCommandGroup->getRequirements().end());

std::shared_ptr<partition> &Partition = MPartitions[MPartitionNodes[Node]];
std::shared_ptr<partition> &Partition =
MPartitions[MPartitionNodes[Node.get()]];

Partition->MRequirements.insert(
Partition->MRequirements.end(),
Expand Down Expand Up @@ -878,10 +875,10 @@ void exec_graph_impl::createCommandBuffers(
Node->MCommandGroup.get())
->MStreams.size() ==
0) {
MSyncPoints[Node] =
MSyncPoints[Node.get()] =
enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node);
} else {
MSyncPoints[Node] = enqueueNode(OutCommandBuffer, Node);
MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node);
}
}

Expand Down Expand Up @@ -1275,8 +1272,8 @@ void exec_graph_impl::duplicateNodes() {
auto NodeCopy = NewNodes[i];
// Look through all the original node successors, find their copies and
// register those as successors with the current copied node
for (auto &NextNode : OriginalNode->MSuccessors) {
auto Successor = NodesMap.at(NextNode.lock());
for (node_impl &NextNode : OriginalNode->successors()) {
auto Successor = NodesMap.at(NextNode.shared_from_this());
NodeCopy->registerSuccessor(Successor);
}
}
Expand Down Expand Up @@ -1317,8 +1314,8 @@ void exec_graph_impl::duplicateNodes() {
auto SubgraphNode = SubgraphNodes[i];
auto NodeCopy = NewSubgraphNodes[i];

for (auto &NextNode : SubgraphNode->MSuccessors) {
auto Successor = SubgraphNodesMap.at(NextNode.lock());
for (node_impl &NextNode : SubgraphNode->successors()) {
auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this());
NodeCopy->registerSuccessor(Successor);
}
}
Expand All @@ -1339,9 +1336,8 @@ void exec_graph_impl::duplicateNodes() {
// original subgraph node

// Predecessors
for (auto &PredNodeWeak : NewNode->MPredecessors) {
auto PredNode = PredNodeWeak.lock();
auto &Successors = PredNode->MSuccessors;
for (node_impl &PredNode : NewNode->predecessors()) {
auto &Successors = PredNode.MSuccessors;

// Remove the subgraph node from this nodes successors
Successors.erase(std::remove_if(Successors.begin(), Successors.end(),
Expand All @@ -1353,14 +1349,13 @@ void exec_graph_impl::duplicateNodes() {
// Add all input nodes from the subgraph as successors for this node
// instead
for (auto &Input : Inputs) {
PredNode->registerSuccessor(Input);
PredNode.registerSuccessor(Input);
}
}

// Successors
for (auto &SuccNodeWeak : NewNode->MSuccessors) {
auto SuccNode = SuccNodeWeak.lock();
auto &Predecessors = SuccNode->MPredecessors;
for (node_impl &SuccNode : NewNode->successors()) {
auto &Predecessors = SuccNode.MPredecessors;

// Remove the subgraph node from this nodes successors
Predecessors.erase(std::remove_if(Predecessors.begin(),
Expand All @@ -1373,7 +1368,7 @@ void exec_graph_impl::duplicateNodes() {
// Add all Output nodes from the subgraph as predecessors for this node
// instead
for (auto &Output : Outputs) {
Output->registerSuccessor(SuccNode);
Output->registerSuccessor(SuccNode.shared_from_this());
}
}

Expand Down Expand Up @@ -1729,7 +1724,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");

auto Command = MCommandMap.find(ExecNode->second);
auto Command = MCommandMap.find(ExecNode->second.get());
assert(Command != MCommandMap.end());
UpdateDesc.hCommand = Command->second;

Expand Down Expand Up @@ -1759,7 +1754,7 @@ exec_graph_impl::getURUpdatableNodes(

auto ExecNode = MIDCache.find(Node->MID);
assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache");
auto PartitionIndex = MPartitionNodes.find(ExecNode->second);
auto PartitionIndex = MPartitionNodes.find(ExecNode->second.get());
assert(PartitionIndex != MPartitionNodes.end());
PartitionedNodes[PartitionIndex->second].push_back(Node);
}
Expand Down
Loading