Skip to content

[SYCL][NFC] Remove AdapterPtr from SYCL RT #19315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: sycl
Choose a base branch
from
7 changes: 3 additions & 4 deletions sycl/source/detail/allowlist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,7 @@ bool deviceIsAllowed(const DeviceDescT &DeviceDesc,
}

void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
ur_platform_handle_t UrPlatform,
const AdapterPtr &Adapter) {
ur_platform_handle_t UrPlatform, adapter_impl &Adapter) {

AllowListParsedT AllowListParsed =
parseAllowList(SYCLConfig<SYCL_DEVICE_ALLOWLIST>::get());
Expand All @@ -375,7 +374,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
// Get platform's backend and put it to DeviceDesc
DeviceDescT DeviceDesc;
platform_impl &PlatformImpl =
platform_impl::getOrMakePlatformImpl(UrPlatform, *Adapter);
platform_impl::getOrMakePlatformImpl(UrPlatform, Adapter);
backend Backend = PlatformImpl.getBackend();

for (const auto &SyclBe : getSyclBeMap()) {
Expand All @@ -396,7 +395,7 @@ void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
device_impl &DeviceImpl = PlatformImpl.getOrMakeDeviceImpl(Device);
// get DeviceType value and put it to DeviceDesc
ur_device_type_t UrDevType = UR_DEVICE_TYPE_ALL;
Adapter->call<UrApiKind::urDeviceGetInfo>(
Adapter.call<UrApiKind::urDeviceGetInfo>(
Device, UR_DEVICE_INFO_TYPE, sizeof(UrDevType), &UrDevType, nullptr);
// TODO need mechanism to do these casts, there's a bunch of this sort of
// thing
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/allowlist.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ bool deviceIsAllowed(const DeviceDescT &DeviceDesc,
const AllowListParsedT &AllowListParsed);

void applyAllowList(std::vector<ur_device_handle_t> &UrDevices,
ur_platform_handle_t UrPlatform, const AdapterPtr &Adapter);
ur_platform_handle_t UrPlatform, adapter_impl &Adapter);

} // namespace detail
} // namespace _V1
Expand Down
12 changes: 6 additions & 6 deletions sycl/source/detail/async_alloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
sycl::make_error_code(sycl::errc::feature_not_supported),
"Only device backed asynchronous allocations are supported!");

auto &Adapter = h.getContextImpl().getAdapter();
detail::adapter_impl &Adapter = h.getContextImpl().getAdapter();

// Get CG event dependencies for this allocation.
const auto &DepEvents = h.impl->CGData.MEvents;
Expand All @@ -84,8 +84,8 @@ void *async_malloc(sycl::handler &h, sycl::usm::alloc kind, size_t size) {
alloc = Graph->getMemPool().malloc(size, kind, DepNodes);
} else {
ur_queue_handle_t Q = h.impl->get_queue().getHandleRef();
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
Adapter.call<sycl::errc::runtime,
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
Q, (ur_usm_pool_handle_t)0, size, nullptr, UREvents.size(),
UREvents.data(), &alloc, &Event);
}
Expand Down Expand Up @@ -118,7 +118,7 @@ __SYCL_EXPORT void *async_malloc(const sycl::queue &q, sycl::usm::alloc kind,
__SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
const memory_pool &pool) {

auto &Adapter = h.getContextImpl().getAdapter();
detail::adapter_impl &Adapter = h.getContextImpl().getAdapter();
detail::memory_pool_impl &memPoolImpl = *detail::getSyclObjImpl(pool);

// Get CG event dependencies for this allocation.
Expand All @@ -138,8 +138,8 @@ __SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
detail::getSyclObjImpl(pool).get());
} else {
ur_queue_handle_t Q = h.impl->get_queue().getHandleRef();
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
Adapter.call<sycl::errc::runtime,
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
Q, memPoolImpl.get_handle(), size, nullptr, UREvents.size(),
UREvents.data(), &alloc, &Event);
}
Expand Down
15 changes: 7 additions & 8 deletions sycl/source/detail/buffer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ void buffer_impl::addInteropObject(
if (std::find(Handles.begin(), Handles.end(),
ur::cast<ur_native_handle_t>(MInteropMemObject)) ==
Handles.end()) {
const AdapterPtr &Adapter = getAdapter();
Adapter->call<UrApiKind::urMemRetain>(
adapter_impl &Adapter = getAdapter();
Adapter.call<UrApiKind::urMemRetain>(
ur::cast<ur_mem_handle_t>(MInteropMemObject));
ur_native_handle_t NativeHandle = 0;
Adapter->call<UrApiKind::urMemGetNativeHandle>(MInteropMemObject, nullptr,
&NativeHandle);
Adapter.call<UrApiKind::urMemGetNativeHandle>(MInteropMemObject, nullptr,
&NativeHandle);
Handles.push_back(NativeHandle);
}
}
Expand All @@ -83,14 +83,13 @@ buffer_impl::getNativeVector(backend BackendName) const {
if (Platform.getBackend() != BackendName)
continue;

auto Adapter = Platform.getAdapter();

adapter_impl &Adapter = Platform.getAdapter();
ur_native_handle_t Handle = 0;
// When doing buffer interop we don't know what device the memory should be
// resident on, so pass nullptr for Device param. Buffer interop may not be
// supported by all backends.
Adapter->call<UrApiKind::urMemGetNativeHandle>(NativeMem, /*Dev*/ nullptr,
&Handle);
Adapter.call<UrApiKind::urMemGetNativeHandle>(NativeMem, /*Dev*/ nullptr,
&Handle);
Handles.push_back(Handle);

if (Platform.getBackend() == backend::opencl) {
Expand Down
39 changes: 19 additions & 20 deletions sycl/source/detail/context_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ context_impl::context_impl(const std::vector<sycl::device> Devices,
DeviceIds.push_back(getSyclObjImpl(D)->getHandleRef());
}

getAdapter()->call<UrApiKind::urContextCreate>(
getAdapter().call<UrApiKind::urContextCreate>(
DeviceIds.size(), DeviceIds.data(), nullptr, &MContext);

MKernelProgramCache.setContextPtr(this);
Expand Down Expand Up @@ -102,17 +102,17 @@ context_impl::context_impl(ur_context_handle_t UrContext,
// TODO: Move this backend-specific retain of the context to SYCL-2020 style
// make_context<backend::opencl> interop, when that is created.
if (getBackend() == sycl::backend::opencl) {
getAdapter()->call<UrApiKind::urContextRetain>(MContext);
getAdapter().call<UrApiKind::urContextRetain>(MContext);
}
MKernelProgramCache.setContextPtr(this);
}

cl_context context_impl::get() const {
// TODO catch an exception and put it to list of asynchronous exceptions
getAdapter()->call<UrApiKind::urContextRetain>(MContext);
getAdapter().call<UrApiKind::urContextRetain>(MContext);
ur_native_handle_t nativeHandle = 0;
getAdapter()->call<UrApiKind::urContextGetNativeHandle>(MContext,
&nativeHandle);
getAdapter().call<UrApiKind::urContextGetNativeHandle>(MContext,
&nativeHandle);
return ur::cast<cl_context>(nativeHandle);
}

Expand All @@ -130,10 +130,10 @@ context_impl::~context_impl() {
}
for (auto LibProg : MCachedLibPrograms) {
assert(LibProg.second && "Null program must not be kept in the cache");
getAdapter()->call<UrApiKind::urProgramRelease>(LibProg.second);
getAdapter().call<UrApiKind::urProgramRelease>(LibProg.second);
}
// TODO catch an exception and put it to list of asynchronous exceptions
getAdapter()->call_nocheck<UrApiKind::urContextRelease>(MContext);
getAdapter().call_nocheck<UrApiKind::urContextRelease>(MContext);
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~context_impl", e);
}
Expand Down Expand Up @@ -292,9 +292,9 @@ context_impl::findMatchingDeviceImpl(ur_device_handle_t &DeviceUR) const {
}

ur_native_handle_t context_impl::getNative() const {
const auto &Adapter = getAdapter();
detail::adapter_impl &Adapter = getAdapter();
ur_native_handle_t Handle;
Adapter->call<UrApiKind::urContextGetNativeHandle>(getHandleRef(), &Handle);
Adapter.call<UrApiKind::urContextGetNativeHandle>(getHandleRef(), &Handle);
if (getBackend() == backend::opencl) {
__SYCL_OCL_CALL(clRetainContext, ur::cast<cl_context>(Handle));
}
Expand Down Expand Up @@ -345,7 +345,7 @@ std::vector<ur_event_handle_t> context_impl::initializeDeviceGlobals(
if (!MDeviceGlobalNotInitializedCnt.load(std::memory_order_acquire))
return {};

const AdapterPtr &Adapter = getAdapter();
detail::adapter_impl &Adapter = getAdapter();
device_impl &DeviceImpl = QueueImpl.getDeviceImpl();
std::lock_guard<std::mutex> NativeProgramLock(MDeviceGlobalInitializersMutex);
auto ImgIt = MDeviceGlobalInitializers.find(
Expand All @@ -365,11 +365,11 @@ std::vector<ur_event_handle_t> context_impl::initializeDeviceGlobals(
InitEventsRef.begin(), InitEventsRef.end(),
[&Adapter](const ur_event_handle_t &Event) {
return get_event_info<info::event::command_execution_status>(
Event, *Adapter) == info::event_command_status::complete;
Event, Adapter) == info::event_command_status::complete;
});
// Release the removed events.
for (auto EventIt = NewEnd; EventIt != InitEventsRef.end(); ++EventIt)
Adapter->call<UrApiKind::urEventRelease>(*EventIt);
Adapter.call<UrApiKind::urEventRelease>(*EventIt);
// Remove them from the collection.
InitEventsRef.erase(NewEnd, InitEventsRef.end());
// If there are no more events, we can mark it as fully initialized.
Expand Down Expand Up @@ -431,14 +431,14 @@ std::vector<ur_event_handle_t> context_impl::initializeDeviceGlobals(
// are cleaned up separately from cleaning up the device global USM memory
// this must retain the event.
{
if (OwnedUrEvent ZIEvent = DeviceGlobalUSM.getInitEvent(*Adapter))
if (OwnedUrEvent ZIEvent = DeviceGlobalUSM.getInitEvent(Adapter))
InitEventsRef.push_back(ZIEvent.TransferOwnership());
}
// Write the pointer to the device global and store the event in the
// initialize events list.
ur_event_handle_t InitEvent;
void *const &USMPtr = DeviceGlobalUSM.getPtr();
Adapter->call<UrApiKind::urEnqueueDeviceGlobalVariableWrite>(
Adapter.call<UrApiKind::urEnqueueDeviceGlobalVariableWrite>(
QueueImpl.getHandleRef(), NativePrg,
DeviceGlobalEntry->MUniqueId.c_str(), false, sizeof(void *), 0,
&USMPtr, 0, nullptr, &InitEvent);
Expand All @@ -449,10 +449,9 @@ std::vector<ur_event_handle_t> context_impl::initializeDeviceGlobals(
}
}

void context_impl::DeviceGlobalInitializer::ClearEvents(
const AdapterPtr &Adapter) {
void context_impl::DeviceGlobalInitializer::ClearEvents(adapter_impl &Adapter) {
for (const ur_event_handle_t &Event : MDeviceGlobalInitEvents)
Adapter->call<UrApiKind::urEventRelease>(Event);
Adapter.call<UrApiKind::urEventRelease>(Event);
MDeviceGlobalInitEvents.clear();
}

Expand Down Expand Up @@ -577,7 +576,7 @@ context_impl::get_default_memory_pool(const context &Context,

detail::device_impl &DevImpl = *detail::getSyclObjImpl(Device);
ur_device_handle_t DeviceHandle = DevImpl.getHandleRef();
const sycl::detail::AdapterPtr &Adapter = this->getAdapter();
detail::adapter_impl &Adapter = this->getAdapter();

// Check dev is already in our list of device pool pairs.
if (auto it = std::find_if(MMemPoolImplPtrs.begin(), MMemPoolImplPtrs.end(),
Expand All @@ -590,8 +589,8 @@ context_impl::get_default_memory_pool(const context &Context,

// The memory_pool_impl does not exist for this device yet.
ur_usm_pool_handle_t PoolHandle;
Adapter->call<sycl::errc::runtime,
sycl::detail::UrApiKind::urUSMPoolGetDefaultDevicePoolExp>(
Adapter.call<sycl::errc::runtime,
sycl::detail::UrApiKind::urUSMPoolGetDefaultDevicePoolExp>(
this->getHandleRef(), DeviceHandle, &PoolHandle);

auto MemPoolImplPtr = std::make_shared<
Expand Down
8 changes: 4 additions & 4 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
const async_handler &get_async_handler() const;

/// \return the Adapter associated with the platform of this context.
const AdapterPtr &getAdapter() const { return MPlatform->getAdapter(); }
adapter_impl &getAdapter() const { return MPlatform->getAdapter(); }

/// \return the PlatformImpl associated with this context.
platform_impl &getPlatformImpl() const { return *MPlatform; }
Expand Down Expand Up @@ -295,7 +295,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
}

/// Clears all events of the initializer. This will not acquire the lock.
void ClearEvents(const AdapterPtr &Adapter);
void ClearEvents(adapter_impl &Adapter);

/// The binary image of the program.
const RTDeviceBinaryImage *MBinImage = nullptr;
Expand Down Expand Up @@ -367,7 +367,7 @@ void GetCapabilitiesIntersectionSet(const std::vector<sycl::device> &Devices,
// convenient to be able to reference them without extra `detail::`.
inline auto get_ur_handles(sycl::detail::context_impl &Ctx) {
ur_context_handle_t urCtx = Ctx.getHandleRef();
return std::tuple{urCtx, Ctx.getAdapter()};
return std::tuple{urCtx, &Ctx.getAdapter()};
}
inline auto get_ur_handles(const sycl::context &syclContext) {
return get_ur_handles(*sycl::detail::getSyclObjImpl(syclContext));
Expand All @@ -382,7 +382,7 @@ inline auto get_ur_handles(const sycl::device &syclDevice,
inline auto get_ur_handles(const sycl::device &syclDevice) {
auto &implDevice = *sycl::detail::getSyclObjImpl(syclDevice);
ur_device_handle_t urDevice = implDevice.getHandleRef();
return std::tuple{urDevice, implDevice.getAdapter()};
return std::tuple{urDevice, &implDevice.getAdapter()};
}

} // namespace _V1
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/context_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ namespace detail {

template <typename Param>
typename Param::return_type get_context_info(ur_context_handle_t Ctx,
const AdapterPtr &Adapter) {
adapter_impl &Adapter) {
static_assert(is_context_info_desc<Param>::value,
"Invalid context information descriptor");
typename Param::return_type Result = 0;
// TODO catch an exception and put it to list of asynchronous exceptions
Adapter->call<UrApiKind::urContextGetInfo>(Ctx, UrInfoCode<Param>::value,
sizeof(Result), &Result, nullptr);
Adapter.call<UrApiKind::urContextGetInfo>(Ctx, UrInfoCode<Param>::value,
sizeof(Result), &Result, nullptr);
return Result;
}

Expand Down
5 changes: 2 additions & 3 deletions sycl/source/detail/device_global_map_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void DeviceGlobalMapEntry::removeAssociatedResources(
DeviceGlobalUSMMem &USMMem = USMPtrIt->second;
detail::usm::freeInternal(USMMem.MPtr, CtxImpl);
if (USMMem.MInitEvent.has_value())
CtxImpl->getAdapter()->call<UrApiKind::urEventRelease>(
CtxImpl->getAdapter().call<UrApiKind::urEventRelease>(
*USMMem.MInitEvent);
#ifndef NDEBUG
// For debugging we set the event and memory to some recognizable values
Expand All @@ -185,8 +185,7 @@ void DeviceGlobalMapEntry::cleanup() {
DeviceGlobalUSMMem &USMMem = USMPtrIt.second;
detail::usm::freeInternal(USMMem.MPtr, CtxImpl);
if (USMMem.MInitEvent.has_value())
CtxImpl->getAdapter()->call<UrApiKind::urEventRelease>(
*USMMem.MInitEvent);
CtxImpl->getAdapter().call<UrApiKind::urEventRelease>(*USMMem.MInitEvent);
#ifndef NDEBUG
// For debugging we set the event and memory to some recognizable values
// to allow us to check that this cleanup happens before erasure.
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/device_image_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ std::shared_ptr<kernel_impl> device_image_impl::tryGetExtensionKernel(
}

ur_program_handle_t UrProgram = get_ur_program_ref();
const AdapterPtr &Adapter = getSyclObjImpl(Context)->getAdapter();
detail::adapter_impl &Adapter = getSyclObjImpl(Context)->getAdapter();
ur_kernel_handle_t UrKernel = nullptr;
Adapter->call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.c_str(),
&UrKernel);
Adapter.call<UrApiKind::urKernelCreate>(UrProgram, AdjustedName.c_str(),
&UrKernel);
// Kernel created by urKernelCreate is implicitly retained.

const KernelArgMask *ArgMask = nullptr;
Expand Down
Loading