Skip to content

Commit 3cd28ce

Browse files
JChunXfacebook-github-bot
authored andcommitted
Optimize tbe_input_combine_with_length_cuda on AMD
Summary: X-link: facebookresearch/FBGEMM#1496 DPER frontend benchmark show tbe_input_combine_with_length_cuda as one of the top contributors to local net latency on CMF fully remote model. Especially on AMD, where latency for this kernel is ~2x of NVIDIA (albeit AMD executes with more kernels in parallel). VEC_WIDTH=32 Increase items processed on AMD per thread, improving memory access patterns and taking advantage of AMD GPU larger memory bandwidth. Reviewed By: q10 Differential Revision: D75886673
1 parent 0dbd1bc commit 3cd28ce

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

fbgemm_gpu/src/input_combine_ops/input_combine.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ DEVICE_INLINE void vec_copy_with_implicit_type_cast(
2424
const uint64_t src_bound) {
2525
// TODO: Use vector load/store if address aligns with the vector type
2626
const src_t* const src = reinterpret_cast<src_t*>(src_addr);
27+
#ifdef __HIP_PLATFORM_AMD__
28+
#pragma unroll 4
29+
#else
2730
#pragma unroll
31+
#endif
2832
for (uint64_t i = 0; i < VEC_WIDTH && src_offset + i < src_bound; i++) {
2933
dst[dst_offset + i] = src[src_offset + i];
3034
}
@@ -130,9 +134,12 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cuda(
130134
.dtype(at::kFloat)
131135
.device(at::kCUDA, at::cuda::current_device()));
132136

133-
// Each thread loads 4 elements (rule of thumb; should work well with 32-bit
134-
// inputs)
135-
constexpr uint32_t VEC_WIDTH = 4;
137+
// Each thread loads VEC_WIDTH elements (tuned for specific hardware)
138+
#ifdef __HIP_PLATFORM_AMD__
139+
constexpr uint32_t VEC_WIDTH = 32;
140+
#else
141+
constexpr uint32_t VEC_WIDTH = 8;
142+
#endif
136143
constexpr uint32_t NUM_WARPS_PER_BLOCK = kMaxThreads / kWarpSize;
137144
const auto num_warps_per_list =
138145
div_round_up(max_list_size, kWarpSize * VEC_WIDTH);

0 commit comments

Comments
 (0)