Skip to content

Commit 3571258

Browse files
ghq24intfacebook-github-bot
authored andcommitted
Invoke AMD specific kernel reorder_batched_ad_indices_kernel_vec (#4412)
Summary: Pull Request resolved: #4412 X-link: facebookresearch/FBGEMM#1483 For the benchmark in the codebase, the larger the profuct of length and num-ads is, the better performance. Two optimization: 1. Vector loading in a warp. 2. The product of batch-size and table-size determines the # of thread blocks (https://www.internalfb.com/code/fbsource/[cecfed562b79afad0eb9c44259141f50352da342]/fbcode/deeplearning/fbgemm/fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu?lines=361). In MRS models, we expect more thread blocks in our user cases. As such, we shrink the block size to achieve more thread blocks, thus improving compute utilization. Performance results and local test benchmarks: D77066925 Reviewed By: jwfromm, jianyuh, q10 Differential Revision: D77459476 fbshipit-source-id: 178a111cbcc67a59986410027bacbe75fc92ab26
1 parent b16ce18 commit 3571258

File tree

2 files changed

+30
-23
lines changed

2 files changed

+30
-23
lines changed

fbgemm_gpu/src/sparse_ops/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "fbgemm_gpu/split_embeddings_utils.cuh"
3333
#include "fbgemm_gpu/utils/binary_search_range.cuh"
3434
#include "fbgemm_gpu/utils/dispatch_macros.h"
35+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
3536
#include "fbgemm_gpu/utils/log2.h"
3637
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
3738

fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ __launch_bounds__(fbgemm_gpu::kMaxThreads) void reorder_batched_ad_indices_kerne
300300
// Idea: we want to copy the entire segment of size sum_a(length_{b, t, a})
301301
// from starting point (given by cat_ad_offsets[b, t])
302302
// to end point (given by reordered_cat_ad_indices[t][b])
303-
if (num_elements <= 64) {
303+
if (num_elements <= 64 || !(sizeof(Dtype) == 4 || sizeof(Dtype) == 8)) {
304304
for (auto i = threadIdx.x; i < input_segment_end - input_segment_start;
305305
i += blockDim.x) {
306306
// coalesced global memory access, can be optimzed through ILP with the
@@ -450,11 +450,6 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
450450
return reordered_cat_ad_indices;
451451
}
452452
}
453-
constexpr auto NUM_WARPS = 32;
454-
auto maxWarpSize = kMaxThreads / NUM_WARPS;
455-
const dim3 threads(
456-
NUM_WARPS, maxWarpSize < kWarpSize ? maxWarpSize : kWarpSize); // 32 x 32
457-
const dim3 blocks(cuda_calc_xblock_count(B * T, NUM_WARPS));
458453
FBGEMM_DISPATCH_ALL_TYPES(
459454
cat_ad_indices.scalar_type(),
460455
"reorder_batched_ad_indices_gpu_kernel_1",
@@ -463,24 +458,35 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
463458
cat_ad_offsets.scalar_type(),
464459
"reorder_batched_ad_indices_gpu_kernel_2",
465460
[&] {
466-
#ifdef FBGEMM_GPU_MEMCHECK
467-
const auto func_name = "reorder_batched_ad_indices_kernel";
461+
#if defined __HIP_PLATFORM_AMD__
462+
constexpr auto NUM_WARPS = 4;
463+
const dim3 threads(32, NUM_WARPS); // 32 x 4
464+
const dim3 blocks(cuda_calc_xblock_count(B * T, NUM_WARPS));
465+
constexpr auto reorder_batched_ad_indices_kernel_name =
466+
reorder_batched_ad_indices_kernel_vec<scalar_t, index_t>;
467+
#else
468+
constexpr auto NUM_WARPS = 32;
469+
auto maxWarpSize = kMaxThreads / NUM_WARPS;
470+
const dim3 threads(
471+
NUM_WARPS,
472+
maxWarpSize < kWarpSize ? maxWarpSize : kWarpSize); // 32 x 32
473+
const dim3 blocks(cuda_calc_xblock_count(B * T, NUM_WARPS));
474+
constexpr auto reorder_batched_ad_indices_kernel_name =
475+
reorder_batched_ad_indices_kernel<scalar_t, index_t>;
468476
#endif
469-
reorder_batched_ad_indices_kernel<scalar_t, index_t>
470-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
471-
MAKE_PTA_WITH_NAME(
472-
func_name, cat_ad_offsets, index_t, 1, 32),
473-
MAKE_PTA_WITH_NAME(
474-
func_name, cat_ad_indices, scalar_t, 1, 32),
475-
MAKE_PTA_WITH_NAME(
476-
func_name, reordered_cat_ad_offsets, index_t, 1, 32),
477-
MAKE_PTA_WITH_NAME(
478-
func_name, reordered_cat_ad_indices, scalar_t, 1, 32),
479-
MAKE_PTA_WITH_NAME(
480-
func_name, batch_offsets, int32_t, 1, 32),
481-
T,
482-
broadcast_indices);
483-
C10_CUDA_KERNEL_LAUNCH_CHECK();
477+
FBGEMM_LAUNCH_KERNEL(
478+
(reorder_batched_ad_indices_kernel_name),
479+
blocks,
480+
threads,
481+
0,
482+
at::cuda::getCurrentCUDAStream(),
483+
PTA_B(cat_ad_offsets, index_t, 1, 32),
484+
PTA_B(cat_ad_indices, scalar_t, 1, 32),
485+
PTA_B(reordered_cat_ad_offsets, index_t, 1, 32),
486+
PTA_B(reordered_cat_ad_indices, scalar_t, 1, 32),
487+
PTA_B(batch_offsets, int32_t, 1, 32),
488+
T,
489+
broadcast_indices);
484490
});
485491
});
486492

0 commit comments

Comments
 (0)