Skip to content

Commit 9805c6c

Browse files
[NFC][SYCL] Prepare kernel_bundle_impl.hpp for getSyclObjImpl to return raw ref (#19251)
I'm planning to change `getSyclObjImpl` to return a raw reference in a later patch, uploading a bunch of PRs in preparation to that to make the subsequent review easier.
1 parent 3b208cc commit 9805c6c

File tree

1 file changed

+27
-31
lines changed

1 file changed

+27
-31
lines changed

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,11 @@ class kernel_bundle_impl
138138
: MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
139139
MState(TargetState) {
140140

141-
const std::shared_ptr<kernel_bundle_impl> &InputBundleImpl =
142-
getSyclObjImpl(InputBundle);
143-
MSpecConstValues = InputBundleImpl->get_spec_const_map_ref();
141+
kernel_bundle_impl &InputBundleImpl = *getSyclObjImpl(InputBundle);
142+
MSpecConstValues = InputBundleImpl.get_spec_const_map_ref();
144143

145144
const std::vector<device> &InputBundleDevices =
146-
InputBundleImpl->get_devices();
145+
InputBundleImpl.get_devices();
147146
const bool AllDevsAssociatedWithInputBundle =
148147
std::all_of(MDevices.begin(), MDevices.end(),
149148
[&InputBundleDevices](const device &Dev) {
@@ -158,11 +157,11 @@ class kernel_bundle_impl
158157
"devices for input bundle or vector of devices is empty");
159158

160159
// Copy SYCLBINs to ensure lifetime is preserved by the executable bundle.
161-
MSYCLBINs.insert(MSYCLBINs.end(), InputBundleImpl->MSYCLBINs.begin(),
162-
InputBundleImpl->MSYCLBINs.end());
160+
MSYCLBINs.insert(MSYCLBINs.end(), InputBundleImpl.MSYCLBINs.begin(),
161+
InputBundleImpl.MSYCLBINs.end());
163162

164163
for (const DevImgPlainWithDeps &DevImgWithDeps :
165-
InputBundleImpl->MDeviceImages) {
164+
InputBundleImpl.MDeviceImages) {
166165
// Skip images which are not compatible with devices provided
167166
if (std::none_of(MDevices.begin(), MDevices.end(),
168167
[&DevImgWithDeps](const device &Dev) {
@@ -311,11 +310,11 @@ class kernel_bundle_impl
311310
// images collection.
312311
std::map<std::string_view, size_t> ExportMap;
313312
for (size_t I = 0; I < DevImages.size(); ++I) {
314-
auto DevImageImpl = getSyclObjImpl(DevImages[I]);
315-
if (DevImageImpl->get_bin_image_ref() == nullptr)
313+
device_image_impl &DevImageImpl = *getSyclObjImpl(DevImages[I]);
314+
if (DevImageImpl.get_bin_image_ref() == nullptr)
316315
continue;
317316
for (const sycl_device_binary_property &ESProp :
318-
DevImageImpl->get_bin_image_ref()->getExportedSymbols()) {
317+
DevImageImpl.get_bin_image_ref()->getExportedSymbols()) {
319318
if (ExportMap.find(ESProp->Name) != ExportMap.end())
320319
throw sycl::exception(make_error_code(errc::invalid),
321320
"Duplicate exported symbol \"" +
@@ -329,12 +328,12 @@ class kernel_bundle_impl
329328
std::vector<std::vector<size_t>> Dependencies;
330329
Dependencies.resize(DevImages.size());
331330
for (size_t I = 0; I < DevImages.size(); ++I) {
332-
auto DevImageImpl = getSyclObjImpl(DevImages[I]);
333-
if (DevImageImpl->get_bin_image_ref() == nullptr)
331+
device_image_impl &DevImageImpl = *getSyclObjImpl(DevImages[I]);
332+
if (DevImageImpl.get_bin_image_ref() == nullptr)
334333
continue;
335334
std::set<size_t> DeviceImageDepsSet;
336335
for (const sycl_device_binary_property &ISProp :
337-
DevImageImpl->get_bin_image_ref()->getImportedSymbols()) {
336+
DevImageImpl.get_bin_image_ref()->getImportedSymbols()) {
338337
auto ExportSymbolIt = ExportMap.find(ISProp->Name);
339338
if (ExportSymbolIt == ExportMap.end())
340339
throw sycl::exception(make_error_code(errc::invalid),
@@ -348,13 +347,12 @@ class kernel_bundle_impl
348347
}
349348

350349
// Create a link graph and clone it for each device.
351-
const std::shared_ptr<device_impl> &FirstDevice =
352-
getSyclObjImpl(MDevices[0]);
350+
device_impl &FirstDevice = *getSyclObjImpl(MDevices[0]);
353351
std::map<std::shared_ptr<device_impl>, LinkGraph<device_image_plain>>
354352
DevImageLinkGraphs;
355353
const auto &FirstGraph =
356354
DevImageLinkGraphs
357-
.emplace(FirstDevice,
355+
.emplace(FirstDevice.shared_from_this(),
358356
LinkGraph<device_image_plain>{DevImages, Dependencies})
359357
.first->second;
360358
for (size_t I = 1; I < MDevices.size(); ++I)
@@ -498,12 +496,12 @@ class kernel_bundle_impl
498496
if (get_bundle_state() == bundle_state::input) {
499497
// Copy spec constants values from the device images.
500498
auto MergeSpecConstants = [this](const device_image_plain &Img) {
501-
const detail::DeviceImageImplPtr &ImgImpl = getSyclObjImpl(Img);
499+
detail::device_image_impl &ImgImpl = *getSyclObjImpl(Img);
502500
const std::map<std::string,
503501
std::vector<device_image_impl::SpecConstDescT>>
504-
&SpecConsts = ImgImpl->get_spec_const_data_ref();
502+
&SpecConsts = ImgImpl.get_spec_const_data_ref();
505503
const std::vector<unsigned char> &Blob =
506-
ImgImpl->get_spec_const_blob_ref();
504+
ImgImpl.get_spec_const_blob_ref();
507505
for (const std::pair<const std::string,
508506
std::vector<device_image_impl::SpecConstDescT>>
509507
&SpecConst : SpecConsts) {
@@ -675,10 +673,9 @@ class kernel_bundle_impl
675673
// resulting kernel object should be able to map devices to their
676674
// respective backend kernel objects.
677675
for (const device_image_plain &DevImg : MUniqueDeviceImages) {
678-
const std::shared_ptr<device_image_impl> &DevImgImpl =
679-
getSyclObjImpl(DevImg);
676+
device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg);
680677
if (std::shared_ptr<kernel_impl> PotentialKernelImpl =
681-
DevImgImpl->tryGetExtensionKernel(Name, MContext, *this))
678+
DevImgImpl.tryGetExtensionKernel(Name, MContext, *this))
682679
return detail::createSyclObjFromImpl<kernel>(
683680
std::move(PotentialKernelImpl));
684681
}
@@ -731,10 +728,10 @@ class kernel_bundle_impl
731728
"'device_image_scope' property");
732729
}
733730

734-
const auto &DeviceImpl = getSyclObjImpl(Dev);
731+
device_impl &DeviceImpl = *getSyclObjImpl(Dev);
735732
bool SupportContextMemcpy = false;
736-
DeviceImpl->getAdapter()->call<UrApiKind::urDeviceGetInfo>(
737-
DeviceImpl->getHandleRef(),
733+
DeviceImpl.getAdapter()->call<UrApiKind::urDeviceGetInfo>(
734+
DeviceImpl.getHandleRef(),
738735
UR_DEVICE_INFO_USM_CONTEXT_MEMCPY_SUPPORT_EXP,
739736
sizeof(SupportContextMemcpy), &SupportContextMemcpy, nullptr);
740737
if (SupportContextMemcpy) {
@@ -764,14 +761,14 @@ class kernel_bundle_impl
764761
// Collect kernel ids from all device images, then remove duplicates
765762
std::vector<kernel_id> Result;
766763
for (const device_image_plain &DeviceImage : MUniqueDeviceImages) {
767-
const auto &DevImgImpl = getSyclObjImpl(DeviceImage);
764+
detail::device_image_impl &DevImgImpl = *getSyclObjImpl(DeviceImage);
768765

769766
// RTC kernel bundles shouldn't have user-facing kernel ids, return an
770767
// empty vector when the bundle contains RTC kernels.
771-
if (DevImgImpl->getRTCInfo())
768+
if (DevImgImpl.getRTCInfo())
772769
continue;
773770

774-
const std::vector<kernel_id> &KernelIDs = DevImgImpl->get_kernel_ids();
771+
const std::vector<kernel_id> &KernelIDs = DevImgImpl.get_kernel_ids();
775772

776773
Result.insert(Result.end(), KernelIDs.begin(), KernelIDs.end());
777774
}
@@ -1016,10 +1013,9 @@ class kernel_bundle_impl
10161013
// {kernel_name, device} and their corresponding image.
10171014
// First look through the kernels registered in source-based images.
10181015
for (const device_image_plain &DevImg : MUniqueDeviceImages) {
1019-
const std::shared_ptr<device_image_impl> &DevImgImpl =
1020-
getSyclObjImpl(DevImg);
1016+
device_image_impl &DevImgImpl = *getSyclObjImpl(DevImg);
10211017
if (std::shared_ptr<kernel_impl> SourceBasedKernel =
1022-
DevImgImpl->tryGetExtensionKernel(Name, MContext, *this))
1018+
DevImgImpl.tryGetExtensionKernel(Name, MContext, *this))
10231019
return SourceBasedKernel;
10241020
}
10251021

0 commit comments

Comments
 (0)