Skip to content

Commit d2ff9fe

Browse files
[NFC][SYCL][Graph] Add successors/predecessors views + cleanup
Part of refactoring to get rid of most (all?) `std::weak_ptr<node_impl>` and some of `std::shared_ptr<node_impl>` started in #19295. Use `nodes_range` from that PR to implement `successors`/`predecessors` views and update read-only accesses to the successors/predecessors to go through them. I'm not changing the data members `MSuccessors`/`MPredecessors` yet because it would affect unittests. I'd prefer to refactor most of the code in future PRs before making that change and updating unittests in one go. I'm updating some APIs to accept `node_impl &` instead of `std::shared_ptr` where the change is mostly localized to the callers iterating over successors/predecessors and doesn't spoil into other code too much. For those that weren't updated here we (temporarily) use `shared_from_this()` but I expect to eliminate those unnecessary copies when those interfaces will be updated in the subsequent PRs.
1 parent f409eb7 commit d2ff9fe

File tree

5 files changed

+78
-80
lines changed

5 files changed

+78
-80
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 49 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,16 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
100100
Source.pop();
101101
SortedNodes.push_back(Node);
102102

103-
for (auto &SuccWP : Node->MSuccessors) {
104-
auto Succ = SuccWP.lock();
103+
for (node_impl &Succ : Node->successors()) {
105104

106-
if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum)) {
105+
if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) {
107106
continue;
108107
}
109108

110-
auto &TotalVisitedEdges = Succ->MTotalVisitedEdges;
109+
auto &TotalVisitedEdges = Succ.MTotalVisitedEdges;
111110
++TotalVisitedEdges;
112-
if (TotalVisitedEdges == Succ->MPredecessors.size()) {
113-
Source.push(Succ);
111+
if (TotalVisitedEdges == Succ.MPredecessors.size()) {
112+
Source.push(Succ.weak_from_this());
114113
}
115114
}
116115
}
@@ -127,14 +126,14 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
127126
/// a node with a smaller partition number.
128127
/// @param Node Node to assign to the partition.
129128
/// @param PartitionNum Number to propagate.
130-
void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
131-
if (((Node->MPartitionNum != -1) && (Node->MPartitionNum <= PartitionNum)) ||
132-
(Node->MCGType == sycl::detail::CGType::CodeplayHostTask)) {
129+
void propagatePartitionUp(node_impl &Node, int PartitionNum) {
130+
if (((Node.MPartitionNum != -1) && (Node.MPartitionNum <= PartitionNum)) ||
131+
(Node.MCGType == sycl::detail::CGType::CodeplayHostTask)) {
133132
return;
134133
}
135-
Node->MPartitionNum = PartitionNum;
136-
for (auto &Predecessor : Node->MPredecessors) {
137-
propagatePartitionUp(Predecessor.lock(), PartitionNum);
134+
Node.MPartitionNum = PartitionNum;
135+
for (node_impl &Predecessor : Node.predecessors()) {
136+
propagatePartitionUp(Predecessor, PartitionNum);
138137
}
139138
}
140139

@@ -146,17 +145,17 @@ void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
146145
/// @param HostTaskList List of host tasks that have already been processed and
147146
/// are encountered as successors to the node Node.
148147
void propagatePartitionDown(
149-
const std::shared_ptr<node_impl> &Node, int PartitionNum,
148+
node_impl &Node, int PartitionNum,
150149
std::list<std::shared_ptr<node_impl>> &HostTaskList) {
151-
if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) {
152-
if (Node->MPartitionNum != -1) {
153-
HostTaskList.push_front(Node);
150+
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
151+
if (Node.MPartitionNum != -1) {
152+
HostTaskList.push_front(Node.shared_from_this());
154153
}
155154
return;
156155
}
157-
Node->MPartitionNum = PartitionNum;
158-
for (auto &Successor : Node->MSuccessors) {
159-
propagatePartitionDown(Successor.lock(), PartitionNum, HostTaskList);
156+
Node.MPartitionNum = PartitionNum;
157+
for (node_impl &Successor : Node.successors()) {
158+
propagatePartitionDown(Successor, PartitionNum, HostTaskList);
160159
}
161160
}
162161

@@ -165,8 +164,8 @@ void propagatePartitionDown(
165164
/// @param Node node to test
166165
/// @return True is `Node` is a root of its partition
167166
bool isPartitionRoot(std::shared_ptr<node_impl> Node) {
168-
for (auto &Predecessor : Node->MPredecessors) {
169-
if (Predecessor.lock()->MPartitionNum == Node->MPartitionNum) {
167+
for (node_impl &Predecessor : Node->predecessors()) {
168+
if (Predecessor.MPartitionNum == Node->MPartitionNum) {
170169
return false;
171170
}
172171
}
@@ -221,15 +220,15 @@ void exec_graph_impl::makePartitions() {
221220
auto Node = HostTaskList.front();
222221
HostTaskList.pop_front();
223222
CurrentPartition++;
224-
for (auto &Predecessor : Node->MPredecessors) {
225-
propagatePartitionUp(Predecessor.lock(), CurrentPartition);
223+
for (node_impl &Predecessor : Node->predecessors()) {
224+
propagatePartitionUp(Predecessor, CurrentPartition);
226225
}
227226
CurrentPartition++;
228227
Node->MPartitionNum = CurrentPartition;
229228
CurrentPartition++;
230229
auto TmpSize = HostTaskList.size();
231-
for (auto &Successor : Node->MSuccessors) {
232-
propagatePartitionDown(Successor.lock(), CurrentPartition, HostTaskList);
230+
for (node_impl &Successor : Node->successors()) {
231+
propagatePartitionDown(Successor, CurrentPartition, HostTaskList);
233232
}
234233
if (HostTaskList.size() > TmpSize) {
235234
// At least one HostTask has been re-numbered so group merge opportunities
@@ -290,9 +289,9 @@ void exec_graph_impl::makePartitions() {
290289
for (const auto &Partition : MPartitions) {
291290
for (auto const &Root : Partition->MRoots) {
292291
auto RootNode = Root.lock();
293-
for (const auto &Dep : RootNode->MPredecessors) {
294-
auto NodeDep = Dep.lock();
295-
auto &Predecessor = MPartitions[MPartitionNodes[NodeDep]];
292+
for (node_impl &NodeDep : RootNode->predecessors()) {
293+
auto &Predecessor =
294+
MPartitions[MPartitionNodes[NodeDep.shared_from_this()]];
296295
Partition->MPredecessors.push_back(Predecessor.get());
297296
Predecessor->MSuccessors.push_back(Partition.get());
298297
}
@@ -390,8 +389,8 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
390389
bool ShouldAddDep = true;
391390
// If any of this node's successors have this requirement then we skip
392391
// adding the current node as a dependency.
393-
for (auto &Succ : Node->MSuccessors) {
394-
if (Succ.lock()->hasRequirementDependency(Req)) {
392+
for (node_impl &Succ : Node->successors()) {
393+
if (Succ.hasRequirementDependency(Req)) {
395394
ShouldAddDep = false;
396395
break;
397396
}
@@ -721,17 +720,17 @@ void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
721720
// predecessors until we find the real dependency.
722721
void exec_graph_impl::findRealDeps(
723722
std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
724-
std::shared_ptr<node_impl> CurrentNode, int ReferencePartitionNum) {
725-
if (!CurrentNode->requiresEnqueue()) {
726-
for (auto &N : CurrentNode->MPredecessors) {
727-
auto NodeImpl = N.lock();
723+
node_impl &CurrentNode, int ReferencePartitionNum) {
724+
if (!CurrentNode.requiresEnqueue()) {
725+
for (node_impl &NodeImpl : CurrentNode.predecessors()) {
728726
findRealDeps(Deps, NodeImpl, ReferencePartitionNum);
729727
}
730728
} else {
729+
auto CurrentNodePtr = CurrentNode.shared_from_this();
731730
// Verify if CurrentNode belong the the same partition
732-
if (MPartitionNodes[CurrentNode] == ReferencePartitionNum) {
731+
if (MPartitionNodes[CurrentNodePtr] == ReferencePartitionNum) {
733732
// Verify that the sync point has actually been set for this node.
734-
auto SyncPoint = MSyncPoints.find(CurrentNode);
733+
auto SyncPoint = MSyncPoints.find(CurrentNodePtr);
735734
assert(SyncPoint != MSyncPoints.end() &&
736735
"No sync point has been set for node dependency.");
737736
// Check if the dependency has already been added.
@@ -749,8 +748,8 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
749748
ur_exp_command_buffer_handle_t CommandBuffer,
750749
std::shared_ptr<node_impl> Node) {
751750
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
752-
for (auto &N : Node->MPredecessors) {
753-
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
751+
for (node_impl &N : Node->predecessors()) {
752+
findRealDeps(Deps, N, MPartitionNodes[Node]);
754753
}
755754
ur_exp_command_buffer_sync_point_t NewSyncPoint;
756755
ur_exp_command_buffer_command_handle_t NewCommand = 0;
@@ -805,8 +804,8 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
805804
std::shared_ptr<node_impl> Node) {
806805

807806
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
808-
for (auto &N : Node->MPredecessors) {
809-
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
807+
for (node_impl &N : Node->predecessors()) {
808+
findRealDeps(Deps, N, MPartitionNodes[Node]);
810809
}
811810

812811
sycl::detail::EventImplPtr Event =
@@ -1275,8 +1274,8 @@ void exec_graph_impl::duplicateNodes() {
12751274
auto NodeCopy = NewNodes[i];
12761275
// Look through all the original node successors, find their copies and
12771276
// register those as successors with the current copied node
1278-
for (auto &NextNode : OriginalNode->MSuccessors) {
1279-
auto Successor = NodesMap.at(NextNode.lock());
1277+
for (node_impl &NextNode : OriginalNode->successors()) {
1278+
auto Successor = NodesMap.at(NextNode.shared_from_this());
12801279
NodeCopy->registerSuccessor(Successor);
12811280
}
12821281
}
@@ -1317,8 +1316,8 @@ void exec_graph_impl::duplicateNodes() {
13171316
auto SubgraphNode = SubgraphNodes[i];
13181317
auto NodeCopy = NewSubgraphNodes[i];
13191318

1320-
for (auto &NextNode : SubgraphNode->MSuccessors) {
1321-
auto Successor = SubgraphNodesMap.at(NextNode.lock());
1319+
for (node_impl &NextNode : SubgraphNode->successors()) {
1320+
auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this());
13221321
NodeCopy->registerSuccessor(Successor);
13231322
}
13241323
}
@@ -1339,9 +1338,8 @@ void exec_graph_impl::duplicateNodes() {
13391338
// original subgraph node
13401339

13411340
// Predecessors
1342-
for (auto &PredNodeWeak : NewNode->MPredecessors) {
1343-
auto PredNode = PredNodeWeak.lock();
1344-
auto &Successors = PredNode->MSuccessors;
1341+
for (node_impl &PredNode : NewNode->predecessors()) {
1342+
auto &Successors = PredNode.MSuccessors;
13451343

13461344
// Remove the subgraph node from this nodes successors
13471345
Successors.erase(std::remove_if(Successors.begin(), Successors.end(),
@@ -1353,14 +1351,13 @@ void exec_graph_impl::duplicateNodes() {
13531351
// Add all input nodes from the subgraph as successors for this node
13541352
// instead
13551353
for (auto &Input : Inputs) {
1356-
PredNode->registerSuccessor(Input);
1354+
PredNode.registerSuccessor(Input);
13571355
}
13581356
}
13591357

13601358
// Successors
1361-
for (auto &SuccNodeWeak : NewNode->MSuccessors) {
1362-
auto SuccNode = SuccNodeWeak.lock();
1363-
auto &Predecessors = SuccNode->MPredecessors;
1359+
for (node_impl &SuccNode : NewNode->successors()) {
1360+
auto &Predecessors = SuccNode.MPredecessors;
13641361

13651362
// Remove the subgraph node from this nodes successors
13661363
Predecessors.erase(std::remove_if(Predecessors.begin(),
@@ -1373,7 +1370,7 @@ void exec_graph_impl::duplicateNodes() {
13731370
// Add all Output nodes from the subgraph as predecessors for this node
13741371
// instead
13751372
for (auto &Output : Outputs) {
1376-
Output->registerSuccessor(SuccNode);
1373+
Output->registerSuccessor(SuccNode.shared_from_this());
13771374
}
13781375
}
13791376

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -346,19 +346,17 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
346346
/// @param NodeA pointer to the first node for comparison
347347
/// @param NodeB pointer to the second node for comparison
348348
/// @return true is same structure found, false otherwise
349-
static bool checkNodeRecursive(const std::shared_ptr<node_impl> &NodeA,
350-
const std::shared_ptr<node_impl> &NodeB) {
349+
static bool checkNodeRecursive(node_impl &NodeA, node_impl &NodeB) {
351350
size_t FoundCnt = 0;
352-
for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
353-
for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
354-
if (NodeA->isSimilar(*NodeB) &&
355-
checkNodeRecursive(SuccA.lock(), SuccB.lock())) {
351+
for (node_impl &SuccA : NodeA.successors()) {
352+
for (node_impl &SuccB : NodeB.successors()) {
353+
if (NodeA.isSimilar(NodeB) && checkNodeRecursive(SuccA, SuccB)) {
356354
FoundCnt++;
357355
break;
358356
}
359357
}
360358
}
361-
if (FoundCnt != NodeA->MSuccessors.size()) {
359+
if (FoundCnt != NodeA.MSuccessors.size()) {
362360
return false;
363361
}
364362

@@ -428,7 +426,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
428426
auto NodeBLocked = NodeB.lock();
429427

430428
if (NodeALocked->isSimilar(*NodeBLocked)) {
431-
if (checkNodeRecursive(NodeALocked, NodeBLocked)) {
429+
if (checkNodeRecursive(*NodeALocked, *NodeBLocked)) {
432430
RootsFound++;
433431
break;
434432
}
@@ -817,8 +815,7 @@ class exec_graph_impl {
817815
/// SyncPoint for CurrentNode, otherwise we need to
818816
/// synchronize on the host with the completion of previous partitions.
819817
void findRealDeps(std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
820-
std::shared_ptr<node_impl> CurrentNode,
821-
int ReferencePartitionNum);
818+
node_impl &CurrentNode, int ReferencePartitionNum);
822819

823820
/// Duplicate nodes from the modifiable graph associated with this executable
824821
/// graph and store them locally. Any subgraph nodes in the modifiable graph

sycl/source/detail/graph/memory_pool.cpp

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -116,49 +116,44 @@ graph_mem_pool::tryReuseExistingAllocation(
116116
// free nodes. We do this in a breadth-first approach because we want to find
117117
// the shortest path to a reusable allocation.
118118

119-
std::queue<std::weak_ptr<node_impl>> NodesToCheck;
119+
std::queue<node_impl *> NodesToCheck;
120120

121121
// Add all the dependent nodes to the queue, they will be popped first
122122
for (auto &Dep : DepNodes) {
123-
NodesToCheck.push(Dep);
123+
NodesToCheck.push(&*Dep);
124124
}
125125

126126
// Called when traversing over nodes to check if the current node is a free
127127
// node for one of the available allocations. If it is we populate AllocInfo
128128
// with the allocation to be reused.
129129
auto CheckNodeEqual =
130-
[&CompatibleAllocs](const std::shared_ptr<node_impl> &CurrentNode)
131-
-> std::optional<alloc_info> {
130+
[&CompatibleAllocs](node_impl &CurrentNode) -> std::optional<alloc_info> {
132131
for (auto &Alloc : CompatibleAllocs) {
133-
const auto &AllocFreeNode = Alloc.LastFreeNode;
134-
// Compare control blocks without having to lock AllocFreeNode to check
135-
// for node equality
136-
if (!CurrentNode.owner_before(AllocFreeNode) &&
137-
!AllocFreeNode.owner_before(CurrentNode)) {
132+
if (&CurrentNode == Alloc.LastFreeNode) {
138133
return Alloc;
139134
}
140135
}
141136
return std::nullopt;
142137
};
143138

144139
while (!NodesToCheck.empty()) {
145-
auto CurrentNode = NodesToCheck.front().lock();
140+
node_impl &CurrentNode = *NodesToCheck.front();
146141

147-
if (CurrentNode->MTotalVisitedEdges > 0) {
142+
if (CurrentNode.MTotalVisitedEdges > 0) {
148143
continue;
149144
}
150145

151146
// Check if the node is a free node and, if so, check if it is a free node
152147
// for any of the allocations which are free for reuse. We should not bother
153148
// checking nodes that are not free nodes, so we continue and check their
154149
// predecessors.
155-
if (CurrentNode->MNodeType == node_type::async_free) {
150+
if (CurrentNode.MNodeType == node_type::async_free) {
156151
std::optional<alloc_info> AllocFound = CheckNodeEqual(CurrentNode);
157152
if (AllocFound) {
158153
// Reset visited nodes tracking
159154
MGraph.resetNodeVisitedEdges();
160155
// Reset last free node for allocation
161-
MAllocations.at(AllocFound.value().Ptr).LastFreeNode.reset();
156+
MAllocations.at(AllocFound.value().Ptr).LastFreeNode = nullptr;
162157
// Remove found allocation from the free list
163158
MFreeAllocations.erase(std::find(MFreeAllocations.begin(),
164159
MFreeAllocations.end(),
@@ -168,12 +163,12 @@ graph_mem_pool::tryReuseExistingAllocation(
168163
}
169164

170165
// Add CurrentNode predecessors to queue
171-
for (auto &Pred : CurrentNode->MPredecessors) {
172-
NodesToCheck.push(Pred);
166+
for (node_impl &Pred : CurrentNode.predecessors()) {
167+
NodesToCheck.push(&Pred);
173168
}
174169

175170
// Mark node as visited
176-
CurrentNode->MTotalVisitedEdges = 1;
171+
CurrentNode.MTotalVisitedEdges = 1;
177172
NodesToCheck.pop();
178173
}
179174

@@ -183,7 +178,7 @@ graph_mem_pool::tryReuseExistingAllocation(
183178
void graph_mem_pool::markAllocationAsAvailable(
184179
void *Ptr, const std::shared_ptr<node_impl> &FreeNode) {
185180
MFreeAllocations.push_back(Ptr);
186-
MAllocations.at(Ptr).LastFreeNode = FreeNode;
181+
MAllocations.at(Ptr).LastFreeNode = FreeNode.get();
187182
}
188183

189184
} // namespace detail

sycl/source/detail/graph/memory_pool.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class graph_mem_pool {
4444
// Should the allocation be zero initialized during initial allocation
4545
bool ZeroInit = false;
4646
// Last free node for this allocation in the graph
47-
std::weak_ptr<node_impl> LastFreeNode = {};
47+
node_impl *LastFreeNode = nullptr;
4848
};
4949

5050
public:

sycl/source/detail/graph/node_impl.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class node;
3131
namespace detail {
3232
// Forward declarations
3333
class node_impl;
34+
class nodes_range;
3435
class exec_graph_impl;
3536

3637
/// Takes a vector of weak_ptrs to node_impls and returns a vector of node
@@ -116,6 +117,10 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
116117
/// cannot be used to find out the partion of a node outside of this process.
117118
int MPartitionNum = -1;
118119

120+
// Out-of-class as need "complete" `nodes_range`:
121+
inline nodes_range successors() const;
122+
inline nodes_range predecessors() const;
123+
119124
/// Add successor to the node.
120125
/// @param Node Node to add as a successor.
121126
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
@@ -830,6 +835,10 @@ class nodes_range {
830835
size_t size() const { return Size; }
831836
bool empty() const { return Size == 0; }
832837
};
838+
839+
inline nodes_range node_impl::successors() const { return MSuccessors; }
840+
inline nodes_range node_impl::predecessors() const { return MPredecessors; }
841+
833842
} // namespace detail
834843
} // namespace experimental
835844
} // namespace oneapi

0 commit comments

Comments
 (0)