Skip to content

Commit e0633c4

Browse files
committed
fine-granularity locking instead of only one mutex: now each free_list has its own mutex
Signed-off-by: Jigao Luo <jigao.luo@outlook.com>
1 parent 8e19009 commit e0633c4

File tree

4 files changed

+77
-17
lines changed

4 files changed

+77
-17
lines changed

cpp/include/rmm/mr/device/detail/free_list.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <rmm/detail/export.hpp>
2020

2121
#include <algorithm>
22+
#include <mutex>
2223
#ifdef RMM_DEBUG_PRINT
2324
#include <iostream>
2425
#endif
@@ -138,6 +139,12 @@ class free_list {
138139
}
139140
#endif
140141

142+
/**
143+
* @brief Returns a reference to the mutex used for synchronizing the free list.
144+
*
145+
*/
146+
[[nodiscard]] std::mutex& get_mutex() { return mtx_; }
147+
141148
protected:
142149
/**
143150
* @brief Insert a block in the free list before the specified position
@@ -182,6 +189,7 @@ class free_list {
182189

183190
private:
184191
list_type blocks; // The internal container of blocks
192+
std::mutex mtx_; // The mutex for each free list
185193
};
186194

187195
} // namespace mr::detail

cpp/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cstddef>
3030
#include <map>
3131
#include <mutex>
32+
#include <shared_mutex>
3233
#include <unordered_map>
3334
#ifdef RMM_DEBUG_PRINT
3435
#include <iostream>
@@ -87,9 +88,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
8788
stream_ordered_memory_resource& operator=(stream_ordered_memory_resource&&) = delete;
8889

8990
protected:
90-
using free_list = FreeListType;
91-
using block_type = typename free_list::block_type;
92-
using lock_guard = std::lock_guard<std::mutex>;
91+
using free_list = FreeListType;
92+
using block_type = typename free_list::block_type;
93+
using lock_guard = std::lock_guard<std::mutex>;
94+
using read_lock_guard = std::shared_lock<std::shared_mutex>;
95+
using write_lock_guard = std::unique_lock<std::shared_mutex>;
9396

9497
// Derived classes must implement these four methods
9598

@@ -204,12 +207,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
204207
*/
205208
void* do_allocate(std::size_t size, cuda_stream_view stream) override
206209
{
210+
RMM_FUNC_RANGE();
207211
RMM_LOG_TRACE("[A][stream %s][%zuB]", rmm::detail::format_stream(stream), size);
208212

209213
if (size <= 0) { return nullptr; }
210214

211-
lock_guard lock(mtx_);
212-
213215
auto stream_event = get_event(stream);
214216

215217
size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -224,7 +226,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
224226
size,
225227
block.pointer());
226228

227-
log_summary_trace();
229+
// TODO(jigao): this logging is not protected by mutex!
230+
// log_summary_trace();
228231

229232
return block.pointer();
230233
}
@@ -238,11 +241,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
238241
*/
239242
void do_deallocate(void* ptr, std::size_t size, cuda_stream_view stream) override
240243
{
244+
RMM_FUNC_RANGE();
241245
RMM_LOG_TRACE("[D][stream %s][%zuB][%p]", rmm::detail::format_stream(stream), size, ptr);
242246

243247
if (size <= 0 || ptr == nullptr) { return; }
244248

245-
lock_guard lock(mtx_);
246249
auto stream_event = get_event(stream);
247250

248251
size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -253,9 +256,21 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
253256
// streams allows stealing from deleted streams.
254257
RMM_ASSERT_CUDA_SUCCESS(cudaEventRecord(stream_event.event, stream.value()));
255258

256-
stream_free_blocks_[stream_event].insert(block);
257-
258-
log_summary_trace();
259+
read_lock_guard rlock(stream_free_blocks_mtx_);
260+
// Try to find a satisfactory block in free list for the same stream (no sync required)
261+
auto iter = stream_free_blocks_.find(stream_event);
262+
if (iter != stream_free_blocks_.end()) {
263+
// Hot path
264+
lock_guard free_list_lock(iter->second.get_mutex());
265+
iter->second.insert(block);
266+
} else {
267+
rlock.unlock();
268+
// Cold path
269+
write_lock_guard wlock(stream_free_blocks_mtx_);
270+
stream_free_blocks_[stream_event].insert(block); // TODO(jigao): is it thread-safe?
271+
}
272+
// TODO(jigao): this logging is not protected by mutex!
273+
// log_summary_trace();
259274
}
260275

261276
private:
@@ -271,7 +286,9 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
271286
*/
272287
stream_event_pair get_event(cuda_stream_view stream)
273288
{
289+
RMM_FUNC_RANGE();
274290
if (stream.is_per_thread_default()) {
291+
// Hot path
275292
// Create a thread-local event for each device. These events are
276293
// deliberately leaked since the destructor needs to call into
277294
// the CUDA runtime and thread_local destructors (can) run below
@@ -289,6 +306,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
289306
}();
290307
return stream_event_pair{stream.value(), event};
291308
}
309+
write_lock_guard wlock(stream_events_mtx_);
310+
// Cold path
292311
// We use cudaStreamLegacy as the event map key for the default stream for consistency between
293312
// PTDS and non-PTDS mode. In PTDS mode, the cudaStreamLegacy map key will only exist if the
294313
// user explicitly passes it, so it is used as the default location for the free list
@@ -319,6 +338,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
319338
*/
320339
block_type allocate_and_insert_remainder(block_type block, std::size_t size, free_list& blocks)
321340
{
341+
RMM_FUNC_RANGE();
322342
auto const [allocated, remainder] = this->underlying().allocate_from_block(block, size);
323343
if (remainder.is_valid()) { blocks.insert(remainder); }
324344
return allocated;
@@ -333,15 +353,30 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
333353
*/
334354
block_type get_block(std::size_t size, stream_event_pair stream_event)
335355
{
336-
// Try to find a satisfactory block in free list for the same stream (no sync required)
337-
auto iter = stream_free_blocks_.find(stream_event);
338-
if (iter != stream_free_blocks_.end()) {
339-
block_type const block = iter->second.get_block(size);
340-
if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); }
356+
RMM_FUNC_RANGE();
357+
{
358+
// The hot path of get_block:
359+
// 1. Read-lock the map for lookup
360+
// 2. then exclusively lock the free_list to get a block locally.
361+
read_lock_guard rlock(stream_free_blocks_mtx_);
362+
// Try to find a satisfactory block in free list for the same stream (no sync required)
363+
auto iter = stream_free_blocks_.find(stream_event);
364+
if (iter != stream_free_blocks_.end()) {
365+
lock_guard free_list_lock(iter->second.get_mutex());
366+
block_type const block = iter->second.get_block(size);
367+
if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); }
368+
}
341369
}
342370

371+
// The cold path of get_block:
372+
// Write lock the map to safely perform another lookup and possibly modify entries.
373+
// This exclusive lock ensures no other threads can access the map and all free lists in the
374+
// map.
375+
write_lock_guard wlock(stream_free_blocks_mtx_);
376+
auto iter = stream_free_blocks_.find(stream_event);
343377
free_list& blocks =
344378
(iter != stream_free_blocks_.end()) ? iter->second : stream_free_blocks_[stream_event];
379+
lock_guard free_list_lock(blocks.get_mutex());
345380

346381
// Try to find an existing block in another stream
347382
{
@@ -382,6 +417,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
382417
free_list& blocks,
383418
bool merge_first)
384419
{
420+
RMM_FUNC_RANGE();
385421
auto find_block = [&](auto iter) {
386422
auto other_event = iter->first.event;
387423
auto& other_blocks = iter->second;
@@ -415,6 +451,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
415451
++next_iter; // Points to element after `iter` to allow erasing `iter` in the loop body
416452

417453
if (iter->first.event != stream_event.event) {
454+
lock_guard free_list_lock(iter->second.get_mutex());
418455
block_type const block = find_block(iter);
419456

420457
if (block.is_valid()) {
@@ -435,6 +472,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
435472
cudaEvent_t other_event,
436473
free_list&& other_blocks)
437474
{
475+
RMM_FUNC_RANGE();
438476
// Since we found a block associated with a different stream, we have to insert a wait
439477
// on the stream's associated event into the allocating stream.
440478
RMM_CUDA_TRY(cudaStreamWaitEvent(stream_event.stream, other_event, 0));
@@ -450,7 +488,10 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
450488
*/
451489
void release()
452490
{
453-
lock_guard lock(mtx_);
491+
RMM_FUNC_RANGE();
492+
// lock_guard lock(mtx_); TOOD(jigao): rethink mtx_
493+
write_lock_guard stream_event_lock(stream_events_mtx_);
494+
write_lock_guard wlock(stream_free_blocks_mtx_);
454495

455496
for (auto s_e : stream_events_) {
456497
RMM_ASSERT_CUDA_SUCCESS(cudaEventSynchronize(s_e.second.event));
@@ -464,6 +505,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
464505
void log_summary_trace()
465506
{
466507
#if (RMM_LOG_ACTIVE_LEVEL <= RMM_LOG_LEVEL_TRACE)
508+
RMM_FUNC_RANGE();
467509
std::size_t num_blocks{0};
468510
std::size_t max_block{0};
469511
std::size_t free_mem{0};
@@ -491,8 +533,17 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
491533
// bidirectional mapping between non-default streams and events
492534
std::unordered_map<cudaStream_t, stream_event_pair> stream_events_;
493535

536+
// TODO(jigao): think about get_mutex function?
494537
std::mutex mtx_; // mutex for thread-safe access
495538

539+
// mutex for thread-safe access to stream_free_blocks_
540+
// Used in the writing part of get_block, get_block_from_other_stream
541+
std::shared_mutex stream_free_blocks_mtx_;
542+
543+
// mutex for thread-safe access to stream_events_
544+
// Used in the NON-PTDS part of get_event
545+
std::shared_mutex stream_events_mtx_;
546+
496547
rmm::cuda_device_id device_id_{rmm::get_current_cuda_device()};
497548
}; // namespace detail
498549

cpp/include/rmm/mr/device/pool_memory_resource.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ class pool_memory_resource final
400400
*/
401401
block_type free_block(void* ptr, std::size_t size) noexcept
402402
{
403+
RMM_FUNC_RANGE();
403404
#ifdef RMM_POOL_TRACK_ALLOCATIONS
404405
if (ptr == nullptr) return block_type{};
405406
auto const iter = allocated_blocks_.find(static_cast<char*>(ptr));

cpp/tests/mr/device/mr_ref_multithreaded_tests.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ void spawn_n(std::size_t num_threads, Task task, Arguments&&... args)
5757
template <typename Task, typename... Arguments>
5858
void spawn(Task task, Arguments&&... args)
5959
{
60-
spawn_n(4, task, std::forward<Arguments>(args)...);
60+
spawn_n(16, task, std::forward<Arguments>(args)...);
6161
}
6262

6363
TEST(DefaultTest, UseCurrentDeviceResource_mt) { spawn(test_get_current_device_resource); }

0 commit comments

Comments
 (0)