@@ -341,7 +341,7 @@ void graph_impl::addRoot(node_impl &Root) { MRoots.insert(&Root); }
341
341
342
342
void graph_impl::removeRoot (node_impl &Root) { MRoots.erase (&Root); }
343
343
344
- std::set<std::shared_ptr< node_impl> > graph_impl::getCGEdges (
344
+ std::set<node_impl * > graph_impl::getCGEdges (
345
345
const std::shared_ptr<sycl::detail::CG> &CommandGroup) const {
346
346
const auto &Requirements = CommandGroup->getRequirements ();
347
347
if (!MAllowBuffers && Requirements.size ()) {
@@ -362,14 +362,14 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
362
362
}
363
363
364
364
// Add any nodes specified by event dependencies into the dependency list
365
- std::set<std::shared_ptr< node_impl> > UniqueDeps;
365
+ std::set<node_impl * > UniqueDeps;
366
366
for (auto &Dep : CommandGroup->getEvents ()) {
367
367
if (auto NodeImpl = MEventsMap.find (Dep); NodeImpl == MEventsMap.end ()) {
368
368
throw sycl::exception (sycl::make_error_code (errc::invalid),
369
369
" Event dependency from handler::depends_on does "
370
370
" not correspond to a node within the graph" );
371
371
} else {
372
- UniqueDeps.insert (NodeImpl->second );
372
+ UniqueDeps.insert (NodeImpl->second . get () );
373
373
}
374
374
}
375
375
@@ -388,7 +388,7 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
388
388
}
389
389
}
390
390
if (ShouldAddDep) {
391
- UniqueDeps.insert (Node);
391
+ UniqueDeps.insert (Node. get () );
392
392
}
393
393
}
394
394
}
@@ -501,7 +501,7 @@ graph_impl::add(node_type NodeType,
501
501
nodes_range Deps) {
502
502
503
503
// A unique set of dependencies obtained by checking requirements and events
504
- std::set<std::shared_ptr< node_impl> > UniqueDeps = getCGEdges (CommandGroup);
504
+ std::set<node_impl * > UniqueDeps = getCGEdges (CommandGroup);
505
505
506
506
// Track and mark the memory objects being used by the graph.
507
507
markCGMemObjs (CommandGroup);
@@ -530,8 +530,7 @@ std::shared_ptr<node_impl>
530
530
graph_impl::add (std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
531
531
nodes_range Deps) {
532
532
// Set of Dependent nodes based on CG event and accessor dependencies.
533
- std::set<std::shared_ptr<node_impl>> DynCGDeps =
534
- getCGEdges (DynCGImpl->MCommandGroups [0 ]);
533
+ std::set<node_impl *> DynCGDeps = getCGEdges (DynCGImpl->MCommandGroups [0 ]);
535
534
for (unsigned i = 1 ; i < DynCGImpl->getNumCGs (); i++) {
536
535
auto &CG = DynCGImpl->MCommandGroups [i];
537
536
auto CGEdges = getCGEdges (CG);
@@ -1559,7 +1558,7 @@ bool exec_graph_impl::needsScheduledUpdate(
1559
1558
}
1560
1559
1561
1560
void exec_graph_impl::populateURKernelUpdateStructs (
1562
- const std::shared_ptr< node_impl> &Node, FastKernelCacheValPtr &BundleObjs,
1561
+ node_impl &Node, FastKernelCacheValPtr &BundleObjs,
1563
1562
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t > &MemobjDescs,
1564
1563
std::vector<ur_kernel_arg_mem_obj_properties_t > &MemobjProps,
1565
1564
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t > &PtrDescs,
@@ -1574,7 +1573,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
1574
1573
1575
1574
// Gather arg information from Node
1576
1575
auto &ExecCG =
1577
- *(static_cast <sycl::detail::CGExecKernel *>(Node-> MCommandGroup .get ()));
1576
+ *(static_cast <sycl::detail::CGExecKernel *>(Node. MCommandGroup .get ()));
1578
1577
// Copy args because we may modify them
1579
1578
std::vector<sycl::detail::ArgDesc> NodeArgs = ExecCG.getArguments ();
1580
1579
// Copy NDR desc since we need to modify it
@@ -1713,7 +1712,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
1713
1712
// TODO: Handle subgraphs or any other cases where multiple nodes may be
1714
1713
// associated with a single key, once those node types are supported for
1715
1714
// update.
1716
- auto ExecNode = MIDCache.find (Node-> MID );
1715
+ auto ExecNode = MIDCache.find (Node. MID );
1717
1716
assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1718
1717
1719
1718
auto Command = MCommandMap.find (ExecNode->second .get ());
@@ -1725,30 +1724,29 @@ void exec_graph_impl::populateURKernelUpdateStructs(
1725
1724
ExecNode->second ->updateFromOtherNode (Node);
1726
1725
}
1727
1726
1728
- std::map<int , std::vector<std::shared_ptr<node_impl>>>
1729
- exec_graph_impl::getURUpdatableNodes (
1730
- const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1727
+ std::map<int , std::vector<node_impl *>>
1728
+ exec_graph_impl::getURUpdatableNodes (nodes_range Nodes) const {
1731
1729
// Iterate over the list of nodes, and for every node that can
1732
1730
// be updated through UR, add it to the list of nodes for
1733
1731
// that can be updated for the UR command-buffer partition.
1734
- std::map<int , std::vector<std::shared_ptr< node_impl> >> PartitionedNodes;
1732
+ std::map<int , std::vector<node_impl * >> PartitionedNodes;
1735
1733
1736
1734
// Initialize vector for each partition
1737
1735
for (size_t i = 0 ; i < MPartitions.size (); i++) {
1738
1736
PartitionedNodes[i] = {};
1739
1737
}
1740
1738
1741
- for (auto &Node : Nodes) {
1739
+ for (node_impl &Node : Nodes) {
1742
1740
// Kernel node update is the only command type supported in UR for update.
1743
- if (Node-> MCGType != sycl::detail::CGType::Kernel) {
1741
+ if (Node. MCGType != sycl::detail::CGType::Kernel) {
1744
1742
continue ;
1745
1743
}
1746
1744
1747
- auto ExecNode = MIDCache.find (Node-> MID );
1745
+ auto ExecNode = MIDCache.find (Node. MID );
1748
1746
assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1749
1747
auto PartitionIndex = MPartitionNodes.find (ExecNode->second .get ());
1750
1748
assert (PartitionIndex != MPartitionNodes.end ());
1751
- PartitionedNodes[PartitionIndex->second ].push_back (Node);
1749
+ PartitionedNodes[PartitionIndex->second ].push_back (& Node);
1752
1750
}
1753
1751
1754
1752
return PartitionedNodes;
@@ -1765,13 +1763,12 @@ void exec_graph_impl::updateHostTasksImpl(
1765
1763
auto ExecNode = MIDCache.find (Node->MID );
1766
1764
assert (ExecNode != MIDCache.end () && " Node ID was not found in ID cache" );
1767
1765
1768
- ExecNode->second ->updateFromOtherNode (Node);
1766
+ ExecNode->second ->updateFromOtherNode (* Node);
1769
1767
}
1770
1768
}
1771
1769
1772
- void exec_graph_impl::updateURImpl (
1773
- ur_exp_command_buffer_handle_t CommandBuffer,
1774
- const std::vector<std::shared_ptr<node_impl>> &Nodes) const {
1770
+ void exec_graph_impl::updateURImpl (ur_exp_command_buffer_handle_t CommandBuffer,
1771
+ nodes_range Nodes) const {
1775
1772
const size_t NumUpdatableNodes = Nodes.size ();
1776
1773
if (NumUpdatableNodes == 0 ) {
1777
1774
return ;
@@ -1797,10 +1794,10 @@ void exec_graph_impl::updateURImpl(
1797
1794
std::vector<FastKernelCacheValPtr> KernelBundleObjList (NumUpdatableNodes);
1798
1795
1799
1796
size_t StructListIndex = 0 ;
1800
- for (auto &Node : Nodes) {
1797
+ for (node_impl &Node : Nodes) {
1801
1798
// This should be the case when getURUpdatableNodes() is used to
1802
1799
// create the list of nodes.
1803
- assert (Node-> MCGType == sycl::detail::CGType::Kernel);
1800
+ assert (Node. MCGType == sycl::detail::CGType::Kernel);
1804
1801
1805
1802
auto &MemobjDescs = MemobjDescsList[StructListIndex];
1806
1803
auto &MemobjProps = MemobjPropsList[StructListIndex];
0 commit comments