29
29
#include < cstddef>
30
30
#include < map>
31
31
#include < mutex>
32
+ #include < shared_mutex>
32
33
#include < unordered_map>
33
34
#ifdef RMM_DEBUG_PRINT
34
35
#include < iostream>
@@ -90,6 +91,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
90
91
using free_list = FreeListType;
91
92
using block_type = typename free_list::block_type;
92
93
using lock_guard = std::lock_guard<std::mutex>;
94
+ using read_lock_guard = std::shared_lock<std::shared_mutex>;
95
+ using write_lock_guard = std::unique_lock<std::shared_mutex>;
93
96
94
97
// Derived classes must implement these four methods
95
98
@@ -204,12 +207,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
204
207
*/
205
208
void * do_allocate (std::size_t size, cuda_stream_view stream) override
206
209
{
210
+ RMM_FUNC_RANGE ();
207
211
RMM_LOG_TRACE (" [A][stream %s][%zuB]" , rmm::detail::format_stream (stream), size);
208
212
209
213
if (size <= 0 ) { return nullptr ; }
210
214
211
- lock_guard lock (mtx_);
212
-
213
215
auto stream_event = get_event (stream);
214
216
215
217
size = rmm::align_up (size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -224,7 +226,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
224
226
size,
225
227
block.pointer ());
226
228
227
- log_summary_trace ();
229
+ // TODO(jigao): this logging is not protected by mutex!
230
+ // log_summary_trace();
228
231
229
232
return block.pointer ();
230
233
}
@@ -238,11 +241,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
238
241
*/
239
242
void do_deallocate (void * ptr, std::size_t size, cuda_stream_view stream) override
240
243
{
244
+ RMM_FUNC_RANGE ();
241
245
RMM_LOG_TRACE (" [D][stream %s][%zuB][%p]" , rmm::detail::format_stream (stream), size, ptr);
242
246
243
247
if (size <= 0 || ptr == nullptr ) { return ; }
244
248
245
- lock_guard lock (mtx_);
246
249
auto stream_event = get_event (stream);
247
250
248
251
size = rmm::align_up (size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -253,9 +256,21 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
253
256
// streams allows stealing from deleted streams.
254
257
RMM_ASSERT_CUDA_SUCCESS (cudaEventRecord (stream_event.event , stream.value ()));
255
258
256
- stream_free_blocks_[stream_event].insert (block);
257
-
258
- log_summary_trace ();
259
+ read_lock_guard rlock (stream_free_blocks_mtx_);
260
+ // Try to find a satisfactory block in free list for the same stream (no sync required)
261
+ auto iter = stream_free_blocks_.find (stream_event);
262
+ if (iter != stream_free_blocks_.end ()) {
263
+ // Hot path
264
+ lock_guard free_list_lock (iter->second .get_mutex ());
265
+ iter->second .insert (block);
266
+ } else {
267
+ rlock.unlock ();
268
+ // Cold path
269
+ write_lock_guard wlock (stream_free_blocks_mtx_);
270
+ stream_free_blocks_[stream_event].insert (block); // TODO(jigao): is it thread-safe?
271
+ }
272
+ // TODO(jigao): this logging is not protected by mutex!
273
+ // log_summary_trace();
259
274
}
260
275
261
276
private:
@@ -271,7 +286,9 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
271
286
*/
272
287
stream_event_pair get_event (cuda_stream_view stream)
273
288
{
289
+ RMM_FUNC_RANGE ();
274
290
if (stream.is_per_thread_default ()) {
291
+ // Hot path
275
292
// Create a thread-local event for each device. These events are
276
293
// deliberately leaked since the destructor needs to call into
277
294
// the CUDA runtime and thread_local destructors (can) run below
@@ -289,6 +306,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
289
306
}();
290
307
return stream_event_pair{stream.value (), event};
291
308
}
309
+ write_lock_guard wlock (stream_events_mtx_);
310
+ // Cold path
292
311
// We use cudaStreamLegacy as the event map key for the default stream for consistency between
293
312
// PTDS and non-PTDS mode. In PTDS mode, the cudaStreamLegacy map key will only exist if the
294
313
// user explicitly passes it, so it is used as the default location for the free list
@@ -319,6 +338,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
319
338
*/
320
339
block_type allocate_and_insert_remainder (block_type block, std::size_t size, free_list& blocks)
321
340
{
341
+ RMM_FUNC_RANGE ();
322
342
auto const [allocated, remainder] = this ->underlying ().allocate_from_block (block, size);
323
343
if (remainder.is_valid ()) { blocks.insert (remainder); }
324
344
return allocated;
@@ -333,15 +353,29 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
333
353
*/
334
354
block_type get_block (std::size_t size, stream_event_pair stream_event)
335
355
{
336
- // Try to find a satisfactory block in free list for the same stream (no sync required)
337
- auto iter = stream_free_blocks_.find (stream_event);
338
- if (iter != stream_free_blocks_.end ()) {
339
- block_type const block = iter->second .get_block (size);
340
- if (block.is_valid ()) { return allocate_and_insert_remainder (block, size, iter->second ); }
356
+ RMM_FUNC_RANGE ();
357
+ {
358
+ // The hot path of get_block:
359
+ // 1. Read-lock the map for lookup
360
+ // 2. then exclusively lock the free_list to get a block locally.
361
+ read_lock_guard rlock (stream_free_blocks_mtx_);
362
+ // Try to find a satisfactory block in free list for the same stream (no sync required)
363
+ auto iter = stream_free_blocks_.find (stream_event);
364
+ if (iter != stream_free_blocks_.end ()) {
365
+ lock_guard free_list_lock (iter->second .get_mutex ());
366
+ block_type const block = iter->second .get_block (size);
367
+ if (block.is_valid ()) { return allocate_and_insert_remainder (block, size, iter->second ); }
368
+ }
341
369
}
342
370
371
+ // The cold path of get_block:
372
+ // Write lock the map to safely perform another lookup and possibly modify entries.
373
+ // This exclusive lock ensures no other threads can access the map and all free lists in the map.
374
+ write_lock_guard wlock (stream_free_blocks_mtx_);
375
+ auto iter = stream_free_blocks_.find (stream_event);
343
376
free_list& blocks =
344
377
(iter != stream_free_blocks_.end ()) ? iter->second : stream_free_blocks_[stream_event];
378
+ lock_guard free_list_lock (blocks.get_mutex ());
345
379
346
380
// Try to find an existing block in another stream
347
381
{
@@ -382,6 +416,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
382
416
free_list& blocks,
383
417
bool merge_first)
384
418
{
419
+ RMM_FUNC_RANGE ();
385
420
auto find_block = [&](auto iter) {
386
421
auto other_event = iter->first .event ;
387
422
auto & other_blocks = iter->second ;
@@ -415,6 +450,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
415
450
++next_iter; // Points to element after `iter` to allow erasing `iter` in the loop body
416
451
417
452
if (iter->first .event != stream_event.event ) {
453
+ lock_guard free_list_lock (iter->second .get_mutex ());
418
454
block_type const block = find_block (iter);
419
455
420
456
if (block.is_valid ()) {
@@ -435,6 +471,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
435
471
cudaEvent_t other_event,
436
472
free_list&& other_blocks)
437
473
{
474
+ RMM_FUNC_RANGE ();
438
475
// Since we found a block associated with a different stream, we have to insert a wait
439
476
// on the stream's associated event into the allocating stream.
440
477
RMM_CUDA_TRY (cudaStreamWaitEvent (stream_event.stream , other_event, 0 ));
@@ -450,7 +487,10 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
450
487
*/
451
488
void release ()
452
489
{
453
- lock_guard lock (mtx_);
490
+ RMM_FUNC_RANGE ();
491
+ // lock_guard lock(mtx_); TOOD(jigao): rethink mtx_
492
+ write_lock_guard stream_event_lock (stream_events_mtx_);
493
+ write_lock_guard wlock (stream_free_blocks_mtx_);
454
494
455
495
for (auto s_e : stream_events_) {
456
496
RMM_ASSERT_CUDA_SUCCESS (cudaEventSynchronize (s_e.second .event ));
@@ -464,6 +504,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
464
504
void log_summary_trace ()
465
505
{
466
506
#if (RMM_LOG_ACTIVE_LEVEL <= RMM_LOG_LEVEL_TRACE)
507
+ RMM_FUNC_RANGE ();
467
508
std::size_t num_blocks{0 };
468
509
std::size_t max_block{0 };
469
510
std::size_t free_mem{0 };
@@ -491,8 +532,17 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
491
532
// bidirectional mapping between non-default streams and events
492
533
std::unordered_map<cudaStream_t, stream_event_pair> stream_events_;
493
534
535
+ // TODO(jigao): think about get_mutex function?
494
536
std::mutex mtx_; // mutex for thread-safe access
495
537
538
+ // mutex for thread-safe access to stream_free_blocks_
539
+ // Used in the writing part of get_block, get_block_from_other_stream
540
+ std::shared_mutex stream_free_blocks_mtx_;
541
+
542
+ // mutex for thread-safe access to stream_events_
543
+ // Used in the NON-PTDS part of get_event
544
+ std::shared_mutex stream_events_mtx_;
545
+
496
546
rmm::cuda_device_id device_id_{rmm::get_current_cuda_device ()};
497
547
}; // namespace detail
498
548
0 commit comments