Skip to content

Use reference counting on factories #2048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,17 @@ namespace ur_loader

%endif
%endif
## Before we can re-enable the releases we will need ref-counted object_t.
## See unified-runtime github issue #1784
##%if item['release']:
##// release loader handle
##${item['factory']}.release( ${item['name']} );
## Possibly handle release/retain ref counting - there are no ur_exp-image factories
%if 'factory' in item and '_exp_image_' not in item['factory']:
%if item['release']:
// release loader handle
context->factories.${item['factory']}.release( ${item['name']} );
%endif
%if item['retain']:
// increment refcount of handle
context->factories.${item['factory']}.retain( ${item['name']} );
%endif
%endif
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
try
{
Expand Down
29 changes: 25 additions & 4 deletions source/common/ur_singleton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@
#ifndef UR_SINGLETON_H
#define UR_SINGLETON_H 1

#include <cassert>
#include <memory>
#include <mutex>
#include <unordered_map>

//////////////////////////////////////////////////////////////////////////
/// a abstract factory for creation of singleton objects
template <typename singleton_tn, typename key_tn> class singleton_factory_t {
struct entry_t {
std::unique_ptr<singleton_tn> ptr;
size_t ref_count;
};

protected:
using singleton_t = singleton_tn;
using key_t = typename std::conditional<std::is_pointer<key_tn>::value,
size_t, key_tn>::type;

using ptr_t = std::unique_ptr<singleton_t>;
using map_t = std::unordered_map<key_t, ptr_t>;
using map_t = std::unordered_map<key_t, entry_t>;

std::mutex mut; ///< lock for thread-safety
map_t map; ///< single instance of singleton for each unique key
Expand Down Expand Up @@ -60,16 +66,31 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
if (map.end() == iter) {
auto ptr =
std::make_unique<singleton_t>(std::forward<Ts>(params)...);
iter = map.emplace(key, std::move(ptr)).first;
iter = map.emplace(key, entry_t{std::move(ptr), 0}).first;
} else {
iter->second.ref_count++;
}
return iter->second.get();
return iter->second.ptr.get();
}

void retain(key_tn key) {
std::lock_guard<std::mutex> lk(mut);
auto iter = map.find(getKey(key));
assert(iter != map.end());
iter->second.ref_count++;
}

//////////////////////////////////////////////////////////////////////////
/// once the key is no longer valid, release the singleton
void release(key_tn key) {
std::lock_guard<std::mutex> lk(mut);
map.erase(getKey(key));
auto iter = map.find(getKey(key));
assert(iter != map.end());
if (iter->second.ref_count == 0) {
map.erase(iter);
} else {
iter->second.ref_count--;
}
}

void clear() {
Expand Down
85 changes: 85 additions & 0 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
// forward to device-platform
result = pfnAdapterRelease(hAdapter);

// release loader handle
context->factories.ur_adapter_factory.release(hAdapter);

return result;
}

Expand All @@ -110,6 +113,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(
// forward to device-platform
result = pfnAdapterRetain(hAdapter);

// increment refcount of handle
context->factories.ur_adapter_factory.retain(hAdapter);

return result;
}

Expand Down Expand Up @@ -647,6 +653,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(
// forward to device-platform
result = pfnRetain(hDevice);

// increment refcount of handle
context->factories.ur_device_factory.retain(hDevice);

return result;
}

Expand All @@ -673,6 +682,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
// forward to device-platform
result = pfnRelease(hDevice);

// release loader handle
context->factories.ur_device_factory.release(hDevice);

return result;
}

Expand Down Expand Up @@ -943,6 +955,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(
// forward to device-platform
result = pfnRetain(hContext);

// increment refcount of handle
context->factories.ur_context_factory.retain(hContext);

return result;
}

Expand All @@ -969,6 +984,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
// forward to device-platform
result = pfnRelease(hContext);

// release loader handle
context->factories.ur_context_factory.release(hContext);

return result;
}

Expand Down Expand Up @@ -1271,6 +1289,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(
// forward to device-platform
result = pfnRetain(hMem);

// increment refcount of handle
context->factories.ur_mem_factory.retain(hMem);

return result;
}

Expand All @@ -1297,6 +1318,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
// forward to device-platform
result = pfnRelease(hMem);

// release loader handle
context->factories.ur_mem_factory.release(hMem);

return result;
}

Expand Down Expand Up @@ -1648,6 +1672,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(
// forward to device-platform
result = pfnRetain(hSampler);

// increment refcount of handle
context->factories.ur_sampler_factory.retain(hSampler);

return result;
}

Expand All @@ -1674,6 +1701,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
// forward to device-platform
result = pfnRelease(hSampler);

// release loader handle
context->factories.ur_sampler_factory.release(hSampler);

return result;
}

Expand Down Expand Up @@ -2107,6 +2137,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(
// forward to device-platform
result = pfnPoolRetain(pPool);

// increment refcount of handle
context->factories.ur_usm_pool_factory.retain(pPool);

return result;
}

Expand All @@ -2132,6 +2165,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
// forward to device-platform
result = pfnPoolRelease(pPool);

// release loader handle
context->factories.ur_usm_pool_factory.release(pPool);

return result;
}

Expand Down Expand Up @@ -2517,6 +2553,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(
// forward to device-platform
result = pfnRetain(hPhysicalMem);

// increment refcount of handle
context->factories.ur_physical_mem_factory.retain(hPhysicalMem);

return result;
}

Expand Down Expand Up @@ -2545,6 +2584,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
// forward to device-platform
result = pfnRelease(hPhysicalMem);

// release loader handle
context->factories.ur_physical_mem_factory.release(hPhysicalMem);

return result;
}

Expand Down Expand Up @@ -2876,6 +2918,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(
// forward to device-platform
result = pfnRetain(hProgram);

// increment refcount of handle
context->factories.ur_program_factory.retain(hProgram);

return result;
}

Expand All @@ -2902,6 +2947,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
// forward to device-platform
result = pfnRelease(hProgram);

// release loader handle
context->factories.ur_program_factory.release(hProgram);

return result;
}

Expand Down Expand Up @@ -3499,6 +3547,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
// forward to device-platform
result = pfnRetain(hKernel);

// increment refcount of handle
context->factories.ur_kernel_factory.retain(hKernel);

return result;
}

Expand All @@ -3525,6 +3576,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
// forward to device-platform
result = pfnRelease(hKernel);

// release loader handle
context->factories.ur_kernel_factory.release(hKernel);

return result;
}

Expand Down Expand Up @@ -3975,6 +4029,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(
// forward to device-platform
result = pfnRetain(hQueue);

// increment refcount of handle
context->factories.ur_queue_factory.retain(hQueue);

return result;
}

Expand All @@ -4001,6 +4058,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
// forward to device-platform
result = pfnRelease(hQueue);

// release loader handle
context->factories.ur_queue_factory.release(hQueue);

return result;
}

Expand Down Expand Up @@ -4305,6 +4365,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(
// forward to device-platform
result = pfnRetain(hEvent);

// increment refcount of handle
context->factories.ur_event_factory.retain(hEvent);

return result;
}

Expand All @@ -4330,6 +4393,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
// forward to device-platform
result = pfnRelease(hEvent);

// release loader handle
context->factories.ur_event_factory.release(hEvent);

return result;
}

Expand Down Expand Up @@ -6862,6 +6928,9 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalMemoryExp(
// forward to device-platform
result = pfnReleaseExternalMemoryExp(hContext, hDevice, hExternalMem);

// release loader handle
context->factories.ur_exp_external_mem_factory.release(hExternalMem);

return result;
}

Expand Down Expand Up @@ -6952,6 +7021,10 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalSemaphoreExp(
result =
pfnReleaseExternalSemaphoreExp(hContext, hDevice, hExternalSemaphore);

// release loader handle
context->factories.ur_exp_external_semaphore_factory.release(
hExternalSemaphore);

return result;
}

Expand Down Expand Up @@ -7179,6 +7252,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp(
// forward to device-platform
result = pfnRetainExp(hCommandBuffer);

// increment refcount of handle
context->factories.ur_exp_command_buffer_factory.retain(hCommandBuffer);

return result;
}

Expand Down Expand Up @@ -7209,6 +7285,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp(
// forward to device-platform
result = pfnReleaseExp(hCommandBuffer);

// release loader handle
context->factories.ur_exp_command_buffer_factory.release(hCommandBuffer);

return result;
}

Expand Down Expand Up @@ -8525,6 +8604,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp(
// forward to device-platform
result = pfnRetainCommandExp(hCommand);

// increment refcount of handle
context->factories.ur_exp_command_buffer_command_factory.retain(hCommand);

return result;
}

Expand Down Expand Up @@ -8556,6 +8638,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp(
// forward to device-platform
result = pfnReleaseCommandExp(hCommand);

// release loader handle
context->factories.ur_exp_command_buffer_command_factory.release(hCommand);

return result;
}

Expand Down
Loading