Skip to content

Commit e9ce63d

Browse files
q10facebook-github-bot
authored andcommitted
Add set_max_dynamic_smem (pytorch#4398)
Summary: Pull Request resolved: pytorch#4398 X-link: facebookresearch/FBGEMM#1469 - Fold out duplicate code with setting `cudaFuncAttributeMaxDynamicSharedMemorySize` into `set_max_dynamic_smem` Reviewed By: jianyuh, ionuthristodorescu Differential Revision: D76700646 fbshipit-source-id: 01c4b651735f3b1c5c5d24d0af9b13ccd4da7398
1 parent 06247d1 commit e9ce63d

File tree

9 files changed

+109
-97
lines changed

9 files changed

+109
-97
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "fbgemm_gpu/sparse_ops.h"
2525
#include "fbgemm_gpu/config/feature_gates.h"
2626
#include "fbgemm_gpu/split_embeddings_utils.cuh"
27+
#include "fbgemm_gpu/utils/cuda_utilities.cuh"
2728
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
2829
#include "fbgemm_gpu/utils/ops_utils.h"
2930
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
@@ -483,20 +484,8 @@ int32_t compute_num_groups_and_dynamic_smem_bytes(
483484
}
484485
TORCH_CHECK_GE(*num_groups, 1);
485486

486-
// Check https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x
487-
// "Compute capability 7.x devices allow a single thread block to
488-
// address the full capacity of shared memory: 96 KB on Volta,
489-
// 64 KB on Turing. Kernels relying on shared memory allocations
490-
// over 48 KB per block are architecture-specific, as such they
491-
// must use dynamic shared memory (rather than statically sized
492-
// arrays) and require an explicit opt-in using cudaFuncSetAttribute()".
493-
#ifndef USE_ROCM
494-
cudaFuncSetAttribute(
495-
bwd_kernel_fn,
496-
cudaFuncAttributeMaxDynamicSharedMemorySize,
497-
used_shared_bytes); // V100: 64 KB; A100: 96 KB; H100: 144 KB
498-
C10_CUDA_KERNEL_LAUNCH_CHECK();
499-
#endif
487+
utils::cuda::set_max_dynamic_smem(bwd_kernel_fn, used_shared_bytes);
488+
500489
return smem_bytes;
501490
}
502491

fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
////////////////////////////////////////////////////////////////////////////////
3838
#include "fbgemm_gpu/utils/ops_utils.h"
3939
{%- endif %}
40-
#include "fbgemm_gpu/utils/device_properties.cuh"
40+
#include "fbgemm_gpu/utils/cuda_utilities.cuh"
4141
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
4242
#include "fbgemm_gpu/embedding_forward_template_helpers.cuh"
4343
#include "fbgemm_gpu/split_embeddings_cache_cuda.cuh"
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ATen/cuda/CUDAContext.h>
12+
#include <c10/cuda/CUDAException.h>
13+
#include <c10/cuda/CUDAStream.h>
14+
#include <cuda.h>
15+
16+
namespace fbgemm_gpu::utils::cuda {
17+
18+
// Based on the empirical study, max grid size that is 64x larger than the
19+
// number of SMs gives good performance across the board
20+
constexpr int32_t MAX_THREAD_BLOCKS_FACTOR = 64;
21+
22+
inline auto get_max_thread_blocks(const c10::cuda::CUDAStream& stream) {
23+
const auto device = stream.device_index();
24+
return MAX_THREAD_BLOCKS_FACTOR *
25+
at::cuda::getDeviceProperties(device)->multiProcessorCount;
26+
}
27+
28+
inline auto get_compute_versions() {
29+
static const auto versions = [] {
30+
int runtime_version = 0;
31+
cudaRuntimeGetVersion(&runtime_version);
32+
33+
int driver_version = 0;
34+
cudaDriverGetVersion(&driver_version);
35+
36+
return std::make_tuple(runtime_version, driver_version);
37+
}();
38+
39+
return versions;
40+
}
41+
42+
template <typename func_t>
43+
inline void set_max_dynamic_smem(
44+
func_t kernel,
45+
const int32_t smem_bytes,
46+
const int32_t device = at::cuda::current_device()) {
47+
#ifndef USE_ROCM
48+
49+
// Check
50+
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x
51+
// "Compute capability 7.x devices allow a single thread block to
52+
// address the full capacity of shared memory: 96 KB on Volta,
53+
// 64 KB on Turing. Kernels relying on shared memory allocations
54+
// over 48 KB per block are architecture-specific, as such they
55+
// must use dynamic shared memory (rather than statically sized
56+
// arrays) and require an explicit opt-in using cudaFuncSetAttribute()".
57+
58+
TORCH_CHECK(smem_bytes > 0);
59+
60+
int max_smem_bytes = 0;
61+
C10_CUDA_CHECK(cudaDeviceGetAttribute(
62+
&max_smem_bytes,
63+
#ifndef __HIP_PLATFORM_AMD__
64+
cudaDevAttrMaxSharedMemoryPerBlockOptin,
65+
#else
66+
hipDeviceAttributeMaxSharedMemoryPerBlock,
67+
#endif
68+
device));
69+
70+
TORCH_CHECK(
71+
smem_bytes <= max_smem_bytes,
72+
"Attempted to allocate ",
73+
smem_bytes / 1024,
74+
" KB of shared memory but only ",
75+
max_smem_bytes / 1024,
76+
" KB is available");
77+
78+
C10_CUDA_CHECK(cudaFuncSetAttribute(
79+
reinterpret_cast<void*>(kernel),
80+
cudaFuncAttributeMaxDynamicSharedMemorySize,
81+
// V100: 64 KB; A100: 96 KB; H100: 144 KB
82+
smem_bytes));
83+
84+
#endif
85+
}
86+
87+
} // namespace fbgemm_gpu::utils::cuda

fbgemm_gpu/include/fbgemm_gpu/utils/device_properties.cuh

Lines changed: 0 additions & 42 deletions
This file was deleted.

fbgemm_gpu/src/jagged_tensor_ops/common.cuh

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "fbgemm_gpu/sparse_ops.h"
2929
#include "fbgemm_gpu/utils/binary_search_range.cuh"
3030
#include "fbgemm_gpu/utils/cuda_block_count.h"
31+
#include "fbgemm_gpu/utils/cuda_utilities.cuh"
3132
#include "fbgemm_gpu/utils/dispatch_macros.h"
3233
#include "fbgemm_gpu/utils/fixed_divisor.cuh"
3334
#include "fbgemm_gpu/utils/inclusive_sum_scan.cuh"
@@ -834,14 +835,12 @@ void jagged_dense_elementwise_jagged_output_opt_(
834835
int used_shared_kb = shared_kb;
835836
#endif
836837
int used_shared_bytes = used_shared_kb << 10;
837-
#ifndef USE_ROCM
838-
C10_CUDA_CHECK(cudaFuncSetAttribute(
838+
839+
utils::cuda::set_max_dynamic_smem(
839840
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
840841
index_t>,
841-
cudaFuncAttributeMaxDynamicSharedMemorySize,
842-
used_shared_bytes)); // V100: 64 KB; A100: 96 KB; H100: 144 KB
843-
#endif
844-
C10_CUDA_KERNEL_LAUNCH_CHECK();
842+
used_shared_bytes);
843+
845844
TORCH_CHECK(dynamic_smem_size <= used_shared_bytes);
846845
}
847846

fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,16 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
124124
int used_shared_kb = shared_kb;
125125
#endif
126126
int used_shared_bytes = used_shared_kb << 10;
127-
#ifndef USE_ROCM
128-
C10_CUDA_CHECK(cudaFuncSetAttribute(
127+
TORCH_CHECK_LE(dynamic_smem_size, used_shared_bytes);
128+
129+
utils::cuda::set_max_dynamic_smem(
129130
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
130131
index_t>,
131-
cudaFuncAttributeMaxDynamicSharedMemorySize,
132-
used_shared_bytes)); // V100: 64 KB; A100: 96 KB.
133-
#endif
134-
C10_CUDA_KERNEL_LAUNCH_CHECK();
135-
TORCH_CHECK_LE(dynamic_smem_size, used_shared_bytes);
132+
used_shared_bytes);
136133
}
137-
dim3 threads_bs = dim3(1024, 1, 1);
138-
dim3 blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1);
134+
135+
const auto threads_bs = dim3(1024, 1, 1);
136+
const auto blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1);
139137

140138
#ifdef FBGEMM_GPU_MEMCHECK
141139
const auto func_name1 =

fbgemm_gpu/src/quantize_ops/quantize_mx.cu

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <torch/types.h>
1616

1717
#include "c10/core/ScalarType.h"
18+
#include "fbgemm_gpu/utils/cuda_utilities.cuh"
1819
#include "fbgemm_gpu/utils/ops_utils.h"
1920
#include "fbgemm_gpu/utils/tensor_utils.h"
2021

@@ -81,21 +82,8 @@ int32_t compute_num_groups_and_dynamic_smem_bytes(
8182
}
8283
TORCH_CHECK_GE(*num_groups_per_block, 1);
8384

84-
// Check
85-
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x
86-
// "Compute capability 7.x devices allow a single thread block to
87-
// address the full capacity of shared memory: 96 KB on Volta,
88-
// 64 KB on Turing. Kernels relying on shared memory allocations
89-
// over 48 KB per block are architecture-specific, as such they
90-
// must use dynamic shared memory (rather than statically sized
91-
// arrays) and require an explicit opt-in using cudaFuncSetAttribute()".
92-
#ifndef USE_ROCM
93-
cudaFuncSetAttribute(
94-
kernel_func_name,
95-
cudaFuncAttributeMaxDynamicSharedMemorySize,
96-
used_shared_bytes); // V100: 64 KB; A100: 96 KB; H100: 144 KB
97-
C10_CUDA_KERNEL_LAUNCH_CHECK();
98-
#endif
85+
utils::cuda::set_max_dynamic_smem(kernel_func_name, used_shared_bytes);
86+
9987
return smem_bytes;
10088
}
10189

fbgemm_gpu/src/sparse_ops/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "fbgemm_gpu/sparse_ops.cuh"
1010
#include "fbgemm_gpu/sparse_ops.h"
1111
#include "fbgemm_gpu/utils/cuda_block_count.h"
12+
#include "fbgemm_gpu/utils/cuda_utilities.cuh"
1213
#include "fbgemm_gpu/utils/ops_utils.h"
1314

1415
#include <ATen/ATen.h>

fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,6 @@ void adjust_block_bucketize_sparse_features_kernel_launch_configs_based_on_smem(
7171
grid_dims->x = cuda_calc_xblock_count(lengths_size, block_dims->y);
7272
}
7373

74-
template <typename func_t>
75-
void increase_gpu_max_dynamic_shared_memory(func_t kernel, const int max_smem) {
76-
TORCH_CHECK(max_smem > 0);
77-
C10_CUDA_CHECK(cudaFuncSetAttribute(
78-
(void*)kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_smem));
79-
C10_CUDA_KERNEL_LAUNCH_CHECK();
80-
}
81-
8274
// Kernel for bucketize lengths, with the Block distribution (vs. cyclic,
8375
// block-cyclic distribution). Used for bucketize sparse feature, especially for
8476
// checkpointing with row-wise partition (sparse_feature is partitioned
@@ -562,7 +554,7 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
562554
index_t, \
563555
scalar_t>; \
564556
if (smem_size > smem_adjust_threshold) { \
565-
increase_gpu_max_dynamic_shared_memory( \
557+
utils::cuda::set_max_dynamic_smem( \
566558
block_bucketize_kernel, max_smem); \
567559
} \
568560
block_bucketize_kernel<<< \
@@ -625,7 +617,7 @@ __launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
625617
index_t, \
626618
std::nullptr_t>; \
627619
if (smem_size > smem_adjust_threshold) { \
628-
increase_gpu_max_dynamic_shared_memory( \
620+
utils::cuda::set_max_dynamic_smem( \
629621
block_bucketize_kernel, max_smem); \
630622
} \
631623
block_bucketize_kernel<<< \

0 commit comments

Comments
 (0)