Skip to content

[NFC][SYCL][Graph] Use raw node_impl * in MRoots/MSchedule #19350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 45 additions & 54 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,31 +85,29 @@ inline const char *nodeTypeToString(node_type NodeType) {
/// @param[in] PartitionBounded If set to true, the topological sort is stopped
/// at partition borders. Hence, nodes belonging to a partition different from
/// the NodeImpl partition are not processed.
void sortTopological(std::set<std::weak_ptr<node_impl>,
std::owner_less<std::weak_ptr<node_impl>>> &Roots,
std::list<std::shared_ptr<node_impl>> &SortedNodes,
void sortTopological(nodes_range Roots, std::list<node_impl *> &SortedNodes,
bool PartitionBounded) {
std::stack<std::weak_ptr<node_impl>> Source;
std::stack<node_impl *> Source;

for (auto &Node : Roots) {
Source.push(Node);
for (node_impl &Node : Roots) {
Source.push(&Node);
}

while (!Source.empty()) {
auto Node = Source.top().lock();
node_impl &Node = *Source.top();
Source.pop();
SortedNodes.push_back(Node);
SortedNodes.push_back(&Node);

for (node_impl &Succ : Node->successors()) {
for (node_impl &Succ : Node.successors()) {

if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) {
if (PartitionBounded && (Succ.MPartitionNum != Node.MPartitionNum)) {
continue;
}

auto &TotalVisitedEdges = Succ.MTotalVisitedEdges;
++TotalVisitedEdges;
if (TotalVisitedEdges == Succ.MPredecessors.size()) {
Source.push(Succ.weak_from_this());
Source.push(&Succ);
}
}
}
Expand Down Expand Up @@ -163,17 +161,17 @@ void propagatePartitionDown(
/// belong to the same partition)
/// @param Node node to test
/// @return True is `Node` is a root of its partition
bool isPartitionRoot(std::shared_ptr<node_impl> Node) {
for (node_impl &Predecessor : Node->predecessors()) {
if (Predecessor.MPartitionNum == Node->MPartitionNum) {
bool isPartitionRoot(node_impl &Node) {
for (node_impl &Predecessor : Node.predecessors()) {
if (Predecessor.MPartitionNum == Node.MPartitionNum) {
return false;
}
}
return true;
}
} // anonymous namespace

void partition::schedule() {
void partition::updateSchedule() {
if (MSchedule.empty()) {
// There is no need to reset MTotalVisitedEdges before calling
// sortTopological because this function is only called once per partition.
Expand Down Expand Up @@ -256,16 +254,16 @@ void exec_graph_impl::makePartitions() {
for (auto &Node : MNodeStorage) {
if (Node->MPartitionNum == i) {
MPartitionNodes[Node.get()] = PartitionFinalNum;
if (isPartitionRoot(Node)) {
Partition->MRoots.insert(Node);
if (isPartitionRoot(*Node)) {
Partition->MRoots.insert(Node.get());
if (Node->MCGType == CGType::CodeplayHostTask) {
Partition->MIsHostTask = true;
}
}
}
}
if (Partition->MRoots.size() > 0) {
Partition->schedule();
Partition->updateSchedule();
Partition->MIsInOrderGraph = Partition->checkIfGraphIsSinglePath();
MPartitions.push_back(Partition);
MRootPartitions.push_back(Partition);
Expand All @@ -287,9 +285,8 @@ void exec_graph_impl::makePartitions() {

// Compute partition dependencies
for (const auto &Partition : MPartitions) {
for (auto const &Root : Partition->MRoots) {
auto RootNode = Root.lock();
for (node_impl &NodeDep : RootNode->predecessors()) {
for (node_impl &Root : Partition->roots()) {
for (node_impl &NodeDep : Root.predecessors()) {
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
Partition->MPredecessors.push_back(Predecessor.get());
Predecessor->MSuccessors.push_back(Partition.get());
Expand Down Expand Up @@ -340,13 +337,9 @@ graph_impl::~graph_impl() {
}
}

void graph_impl::addRoot(const std::shared_ptr<node_impl> &Root) {
MRoots.insert(Root);
}
void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); }

void graph_impl::removeRoot(const std::shared_ptr<node_impl> &Root) {
MRoots.erase(Root);
}
void graph_impl::removeRoot(node_impl &Root) { MRoots.erase(&Root); }

std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
const std::shared_ptr<sycl::detail::CG> &CommandGroup) const {
Expand Down Expand Up @@ -593,7 +586,7 @@ bool graph_impl::clearQueues() {
}

bool graph_impl::checkForCycles() {
std::list<std::shared_ptr<node_impl>> SortedNodes;
std::list<node_impl *> SortedNodes;
sortTopological(MRoots, SortedNodes, false);

// If after a topological sort, not all the nodes in the graph are sorted,
Expand Down Expand Up @@ -664,7 +657,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
if (DestLostRootStatus) {
// Dest is no longer a Root node, so we need to remove it from MRoots.
MRoots.erase(Dest);
MRoots.erase(Dest.get());
}

// We can skip cycle checks if either Dest has no successors (cycle not
Expand All @@ -679,14 +672,14 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
Dest->MPredecessors.pop_back();
if (DestLostRootStatus) {
// Add Dest back into MRoots.
MRoots.insert(Dest);
MRoots.insert(Dest.get());
}

throw sycl::exception(make_error_code(sycl::errc::invalid),
"Command graphs cannot contain cycles.");
}
}
removeRoot(Dest); // remove receiver from root node list
removeRoot(*Dest); // remove receiver from root node list
}

std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
Expand Down Expand Up @@ -740,14 +733,12 @@ void exec_graph_impl::findRealDeps(
}
}

ur_exp_command_buffer_sync_point_t
exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
sycl::detail::device_impl &DeviceImpl,
ur_exp_command_buffer_handle_t CommandBuffer,
std::shared_ptr<node_impl> Node) {
ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNodeDirect(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and enqueue are called at around line 870, with shared_ptrs removed from that area of the code. I could have used shared_from_this() and kept these two the same, but they don't seem do be called anywhere else and it made sense to me to update them as part of this PR.

const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl,
ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node) {
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
for (node_impl &N : Node->predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
for (node_impl &N : Node.predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[&Node]);
}
ur_exp_command_buffer_sync_point_t NewSyncPoint;
ur_exp_command_buffer_command_handle_t NewCommand = 0;
Expand All @@ -760,7 +751,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
if (xptiEnabled) {
StreamID = xptiRegisterStream(sycl::detail::SYCL_STREAM_NAME);
sycl::detail::CGExecKernel *CGExec =
static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());
static_cast<sycl::detail::CGExecKernel *>(Node.MCommandGroup.get());
sycl::detail::code_location CodeLoc(CGExec->MFileName.c_str(),
CGExec->MFunctionName.c_str(),
CGExec->MLine, CGExec->MColumn);
Expand All @@ -776,11 +767,11 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,

ur_result_t Res = sycl::detail::enqueueImpCommandBufferKernel(
Ctx, DeviceImpl, CommandBuffer,
*static_cast<sycl::detail::CGExecKernel *>((Node->MCommandGroup.get())),
*static_cast<sycl::detail::CGExecKernel *>((Node.MCommandGroup.get())),
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr);

if (MIsUpdatable) {
MCommandMap[Node.get()] = NewCommand;
MCommandMap[&Node] = NewCommand;
}

if (Res != UR_RESULT_SUCCESS) {
Expand All @@ -799,20 +790,20 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,

ur_exp_command_buffer_sync_point_t
exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
std::shared_ptr<node_impl> Node) {
node_impl &Node) {

std::vector<ur_exp_command_buffer_sync_point_t> Deps;
for (node_impl &N : Node->predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
for (node_impl &N : Node.predecessors()) {
findRealDeps(Deps, N, MPartitionNodes[&Node]);
}

sycl::detail::EventImplPtr Event =
sycl::detail::Scheduler::getInstance().addCG(
Node->getCGCopy(), *MQueueImpl,
Node.getCGCopy(), *MQueueImpl,
/*EventNeeded=*/true, CommandBuffer, Deps);

if (MIsUpdatable) {
MCommandMap[Node.get()] = Event->getCommandBufferCommand();
MCommandMap[&Node] = Event->getCommandBufferCommand();
}

return Event->getSyncPoint();
Expand Down Expand Up @@ -861,25 +852,25 @@ void exec_graph_impl::createCommandBuffers(

Partition->MCommandBuffers[Device] = OutCommandBuffer;

for (const auto &Node : Partition->MSchedule) {
for (node_impl &Node : Partition->schedule()) {
// Some nodes are not scheduled like other nodes, and only their
// dependencies are propagated in findRealDeps
if (!Node->requiresEnqueue())
if (!Node.requiresEnqueue())
continue;

sycl::detail::CGType type = Node->MCGType;
sycl::detail::CGType type = Node.MCGType;
// If the node is a kernel with no special requirements we can enqueue it
// directly.
if (type == sycl::detail::CGType::Kernel &&
Node->MCommandGroup->getRequirements().size() +
Node.MCommandGroup->getRequirements().size() +
static_cast<sycl::detail::CGExecKernel *>(
Node->MCommandGroup.get())
Node.MCommandGroup.get())
->MStreams.size() ==
0) {
MSyncPoints[Node.get()] =
MSyncPoints[&Node] =
enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node);
} else {
MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node);
MSyncPoints[&Node] = enqueueNode(OutCommandBuffer, Node);
}
}

Expand Down Expand Up @@ -2007,7 +1998,7 @@ std::vector<node> modifiable_command_graph::get_nodes() const {
std::vector<node> modifiable_command_graph::get_root_nodes() const {
graph_impl::ReadLock Lock(impl->MMutex);
auto &Roots = impl->MRoots;
std::vector<std::weak_ptr<node_impl>> Impls{};
std::vector<node_impl *> Impls{};

std::copy(Roots.begin(), Roots.end(), std::back_inserter(Impls));
return createNodesFromImpls(Impls);
Expand Down
Loading