Skip to content

Commit a782779

Browse files
[SYCL] Relax kernel bundle device check to allow descendent devices (#7334)
Change kernel_bundle_impl constructors to treat descendent devices of context members as valid in accordance with SYCL 2020.
1 parent 4cc1094 commit a782779

File tree

5 files changed

+114
-34
lines changed

5 files changed

+114
-34
lines changed

sycl/source/detail/context_impl.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,27 @@ class context_impl {
163163
/// Returns true if and only if context contains the given device.
164164
bool hasDevice(std::shared_ptr<detail::device_impl> Device) const;
165165

166+
/// Returns true if and only if the device can be used within this context.
167+
/// For OpenCL this is currently equivalent to hasDevice, for other backends
168+
/// it returns true if the device is either a member of the context or a
169+
/// descendant of a member.
170+
bool isDeviceValid(DeviceImplPtr Device) {
171+
// OpenCL does not support using descendants of context members within that
172+
// context yet.
173+
// TODO remove once this limitation is lifted
174+
if (!is_host() && getPlugin().getBackend() == backend::opencl)
175+
return hasDevice(Device);
176+
177+
while (!hasDevice(Device)) {
178+
if (Device->isRootDevice())
179+
return false;
180+
Device = detail::getSyclObjImpl(
181+
Device->get_info<info::device::parent_device>());
182+
}
183+
184+
return true;
185+
}
186+
166187
/// Given a PiDevice, returns the matching shared_ptr<device_impl>
167188
/// within this context. May return nullptr if no match discovered.
168189
DeviceImplPtr findMatchingDeviceImpl(RT::PiDevice &DevicePI) const;

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,9 @@ namespace detail {
3030

3131
static bool checkAllDevicesAreInContext(const std::vector<device> &Devices,
3232
const context &Context) {
33-
const std::vector<device> &ContextDevices = Context.get_devices();
3433
return std::all_of(
35-
Devices.begin(), Devices.end(), [&ContextDevices](const device &Dev) {
36-
return ContextDevices.end() !=
37-
std::find(ContextDevices.begin(), ContextDevices.end(), Dev);
34+
Devices.begin(), Devices.end(), [&Context](const device &Dev) {
35+
return getSyclObjImpl(Context)->isDeviceValid(getSyclObjImpl(Dev));
3836
});
3937
}
4038

sycl/source/detail/queue_impl.hpp

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class queue_impl {
6262

6363
ContextImplPtr DefaultContext = detail::getSyclObjImpl(
6464
Device->get_platform().ext_oneapi_get_default_context());
65-
if (isValidDevice(DefaultContext, Device))
65+
if (DefaultContext->isDeviceValid(Device))
6666
return DefaultContext;
6767
return detail::getSyclObjImpl(
6868
context{createSyclObjFromImpl<device>(Device), {}, {}});
@@ -104,7 +104,7 @@ class queue_impl {
104104
"Queue cannot be constructed with both of "
105105
"discard_events and enable_profiling.");
106106
}
107-
if (!isValidDevice(Context, Device)) {
107+
if (!Context->isDeviceValid(Device)) {
108108
if (!Context->is_host() &&
109109
Context->getPlugin().getBackend() == backend::opencl)
110110
throw sycl::invalid_object_error(
@@ -486,27 +486,6 @@ class queue_impl {
486486
}
487487

488488
protected:
489-
/// Helper function for checking whether a device is either a member of a
490-
/// context or a descendnant of its member.
491-
/// \return True iff the device or its parent is a member of the context.
492-
static bool isValidDevice(const ContextImplPtr &Context,
493-
DeviceImplPtr Device) {
494-
// OpenCL does not support creating a queue with a descendant of a device
495-
// from the given context yet.
496-
// TODO remove once this limitation is lifted
497-
if (!Context->is_host() &&
498-
Context->getPlugin().getBackend() == backend::opencl)
499-
return Context->hasDevice(Device);
500-
501-
while (!Context->hasDevice(Device)) {
502-
if (Device->isRootDevice())
503-
return false;
504-
Device = detail::getSyclObjImpl(
505-
Device->get_info<info::device::parent_device>());
506-
}
507-
return true;
508-
}
509-
510489
/// Performs command group submission to the queue.
511490
///
512491
/// \param CGF is a function object containing command group.

sycl/unittests/SYCL2020/KernelBundle.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <detail/device_impl.hpp>
910
#include <detail/kernel_bundle_impl.hpp>
1011
#include <sycl/sycl.hpp>
1112

@@ -459,3 +460,80 @@ TEST(KernelBundle, EmptyDevicesKernelBundleLinkException) {
459460
FAIL() << "Unexpected exception was thrown in sycl::link.";
460461
}
461462
}
463+
464+
pi_device ParentDevice = nullptr;
465+
pi_platform PiPlatform = nullptr;
466+
467+
pi_result redefinedDeviceGetInfoAfter(pi_device device,
468+
pi_device_info param_name,
469+
size_t param_value_size,
470+
void *param_value,
471+
size_t *param_value_size_ret) {
472+
if (param_name == PI_DEVICE_INFO_PARTITION_PROPERTIES) {
473+
if (param_value) {
474+
auto *Result =
475+
reinterpret_cast<pi_device_partition_property *>(param_value);
476+
*Result = PI_DEVICE_PARTITION_EQUALLY;
477+
}
478+
if (param_value_size_ret)
479+
*param_value_size_ret = sizeof(pi_device_partition_property);
480+
} else if (param_name == PI_DEVICE_INFO_MAX_COMPUTE_UNITS) {
481+
auto *Result = reinterpret_cast<pi_uint32 *>(param_value);
482+
*Result = 2;
483+
} else if (param_name == PI_DEVICE_INFO_PARENT_DEVICE) {
484+
auto *Result = reinterpret_cast<pi_device *>(param_value);
485+
*Result = (device == ParentDevice) ? nullptr : ParentDevice;
486+
} else if (param_name == PI_DEVICE_INFO_PLATFORM) {
487+
auto *Result = reinterpret_cast<pi_platform *>(param_value);
488+
*Result = PiPlatform;
489+
}
490+
return PI_SUCCESS;
491+
}
492+
493+
pi_result redefinedDevicePartitionAfter(
494+
pi_device device, const pi_device_partition_property *properties,
495+
pi_uint32 num_devices, pi_device *out_devices, pi_uint32 *out_num_devices) {
496+
if (out_devices) {
497+
for (size_t I = 0; I < num_devices; ++I) {
498+
out_devices[I] = reinterpret_cast<pi_device>(1000 + I);
499+
}
500+
}
501+
if (out_num_devices)
502+
*out_num_devices = num_devices;
503+
return PI_SUCCESS;
504+
}
505+
506+
TEST(KernelBundle, DescendentDevice) {
507+
// Mock a non-OpenCL plugin since use of descendent devices of context members
508+
// is not supported there yet.
509+
sycl::unittest::PiMock Mock(sycl::backend::level_zero);
510+
511+
sycl::platform Plt = Mock.getPlatform();
512+
513+
PiPlatform = sycl::detail::getSyclObjImpl(Plt)->getHandleRef();
514+
515+
Mock.redefineAfter<sycl::detail::PiApiKind::piDeviceGetInfo>(
516+
redefinedDeviceGetInfoAfter);
517+
Mock.redefineAfter<sycl::detail::PiApiKind::piDevicePartition>(
518+
redefinedDevicePartitionAfter);
519+
520+
const sycl::device Dev = Mock.getPlatform().get_devices()[0];
521+
ParentDevice = sycl::detail::getSyclObjImpl(Dev)->getHandleRef();
522+
sycl::context Ctx{Dev};
523+
sycl::device Subdev =
524+
Dev.create_sub_devices<sycl::info::partition_property::partition_equally>(
525+
2)[0];
526+
527+
sycl::queue Queue{Ctx, Subdev};
528+
529+
sycl::kernel_bundle<sycl::bundle_state::executable> KernelBundle =
530+
sycl::get_kernel_bundle<sycl::bundle_state::executable>(Ctx, {Subdev});
531+
532+
sycl::kernel Kernel =
533+
KernelBundle.get_kernel(sycl::get_kernel_id<TestKernel>());
534+
535+
sycl::kernel_bundle<sycl::bundle_state::executable> RetKernelBundle =
536+
Kernel.get_kernel_bundle();
537+
538+
EXPECT_EQ(KernelBundle, RetKernelBundle);
539+
}

sycl/unittests/helpers/PiMock.hpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,12 @@ class PiMock {
192192
/// within the given context. A separate platform instance will be
193193
/// held by the PiMock instance.
194194
///
195-
PiMock() {
195+
/// \param Backend is the backend type to mock, intended for testing backend
196+
/// specific runtime logic.
197+
PiMock(backend Backend = backend::opencl) {
196198
// Create new mock plugin platform and plugin handles
197199
// Note: Mock plugin will be generated if it has not been yet.
198-
MPlatformImpl = GetMockPlatformImpl();
200+
MPlatformImpl = GetMockPlatformImpl(Backend);
199201
std::shared_ptr<detail::plugin> NewPluginPtr;
200202
{
201203
const detail::plugin &OriginalPiPlugin = MPlatformImpl->getPlugin();
@@ -328,7 +330,9 @@ class PiMock {
328330
/// in the global handler. Additionally, all existing plugins will be removed
329331
/// and unloaded to avoid them being accidentally picked up by tests using
330332
/// selectors.
331-
static void EnsureMockPluginInitialized() {
333+
/// \param Backend is the backend type to mock, intended for testing backend
334+
/// specific runtime logic.
335+
static void EnsureMockPluginInitialized(backend Backend = backend::opencl) {
332336
// Only initialize the plugin once.
333337
if (MMockPluginPtr)
334338
return;
@@ -346,8 +350,7 @@ class PiMock {
346350
RT::PiPlugin{"pi.ver.mock", "plugin.ver.mock", /*Targets=*/nullptr,
347351
getProxyMockedFunctionPointers()});
348352

349-
// FIXME: which backend to pass here? does it affect anything?
350-
MMockPluginPtr = std::make_unique<detail::plugin>(RTPlugin, backend::opencl,
353+
MMockPluginPtr = std::make_unique<detail::plugin>(RTPlugin, Backend,
351354
/*Library=*/nullptr);
352355
Plugins.push_back(*MMockPluginPtr);
353356
}
@@ -357,8 +360,9 @@ class PiMock {
357360
/// platform_impl from it.
358361
///
359362
/// \return a shared_ptr to a platform_impl created from the mock PI plugin.
360-
static std::shared_ptr<sycl::detail::platform_impl> GetMockPlatformImpl() {
361-
EnsureMockPluginInitialized();
363+
static std::shared_ptr<sycl::detail::platform_impl>
364+
GetMockPlatformImpl(backend Backend) {
365+
EnsureMockPluginInitialized(Backend);
362366

363367
pi_uint32 NumPlatforms = 0;
364368
MMockPluginPtr->call_nocheck<detail::PiApiKind::piPlatformsGet>(

0 commit comments

Comments
 (0)