From c7d94bb226478259bd5d6f5273e11b1870fc1fa7 Mon Sep 17 00:00:00 2001 From: Jason Xie Date: Wed, 2 Jul 2025 11:42:40 -0700 Subject: [PATCH] Optimize tbe_input_combine_with_length_cuda on AMD (#4430) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/4430 X-link: https://github.com/facebookresearch/FBGEMM/pull/1496 DPER frontend benchmark show tbe_input_combine_with_length_cuda as one of the top contributors to local net latency on CMF fully remote model. Especially on AMD, where latency for this kernel is ~2x of NVIDIA (albeit AMD executes with more kernels in parallel). VEC_WIDTH=32 Increase items processed on AMD per thread, improving memory access patterns and taking advantage of AMD GPU larger memory bandwidth. Reviewed By: q10 Differential Revision: D75886673 --- fbgemm_gpu/src/input_combine_ops/input_combine.cu | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/src/input_combine_ops/input_combine.cu b/fbgemm_gpu/src/input_combine_ops/input_combine.cu index 64003b026c..57e6804585 100644 --- a/fbgemm_gpu/src/input_combine_ops/input_combine.cu +++ b/fbgemm_gpu/src/input_combine_ops/input_combine.cu @@ -24,7 +24,11 @@ DEVICE_INLINE void vec_copy_with_implicit_type_cast( const uint64_t src_bound) { // TODO: Use vector load/store if address aligns with the vector type const src_t* const src = reinterpret_cast(src_addr); +#ifdef __HIP_PLATFORM_AMD__ +#pragma unroll 4 +#else #pragma unroll +#endif for (uint64_t i = 0; i < VEC_WIDTH && src_offset + i < src_bound; i++) { dst[dst_offset + i] = src[src_offset + i]; } @@ -130,9 +134,12 @@ std::tuple tbe_input_combine_with_length_cuda( .dtype(at::kFloat) .device(at::kCUDA, at::cuda::current_device())); - // Each thread loads 4 elements (rule of thumb; should work well with 32-bit - // inputs) - constexpr uint32_t VEC_WIDTH = 4; +// Each thread loads VEC_WIDTH elements (tuned for specific hardware) +#ifdef __HIP_PLATFORM_AMD__ + constexpr uint32_t VEC_WIDTH = 32; +#else + constexpr uint32_t VEC_WIDTH = 8; +#endif constexpr uint32_t NUM_WARPS_PER_BLOCK = kMaxThreads / kWarpSize; const auto num_warps_per_list = div_round_up(max_list_size, kWarpSize * VEC_WIDTH);