Skip to content

Commit e8fadef

Browse files
[NFC][SYCL] Prepare program_manager.cpp for getSyclObjImpl to return raw ref (#19252)
I'm planning to change `getSyclObjImpl` to return a raw reference in a later patch, uploading a bunch of PRs in preparation to that to make the subsequent review easier.
1 parent 9258938 commit e8fadef

File tree

1 file changed

+42
-49
lines changed

1 file changed

+42
-49
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 42 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -871,14 +871,12 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
871871
if (!DeviceImpl.isRootDevice()) {
872872
RootDevImpl = &DeviceImpl;
873873
while (!RootDevImpl->isRootDevice()) {
874-
device_impl *ParentDev =
875-
detail::getSyclObjImpl(
876-
RootDevImpl->get_info<info::device::parent_device>())
877-
.get();
874+
device_impl &ParentDev = *detail::getSyclObjImpl(
875+
RootDevImpl->get_info<info::device::parent_device>());
878876
// Sharing is allowed within a single context only
879-
if (!ContextImpl.hasDevice(*ParentDev))
877+
if (!ContextImpl.hasDevice(ParentDev))
880878
break;
881-
RootDevImpl = ParentDev;
879+
RootDevImpl = &ParentDev;
882880
}
883881

884882
ContextImpl.getAdapter()->call<UrApiKind::urDeviceGetInfo>(
@@ -2714,7 +2712,7 @@ device_image_plain ProgramManager::createDependencyImage(
27142712
void ProgramManager::bringSYCLDeviceImageToState(
27152713
DevImgPlainWithDeps &DeviceImage, bundle_state TargetState) {
27162714
device_image_plain &MainImg = DeviceImage.getMain();
2717-
const DeviceImageImplPtr &MainImgImpl = getSyclObjImpl(MainImg);
2715+
device_image_impl &MainImgImpl = *getSyclObjImpl(MainImg);
27182716
const bundle_state DevImageState = getSyclObjImpl(MainImg)->get_state();
27192717
// At this time, there is no circumstance where a device image should ever
27202718
// be in the source state. That not good.
@@ -2731,7 +2729,7 @@ void ProgramManager::bringSYCLDeviceImageToState(
27312729
break;
27322730
case bundle_state::object:
27332731
if (DevImageState == bundle_state::input) {
2734-
DeviceImage = compile(DeviceImage, MainImgImpl->get_devices(),
2732+
DeviceImage = compile(DeviceImage, MainImgImpl.get_devices(),
27352733
/*PropList=*/{});
27362734
break;
27372735
}
@@ -2746,12 +2744,12 @@ void ProgramManager::bringSYCLDeviceImageToState(
27462744
assert(DevImageState != bundle_state::ext_oneapi_source);
27472745
break;
27482746
case bundle_state::input:
2749-
DeviceImage = build(DeviceImage, MainImgImpl->get_devices(),
2747+
DeviceImage = build(DeviceImage, MainImgImpl.get_devices(),
27502748
/*PropList=*/{});
27512749
break;
27522750
case bundle_state::object: {
27532751
std::vector<device_image_plain> LinkedDevImages =
2754-
link(DeviceImage.getAll(), MainImgImpl->get_devices(),
2752+
link(DeviceImage.getAll(), MainImgImpl.get_devices(),
27552753
/*PropList=*/{});
27562754
// Since only one device image is passed here one output device image is
27572755
// expected
@@ -2760,7 +2758,7 @@ void ProgramManager::bringSYCLDeviceImageToState(
27602758
break;
27612759
}
27622760
case bundle_state::executable:
2763-
DeviceImage = build(DeviceImage, MainImgImpl->get_devices(),
2761+
DeviceImage = build(DeviceImage, MainImgImpl.get_devices(),
27642762
/*PropList=*/{});
27652763
break;
27662764
}
@@ -2920,21 +2918,20 @@ mergeImageData(const std::vector<device_image_plain> &Imgs,
29202918
device_image_impl::SpecConstMapT &NewSpecConstMap,
29212919
std::unique_ptr<DynRTDeviceBinaryImage> &MergedImageStorage) {
29222920
for (const device_image_plain &Img : Imgs) {
2923-
const std::shared_ptr<device_image_impl> &DeviceImageImpl =
2924-
getSyclObjImpl(Img);
2921+
device_image_impl &DeviceImageImpl = *getSyclObjImpl(Img);
29252922
// Duplicates are not expected here, otherwise urProgramLink should fail
2926-
if (DeviceImageImpl->get_kernel_ids_ptr())
2923+
if (DeviceImageImpl.get_kernel_ids_ptr())
29272924
KernelIDs.insert(KernelIDs.end(),
2928-
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
2929-
DeviceImageImpl->get_kernel_ids_ptr()->end());
2925+
DeviceImageImpl.get_kernel_ids_ptr()->begin(),
2926+
DeviceImageImpl.get_kernel_ids_ptr()->end());
29302927
// To be able to answer queries about specialziation constants, the new
29312928
// device image should have the specialization constants from all the linked
29322929
// images.
29332930
const std::lock_guard<std::mutex> SpecConstLock(
2934-
DeviceImageImpl->get_spec_const_data_lock());
2931+
DeviceImageImpl.get_spec_const_data_lock());
29352932
// Copy all map entries to the new map. Since the blob will be copied to
29362933
// the end of the new blob we need to move the blob offset of each entry.
2937-
for (const auto &SpecConstIt : DeviceImageImpl->get_spec_const_data_ref()) {
2934+
for (const auto &SpecConstIt : DeviceImageImpl.get_spec_const_data_ref()) {
29382935
std::vector<device_image_impl::SpecConstDescT> &NewDescEntries =
29392936
NewSpecConstMap[SpecConstIt.first];
29402937

@@ -2952,8 +2949,8 @@ mergeImageData(const std::vector<device_image_plain> &Imgs,
29522949
// Copy the blob from the device image into the new blob. This moves the
29532950
// offsets of the following blobs.
29542951
NewSpecConstBlob.insert(NewSpecConstBlob.end(),
2955-
DeviceImageImpl->get_spec_const_blob_ref().begin(),
2956-
DeviceImageImpl->get_spec_const_blob_ref().end());
2952+
DeviceImageImpl.get_spec_const_blob_ref().begin(),
2953+
DeviceImageImpl.get_spec_const_blob_ref().end());
29572954
}
29582955
// device_image_impl expects kernel ids to be sorted for fast search
29592956
std::sort(KernelIDs.begin(), KernelIDs.end(), LessByHash<kernel_id>{});
@@ -2999,14 +2996,13 @@ ProgramManager::link(const std::vector<device_image_plain> &Imgs,
29992996
// FIXME: Linker options are picked from the first object, but is that safe?
30002997
std::string LinkOptionsStr;
30012998
applyLinkOptionsFromEnvironment(LinkOptionsStr);
3002-
const std::shared_ptr<device_image_impl> &FirstImgImpl =
3003-
getSyclObjImpl(Imgs[0]);
3004-
if (LinkOptionsStr.empty() && FirstImgImpl->get_bin_image_ref())
2999+
device_image_impl &FirstImgImpl = *getSyclObjImpl(Imgs[0]);
3000+
if (LinkOptionsStr.empty() && FirstImgImpl.get_bin_image_ref())
30053001
appendLinkOptionsFromImage(LinkOptionsStr,
3006-
*(FirstImgImpl->get_bin_image_ref()));
3002+
*(FirstImgImpl.get_bin_image_ref()));
30073003
// Should always come last!
30083004
appendLinkEnvironmentVariablesThatAppend(LinkOptionsStr);
3009-
const context &Context = FirstImgImpl->get_context();
3005+
const context &Context = FirstImgImpl.get_context();
30103006
context_impl &ContextImpl = *getSyclObjImpl(Context);
30113007
const AdapterPtr &Adapter = ContextImpl.getAdapter();
30123008

@@ -3059,11 +3055,9 @@ ProgramManager::link(const std::vector<device_image_plain> &Imgs,
30593055
// removal of map entries with same handle (obviously invalid entries).
30603056
std::ignore = NativePrograms.erase(LinkedProg);
30613057
for (const device_image_plain &Img : Imgs) {
3062-
const std::shared_ptr<device_image_impl> &ImgImpl = getSyclObjImpl(Img);
3063-
if (ImgImpl->get_bin_image_ref())
3058+
if (auto BinImageRef = getSyclObjImpl(Img)->get_bin_image_ref())
30643059
NativePrograms.insert(
3065-
{LinkedProg,
3066-
{ContextImpl.shared_from_this(), ImgImpl->get_bin_image_ref()}});
3060+
{LinkedProg, {ContextImpl.shared_from_this(), BinImageRef}});
30673061
}
30683062
}
30693063

@@ -3077,14 +3071,14 @@ ProgramManager::link(const std::vector<device_image_plain> &Imgs,
30773071
KernelNameSetT MergedKernelNames;
30783072
std::unordered_map<std::string, KernelArgMask> MergedEliminatedKernelArgMasks;
30793073
for (const device_image_plain &DevImg : Imgs) {
3080-
const DeviceImageImplPtr &DevImgImpl = getSyclObjImpl(DevImg);
3081-
CombinedOrigins |= DevImgImpl->getOriginMask();
3082-
RTCInfoPtrs.emplace_back(&(DevImgImpl->getRTCInfo()));
3083-
MergedKernelNames.insert(DevImgImpl->getKernelNames().begin(),
3084-
DevImgImpl->getKernelNames().end());
3074+
device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg);
3075+
CombinedOrigins |= DevImgImpl.getOriginMask();
3076+
RTCInfoPtrs.emplace_back(&(DevImgImpl.getRTCInfo()));
3077+
MergedKernelNames.insert(DevImgImpl.getKernelNames().begin(),
3078+
DevImgImpl.getKernelNames().end());
30853079
MergedEliminatedKernelArgMasks.insert(
3086-
DevImgImpl->getEliminatedKernelArgMasks().begin(),
3087-
DevImgImpl->getEliminatedKernelArgMasks().end());
3080+
DevImgImpl.getEliminatedKernelArgMasks().begin(),
3081+
DevImgImpl.getEliminatedKernelArgMasks().end());
30883082
}
30893083
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge(RTCInfoPtrs);
30903084

@@ -3114,10 +3108,9 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
31143108
PropList, NoAllowedPropertiesCheck, NoAllowedPropertiesCheck);
31153109
}
31163110

3117-
const std::shared_ptr<device_image_impl> &MainInputImpl =
3118-
getSyclObjImpl(DevImgWithDeps.getMain());
3111+
device_image_impl &MainInputImpl = *getSyclObjImpl(DevImgWithDeps.getMain());
31193112

3120-
const context &Context = MainInputImpl->get_context();
3113+
const context &Context = MainInputImpl.get_context();
31213114
context_impl &ContextImpl = *detail::getSyclObjImpl(Context);
31223115

31233116
std::vector<const RTDeviceBinaryImage *> BinImgs;
@@ -3130,7 +3123,7 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
31303123
device_image_impl::SpecConstMapT SpecConstMap;
31313124

31323125
std::unique_ptr<DynRTDeviceBinaryImage> MergedImageStorage;
3133-
const RTDeviceBinaryImage *ResultBinImg = MainInputImpl->get_bin_image_ref();
3126+
const RTDeviceBinaryImage *ResultBinImg = MainInputImpl.get_bin_image_ref();
31343127
if (DevImgWithDeps.hasDeps()) {
31353128
KernelIDs = std::make_shared<std::vector<kernel_id>>();
31363129
// Sort the images to make the order of spec constant values used for
@@ -3144,9 +3137,9 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
31443137
ResultBinImg = mergeImageData(SortedImgs, *KernelIDs, SpecConstBlob,
31453138
SpecConstMap, MergedImageStorage);
31463139
} else {
3147-
KernelIDs = MainInputImpl->get_kernel_ids_ptr();
3148-
SpecConstBlob = MainInputImpl->get_spec_const_blob_ref();
3149-
SpecConstMap = MainInputImpl->get_spec_const_data_ref();
3140+
KernelIDs = MainInputImpl.get_kernel_ids_ptr();
3141+
SpecConstBlob = MainInputImpl.get_spec_const_blob_ref();
3142+
SpecConstMap = MainInputImpl.get_spec_const_data_ref();
31503143
}
31513144

31523145
ur_program_handle_t ResProgram = getBuiltURProgram(
@@ -3163,13 +3156,13 @@ ProgramManager::build(const DevImgPlainWithDeps &DevImgWithDeps,
31633156
KernelNameSetT MergedKernelNames;
31643157
std::unordered_map<std::string, KernelArgMask> MergedEliminatedKernelArgMasks;
31653158
for (const device_image_plain &DevImg : DevImgWithDeps) {
3166-
const auto &DevImgImpl = getSyclObjImpl(DevImg);
3167-
RTCInfoPtrs.emplace_back(&(DevImgImpl->getRTCInfo()));
3168-
MergedKernelNames.insert(DevImgImpl->getKernelNames().begin(),
3169-
DevImgImpl->getKernelNames().end());
3159+
device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg);
3160+
RTCInfoPtrs.emplace_back(&(DevImgImpl.getRTCInfo()));
3161+
MergedKernelNames.insert(DevImgImpl.getKernelNames().begin(),
3162+
DevImgImpl.getKernelNames().end());
31703163
MergedEliminatedKernelArgMasks.insert(
3171-
DevImgImpl->getEliminatedKernelArgMasks().begin(),
3172-
DevImgImpl->getEliminatedKernelArgMasks().end());
3164+
DevImgImpl.getEliminatedKernelArgMasks().begin(),
3165+
DevImgImpl.getEliminatedKernelArgMasks().end());
31733166
}
31743167
auto MergedRTCInfo = detail::KernelCompilerBinaryInfo::Merge(RTCInfoPtrs);
31753168

0 commit comments

Comments
 (0)