Skip to content

Optimize tbe_input_combine_with_length_cuda on AMD #4430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions fbgemm_gpu/src/input_combine_ops/input_combine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ DEVICE_INLINE void vec_copy_with_implicit_type_cast(
const uint64_t src_bound) {
// TODO: Use vector load/store if address aligns with the vector type
const src_t* const src = reinterpret_cast<src_t*>(src_addr);
#ifdef __HIP_PLATFORM_AMD__
#pragma unroll 4
#else
#pragma unroll
#endif
for (uint64_t i = 0; i < VEC_WIDTH && src_offset + i < src_bound; i++) {
dst[dst_offset + i] = src[src_offset + i];
}
Expand Down Expand Up @@ -130,9 +134,12 @@ std::tuple<Tensor, Tensor, Tensor> tbe_input_combine_with_length_cuda(
.dtype(at::kFloat)
.device(at::kCUDA, at::cuda::current_device()));

// Each thread loads 4 elements (rule of thumb; should work well with 32-bit
// inputs)
constexpr uint32_t VEC_WIDTH = 4;
// Each thread loads VEC_WIDTH elements (tuned for specific hardware)
#ifdef __HIP_PLATFORM_AMD__
constexpr uint32_t VEC_WIDTH = 32;
#else
constexpr uint32_t VEC_WIDTH = 8;
#endif
constexpr uint32_t NUM_WARPS_PER_BLOCK = kMaxThreads / kWarpSize;
const auto num_warps_per_list =
div_round_up(max_list_size, kWarpSize * VEC_WIDTH);
Expand Down
Loading