Skip to content

Commit fb58e23

Browse files
authored
[SYCL] Avoid device_impl shared_ptr copy in getMaxWorkGroups() (#17705)
1 parent 88827d4 commit fb58e23

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

sycl/include/sycl/handler.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ class HandlerAccess;
213213
class HostTask;
214214

215215
using EventImplPtr = std::shared_ptr<event_impl>;
216+
using DeviceImplPtr = std::shared_ptr<device_impl>;
216217

217218
template <typename RetType, typename Func, typename Arg>
218219
static Arg member_ptr_helper(RetType (Func::*)(Arg) const);
@@ -249,6 +250,7 @@ template <typename Type> struct get_kernel_wrapper_name_t {
249250
};
250251

251252
__SYCL_EXPORT device getDeviceFromHandler(handler &);
253+
const DeviceImplPtr &getDeviceImplFromHandler(handler &);
252254

253255
// Checks if a device_global has any registered kernel usage.
254256
__SYCL_EXPORT bool isDeviceGlobalUsedInKernel(const void *DeviceGlobalPtr);
@@ -3481,6 +3483,8 @@ class __SYCL_EXPORT handler {
34813483
typename PropertyListT>
34823484
friend class accessor;
34833485
friend device detail::getDeviceFromHandler(handler &);
3486+
friend const detail::DeviceImplPtr &
3487+
detail::getDeviceImplFromHandler(handler &);
34843488

34853489
template <typename DataT, int Dimensions, access::mode AccessMode,
34863490
access::target AccessTarget, access::placeholder IsPlaceholder>

sycl/source/detail/graph_impl.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
919919
/// @return Context associated with graph.
920920
sycl::context getContext() const { return MContext; }
921921

922+
/// Query for the device_impl tied to this graph.
923+
/// @return device_impl shared ptr reference associated with graph.
924+
const DeviceImplPtr &getDeviceImplPtr() const {
925+
return getSyclObjImpl(MDevice);
926+
}
927+
922928
/// Query for the device tied to this graph.
923929
/// @return Device associated with graph.
924930
sycl::device getDevice() const { return MDevice; }

sycl/source/handler.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ inline namespace _V1 {
4343

4444
namespace detail {
4545

46+
const DeviceImplPtr &getDeviceImplFromHandler(handler &CGH) {
47+
assert((CGH.MQueue || getSyclObjImpl(CGH)->MGraph) &&
48+
"One of MQueue or MGraph should be nonnull!");
49+
if (CGH.MQueue)
50+
return CGH.MQueue->getDeviceImplPtr();
51+
52+
return getSyclObjImpl(CGH)->MGraph->getDeviceImplPtr();
53+
}
54+
4655
bool isDeviceGlobalUsedInKernel(const void *DeviceGlobalPtr) {
4756
DeviceGlobalMapEntry *DGEntry =
4857
detail::ProgramManager::getInstance().getDeviceGlobalEntry(
@@ -2059,10 +2068,10 @@ void handler::setUserFacingNodeType(ext::oneapi::experimental::node_type Type) {
20592068
}
20602069

20612070
std::optional<std::array<size_t, 3>> handler::getMaxWorkGroups() {
2062-
auto Dev = detail::getSyclObjImpl(detail::getDeviceFromHandler(*this));
2071+
const auto &DeviceImpl = detail::getDeviceImplFromHandler(*this);
20632072
std::array<size_t, 3> UrResult = {};
2064-
auto Ret = Dev->getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
2065-
Dev->getHandleRef(),
2073+
auto Ret = DeviceImpl->getAdapter()->call_nocheck<UrApiKind::urDeviceGetInfo>(
2074+
DeviceImpl->getHandleRef(),
20662075
UrInfoCode<
20672076
ext::oneapi::experimental::info::device::max_work_groups<3>>::value,
20682077
sizeof(UrResult), &UrResult, nullptr);

0 commit comments

Comments
 (0)