@@ -255,7 +255,7 @@ void exec_graph_impl::makePartitions() {
255
255
const std::shared_ptr<partition> &Partition = std::make_shared<partition>();
256
256
for (auto &Node : MNodeStorage) {
257
257
if (Node->MPartitionNum == i) {
258
- MPartitionNodes[Node] = PartitionFinalNum;
258
+ MPartitionNodes[Node. get () ] = PartitionFinalNum;
259
259
if (isPartitionRoot (Node)) {
260
260
Partition->MRoots .insert (Node);
261
261
if (Node->MCGType == CGType::CodeplayHostTask) {
@@ -290,8 +290,7 @@ void exec_graph_impl::makePartitions() {
290
290
for (auto const &Root : Partition->MRoots ) {
291
291
auto RootNode = Root.lock ();
292
292
for (node_impl &NodeDep : RootNode->predecessors ()) {
293
- auto &Predecessor =
294
- MPartitions[MPartitionNodes[NodeDep.shared_from_this ()]];
293
+ auto &Predecessor = MPartitions[MPartitionNodes[&NodeDep]];
295
294
Partition->MPredecessors .push_back (Predecessor.get ());
296
295
Predecessor->MSuccessors .push_back (Partition.get ());
297
296
}
@@ -610,8 +609,7 @@ bool graph_impl::checkForCycles() {
610
609
return CycleFound;
611
610
}
612
611
613
- std::shared_ptr<node_impl>
614
- graph_impl::getLastInorderNode (sycl::detail::queue_impl *Queue) {
612
+ node_impl *graph_impl::getLastInorderNode (sycl::detail::queue_impl *Queue) {
615
613
if (!Queue) {
616
614
assert (0 ==
617
615
MInorderQueueMap.count (std::weak_ptr<sycl::detail::queue_impl>{}));
@@ -625,7 +623,7 @@ graph_impl::getLastInorderNode(sycl::detail::queue_impl *Queue) {
625
623
626
624
void graph_impl::setLastInorderNode (sycl::detail::queue_impl &Queue,
627
625
std::shared_ptr<node_impl> Node) {
628
- MInorderQueueMap[Queue.weak_from_this ()] = std::move ( Node) ;
626
+ MInorderQueueMap[Queue.weak_from_this ()] = &* Node;
629
627
}
630
628
631
629
void graph_impl::makeEdge (std::shared_ptr<node_impl> Src,
@@ -726,11 +724,10 @@ void exec_graph_impl::findRealDeps(
726
724
findRealDeps (Deps, NodeImpl, ReferencePartitionNum);
727
725
}
728
726
} else {
729
- auto CurrentNodePtr = CurrentNode.shared_from_this ();
730
727
// Verify if CurrentNode belong the the same partition
731
- if (MPartitionNodes[CurrentNodePtr ] == ReferencePartitionNum) {
728
+ if (MPartitionNodes[&CurrentNode ] == ReferencePartitionNum) {
732
729
// Verify that the sync point has actually been set for this node.
733
- auto SyncPoint = MSyncPoints.find (CurrentNodePtr );
730
+ auto SyncPoint = MSyncPoints.find (&CurrentNode );
734
731
assert (SyncPoint != MSyncPoints.end () &&
735
732
" No sync point has been set for node dependency." );
736
733
// Check if the dependency has already been added.
@@ -749,7 +746,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
749
746
std::shared_ptr<node_impl> Node) {
750
747
std::vector<ur_exp_command_buffer_sync_point_t > Deps;
751
748
for (node_impl &N : Node->predecessors ()) {
752
- findRealDeps (Deps, N, MPartitionNodes[Node]);
749
+ findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
753
750
}
754
751
ur_exp_command_buffer_sync_point_t NewSyncPoint;
755
752
ur_exp_command_buffer_command_handle_t NewCommand = 0 ;
@@ -782,7 +779,7 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
782
779
Deps, &NewSyncPoint, MIsUpdatable ? &NewCommand : nullptr , nullptr );
783
780
784
781
if (MIsUpdatable) {
785
- MCommandMap[Node] = NewCommand;
782
+ MCommandMap[Node. get () ] = NewCommand;
786
783
}
787
784
788
785
if (Res != UR_RESULT_SUCCESS) {
@@ -805,7 +802,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
805
802
806
803
std::vector<ur_exp_command_buffer_sync_point_t > Deps;
807
804
for (node_impl &N : Node->predecessors ()) {
808
- findRealDeps (Deps, N, MPartitionNodes[Node]);
805
+ findRealDeps (Deps, N, MPartitionNodes[Node. get () ]);
809
806
}
810
807
811
808
sycl::detail::EventImplPtr Event =
@@ -814,7 +811,7 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
814
811
/* EventNeeded=*/ true , CommandBuffer, Deps);
815
812
816
813
if (MIsUpdatable) {
817
- MCommandMap[Node] = Event->getCommandBufferCommand ();
814
+ MCommandMap[Node. get () ] = Event->getCommandBufferCommand ();
818
815
}
819
816
820
817
return Event->getSyncPoint ();
@@ -830,7 +827,8 @@ void exec_graph_impl::buildRequirements() {
830
827
Node->MCommandGroup ->getRequirements ().begin (),
831
828
Node->MCommandGroup ->getRequirements ().end ());
832
829
833
- std::shared_ptr<partition> &Partition = MPartitions[MPartitionNodes[Node]];
830
+ std::shared_ptr<partition> &Partition =
831
+ MPartitions[MPartitionNodes[Node.get ()]];
834
832
835
833
Partition->MRequirements .insert (
836
834
Partition->MRequirements .end (),
@@ -877,10 +875,10 @@ void exec_graph_impl::createCommandBuffers(
877
875
Node->MCommandGroup .get ())
878
876
->MStreams .size () ==
879
877
0 ) {
880
- MSyncPoints[Node] =
878
+ MSyncPoints[Node. get () ] =
881
879
enqueueNodeDirect (MContext, DeviceImpl, OutCommandBuffer, Node);
882
880
} else {
883
- MSyncPoints[Node] = enqueueNode (OutCommandBuffer, Node);
881
+ MSyncPoints[Node. get () ] = enqueueNode (OutCommandBuffer, Node);
884
882
}
885
883
}
886
884
@@ -1726,7 +1724,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
1726
1724
auto ExecNode = MIDCache.find (Node->MID );
1727
1725
assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1728
1726
1729
- auto Command = MCommandMap.find (ExecNode->second );
1727
+ auto Command = MCommandMap.find (ExecNode->second . get () );
1730
1728
assert (Command != MCommandMap.end ());
1731
1729
UpdateDesc.hCommand = Command->second ;
1732
1730
@@ -1756,7 +1754,7 @@ exec_graph_impl::getURUpdatableNodes(
1756
1754
1757
1755
auto ExecNode = MIDCache.find (Node->MID );
1758
1756
assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1759
- auto PartitionIndex = MPartitionNodes.find (ExecNode->second );
1757
+ auto PartitionIndex = MPartitionNodes.find (ExecNode->second . get () );
1760
1758
assert (PartitionIndex != MPartitionNodes.end ());
1761
1759
PartitionedNodes[PartitionIndex->second ].push_back (Node);
1762
1760
}
0 commit comments