29
29
#include < cstddef>
30
30
#include < map>
31
31
#include < mutex>
32
+ #include < shared_mutex>
32
33
#include < unordered_map>
33
34
#ifdef RMM_DEBUG_PRINT
34
35
#include < iostream>
@@ -87,9 +88,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
87
88
stream_ordered_memory_resource& operator =(stream_ordered_memory_resource&&) = delete ;
88
89
89
90
protected:
90
- using free_list = FreeListType;
91
- using block_type = typename free_list::block_type;
92
- using lock_guard = std::lock_guard<std::mutex>;
91
+ using free_list = FreeListType;
92
+ using block_type = typename free_list::block_type;
93
+ using lock_guard = std::lock_guard<std::mutex>;
94
+ using read_lock_guard = std::shared_lock<std::shared_mutex>;
95
+ using write_lock_guard = std::unique_lock<std::shared_mutex>;
93
96
94
97
// Derived classes must implement these four methods
95
98
@@ -204,12 +207,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
204
207
*/
205
208
void * do_allocate (std::size_t size, cuda_stream_view stream) override
206
209
{
210
+ RMM_FUNC_RANGE ();
207
211
RMM_LOG_TRACE (" [A][stream %s][%zuB]" , rmm::detail::format_stream (stream), size);
208
212
209
213
if (size <= 0 ) { return nullptr ; }
210
214
211
- lock_guard lock (mtx_);
212
-
213
215
auto stream_event = get_event (stream);
214
216
215
217
size = rmm::align_up (size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -224,7 +226,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
224
226
size,
225
227
block.pointer ());
226
228
227
- log_summary_trace ();
229
+ // TODO(jigao): this logging is not protected by mutex!
230
+ // log_summary_trace();
228
231
229
232
return block.pointer ();
230
233
}
@@ -238,11 +241,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
238
241
*/
239
242
void do_deallocate (void * ptr, std::size_t size, cuda_stream_view stream) override
240
243
{
244
+ RMM_FUNC_RANGE ();
241
245
RMM_LOG_TRACE (" [D][stream %s][%zuB][%p]" , rmm::detail::format_stream (stream), size, ptr);
242
246
243
247
if (size <= 0 || ptr == nullptr ) { return ; }
244
248
245
- lock_guard lock (mtx_);
246
249
auto stream_event = get_event (stream);
247
250
248
251
size = rmm::align_up (size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -253,9 +256,21 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
253
256
// streams allows stealing from deleted streams.
254
257
RMM_ASSERT_CUDA_SUCCESS (cudaEventRecord (stream_event.event , stream.value ()));
255
258
256
- stream_free_blocks_[stream_event].insert (block);
257
-
258
- log_summary_trace ();
259
+ read_lock_guard rlock (stream_free_blocks_mtx_);
260
+ // Try to find a satisfactory block in free list for the same stream (no sync required)
261
+ auto iter = stream_free_blocks_.find (stream_event);
262
+ if (iter != stream_free_blocks_.end ()) {
263
+ // Hot path
264
+ lock_guard free_list_lock (iter->second .get_mutex ());
265
+ iter->second .insert (block);
266
+ } else {
267
+ rlock.unlock ();
268
+ // Cold path
269
+ write_lock_guard wlock (stream_free_blocks_mtx_);
270
+ stream_free_blocks_[stream_event].insert (block); // TODO(jigao): is it thread-safe?
271
+ }
272
+ // TODO(jigao): this logging is not protected by mutex!
273
+ // log_summary_trace();
259
274
}
260
275
261
276
private:
@@ -271,7 +286,9 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
271
286
*/
272
287
stream_event_pair get_event (cuda_stream_view stream)
273
288
{
289
+ RMM_FUNC_RANGE ();
274
290
if (stream.is_per_thread_default ()) {
291
+ // Hot path
275
292
// Create a thread-local event for each device. These events are
276
293
// deliberately leaked since the destructor needs to call into
277
294
// the CUDA runtime and thread_local destructors (can) run below
@@ -289,6 +306,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
289
306
}();
290
307
return stream_event_pair{stream.value (), event};
291
308
}
309
+ write_lock_guard wlock (stream_events_mtx_);
310
+ // Cold path
292
311
// We use cudaStreamLegacy as the event map key for the default stream for consistency between
293
312
// PTDS and non-PTDS mode. In PTDS mode, the cudaStreamLegacy map key will only exist if the
294
313
// user explicitly passes it, so it is used as the default location for the free list
@@ -319,6 +338,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
319
338
*/
320
339
block_type allocate_and_insert_remainder (block_type block, std::size_t size, free_list& blocks)
321
340
{
341
+ RMM_FUNC_RANGE ();
322
342
auto const [allocated, remainder] = this ->underlying ().allocate_from_block (block, size);
323
343
if (remainder.is_valid ()) { blocks.insert (remainder); }
324
344
return allocated;
@@ -333,15 +353,30 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
333
353
*/
334
354
block_type get_block (std::size_t size, stream_event_pair stream_event)
335
355
{
336
- // Try to find a satisfactory block in free list for the same stream (no sync required)
337
- auto iter = stream_free_blocks_.find (stream_event);
338
- if (iter != stream_free_blocks_.end ()) {
339
- block_type const block = iter->second .get_block (size);
340
- if (block.is_valid ()) { return allocate_and_insert_remainder (block, size, iter->second ); }
356
+ RMM_FUNC_RANGE ();
357
+ {
358
+ // The hot path of get_block:
359
+ // 1. Read-lock the map for lookup
360
+ // 2. then exclusively lock the free_list to get a block locally.
361
+ read_lock_guard rlock (stream_free_blocks_mtx_);
362
+ // Try to find a satisfactory block in free list for the same stream (no sync required)
363
+ auto iter = stream_free_blocks_.find (stream_event);
364
+ if (iter != stream_free_blocks_.end ()) {
365
+ lock_guard free_list_lock (iter->second .get_mutex ());
366
+ block_type const block = iter->second .get_block (size);
367
+ if (block.is_valid ()) { return allocate_and_insert_remainder (block, size, iter->second ); }
368
+ }
341
369
}
342
370
371
+ // The cold path of get_block:
372
+ // Write lock the map to safely perform another lookup and possibly modify entries.
373
+ // This exclusive lock ensures no other threads can access the map and all free lists in the
374
+ // map.
375
+ write_lock_guard wlock (stream_free_blocks_mtx_);
376
+ auto iter = stream_free_blocks_.find (stream_event);
343
377
free_list& blocks =
344
378
(iter != stream_free_blocks_.end ()) ? iter->second : stream_free_blocks_[stream_event];
379
+ lock_guard free_list_lock (blocks.get_mutex ());
345
380
346
381
// Try to find an existing block in another stream
347
382
{
@@ -382,6 +417,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
382
417
free_list& blocks,
383
418
bool merge_first)
384
419
{
420
+ RMM_FUNC_RANGE ();
385
421
auto find_block = [&](auto iter) {
386
422
auto other_event = iter->first .event ;
387
423
auto & other_blocks = iter->second ;
@@ -415,6 +451,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
415
451
++next_iter; // Points to element after `iter` to allow erasing `iter` in the loop body
416
452
417
453
if (iter->first .event != stream_event.event ) {
454
+ lock_guard free_list_lock (iter->second .get_mutex ());
418
455
block_type const block = find_block (iter);
419
456
420
457
if (block.is_valid ()) {
@@ -435,6 +472,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
435
472
cudaEvent_t other_event,
436
473
free_list&& other_blocks)
437
474
{
475
+ RMM_FUNC_RANGE ();
438
476
// Since we found a block associated with a different stream, we have to insert a wait
439
477
// on the stream's associated event into the allocating stream.
440
478
RMM_CUDA_TRY (cudaStreamWaitEvent (stream_event.stream , other_event, 0 ));
@@ -450,7 +488,10 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
450
488
*/
451
489
void release ()
452
490
{
453
- lock_guard lock (mtx_);
491
+ RMM_FUNC_RANGE ();
492
+ // lock_guard lock(mtx_); TOOD(jigao): rethink mtx_
493
+ write_lock_guard stream_event_lock (stream_events_mtx_);
494
+ write_lock_guard wlock (stream_free_blocks_mtx_);
454
495
455
496
for (auto s_e : stream_events_) {
456
497
RMM_ASSERT_CUDA_SUCCESS (cudaEventSynchronize (s_e.second .event ));
@@ -464,6 +505,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
464
505
void log_summary_trace ()
465
506
{
466
507
#if (RMM_LOG_ACTIVE_LEVEL <= RMM_LOG_LEVEL_TRACE)
508
+ RMM_FUNC_RANGE ();
467
509
std::size_t num_blocks{0 };
468
510
std::size_t max_block{0 };
469
511
std::size_t free_mem{0 };
@@ -491,8 +533,17 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
491
533
// bidirectional mapping between non-default streams and events
492
534
std::unordered_map<cudaStream_t, stream_event_pair> stream_events_;
493
535
536
+ // TODO(jigao): think about get_mutex function?
494
537
std::mutex mtx_; // mutex for thread-safe access
495
538
539
+ // mutex for thread-safe access to stream_free_blocks_
540
+ // Used in the writing part of get_block, get_block_from_other_stream
541
+ std::shared_mutex stream_free_blocks_mtx_;
542
+
543
+ // mutex for thread-safe access to stream_events_
544
+ // Used in the NON-PTDS part of get_event
545
+ std::shared_mutex stream_events_mtx_;
546
+
496
547
rmm::cuda_device_id device_id_{rmm::get_current_cuda_device ()};
497
548
}; // namespace detail
498
549
0 commit comments