Skip to content

Commit 75dd112

Browse files
banitag1facebook-github-bot
authored andcommitted
using different mechanism for host mapped pinned memory (#1638)
Summary: Pull Request resolved: #1638 This diff adds another mechanism for allocating the host mapped pinned memory to reduce adverse affect on other processes running on the same host when one process is doing some large allocations. Reviewed By: zyan0, jianyuh Differential Revision: D43950253 fbshipit-source-id: 41a434cb63354509d32e00c851c5f3a2d68be686
1 parent 8616ed7 commit 75dd112

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

fbgemm_gpu/src/cumem_utils.cu

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ struct CUDAHostMappedContext {
4141
~CUDAHostMappedContext() {
4242
at::cuda::OptionalCUDAGuard device_guard;
4343
device_guard.set_index(cuda_device_);
44-
AT_CUDA_CHECK(cudaFreeHost(ptr_));
44+
AT_CUDA_CHECK(cudaHostUnregister(ptr_));
45+
free(ptr_);
4546
}
4647

4748
static void release(void* ptr) {
@@ -206,9 +207,28 @@ Tensor new_host_mapped_tensor(
206207
auto strides = defaultStrides(sizes);
207208
size_t size_bytes =
208209
at::detail::computeStorageNbytes(sizes, strides, self.dtype().itemsize());
209-
void* ptr;
210-
AT_CUDA_CHECK(cudaHostAlloc(
211-
&ptr, size_bytes, cudaHostAllocWriteCombined | cudaHostAllocMapped));
210+
211+
// When using cudaHostAlloc for large allocations, we found that it can
212+
// potentially take a global lock and lock out CUDA APIs from other processes.
213+
// The main cost in cudaHostAlloc is faulting/mapping the pages. So, instead
214+
// of using this cuda API, we can do regular malloc, pre-fault the pages, and
215+
// then do cudaHostRegister with GPU mapping flags to lock the pages, so we
216+
// can minimize the cost while holding this global lock.
217+
void* const ptr = malloc(size_bytes);
218+
219+
// advise the kernel to allocate large 2M pages
220+
madvise(ptr, size_bytes, MADV_HUGEPAGE);
221+
222+
// pre-fault/map the pages by setting the first byte of the page
223+
size_t pageSize = (1 << 21);
224+
uintptr_t alignedPtr = (((uintptr_t)ptr + pageSize - 1) & ~(pageSize - 1));
225+
for (uintptr_t p = alignedPtr; p < ((uintptr_t)ptr + size_bytes);
226+
p += pageSize) {
227+
memset((void*)p, 0, 1);
228+
}
229+
230+
AT_CUDA_CHECK(cudaHostRegister(
231+
ptr, size_bytes, cudaHostRegisterMapped | cudaHostRegisterPortable));
212232
void* dev_ptr;
213233
AT_CUDA_CHECK(cudaHostGetDevicePointer(&dev_ptr, ptr, 0));
214234

0 commit comments

Comments
 (0)