Skip to content

Commit 4662cf8

Browse files
[NFC][SYCL][Graph] Use raw node_impl * in MRoots/MSchedule
... and update the code surrounding their uses in the same spirit. Continuation of intel#19295 intel#19332 intel#19334
1 parent 45c2e52 commit 4662cf8

File tree

8 files changed

+156
-192
lines changed

8 files changed

+156
-192
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -85,31 +85,29 @@ inline const char *nodeTypeToString(node_type NodeType) {
8585
/// @param[in] PartitionBounded If set to true, the topological sort is stopped
8686
/// at partition borders. Hence, nodes belonging to a partition different from
8787
/// the NodeImpl partition are not processed.
88-
void sortTopological(std::set<std::weak_ptr<node_impl>,
89-
std::owner_less<std::weak_ptr<node_impl>>> &Roots,
90-
std::list<std::shared_ptr<node_impl>> &SortedNodes,
88+
void sortTopological(nodes_range Roots, std::list<node_impl *> &SortedNodes,
9189
bool PartitionBounded) {
92-
std::stack<std::weak_ptr<node_impl>> Source;
90+
std::stack<node_impl *> Source;
9391

94-
for (auto &Node : Roots) {
95-
Source.push(Node);
92+
for (node_impl &Node : Roots) {
93+
Source.push(&Node);
9694
}
9795

9896
while (!Source.empty()) {
99-
auto Node = Source.top().lock();
97+
node_impl &Node = *Source.top();
10098
Source.pop();
101-
SortedNodes.push_back(Node);
99+
SortedNodes.push_back(&Node);
102100

103-
for (node_impl &Succ : Node->successors()) {
101+
for (node_impl &Succ : Node.successors()) {
104102

105-
if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) {
103+
if (PartitionBounded && (Succ.MPartitionNum != Node.MPartitionNum)) {
106104
continue;
107105
}
108106

109107
auto &TotalVisitedEdges = Succ.MTotalVisitedEdges;
110108
++TotalVisitedEdges;
111109
if (TotalVisitedEdges == Succ.MPredecessors.size()) {
112-
Source.push(Succ.weak_from_this());
110+
Source.push(&Succ);
113111
}
114112
}
115113
}
@@ -173,7 +171,7 @@ bool isPartitionRoot(std::shared_ptr<node_impl> Node) {
173171
}
174172
} // anonymous namespace
175173

176-
void partition::schedule() {
174+
void partition::updateSchedule() {
177175
if (MSchedule.empty()) {
178176
// There is no need to reset MTotalVisitedEdges before calling
179177
// sortTopological because this function is only called once per partition.
@@ -257,15 +255,15 @@ void exec_graph_impl::makePartitions() {
257255
if (Node->MPartitionNum == i) {
258256
MPartitionNodes[Node.get()] = PartitionFinalNum;
259257
if (isPartitionRoot(Node)) {
260-
Partition->MRoots.insert(Node);
258+
Partition->MRoots.insert(Node.get());
261259
if (Node->MCGType == CGType::CodeplayHostTask) {
262260
Partition->MIsHostTask = true;
263261
}
264262
}
265263
}
266264
}
267265
if (Partition->MRoots.size() > 0) {
268-
Partition->schedule();
266+
Partition->updateSchedule();
269267
Partition->MIsInOrderGraph = Partition->checkIfGraphIsSinglePath();
270268
MPartitions.push_back(Partition);
271269
MRootPartitions.push_back(Partition);
@@ -287,9 +285,8 @@ void exec_graph_impl::makePartitions() {
287285

288286
// Compute partition dependencies
289287
for (const auto &Partition : MPartitions) {
290-
for (auto const &Root : Partition->MRoots) {
291-
auto RootNode = Root.lock();
292-
for (node_impl &NodeDep : RootNode->predecessors()) {
288+
for (node_impl &Root : Partition->roots()) {
289+
for (node_impl &NodeDep : Root.predecessors()) {
293290
auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
294291
Partition->MPredecessors.push_back(Predecessor.get());
295292
Predecessor->MSuccessors.push_back(Partition.get());
@@ -340,13 +337,9 @@ graph_impl::~graph_impl() {
340337
}
341338
}
342339

343-
void graph_impl::addRoot(const std::shared_ptr<node_impl> &Root) {
344-
MRoots.insert(Root);
345-
}
340+
void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); }
346341

347-
void graph_impl::removeRoot(const std::shared_ptr<node_impl> &Root) {
348-
MRoots.erase(Root);
349-
}
342+
void graph_impl::removeRoot(node_impl &Root) { MRoots.erase(&Root); }
350343

351344
std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
352345
const std::shared_ptr<sycl::detail::CG> &CommandGroup) const {
@@ -593,7 +586,7 @@ bool graph_impl::clearQueues() {
593586
}
594587

595588
bool graph_impl::checkForCycles() {
596-
std::list<std::shared_ptr<node_impl>> SortedNodes;
589+
std::list<node_impl *> SortedNodes;
597590
sortTopological(MRoots, SortedNodes, false);
598591

599592
// If after a topological sort, not all the nodes in the graph are sorted,
@@ -664,7 +657,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
664657
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
665658
if (DestLostRootStatus) {
666659
// Dest is no longer a Root node, so we need to remove it from MRoots.
667-
MRoots.erase(Dest);
660+
MRoots.erase(Dest.get());
668661
}
669662

670663
// We can skip cycle checks if either Dest has no successors (cycle not
@@ -679,14 +672,14 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
679672
Dest->MPredecessors.pop_back();
680673
if (DestLostRootStatus) {
681674
// Add Dest back into MRoots.
682-
MRoots.insert(Dest);
675+
MRoots.insert(Dest.get());
683676
}
684677

685678
throw sycl::exception(make_error_code(sycl::errc::invalid),
686679
"Command graphs cannot contain cycles.");
687680
}
688681
}
689-
removeRoot(Dest); // remove receiver from root node list
682+
removeRoot(*Dest); // remove receiver from root node list
690683
}
691684

692685
std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
@@ -739,14 +732,12 @@ void exec_graph_impl::findRealDeps(
739732
}
740733
}
741734

742-
ur_exp_command_buffer_sync_point_t
743-
exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
744-
sycl::detail::device_impl &DeviceImpl,
745-
ur_exp_command_buffer_handle_t CommandBuffer,
746-
std::shared_ptr<node_impl> Node) {
735+
ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNodeDirect(
736+
const sycl::context &Ctx, sycl::detail::device_impl &DeviceImpl,
737+
ur_exp_command_buffer_handle_t CommandBuffer, node_impl &Node) {
747738
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
748-
for (node_impl &N : Node->predecessors()) {
749-
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
739+
for (node_impl &N : Node.predecessors()) {
740+
findRealDeps(Deps, N, MPartitionNodes[&Node]);
750741
}
751742
ur_exp_command_buffer_sync_point_t NewSyncPoint;
752743
ur_exp_command_buffer_command_handle_t NewCommand = 0;
@@ -759,7 +750,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
759750
if (xptiEnabled) {
760751
StreamID = xptiRegisterStream(sycl::detail::SYCL_STREAM_NAME);
761752
sycl::detail::CGExecKernel *CGExec =
762-
static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());
753+
static_cast<sycl::detail::CGExecKernel *>(Node.MCommandGroup.get());
763754
sycl::detail::code_location CodeLoc(CGExec->MFileName.c_str(),
764755
CGExec->MFunctionName.c_str(),
765756
CGExec->MLine, CGExec->MColumn);
@@ -775,11 +766,11 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
775766

776767
ur_result_t Res = sycl::detail::enqueueImpCommandBufferKernel(
777768
Ctx, DeviceImpl, CommandBuffer,
778-
*static_cast<sycl::detail::CGExecKernel *>((Node->MCommandGroup.get())),
769+
*static_cast<sycl::detail::CGExecKernel *>((Node.MCommandGroup.get())),
779770
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr, nullptr);
780771

781772
if (MIsUpdatable) {
782-
MCommandMap[Node.get()] = NewCommand;
773+
MCommandMap[&Node] = NewCommand;
783774
}
784775

785776
if (Res != UR_RESULT_SUCCESS) {
@@ -798,20 +789,20 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
798789

799790
ur_exp_command_buffer_sync_point_t
800791
exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
801-
std::shared_ptr<node_impl> Node) {
792+
node_impl &Node) {
802793

803794
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
804-
for (node_impl &N : Node->predecessors()) {
805-
findRealDeps(Deps, N, MPartitionNodes[Node.get()]);
795+
for (node_impl &N : Node.predecessors()) {
796+
findRealDeps(Deps, N, MPartitionNodes[&Node]);
806797
}
807798

808799
sycl::detail::EventImplPtr Event =
809800
sycl::detail::Scheduler::getInstance().addCG(
810-
Node->getCGCopy(), *MQueueImpl,
801+
Node.getCGCopy(), *MQueueImpl,
811802
/*EventNeeded=*/true, CommandBuffer, Deps);
812803

813804
if (MIsUpdatable) {
814-
MCommandMap[Node.get()] = Event->getCommandBufferCommand();
805+
MCommandMap[&Node] = Event->getCommandBufferCommand();
815806
}
816807

817808
return Event->getSyncPoint();
@@ -860,25 +851,25 @@ void exec_graph_impl::createCommandBuffers(
860851

861852
Partition->MCommandBuffers[Device] = OutCommandBuffer;
862853

863-
for (const auto &Node : Partition->MSchedule) {
854+
for (node_impl &Node : Partition->schedule()) {
864855
// Some nodes are not scheduled like other nodes, and only their
865856
// dependencies are propagated in findRealDeps
866-
if (!Node->requiresEnqueue())
857+
if (!Node.requiresEnqueue())
867858
continue;
868859

869-
sycl::detail::CGType type = Node->MCGType;
860+
sycl::detail::CGType type = Node.MCGType;
870861
// If the node is a kernel with no special requirements we can enqueue it
871862
// directly.
872863
if (type == sycl::detail::CGType::Kernel &&
873-
Node->MCommandGroup->getRequirements().size() +
864+
Node.MCommandGroup->getRequirements().size() +
874865
static_cast<sycl::detail::CGExecKernel *>(
875-
Node->MCommandGroup.get())
866+
Node.MCommandGroup.get())
876867
->MStreams.size() ==
877868
0) {
878-
MSyncPoints[Node.get()] =
869+
MSyncPoints[&Node] =
879870
enqueueNodeDirect(MContext, DeviceImpl, OutCommandBuffer, Node);
880871
} else {
881-
MSyncPoints[Node.get()] = enqueueNode(OutCommandBuffer, Node);
872+
MSyncPoints[&Node] = enqueueNode(OutCommandBuffer, Node);
882873
}
883874
}
884875

@@ -2006,7 +1997,7 @@ std::vector<node> modifiable_command_graph::get_nodes() const {
20061997
std::vector<node> modifiable_command_graph::get_root_nodes() const {
20071998
graph_impl::ReadLock Lock(impl->MMutex);
20081999
auto &Roots = impl->MRoots;
2009-
std::vector<std::weak_ptr<node_impl>> Impls{};
2000+
std::vector<node_impl *> Impls{};
20102001

20112002
std::copy(Roots.begin(), Roots.end(), std::back_inserter(Impls));
20122003
return createNodesFromImpls(Impls);

0 commit comments

Comments
 (0)