diff --git a/fbgemm_gpu/src/input_combine_ops/input_combine.cu b/fbgemm_gpu/src/input_combine_ops/input_combine.cu index 64003b026c..57e6804585 100644 --- a/fbgemm_gpu/src/input_combine_ops/input_combine.cu +++ b/fbgemm_gpu/src/input_combine_ops/input_combine.cu @@ -24,7 +24,11 @@ DEVICE_INLINE void vec_copy_with_implicit_type_cast( const uint64_t src_bound) { // TODO: Use vector load/store if address aligns with the vector type const src_t* const src = reinterpret_cast(src_addr); +#ifdef __HIP_PLATFORM_AMD__ +#pragma unroll 4 +#else #pragma unroll +#endif for (uint64_t i = 0; i < VEC_WIDTH && src_offset + i < src_bound; i++) { dst[dst_offset + i] = src[src_offset + i]; } @@ -130,9 +134,12 @@ std::tuple tbe_input_combine_with_length_cuda( .dtype(at::kFloat) .device(at::kCUDA, at::cuda::current_device())); - // Each thread loads 4 elements (rule of thumb; should work well with 32-bit - // inputs) - constexpr uint32_t VEC_WIDTH = 4; +// Each thread loads VEC_WIDTH elements (tuned for specific hardware) +#ifdef __HIP_PLATFORM_AMD__ + constexpr uint32_t VEC_WIDTH = 32; +#else + constexpr uint32_t VEC_WIDTH = 8; +#endif constexpr uint32_t NUM_WARPS_PER_BLOCK = kMaxThreads / kWarpSize; const auto num_warps_per_list = div_round_up(max_list_size, kWarpSize * VEC_WIDTH);