Skip to content

Commit a468d08

Browse files
[NFC][SYCL][Graph] Use raw node_impl * in MRoots/MSchedule
... and update the code surrounding their uses in the same spirit. Continuation of intel#19295 intel#19332 intel#19334
1 parent 0bba746 commit a468d08

File tree

8 files changed

+160
-196
lines changed

8 files changed

+160
-196
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 45 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -85,31 +85,29 @@ inline const char *nodeTypeToString(node_type NodeType) {
8585
/// @param[in] PartitionBounded If set to true, the topological sort is stopped
8686
/// at partition borders. Hence, nodes belonging to a partition different from
8787
/// the NodeImpl partition are not processed.
88-
void sortTopological(std::set<std::weak_ptr<node_impl>,
89-
std::owner_less<std::weak_ptr<node_impl>>> &Roots,
90-
std::list<std::shared_ptr<node_impl>> &SortedNodes,
88+
void sortTopological(nodes_range Roots, std::list<node_impl *> &SortedNodes,
9189
bool PartitionBounded) {
92-
std::stack<std::weak_ptr<node_impl>> Source;
90+
std::stack<node_impl *> Source;
9391

94-
for (auto &Node : Roots) {
95-
Source.push(Node);
92+
for (node_impl &Node : Roots) {
93+
Source.push(&Node);
9694
}
9795

9896
while (!Source.empty()) {
99-
auto Node = Source.top().lock();
97+
node_impl &Node = *Source.top();
10098
Source.pop();
101-
SortedNodes.push_back(Node);
99+
SortedNodes.push_back(&Node);
102100

103-
for (node_impl &Succ : Node->successors()) {
101+
for (node_impl &Succ : Node.successors()) {
104102

105-
if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) {
103+
if (PartitionBounded && (Succ.MPartitionNum != Node.MPartitionNum)) {
106104
continue;
107105
}
108106

109107
auto &TotalVisitedEdges = Succ.MTotalVisitedEdges;
110108
++TotalVisitedEdges;
111109
if (TotalVisitedEdges == Succ.MPredecessors.size()) {
112-
Source.push(Succ.weak_from_this());
110+
Source.push(&Succ);
113111
}
114112
}
115113
}
@@ -163,17 +161,17 @@ void propagatePartitionDown(
163161
/// belong to the same partition)
164162
/// @param Node node to test
165163
/// @return True is `Node` is a root of its partition
166-
bool isPartitionRoot(std::shared_ptr<node_impl> Node) {
167-
for (node_impl &Predecessor : Node->predecessors()) {
168-
if (Predecessor.MPartitionNum == Node->MPartitionNum) {
164+
bool isPartitionRoot(node_impl &Node) {
165+
for (node_impl &Predecessor : Node.predecessors()) {
166+
if (Predecessor.MPartitionNum == Node.MPartitionNum) {
169167
return false;
170168
}
171169
}
172170
return true;
173171
}
174172
} // anonymous namespace
175173

176-
void partition::schedule() {
174+
void partition::updateSchedule() {
177175
if (MSchedule.empty()) {
178176
// There is no need to reset MTotalVisitedEdges before calling
179177
// sortTopological because this function is only called once per partition.
@@ -256,16 +254,16 @@ void exec_graph_impl::makePartitions() {
256254
for (auto &Node : MNodeStorage) {
257255
if (Node->MPartitionNum == i) {
258256
MPartitionNodes[Node.get()] = PartitionFinalNum;
259-
if (isPartitionRoot(Node)) {
260-
Partition->MRoots.insert(Node);
257+
if (isPartitionRoot(*Node)) {
258+
Partition->MRoots.insert(Node.get());
261259
if (Node->MCGType == CGType::CodeplayHostTask) {
262260
Partition->MIsHostTask = true;
263261
}
264262
}
265263
}
266264
}
267265
if (Partition->MRoots.size() > 0) {
268-
Partition->schedule();
266+
Partition->updateSchedule();
269267
Partition->MIsInOrderGraph = Partition->checkIfGraphIsSinglePath();
270268
MPartitions.push_back(Partition);
271269
MRootPartitions.push_back(Partition);
@@ -287,9 +285,8 @@ void exec_graph_impl::makePartitions() {
287285

288286
// Compute partition dependencies
289287
for (const auto &Partition : MPartitions) {
290-
for (auto const &Root : Partition->MRoots) {
291-
auto RootNode = Root.lock();
292-
for (node_impl &NodeDep : RootNode->predecessors()) {
288+
for (node_impl &Root : Partition->roots()) {
289+
for (node_impl &NodeDep : Root.predecessors()) {
293290
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
294291
Partition->MPredecessors.push_back(Predecessor.get());
295292
Predecessor->MSuccessors.push_back(Partition.get());
@@ -340,13 +337,9 @@ graph_impl::~graph_impl() {
340337
}
341338
}
342339

343-
void graph_impl::addRoot(const std::shared_ptr<node_impl> &Root) {
344-
MRoots.insert(Root);
345-
}
340+
void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); }
346341

347-
void graph_impl::removeRoot(const std::shared_ptr<node_impl> &Root) {
348-
MRoots.erase(Root);
349-
}
342+
void graph_impl::removeRoot(node_impl &Root) { MRoots.erase(&Root); }
350343

351344
std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
352345
const std::shared_ptr<sycl::detail::CG> &CommandGroup) const {
@@ -593,7 +586,7 @@ bool graph_impl::clearQueues() {
593586
}
594587

595588
bool graph_impl::checkForCycles() {
596-
std::list<std::shared_ptr<node_impl>> SortedNodes;
589+
std::list<node_impl *> SortedNodes;
597590
sortTopological(MRoots, SortedNodes, false);
598591

599592
// If after a topological sort, not all the nodes in the graph are sorted,
@@ -664,7 +657,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
664657
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
665658
if (DestLostRootStatus) {
666659
// Dest is no longer a Root node, so we need to remove it from MRoots.
667-
MRoots.erase(Dest);
660+
MRoots.erase(Dest.get());
668661
}
669662

670663
// We can skip cycle checks if either Dest has no successors (cycle not
@@ -679,14 +672,14 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
679672
Dest->MPredecessors.pop_back();
680673
if (DestLostRootStatus) {
681674
// Add Dest back into MRoots.
682-
MRoots.insert(Dest);
675+
MRoots.insert(Dest.get());
683676
}
684677

685678
throw sycl::exception(make_error_code(sycl::errc::invalid),
686679
"Command graphs cannot contain cycles.");
687680
}
688681
}
689-
removeRoot(Dest); // remove receiver from root node list
682+
removeRoot(*Dest); // remove receiver from root node list
690683
}
691684

692685
std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
@@ -740,14 +733,12 @@ void exec_graph_impl::findRealDeps(
740733
}
741734
}
742735

743-
ur_exp_command_buffer_sync_point_t
744-
exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
745-
sycl::detail::device_impl &DeviceImpl,
746-
ur_exp_command_buffer_handle_t CommandBuffer,
747-
std::shared_ptr<node_impl> Node) {
736+
ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNodeDirect(
737+
const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl,
738+
ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node) {
748739
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
749-
for (node_impl &N : Node->predecessors()) {
750-
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
740+
for (node_impl &N : Node.predecessors()) {
741+
findRealDeps(Deps, N, MPartitionNodes[&Node]);
751742
}
752743
ur_exp_command_buffer_sync_point_t NewSyncPoint;
753744
ur_exp_command_buffer_command_handle_t NewCommand = 0;
@@ -760,7 +751,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
760751
if (xptiEnabled) {
761752
StreamID = xptiRegisterStream(sycl::detail::SYCL_STREAM_NAME);
762753
sycl::detail::CGExecKernel *CGExec =
763-
static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());
754+
static_cast<sycl::detail::CGExecKernel *>(Node.MCommandGroup.get());
764755
sycl::detail::code_location CodeLoc(CGExec->MFileName.c_str(),
765756
CGExec->MFunctionName.c_str(),
766757
CGExec->MLine, CGExec->MColumn);
@@ -776,11 +767,11 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
776767

777768
ur_result_t Res = sycl::detail::enqueueImpCommandBufferKernel(
778769
Ctx, DeviceImpl, CommandBuffer,
779-
*static_cast<sycl::detail::CGExecKernel *>((Node->MCommandGroup.get())),
770+
*static_cast<sycl::detail::CGExecKernel *>((Node.MCommandGroup.get())),
780771
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr);
781772

782773
if (MIsUpdatable) {
783-
MCommandMap[Node.get()] = NewCommand;
774+
MCommandMap[&Node] = NewCommand;
784775
}
785776

786777
if (Res != UR_RESULT_SUCCESS) {
@@ -799,20 +790,20 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
799790

800791
ur_exp_command_buffer_sync_point_t
801792
exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
802-
std::shared_ptr<node_impl> Node) {
793+
node_impl &Node) {
803794

804795
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
805-
for (node_impl &N : Node->predecessors()) {
806-
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
796+
for (node_impl &N : Node.predecessors()) {
797+
findRealDeps(Deps, N, MPartitionNodes[&Node]);
807798
}
808799

809800
sycl::detail::EventImplPtr Event =
810801
sycl::detail::Scheduler::getInstance().addCG(
811-
Node->getCGCopy(), *MQueueImpl,
802+
Node.getCGCopy(), *MQueueImpl,
812803
/*EventNeeded=*/true, CommandBuffer, Deps);
813804

814805
if (MIsUpdatable) {
815-
MCommandMap[Node.get()] = Event->getCommandBufferCommand();
806+
MCommandMap[&Node] = Event->getCommandBufferCommand();
816807
}
817808

818809
return Event->getSyncPoint();
@@ -861,25 +852,25 @@ void exec_graph_impl::createCommandBuffers(
861852

862853
Partition->MCommandBuffers[Device] = OutCommandBuffer;
863854

864-
for (const auto &Node : Partition->MSchedule) {
855+
for (node_impl &Node : Partition->schedule()) {
865856
// Some nodes are not scheduled like other nodes, and only their
866857
// dependencies are propagated in findRealDeps
867-
if (!Node->requiresEnqueue())
858+
if (!Node.requiresEnqueue())
868859
continue;
869860

870-
sycl::detail::CGType type = Node->MCGType;
861+
sycl::detail::CGType type = Node.MCGType;
871862
// If the node is a kernel with no special requirements we can enqueue it
872863
// directly.
873864
if (type == sycl::detail::CGType::Kernel &&
874-
Node->MCommandGroup->getRequirements().size() +
865+
Node.MCommandGroup->getRequirements().size() +
875866
static_cast<sycl::detail::CGExecKernel *>(
876-
Node->MCommandGroup.get())
867+
Node.MCommandGroup.get())
877868
->MStreams.size() ==
878869
0) {
879-
MSyncPoints[Node.get()] =
870+
MSyncPoints[&Node] =
880871
enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node);
881872
} else {
882-
MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node);
873+
MSyncPoints[&Node] = enqueueNode(OutCommandBuffer, Node);
883874
}
884875
}
885876

@@ -2007,7 +1998,7 @@ std::vector<node> modifiable_command_graph::get_nodes() const {
20071998
std::vector<node> modifiable_command_graph::get_root_nodes() const {
20081999
graph_impl::ReadLock Lock(impl->MMutex);
20092000
auto &Roots = impl->MRoots;
2010-
std::vector<std::weak_ptr<node_impl>> Impls{};
2001+
std::vector<node_impl *> Impls{};
20112002

20122003
std::copy(Roots.begin(), Roots.end(), std::back_inserter(Impls));
20132004
return createNodesFromImpls(Impls);

0 commit comments

Comments
 (0)