Skip to content

Commit 512f347

Browse files
[NFC][SYCL] Prepare memory_pool/async_alloc for getSyclObjImpl to return raw ref (#19249)
I'm planning to change `getSyclObjImpl` to return a raw reference in a later patch, uploading a bunch of PRs in preparation to that to make the subsequent review easier.
1 parent 6d97d98 commit 512f347

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

sycl/source/detail/async_alloc.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ std::vector<std::shared_ptr<detail::node_impl>> getDepGraphNodes(
3737
sycl::handler &Handler, detail::queue_impl *Queue,
3838
const std::shared_ptr<detail::graph_impl> &Graph,
3939
const std::vector<std::shared_ptr<detail::event_impl>> &DepEvents) {
40-
auto HandlerImpl = detail::getSyclObjImpl(Handler);
40+
detail::handler_impl &HandlerImpl = *detail::getSyclObjImpl(Handler);
4141
// Get dependent graph nodes from any events
4242
auto DepNodes = Graph->getNodesForEvents(DepEvents);
4343
// If this node was added explicitly we may have node deps in the handler as
4444
// well, so add them to the list
45-
DepNodes.insert(DepNodes.end(), HandlerImpl->MNodeDeps.begin(),
46-
HandlerImpl->MNodeDeps.end());
45+
DepNodes.insert(DepNodes.end(), HandlerImpl.MNodeDeps.begin(),
46+
HandlerImpl.MNodeDeps.end());
4747
// If this is being recorded from an in-order queue we need to get the last
4848
// in-order node if any, since this will later become a dependency of the
4949
// node being processed here.
@@ -119,7 +119,7 @@ __SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
119119
const memory_pool &pool) {
120120

121121
auto &Adapter = h.getContextImpl().getAdapter();
122-
auto &memPoolImpl = sycl::detail::getSyclObjImpl(pool);
122+
detail::memory_pool_impl &memPoolImpl = *detail::getSyclObjImpl(pool);
123123

124124
// Get CG event dependencies for this allocation.
125125
const auto &DepEvents = h.impl->CGData.MEvents;
@@ -135,12 +135,12 @@ __SYCL_EXPORT void *async_malloc_from_pool(sycl::handler &h, size_t size,
135135

136136
// Memory pool is passed as the graph may use some properties of it.
137137
alloc = Graph->getMemPool().malloc(size, pool.get_alloc_kind(), DepNodes,
138-
sycl::detail::getSyclObjImpl(pool));
138+
detail::getSyclObjImpl(pool).get());
139139
} else {
140140
ur_queue_handle_t Q = h.impl->get_queue().getHandleRef();
141141
Adapter->call<sycl::errc::runtime,
142142
sycl::detail::UrApiKind::urEnqueueUSMDeviceAllocExp>(
143-
Q, memPoolImpl.get()->get_handle(), size, nullptr, UREvents.size(),
143+
Q, memPoolImpl.get_handle(), size, nullptr, UREvents.size(),
144144
UREvents.data(), &alloc, &Event);
145145
}
146146
// Async malloc must return a void* immediately.

sycl/source/detail/graph/memory_pool.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace detail {
2222
void *
2323
graph_mem_pool::malloc(size_t Size, usm::alloc AllocType,
2424
const std::vector<std::shared_ptr<node_impl>> &DepNodes,
25-
const std::shared_ptr<memory_pool_impl> &MemPool) {
25+
memory_pool_impl *MemPool) {
2626
// We are potentially modifying contents of this memory pool and the owning
2727
// graph, so take a lock here.
2828
graph_impl::WriteLock Lock(MGraph.MMutex);
@@ -41,8 +41,8 @@ graph_mem_pool::malloc(size_t Size, usm::alloc AllocType,
4141
switch (AllocType) {
4242
case usm::alloc::device: {
4343

44-
auto &CtxImpl = sycl::detail::getSyclObjImpl(MContext);
45-
auto &Adapter = CtxImpl->getAdapter();
44+
context_impl &CtxImpl = *getSyclObjImpl(MContext);
45+
auto &Adapter = CtxImpl.getAdapter();
4646

4747
size_t Granularity = get_mem_granularity(MDevice, MContext);
4848
uintptr_t StartPtr = 0;
@@ -60,8 +60,8 @@ graph_mem_pool::malloc(size_t Size, usm::alloc AllocType,
6060
// If no allocation could be reused, do a new virtual reservation
6161
Adapter->call<sycl::errc::runtime,
6262
sycl::detail::UrApiKind::urVirtualMemReserve>(
63-
CtxImpl->getHandleRef(), reinterpret_cast<void *>(StartPtr),
64-
AlignedSize, &Alloc);
63+
CtxImpl.getHandleRef(), reinterpret_cast<void *>(StartPtr), AlignedSize,
64+
&Alloc);
6565

6666
AllocInfo.Size = AlignedSize;
6767
AllocInfo.Ptr = Alloc;

sycl/source/detail/graph/memory_pool.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class graph_mem_pool {
8484
/// @return A pointer to the start of the allocation
8585
void *malloc(size_t Size, usm::alloc AllocType,
8686
const std::vector<std::shared_ptr<node_impl>> &DepNodes,
87-
const std::shared_ptr<memory_pool_impl> &MemPool = nullptr);
87+
memory_pool_impl *MemPool = nullptr);
8888

8989
/// Return the total amount of memory being used by this pool
9090
size_t getMemUseCurrent() const {

0 commit comments

Comments
 (0)