Skip to content

[WIP] [Performance Improvement] Fine-granularity locking in stream_ordered_memory_resource #1912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: branch-25.06
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/include/rmm/mr/device/detail/free_list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <rmm/detail/export.hpp>

#include <algorithm>
#include <mutex>
#ifdef RMM_DEBUG_PRINT
#include <iostream>
#endif
Expand Down Expand Up @@ -138,6 +139,12 @@ class free_list {
}
#endif

/**
* @brief Returns a reference to the mutex used for synchronizing the free list.
*
*/
[[nodiscard]] std::mutex& get_mutex() { return mtx_; }

protected:
/**
* @brief Insert a block in the free list before the specified position
Expand Down Expand Up @@ -182,6 +189,7 @@ class free_list {

private:
list_type blocks; // The internal container of blocks
std::mutex mtx_; // The mutex for each free list
};

} // namespace mr::detail
Expand Down
126 changes: 111 additions & 15 deletions cpp/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <cstddef>
#include <map>
#include <mutex>
#include <shared_mutex>
#include <unordered_map>
#ifdef RMM_DEBUG_PRINT
#include <iostream>
Expand Down Expand Up @@ -87,9 +88,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
stream_ordered_memory_resource& operator=(stream_ordered_memory_resource&&) = delete;

protected:
using free_list = FreeListType;
using block_type = typename free_list::block_type;
using lock_guard = std::lock_guard<std::mutex>;
using free_list = FreeListType;
using block_type = typename free_list::block_type;
using lock_guard = std::lock_guard<std::mutex>;
using read_lock_guard = std::shared_lock<std::shared_mutex>;
using write_lock_guard = std::unique_lock<std::shared_mutex>;

// Derived classes must implement these four methods

Expand Down Expand Up @@ -204,12 +207,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
void* do_allocate(std::size_t size, cuda_stream_view stream) override
{
RMM_FUNC_RANGE();
RMM_LOG_TRACE("[A][stream %s][%zuB]", rmm::detail::format_stream(stream), size);

if (size <= 0) { return nullptr; }

lock_guard lock(mtx_);

auto stream_event = get_event(stream);

size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT);
Expand All @@ -224,7 +226,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
size,
block.pointer());

log_summary_trace();
// TODO(jigao): this logging is not protected by mutex!
// log_summary_trace();

return block.pointer();
}
Expand All @@ -238,11 +241,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
void do_deallocate(void* ptr, std::size_t size, cuda_stream_view stream) override
{
RMM_FUNC_RANGE();
RMM_LOG_TRACE("[D][stream %s][%zuB][%p]", rmm::detail::format_stream(stream), size, ptr);

if (size <= 0 || ptr == nullptr) { return; }

lock_guard lock(mtx_);
auto stream_event = get_event(stream);

size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT);
Expand All @@ -253,9 +256,60 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
// streams allows stealing from deleted streams.
RMM_ASSERT_CUDA_SUCCESS(cudaEventRecord(stream_event.event, stream.value()));

stream_free_blocks_[stream_event].insert(block);
read_lock_guard rlock(stream_free_blocks_mtx_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds a lot of complexity to do_deallocate that I think should be in the free list implementation. I originally designed this to be portable to other free list implementations, which is why this function was originally so simple -- it more or less just called insert on the stream's free list.

This allocator is already quite fast, and I think in your exploration ultimately you found that it's not the actual bottleneck? Is it worth adding so much complexity? If the complexity can be put in the free list it might be better.

// Try to find a satisfactory block in free list for the same stream (no sync required)
auto iter = stream_free_blocks_.find(stream_event);
if (iter != stream_free_blocks_.end()) {
// Hot path
lock_guard free_list_lock(iter->second.get_mutex());
iter->second.insert(block);
} else {
rlock.unlock();
// Cold path
write_lock_guard wlock(stream_free_blocks_mtx_);
// Recheck the map since another thread from the same stream
// might have acquired the write lock first and inserted a new free_list into map.
auto iter = stream_free_blocks_.find(stream_event);
free_list& blocks =
(iter != stream_free_blocks_.end()) ? iter->second : stream_free_blocks_[stream_event];
lock_guard free_list_lock(blocks.get_mutex());
blocks.insert(block);
}

{
// Hot Path of do_deallocate:
// 1. Acquire shared read-lock on map for fast lookup
// 2. If entry exists, proceed to hot path
// 3. Acquire exclusive write-lock on free_list for block insertion
read_lock_guard rlock(stream_free_blocks_mtx_);
auto iter = stream_free_blocks_.find(stream_event);
if (iter != stream_free_blocks_.end()) {
lock_guard free_list_lock(iter->second.get_mutex());
iter->second.insert(block);
return;
}
}

log_summary_trace();
{
// Cold Path of do_deallocate:
// 1. Acquire exclusive write-lock on map to:
// - Recheck map state (another thread might have inserted a new free_list)
// - Insert a new free_list into map if still empty
// 2. Acquire exclusive write-lock on the new free_list for block insertion
// (Locking the newly created free_list is redundant as it protected by the map's
// write-lock, but retained for consistency and readability.)
write_lock_guard wlock(stream_free_blocks_mtx_);
auto iter = stream_free_blocks_.find(stream_event);
free_list& blocks =
(iter != stream_free_blocks_.end()) ? iter->second : stream_free_blocks_[stream_event];
lock_guard free_list_lock(blocks.get_mutex());
blocks.insert(block);
return;
}

// TODO(jigao): this logging is not protected by mutex!
// TODO(jigao): do it before return
// log_summary_trace();
}

private:
Expand All @@ -271,7 +325,12 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
stream_event_pair get_event(cuda_stream_view stream)
{
RMM_FUNC_RANGE();
if (stream.is_per_thread_default()) {
// Hot Path (PTDS optimization):
// Leverage thread-local storage for each stream to eliminate contention
// and avoid locking entirely.

// Create a thread-local event for each device. These events are
// deliberately leaked since the destructor needs to call into
// the CUDA runtime and thread_local destructors (can) run below
Expand All @@ -289,6 +348,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
}();
return stream_event_pair{stream.value(), event};
}
write_lock_guard wlock(stream_events_mtx_);
// Cold Path:
// Without PTDS, use pessimistic locking with a broader critical section
// to handle potential future writes to the stream_events_ map.

// We use cudaStreamLegacy as the event map key for the default stream for consistency between
// PTDS and non-PTDS mode. In PTDS mode, the cudaStreamLegacy map key will only exist if the
// user explicitly passes it, so it is used as the default location for the free list
Expand Down Expand Up @@ -319,6 +383,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
block_type allocate_and_insert_remainder(block_type block, std::size_t size, free_list& blocks)
{
RMM_FUNC_RANGE();
auto const [allocated, remainder] = this->underlying().allocate_from_block(block, size);
if (remainder.is_valid()) { blocks.insert(remainder); }
return allocated;
Expand All @@ -333,15 +398,30 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
block_type get_block(std::size_t size, stream_event_pair stream_event)
{
// Try to find a satisfactory block in free list for the same stream (no sync required)
auto iter = stream_free_blocks_.find(stream_event);
if (iter != stream_free_blocks_.end()) {
block_type const block = iter->second.get_block(size);
if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); }
RMM_FUNC_RANGE();
{
// Hot Path of get_block:
// 1. Acquire shared read-lock on map for fast lookup
// 2. Acquire exclusive write-lock on free_list for local block allocation

read_lock_guard rlock(stream_free_blocks_mtx_);
// Try to find a satisfactory block in free list for the same stream (no sync required)
auto iter = stream_free_blocks_.find(stream_event);
if (iter != stream_free_blocks_.end()) {
lock_guard free_list_lock(iter->second.get_mutex());
block_type const block = iter->second.get_block(size);
if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); }
}
}

// Cold Path of get_block:
// Acquire write-lock on map to lookup again and modify map entries if needed
// This exclusive write-lock prevents concurrent access to map and its free_lists
write_lock_guard wlock(stream_free_blocks_mtx_);
auto iter = stream_free_blocks_.find(stream_event);
free_list& blocks =
(iter != stream_free_blocks_.end()) ? iter->second : stream_free_blocks_[stream_event];
lock_guard free_list_lock(blocks.get_mutex());

// Try to find an existing block in another stream
{
Expand Down Expand Up @@ -382,6 +462,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
free_list& blocks,
bool merge_first)
{
RMM_FUNC_RANGE();
auto find_block = [&](auto iter) {
auto other_event = iter->first.event;
auto& other_blocks = iter->second;
Expand Down Expand Up @@ -415,6 +496,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
++next_iter; // Points to element after `iter` to allow erasing `iter` in the loop body

if (iter->first.event != stream_event.event) {
lock_guard free_list_lock(iter->second.get_mutex());
block_type const block = find_block(iter);

if (block.is_valid()) {
Expand All @@ -435,6 +517,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
cudaEvent_t other_event,
free_list&& other_blocks)
{
RMM_FUNC_RANGE();
// Since we found a block associated with a different stream, we have to insert a wait
// on the stream's associated event into the allocating stream.
RMM_CUDA_TRY(cudaStreamWaitEvent(stream_event.stream, other_event, 0));
Expand All @@ -450,7 +533,10 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
void release()
{
lock_guard lock(mtx_);
RMM_FUNC_RANGE();
// lock_guard lock(mtx_); TOOD(jigao): rethink mtx_
write_lock_guard stream_event_lock(stream_events_mtx_);
write_lock_guard wlock(stream_free_blocks_mtx_);

for (auto s_e : stream_events_) {
RMM_ASSERT_CUDA_SUCCESS(cudaEventSynchronize(s_e.second.event));
Expand All @@ -464,6 +550,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
void log_summary_trace()
{
#if (RMM_LOG_ACTIVE_LEVEL <= RMM_LOG_LEVEL_TRACE)
RMM_FUNC_RANGE();
std::size_t num_blocks{0};
std::size_t max_block{0};
std::size_t free_mem{0};
Expand Down Expand Up @@ -491,8 +578,17 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
// bidirectional mapping between non-default streams and events
std::unordered_map<cudaStream_t, stream_event_pair> stream_events_;

// TODO(jigao): think about get_mutex function?
std::mutex mtx_; // mutex for thread-safe access

// mutex for thread-safe access to stream_free_blocks_
// Used in the writing part of get_block, get_block_from_other_stream
std::shared_mutex stream_free_blocks_mtx_;

// mutex for thread-safe access to stream_events_
// Used in the NON-PTDS part of get_event
std::shared_mutex stream_events_mtx_;

rmm::cuda_device_id device_id_{rmm::get_current_cuda_device()};
}; // namespace detail

Expand Down
1 change: 1 addition & 0 deletions cpp/include/rmm/mr/device/pool_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ class pool_memory_resource final
*/
block_type free_block(void* ptr, std::size_t size) noexcept
{
RMM_FUNC_RANGE();
#ifdef RMM_POOL_TRACK_ALLOCATIONS
if (ptr == nullptr) return block_type{};
auto const iter = allocated_blocks_.find(static_cast<char*>(ptr));
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/mr/device/mr_ref_multithreaded_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void spawn_n(std::size_t num_threads, Task task, Arguments&&... args)
template <typename Task, typename... Arguments>
void spawn(Task task, Arguments&&... args)
{
spawn_n(4, task, std::forward<Arguments>(args)...);
spawn_n(16, task, std::forward<Arguments>(args)...);
}

TEST(DefaultTest, UseCurrentDeviceResource_mt) { spawn(test_get_current_device_resource); }
Expand Down