@@ -409,10 +409,8 @@ void graph_impl::markCGMemObjs(
409
409
}
410
410
}
411
411
412
- std::shared_ptr<node_impl> graph_impl::add (nodes_range Deps) {
413
- const std::shared_ptr<node_impl> &NodeImpl = std::make_shared<node_impl>();
414
-
415
- MNodeStorage.push_back (NodeImpl);
412
+ node_impl &graph_impl::add (nodes_range Deps) {
413
+ node_impl &NodeImpl = createNode ();
416
414
417
415
addDepsToNode (NodeImpl, Deps);
418
416
// Add an event associated with this explicit node for mixed usage
@@ -421,10 +419,9 @@ std::shared_ptr<node_impl> graph_impl::add(nodes_range Deps) {
421
419
return NodeImpl;
422
420
}
423
421
424
- std::shared_ptr<node_impl>
425
- graph_impl::add (std::function<void (handler &)> CGF,
426
- const std::vector<sycl::detail::ArgDesc> &Args,
427
- std::vector<std::shared_ptr<node_impl>> &Deps) {
422
+ node_impl &graph_impl::add (std::function<void (handler &)> CGF,
423
+ const std::vector<sycl::detail::ArgDesc> &Args,
424
+ nodes_range Deps) {
428
425
(void )Args;
429
426
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
430
427
detail::handler_impl HandlerImpl{*this };
@@ -435,7 +432,9 @@ graph_impl::add(std::function<void(handler &)> CGF,
435
432
436
433
// Pass the node deps to the handler so they are available when processing the
437
434
// CGF, need for async_malloc nodes.
438
- Handler.impl ->MNodeDeps = Deps;
435
+ Handler.impl ->MNodeDeps .clear (); // TODO: Is that right?
436
+ for (node_impl &N : Deps)
437
+ Handler.impl ->MNodeDeps .push_back (N.shared_from_this ());
439
438
440
439
#if XPTI_ENABLE_INSTRUMENTATION
441
440
// Save code location if one was set in TLS.
@@ -471,7 +470,7 @@ graph_impl::add(std::function<void(handler &)> CGF,
471
470
: ext::oneapi::experimental::detail::getNodeTypeFromCG (
472
471
Handler.getType ());
473
472
474
- auto NodeImpl =
473
+ node_impl & NodeImpl =
475
474
this ->add (NodeType, std::move (Handler.impl ->MGraphNodeCG ), Deps);
476
475
477
476
// Add an event associated with this explicit node for mixed usage
@@ -489,26 +488,23 @@ graph_impl::add(std::function<void(handler &)> CGF,
489
488
}
490
489
491
490
for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
492
- DynamicParam->registerNode (NodeImpl, ArgIndex);
491
+ DynamicParam->registerNode (NodeImpl. shared_from_this () , ArgIndex);
493
492
}
494
493
495
494
return NodeImpl;
496
495
}
497
496
498
- std::shared_ptr<node_impl>
499
- graph_impl::add (node_type NodeType,
500
- std::shared_ptr<sycl::detail::CG> CommandGroup,
501
- nodes_range Deps) {
497
+ node_impl &graph_impl::add (node_type NodeType,
498
+ std::shared_ptr<sycl::detail::CG> CommandGroup,
499
+ nodes_range Deps) {
502
500
503
501
// A unique set of dependencies obtained by checking requirements and events
504
502
std::set<std::shared_ptr<node_impl>> UniqueDeps = getCGEdges (CommandGroup);
505
503
506
504
// Track and mark the memory objects being used by the graph.
507
505
markCGMemObjs (CommandGroup);
508
506
509
- const std::shared_ptr<node_impl> &NodeImpl =
510
- std::make_shared<node_impl>(NodeType, std::move (CommandGroup));
511
- MNodeStorage.push_back (NodeImpl);
507
+ node_impl &NodeImpl = createNode (NodeType, std::move (CommandGroup));
512
508
513
509
// Add any deps determined from requirements and events into the dependency
514
510
// list
@@ -517,16 +513,17 @@ graph_impl::add(node_type NodeType,
517
513
518
514
if (NodeType == node_type::async_free) {
519
515
auto AsyncFreeCG =
520
- static_cast <CGAsyncFree *>(NodeImpl-> MCommandGroup .get ());
516
+ static_cast <CGAsyncFree *>(NodeImpl. MCommandGroup .get ());
521
517
// If this is an async free node mark that it is now available for reuse,
522
518
// and pass the async free node for tracking.
523
- MGraphMemPool.markAllocationAsAvailable (AsyncFreeCG->getPtr (), NodeImpl);
519
+ MGraphMemPool.markAllocationAsAvailable (AsyncFreeCG->getPtr (),
520
+ NodeImpl.shared_from_this ());
524
521
}
525
522
526
523
return NodeImpl;
527
524
}
528
525
529
- std::shared_ptr< node_impl>
526
+ node_impl&
530
527
graph_impl::add (std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
531
528
nodes_range Deps) {
532
529
// Set of Dependent nodes based on CG event and accessor dependencies.
@@ -551,15 +548,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
551
548
const auto &ActiveKernel = DynCGImpl->getActiveCG ();
552
549
node_type NodeType =
553
550
ext::oneapi::experimental::detail::getNodeTypeFromCG (DynCGImpl->MCGType );
554
- std::shared_ptr<detail::node_impl> NodeImpl =
555
- add (NodeType, ActiveKernel, Deps);
551
+ detail::node_impl &NodeImpl = add (NodeType, ActiveKernel, Deps);
556
552
557
553
// Add an event associated with this explicit node for mixed usage
558
554
addEventForNode (sycl::detail::event_impl::create_completed_host_event (),
559
555
NodeImpl);
560
556
561
557
// Track the dynamic command-group used inside the node object
562
- DynCGImpl->MNodes .push_back (NodeImpl);
558
+ DynCGImpl->MNodes .push_back (NodeImpl. shared_from_this () );
563
559
564
560
return NodeImpl;
565
561
}
@@ -652,7 +648,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
652
648
bool DestWasGraphRoot = Dest->MPredecessors .size () == 0 ;
653
649
654
650
// We need to add the edges first before checking for cycles
655
- Src->registerSuccessor (Dest);
651
+ Src->registerSuccessor (* Dest);
656
652
657
653
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors .size () == 1 ;
658
654
if (DestLostRootStatus) {
@@ -1265,7 +1261,7 @@ void exec_graph_impl::duplicateNodes() {
1265
1261
// Look through all the original node successors, find their copies and
1266
1262
// register those as successors with the current copied node
1267
1263
for (node_impl &NextNode : OriginalNode->successors ()) {
1268
- auto Successor = NodesMap.at (NextNode.shared_from_this ());
1264
+ node_impl & Successor = * NodesMap.at (NextNode.shared_from_this ());
1269
1265
NodeCopy->registerSuccessor (Successor);
1270
1266
}
1271
1267
}
@@ -1307,7 +1303,7 @@ void exec_graph_impl::duplicateNodes() {
1307
1303
auto NodeCopy = NewSubgraphNodes[i];
1308
1304
1309
1305
for (node_impl &NextNode : SubgraphNode->successors ()) {
1310
- auto Successor = SubgraphNodesMap.at (NextNode.shared_from_this ());
1306
+ node_impl & Successor = * SubgraphNodesMap.at (NextNode.shared_from_this ());
1311
1307
NodeCopy->registerSuccessor (Successor);
1312
1308
}
1313
1309
}
@@ -1341,7 +1337,7 @@ void exec_graph_impl::duplicateNodes() {
1341
1337
// Add all input nodes from the subgraph as successors for this node
1342
1338
// instead
1343
1339
for (auto &Input : Inputs) {
1344
- PredNode.registerSuccessor (Input);
1340
+ PredNode.registerSuccessor (* Input);
1345
1341
}
1346
1342
}
1347
1343
@@ -1360,7 +1356,7 @@ void exec_graph_impl::duplicateNodes() {
1360
1356
// Add all Output nodes from the subgraph as predecessors for this node
1361
1357
// instead
1362
1358
for (auto &Output : Outputs) {
1363
- Output->registerSuccessor (SuccNode. shared_from_this () );
1359
+ Output->registerSuccessor (SuccNode);
1364
1360
}
1365
1361
}
1366
1362
@@ -1843,38 +1839,25 @@ node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF,
1843
1839
" dynamic command-group." );
1844
1840
}
1845
1841
1846
- std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1847
- for (auto &D : Deps) {
1848
- DepImpls.push_back (sycl::detail::getSyclObjImpl (D));
1849
- }
1850
-
1851
1842
graph_impl::WriteLock Lock (impl->MMutex );
1852
- std::shared_ptr< detail::node_impl> NodeImpl = impl->add (DynCGFImpl, DepImpls );
1853
- return sycl::detail::createSyclObjFromImpl<node>(std::move ( NodeImpl) );
1843
+ detail::node_impl & NodeImpl = impl->add (DynCGFImpl, Deps );
1844
+ return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
1854
1845
}
1855
1846
1856
1847
node modifiable_command_graph::addImpl (const std::vector<node> &Deps) {
1857
1848
impl->throwIfGraphRecordingQueue (" Explicit API \" Add()\" function" );
1858
- std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1859
- for (auto &D : Deps) {
1860
- DepImpls.push_back (sycl::detail::getSyclObjImpl (D));
1861
- }
1862
1849
1863
1850
graph_impl::WriteLock Lock (impl->MMutex );
1864
- std::shared_ptr< detail::node_impl> NodeImpl = impl->add (DepImpls );
1865
- return sycl::detail::createSyclObjFromImpl<node>(std::move ( NodeImpl) );
1851
+ detail::node_impl & NodeImpl = impl->add (Deps );
1852
+ return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
1866
1853
}
1867
1854
1868
1855
node modifiable_command_graph::addImpl (std::function<void (handler &)> CGF,
1869
1856
const std::vector<node> &Deps) {
1870
1857
impl->throwIfGraphRecordingQueue (" Explicit API \" Add()\" function" );
1871
- std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1872
- for (auto &D : Deps) {
1873
- DepImpls.push_back (sycl::detail::getSyclObjImpl (D));
1874
- }
1875
1858
1876
- std::shared_ptr< detail::node_impl> NodeImpl = impl->add (CGF, {}, DepImpls );
1877
- return sycl::detail::createSyclObjFromImpl<node>(std::move ( NodeImpl) );
1859
+ detail::node_impl & NodeImpl = impl->add (CGF, {}, Deps );
1860
+ return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
1878
1861
}
1879
1862
1880
1863
void modifiable_command_graph::addGraphLeafDependencies (node Node) {
0 commit comments