Skip to content

Commit 3b208cc

Browse files
[NFC][SYCL][Graph] Prepare for getSyclObjImpl to return raw ref (#19224)
I'm planning to change `getSyclObjImpl` to return a raw reference in a later patch, uploading a bunch of PRs in preparation to that to make the subsequent review easier. I also did `s/sycl::detail::getSyclObjImpl/getSyclObjImpl/g` here to increase readability and to make it simpler to perform the future patch by simple `sed` (and avoid multiline changes). This PR originated in `unittests/Extensions/CommandGraph`, but I had to modify `isSimilar` and `hasSimilarStructure` in the implementation itself too.
1 parent b9b441f commit 3b208cc

File tree

12 files changed

+505
-570
lines changed

12 files changed

+505
-570
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -905,12 +905,12 @@ void exec_graph_impl::createCommandBuffers(
905905
ur_exp_command_buffer_desc_t Desc{
906906
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC, nullptr, MIsUpdatable,
907907
Partition->MIsInOrderGraph && !MEnableProfiling, MEnableProfiling};
908-
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
909-
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
908+
context_impl &ContextImpl = *sycl::detail::getSyclObjImpl(MContext);
909+
const sycl::detail::AdapterPtr &Adapter = ContextImpl.getAdapter();
910910
sycl::detail::device_impl &DeviceImpl = *sycl::detail::getSyclObjImpl(Device);
911911
ur_result_t Res =
912912
Adapter->call_nocheck<sycl::detail::UrApiKind::urCommandBufferCreateExp>(
913-
ContextImpl->getHandleRef(), DeviceImpl.getHandleRef(), &Desc,
913+
ContextImpl.getHandleRef(), DeviceImpl.getHandleRef(), &Desc,
914914
&OutCommandBuffer);
915915
if (Res != UR_RESULT_SUCCESS) {
916916
throw sycl::exception(errc::invalid, "Failed to create UR command-buffer");
@@ -1636,8 +1636,9 @@ void exec_graph_impl::populateURKernelUpdateStructs(
16361636
std::vector<ur_exp_command_buffer_update_value_arg_desc_t> &ValueDescs,
16371637
sycl::detail::NDRDescT &NDRDesc,
16381638
ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) const {
1639-
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
1640-
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
1639+
sycl::detail::context_impl &ContextImpl =
1640+
*sycl::detail::getSyclObjImpl(MContext);
1641+
const sycl::detail::AdapterPtr &Adapter = ContextImpl.getAdapter();
16411642
sycl::detail::device_impl &DeviceImpl =
16421643
*sycl::detail::getSyclObjImpl(MGraphImpl->getDevice());
16431644

@@ -1665,7 +1666,7 @@ void exec_graph_impl::populateURKernelUpdateStructs(
16651666
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
16661667
} else {
16671668
BundleObjs = sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
1668-
*ContextImpl, DeviceImpl, ExecCG.MKernelName,
1669+
ContextImpl, DeviceImpl, ExecCG.MKernelName,
16691670
ExecCG.MKernelNameBasedCachePtr);
16701671
UrKernel = BundleObjs->MKernelHandle;
16711672
EliminatedArgMask = BundleObjs->MKernelArgMask;
@@ -1884,8 +1885,8 @@ void exec_graph_impl::updateURImpl(
18841885
StructListIndex++;
18851886
}
18861887

1887-
auto ContextImpl = sycl::detail::getSyclObjImpl(MContext);
1888-
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
1888+
context_impl &ContextImpl = *sycl::detail::getSyclObjImpl(MContext);
1889+
const sycl::detail::AdapterPtr &Adapter = ContextImpl.getAdapter();
18891890
Adapter->call<sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
18901891
CommandBuffer, UpdateDescList.size(), UpdateDescList.data());
18911892
}

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
358358
size_t FoundCnt = 0;
359359
for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
360360
for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
361-
if (NodeA->isSimilar(NodeB) &&
361+
if (NodeA->isSimilar(*NodeB) &&
362362
checkNodeRecursive(SuccA.lock(), SuccB.lock())) {
363363
FoundCnt++;
364364
break;
@@ -383,44 +383,44 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
383383
/// @param DebugPrint if set to true throw exception with additional debug
384384
/// information about the spotted graph differences.
385385
/// @return true if the two graphs are similar, false otherwise
386-
bool hasSimilarStructure(std::shared_ptr<detail::graph_impl> Graph,
386+
bool hasSimilarStructure(detail::graph_impl &Graph,
387387
bool DebugPrint = false) const {
388-
if (this == Graph.get())
388+
if (this == &Graph)
389389
return true;
390390

391-
if (MContext != Graph->MContext) {
391+
if (MContext != Graph.MContext) {
392392
if (DebugPrint) {
393393
throw sycl::exception(sycl::make_error_code(errc::invalid),
394394
"MContext are not the same.");
395395
}
396396
return false;
397397
}
398398

399-
if (MDevice != Graph->MDevice) {
399+
if (MDevice != Graph.MDevice) {
400400
if (DebugPrint) {
401401
throw sycl::exception(sycl::make_error_code(errc::invalid),
402402
"MDevice are not the same.");
403403
}
404404
return false;
405405
}
406406

407-
if (MEventsMap.size() != Graph->MEventsMap.size()) {
407+
if (MEventsMap.size() != Graph.MEventsMap.size()) {
408408
if (DebugPrint) {
409409
throw sycl::exception(sycl::make_error_code(errc::invalid),
410410
"MEventsMap sizes are not the same.");
411411
}
412412
return false;
413413
}
414414

415-
if (MInorderQueueMap.size() != Graph->MInorderQueueMap.size()) {
415+
if (MInorderQueueMap.size() != Graph.MInorderQueueMap.size()) {
416416
if (DebugPrint) {
417417
throw sycl::exception(sycl::make_error_code(errc::invalid),
418418
"MInorderQueueMap sizes are not the same.");
419419
}
420420
return false;
421421
}
422422

423-
if (MRoots.size() != Graph->MRoots.size()) {
423+
if (MRoots.size() != Graph.MRoots.size()) {
424424
if (DebugPrint) {
425425
throw sycl::exception(sycl::make_error_code(errc::invalid),
426426
"MRoots sizes are not the same.");
@@ -430,11 +430,11 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
430430

431431
size_t RootsFound = 0;
432432
for (std::weak_ptr<node_impl> NodeA : MRoots) {
433-
for (std::weak_ptr<node_impl> NodeB : Graph->MRoots) {
433+
for (std::weak_ptr<node_impl> NodeB : Graph.MRoots) {
434434
auto NodeALocked = NodeA.lock();
435435
auto NodeBLocked = NodeB.lock();
436436

437-
if (NodeALocked->isSimilar(NodeBLocked)) {
437+
if (NodeALocked->isSimilar(*NodeBLocked)) {
438438
if (checkNodeRecursive(NodeALocked, NodeBLocked)) {
439439
RootsFound++;
440440
break;

sycl/source/detail/graph/node_impl.hpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -326,31 +326,30 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
326326
/// @param CompareContentOnly Skip comparisons related to graph structure,
327327
/// compare only the type and command groups of the nodes
328328
/// @return True if the two nodes are similar
329-
bool isSimilar(const std::shared_ptr<node_impl> &Node,
330-
bool CompareContentOnly = false) const {
329+
bool isSimilar(node_impl &Node, bool CompareContentOnly = false) const {
331330
if (!CompareContentOnly) {
332-
if (MSuccessors.size() != Node->MSuccessors.size())
331+
if (MSuccessors.size() != Node.MSuccessors.size())
333332
return false;
334333

335-
if (MPredecessors.size() != Node->MPredecessors.size())
334+
if (MPredecessors.size() != Node.MPredecessors.size())
336335
return false;
337336
}
338-
if (MCGType != Node->MCGType)
337+
if (MCGType != Node.MCGType)
339338
return false;
340339

341340
switch (MCGType) {
342341
case sycl::detail::CGType::Kernel: {
343342
sycl::detail::CGExecKernel *ExecKernelA =
344343
static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
345344
sycl::detail::CGExecKernel *ExecKernelB =
346-
static_cast<sycl::detail::CGExecKernel *>(Node->MCommandGroup.get());
345+
static_cast<sycl::detail::CGExecKernel *>(Node.MCommandGroup.get());
347346
return ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) == 0;
348347
}
349348
case sycl::detail::CGType::CopyUSM: {
350349
sycl::detail::CGCopyUSM *CopyA =
351350
static_cast<sycl::detail::CGCopyUSM *>(MCommandGroup.get());
352351
sycl::detail::CGCopyUSM *CopyB =
353-
static_cast<sycl::detail::CGCopyUSM *>(Node->MCommandGroup.get());
352+
static_cast<sycl::detail::CGCopyUSM *>(Node.MCommandGroup.get());
354353
return (CopyA->getSrc() == CopyB->getSrc()) &&
355354
(CopyA->getDst() == CopyB->getDst()) &&
356355
(CopyA->getLength() == CopyB->getLength());
@@ -361,7 +360,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
361360
sycl::detail::CGCopy *CopyA =
362361
static_cast<sycl::detail::CGCopy *>(MCommandGroup.get());
363362
sycl::detail::CGCopy *CopyB =
364-
static_cast<sycl::detail::CGCopy *>(Node->MCommandGroup.get());
363+
static_cast<sycl::detail::CGCopy *>(Node.MCommandGroup.get());
365364
return (CopyA->getSrc() == CopyB->getSrc()) &&
366365
(CopyA->getDst() == CopyB->getDst());
367366
}

0 commit comments

Comments
 (0)