Skip to content

Commit d056d6f

Browse files
committed
fine-granularity locking instead of only one mutex: now each free_list has its own mutex
Signed-off-by: Jigao Luo <jigao.luo@outlook.com>
1 parent 8e19009 commit d056d6f

File tree

4 files changed

+73
-14
lines changed

4 files changed

+73
-14
lines changed

cpp/include/rmm/mr/device/detail/free_list.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <rmm/detail/export.hpp>
2020

2121
#include <algorithm>
22+
#include <mutex>
2223
#ifdef RMM_DEBUG_PRINT
2324
#include <iostream>
2425
#endif
@@ -138,6 +139,12 @@ class free_list {
138139
}
139140
#endif
140141

142+
/**
143+
* @brief Returns a reference to the mutex used for synchronizing the free list.
144+
*
145+
*/
146+
[[nodiscard]] std::mutex& get_mutex() { return mtx_; }
147+
141148
protected:
142149
/**
143150
* @brief Insert a block in the free list before the specified position
@@ -182,6 +189,7 @@ class free_list {
182189

183190
private:
184191
list_type blocks; // The internal container of blocks
192+
std::mutex mtx_; // The mutex for each free list
185193
};
186194

187195
} // namespace mr::detail

cpp/include/rmm/mr/device/detail/stream_ordered_memory_resource.hpp

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include <cstddef>
3030
#include <map>
3131
#include <mutex>
32+
#include <shared_mutex>
3233
#include <unordered_map>
3334
#ifdef RMM_DEBUG_PRINT
3435
#include <iostream>
@@ -90,6 +91,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
9091
using free_list = FreeListType;
9192
using block_type = typename free_list::block_type;
9293
using lock_guard = std::lock_guard<std::mutex>;
94+
using read_lock_guard = std::shared_lock<std::shared_mutex>;
95+
using write_lock_guard = std::unique_lock<std::shared_mutex>;
9396

9497
// Derived classes must implement these four methods
9598

@@ -204,12 +207,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
204207
*/
205208
void* do_allocate(std::size_t size, cuda_stream_view stream) override
206209
{
210+
RMM_FUNC_RANGE();
207211
RMM_LOG_TRACE("[A][stream %s][%zuB]", rmm::detail::format_stream(stream), size);
208212

209213
if (size <= 0) { return nullptr; }
210214

211-
lock_guard lock(mtx_);
212-
213215
auto stream_event = get_event(stream);
214216

215217
size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -224,7 +226,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
224226
size,
225227
block.pointer());
226228

227-
log_summary_trace();
229+
// TODO(jigao): this logging is not protected by mutex!
230+
// log_summary_trace();
228231

229232
return block.pointer();
230233
}
@@ -238,11 +241,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
238241
*/
239242
void do_deallocate(void* ptr, std::size_t size, cuda_stream_view stream) override
240243
{
244+
RMM_FUNC_RANGE();
241245
RMM_LOG_TRACE("[D][stream %s][%zuB][%p]", rmm::detail::format_stream(stream), size, ptr);
242246

243247
if (size <= 0 || ptr == nullptr) { return; }
244248

245-
lock_guard lock(mtx_);
246249
auto stream_event = get_event(stream);
247250

248251
size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -253,9 +256,21 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
253256
// streams allows stealing from deleted streams.
254257
RMM_ASSERT_CUDA_SUCCESS(cudaEventRecord(stream_event.event, stream.value()));
255258

256-
stream_free_blocks_[stream_event].insert(block);
257-
258-
log_summary_trace();
259+
read_lock_guard rlock(stream_free_blocks_mtx_);
260+
// Try to find a satisfactory block in free list for the same stream (no sync required)
261+
auto iter = stream_free_blocks_.find(stream_event);
262+
if (iter != stream_free_blocks_.end()) {
263+
// Hot path
264+
lock_guard free_list_lock(iter->second.get_mutex());
265+
iter->second.insert(block);
266+
} else {
267+
rlock.unlock();
268+
// Cold path
269+
write_lock_guard wlock(stream_free_blocks_mtx_);
270+
stream_free_blocks_[stream_event].insert(block); // TODO(jigao): is it thread-safe?
271+
}
272+
// TODO(jigao): this logging is not protected by mutex!
273+
// log_summary_trace();
259274
}
260275

261276
private:
@@ -271,7 +286,9 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
271286
*/
272287
stream_event_pair get_event(cuda_stream_view stream)
273288
{
289+
RMM_FUNC_RANGE();
274290
if (stream.is_per_thread_default()) {
291+
// Hot path
275292
// Create a thread-local event for each device. These events are
276293
// deliberately leaked since the destructor needs to call into
277294
// the CUDA runtime and thread_local destructors (can) run below
@@ -289,6 +306,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
289306
}();
290307
return stream_event_pair{stream.value(), event};
291308
}
309+
write_lock_guard wlock(stream_events_mtx_);
310+
// Cold path
292311
// We use cudaStreamLegacy as the event map key for the default stream for consistency between
293312
// PTDS and non-PTDS mode. In PTDS mode, the cudaStreamLegacy map key will only exist if the
294313
// user explicitly passes it, so it is used as the default location for the free list
@@ -319,6 +338,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
319338
*/
320339
block_type allocate_and_insert_remainder(block_type block, std::size_t size, free_list& blocks)
321340
{
341+
RMM_FUNC_RANGE();
322342
auto const [allocated, remainder] = this->underlying().allocate_from_block(block, size);
323343
if (remainder.is_valid()) { blocks.insert(remainder); }
324344
return allocated;
@@ -333,15 +353,29 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
333353
*/
334354
block_type get_block(std::size_t size, stream_event_pair stream_event)
335355
{
336-
// Try to find a satisfactory block in free list for the same stream (no sync required)
337-
auto iter = stream_free_blocks_.find(stream_event);
338-
if (iter != stream_free_blocks_.end()) {
339-
block_type const block = iter->second.get_block(size);
340-
if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); }
356+
RMM_FUNC_RANGE();
357+
{
358+
// The hot path of get_block:
359+
// 1. Read-lock the map for lookup
360+
// 2. then exclusively lock the free_list to get a block locally.
361+
read_lock_guard rlock(stream_free_blocks_mtx_);
362+
// Try to find a satisfactory block in free list for the same stream (no sync required)
363+
auto iter = stream_free_blocks_.find(stream_event);
364+
if (iter != stream_free_blocks_.end()) {
365+
lock_guard free_list_lock(iter->second.get_mutex());
366+
block_type const block = iter->second.get_block(size);
367+
if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); }
368+
}
341369
}
342370

371+
// The cold path of get_block:
372+
// Write lock the map to safely perform another lookup and possibly modify entries.
373+
// This exclusive lock ensures no other threads can access the map and all free lists in the map.
374+
write_lock_guard wlock(stream_free_blocks_mtx_);
375+
auto iter = stream_free_blocks_.find(stream_event);
343376
free_list& blocks =
344377
(iter != stream_free_blocks_.end()) ? iter->second : stream_free_blocks_[stream_event];
378+
lock_guard free_list_lock(blocks.get_mutex());
345379

346380
// Try to find an existing block in another stream
347381
{
@@ -382,6 +416,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
382416
free_list& blocks,
383417
bool merge_first)
384418
{
419+
RMM_FUNC_RANGE();
385420
auto find_block = [&](auto iter) {
386421
auto other_event = iter->first.event;
387422
auto& other_blocks = iter->second;
@@ -415,6 +450,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
415450
++next_iter; // Points to element after `iter` to allow erasing `iter` in the loop body
416451

417452
if (iter->first.event != stream_event.event) {
453+
lock_guard free_list_lock(iter->second.get_mutex());
418454
block_type const block = find_block(iter);
419455

420456
if (block.is_valid()) {
@@ -435,6 +471,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
435471
cudaEvent_t other_event,
436472
free_list&& other_blocks)
437473
{
474+
RMM_FUNC_RANGE();
438475
// Since we found a block associated with a different stream, we have to insert a wait
439476
// on the stream's associated event into the allocating stream.
440477
RMM_CUDA_TRY(cudaStreamWaitEvent(stream_event.stream, other_event, 0));
@@ -450,7 +487,10 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
450487
*/
451488
void release()
452489
{
453-
lock_guard lock(mtx_);
490+
RMM_FUNC_RANGE();
491+
// lock_guard lock(mtx_); TOOD(jigao): rethink mtx_
492+
write_lock_guard stream_event_lock(stream_events_mtx_);
493+
write_lock_guard wlock(stream_free_blocks_mtx_);
454494

455495
for (auto s_e : stream_events_) {
456496
RMM_ASSERT_CUDA_SUCCESS(cudaEventSynchronize(s_e.second.event));
@@ -464,6 +504,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
464504
void log_summary_trace()
465505
{
466506
#if (RMM_LOG_ACTIVE_LEVEL <= RMM_LOG_LEVEL_TRACE)
507+
RMM_FUNC_RANGE();
467508
std::size_t num_blocks{0};
468509
std::size_t max_block{0};
469510
std::size_t free_mem{0};
@@ -491,8 +532,17 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
491532
// bidirectional mapping between non-default streams and events
492533
std::unordered_map<cudaStream_t, stream_event_pair> stream_events_;
493534

535+
//TODO(jigao): think about get_mutex function?
494536
std::mutex mtx_; // mutex for thread-safe access
495537

538+
// mutex for thread-safe access to stream_free_blocks_
539+
// Used in the writing part of get_block, get_block_from_other_stream
540+
std::shared_mutex stream_free_blocks_mtx_;
541+
542+
// mutex for thread-safe access to stream_events_
543+
// Used in the NON-PTDS part of get_event
544+
std::shared_mutex stream_events_mtx_;
545+
496546
rmm::cuda_device_id device_id_{rmm::get_current_cuda_device()};
497547
}; // namespace detail
498548

cpp/include/rmm/mr/device/pool_memory_resource.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ class pool_memory_resource final
400400
*/
401401
block_type free_block(void* ptr, std::size_t size) noexcept
402402
{
403+
RMM_FUNC_RANGE();
403404
#ifdef RMM_POOL_TRACK_ALLOCATIONS
404405
if (ptr == nullptr) return block_type{};
405406
auto const iter = allocated_blocks_.find(static_cast<char*>(ptr));

cpp/tests/mr/device/mr_ref_multithreaded_tests.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ void spawn_n(std::size_t num_threads, Task task, Arguments&&... args)
5757
template <typename Task, typename... Arguments>
5858
void spawn(Task task, Arguments&&... args)
5959
{
60-
spawn_n(4, task, std::forward<Arguments>(args)...);
60+
spawn_n(16, task, std::forward<Arguments>(args)...);
6161
}
6262

6363
TEST(DefaultTest, UseCurrentDeviceResource_mt) { spawn(test_get_current_device_resource); }

0 commit comments

Comments
 (0)