Skip to content

Commit be11e98

Browse files
sryapfacebook-github-bot
authored andcommitted
Add variable number of columns support to group_index_select (#1592)
Summary: Pull Request resolved: #1592 Before this diff, `group_index_select` requires all input tensors to have the same shape. This diff allows input tensors to have different shapes. However, it still requires the input tensors to have the same number of dimensions and for the first dimensions to match. Reviewed By: jianyuh Differential Revision: D43216425 fbshipit-source-id: fc8c74a472616c0e46785eb71482fa9012155f7e
1 parent 935d8b5 commit be11e98

File tree

5 files changed

+311
-260
lines changed

5 files changed

+311
-260
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -714,31 +714,23 @@ at::Tensor index_add_with_unique_indices_cuda(
714714
const int consecutive_range_length);
715715

716716
///@ingroup sparse-data-cuda
717-
std::vector<at::Tensor> group_index_select_cuda(
717+
void group_index_select_or_add_cuda(
718718
const int64_t* input_ptrs,
719+
const int64_t* output_ptrs,
719720
const int64_t* indices_ptrs,
720-
const c10::TensorOptions& input_tensor_options,
721+
const int64_t* warp_offsets_group,
722+
const int32_t* num_cols_group,
721723
const c10::ScalarType& input_scalar_type,
722724
const c10::ScalarType& indices_scalar_type,
723725
const c10::DeviceIndex& device,
724-
const std::vector<int64_t>& output_shape,
725-
const int num_input_rows,
726-
const int num_output_rows,
727-
const int num_cols,
728-
const int num_groups);
729-
730-
std::vector<at::Tensor> group_index_add_cuda(
731-
const int64_t* input_ptrs,
732-
const int64_t* indices_ptrs,
733-
const c10::TensorOptions& input_tensor_options,
734-
const c10::ScalarType& input_scalar_type,
735-
const c10::ScalarType& indices_scalar_type,
736-
const c10::DeviceIndex& device,
737-
const std::vector<int64_t>& output_shape,
738-
const int num_input_rows,
739-
const int num_output_rows,
740-
const int num_cols,
741-
const int num_groups);
726+
const int max_indices,
727+
const int num_work_rows,
728+
const int64_t total_num_warps,
729+
const int group_size,
730+
const bool use_index_select,
731+
const bool use_var_cols);
732+
733+
int get_group_index_select_cols_per_warp();
742734

743735
std::vector<at::Tensor> jagged_index_select_2d(
744736
const at::Tensor& values,

fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,23 @@ void binary_search_range_cpu(
376376
}
377377
*found = found_;
378378
}
379+
380+
template <int x>
381+
struct log2_calc_ {
382+
enum { value = log2_calc_<(x >> 1)>::value + 1 };
383+
};
384+
template <>
385+
struct log2_calc_<0> {
386+
enum { value = 0 };
387+
};
388+
389+
template <int x>
390+
struct log2_calc {
391+
enum { value = log2_calc_<(x - 1)>::value };
392+
};
393+
#if 0
394+
template <>
395+
struct log2_calc<0> { enum { value = 0 }; };
396+
template <>
397+
struct log2_calc<1> { enum { value = 0 }; };
398+
#endif

fbgemm_gpu/src/sparse_ops.cu

Lines changed: 118 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,198 +2994,160 @@ Tensor index_add_with_unique_indices_cuda(
29942994
return input_grad.reshape(input_shape);
29952995
}
29962996
2997-
template <typename index_t, typename scalar_t, int UNROLL_FACTOR>
2998-
__global__ __launch_bounds__(kMaxThreads) void group_index_select_2d_kernel(
2999-
const int64_t* input_ptrs,
3000-
const int64_t* indices_ptrs,
3001-
scalar_t* output,
3002-
const int64_t num_input_rows,
3003-
const int64_t num_output_rows,
3004-
const int64_t num_cols,
3005-
const int64_t num_groups) {
3006-
for (int64_t bid = threadIdx.y * gridDim.x + blockIdx.x;
3007-
bid < num_groups * num_output_rows;
3008-
bid += gridDim.x * blockDim.y) {
3009-
const int64_t group_id = bid / num_output_rows;
3010-
const int64_t row = bid % num_output_rows;
3011-
scalar_t* input = (scalar_t*)input_ptrs[group_id];
3012-
index_t* indices = (index_t*)indices_ptrs[group_id];
3013-
const index_t idx = indices[row];
3014-
CUDA_KERNEL_ASSERT(idx < num_input_rows)
3015-
int col;
3016-
scalar_t* output_ = output + (num_output_rows * num_cols * group_id);
3017-
for (col = threadIdx.x * UNROLL_FACTOR;
3018-
col < num_cols / UNROLL_FACTOR * UNROLL_FACTOR;
3019-
col += blockDim.x * UNROLL_FACTOR) {
3020-
#pragma unroll
3021-
for (int i = 0; i < UNROLL_FACTOR; i++) {
3022-
output_[row * num_cols + col + i] =
3023-
LDG(&input[idx * num_cols + col + i]);
3024-
}
3025-
}
3026-
for (; col < num_cols; ++col) {
3027-
output_[row * num_cols + col] = LDG(&input[idx * num_cols + col]);
3028-
}
3029-
}
2997+
// TODO: Update UNROLL_FACTOR
2998+
constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1;
2999+
constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
3000+
GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize;
3001+
// GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
3002+
constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
3003+
log2_calc<GROUP_INDEX_SELECT_COLS_PER_WARP>::value;
3004+
3005+
int get_group_index_select_cols_per_warp() {
3006+
return GROUP_INDEX_SELECT_COLS_PER_WARP;
30303007
}
30313008
3032-
template <typename index_t, typename scalar_t, int UNROLL_FACTOR>
3033-
__global__ __launch_bounds__(kMaxThreads) void group_index_add_2d_kernel(
3009+
template <
3010+
typename index_t,
3011+
typename scalar_t,
3012+
bool USE_INDEX_SELECT,
3013+
bool USE_VAR_COLS,
3014+
int UNROLL_FACTOR,
3015+
int COLS_PER_WARP,
3016+
int LOG_COLS_PER_WARP>
3017+
__global__
3018+
__launch_bounds__(kMaxThreads) void group_index_select_or_add_2d_kernel(
30343019
const int64_t* input_ptrs,
3020+
const int64_t* output_ptrs,
30353021
const int64_t* indices_ptrs,
3036-
scalar_t* output,
3037-
const int64_t num_input_rows,
3038-
const int64_t num_output_rows,
3039-
const int64_t num_cols,
3040-
const int64_t num_groups) {
3041-
for (int64_t bid = threadIdx.y * gridDim.x + blockIdx.x;
3042-
bid < num_groups * num_input_rows;
3043-
bid += gridDim.x * blockDim.y) {
3044-
const int64_t group_id = bid / num_input_rows;
3045-
const int64_t row = bid % num_input_rows;
3046-
scalar_t* input = (scalar_t*)input_ptrs[group_id];
3047-
index_t* indices = (index_t*)indices_ptrs[group_id];
3022+
const int64_t* warp_offsets_group,
3023+
const int32_t* num_cols_group,
3024+
const int64_t max_indices,
3025+
const int64_t num_work_rows, // number of rows to work on per member
3026+
const int64_t group_size) {
3027+
const auto total_num_warps = warp_offsets_group[group_size];
3028+
for (int64_t warp_id = threadIdx.y * gridDim.x + blockIdx.x;
3029+
warp_id < total_num_warps;
3030+
warp_id += gridDim.x * blockDim.y) {
3031+
int32_t member_id, member_warp_id, num_cols, warps_per_row;
3032+
if (USE_VAR_COLS) {
3033+
__shared__ int member_ids[kMaxThreads / kWarpSize];
3034+
if (threadIdx.x == 0) {
3035+
binary_search_range(
3036+
&member_ids[threadIdx.y],
3037+
warp_offsets_group + 1,
3038+
warp_id,
3039+
group_size);
3040+
}
3041+
syncwarp();
3042+
member_id = member_ids[threadIdx.y];
3043+
num_cols = num_cols_group[member_id];
3044+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
3045+
member_warp_id = warp_id - warp_offsets_group[member_id];
3046+
} else {
3047+
// All columns are the same
3048+
num_cols = num_cols_group[0];
3049+
warps_per_row = (num_cols + COLS_PER_WARP - 1) >> LOG_COLS_PER_WARP;
3050+
member_id = warp_id / (warps_per_row * num_work_rows);
3051+
member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
3052+
}
3053+
const auto row = member_warp_id / warps_per_row;
3054+
const auto col_offset =
3055+
((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
3056+
(threadIdx.x * UNROLL_FACTOR);
3057+
scalar_t* input =
3058+
reinterpret_cast<scalar_t*>(input_ptrs[member_id]) + col_offset;
3059+
scalar_t* output =
3060+
reinterpret_cast<scalar_t*>(output_ptrs[member_id]) + col_offset;
3061+
index_t* indices = reinterpret_cast<index_t*>(indices_ptrs[member_id]);
30483062
const index_t idx = indices[row];
3049-
CUDA_KERNEL_ASSERT(idx < num_output_rows)
3050-
int col;
3051-
scalar_t* output_ = output + (num_output_rows * num_cols * group_id);
3052-
for (col = threadIdx.x * UNROLL_FACTOR;
3053-
col < num_cols / UNROLL_FACTOR * UNROLL_FACTOR;
3054-
col += blockDim.x * UNROLL_FACTOR) {
3063+
CUDA_KERNEL_ASSERT(idx < max_indices)
30553064
#pragma unroll
3056-
for (int i = 0; i < UNROLL_FACTOR; i++) {
3057-
// PyTorch also uses atomicAdd. It does not require sorting and
3058-
// provides better parallelism. But this can lead to numerical
3059-
// indeterminisim.
3065+
for (int i = 0; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
3066+
// Compile time conditional
3067+
if (USE_INDEX_SELECT) {
3068+
output[row * num_cols + i] = LDG(&input[idx * num_cols + i]);
3069+
} else {
30603070
gpuAtomicAddNoReturn(
3061-
&output_[idx * num_cols + col + i],
3062-
input[row * num_cols + col + i]);
3071+
&output[idx * num_cols + i], input[row * num_cols + i]);
30633072
}
30643073
}
3065-
for (; col < num_cols; ++col) {
3066-
gpuAtomicAddNoReturn(
3067-
&output[idx * num_cols + col], input[row * num_cols + col]);
3068-
}
30693074
}
30703075
}
30713076
3072-
std::vector<Tensor> group_index_select_cuda(
3077+
void group_index_select_or_add_cuda(
30733078
const int64_t* input_ptrs,
3079+
const int64_t* output_ptrs,
30743080
const int64_t* indices_ptrs,
3075-
const c10::TensorOptions& input_tensor_options,
3081+
const int64_t* warp_offsets_group,
3082+
const int32_t* num_cols_group,
30763083
const c10::ScalarType& input_scalar_type,
30773084
const c10::ScalarType& indices_scalar_type,
30783085
const c10::DeviceIndex& device,
3079-
const std::vector<int64_t>& output_shape,
3080-
const int num_input_rows,
3081-
const int num_output_rows,
3082-
const int num_cols,
3083-
const int num_groups) {
3084-
if (num_groups == 0) {
3085-
return std::vector<Tensor>();
3086+
const int max_indices,
3087+
const int num_work_rows,
3088+
const int64_t total_num_warps,
3089+
const int group_size,
3090+
const bool use_index_select,
3091+
const bool use_var_cols) {
3092+
if (group_size == 0) {
3093+
return;
30863094
}
30873095
30883096
at::cuda::OptionalCUDAGuard device_guard;
30893097
device_guard.set_index(device);
30903098
3091-
Tensor output = at::empty(output_shape, input_tensor_options);
3092-
3093-
// Partition work based on num_output_rows
3094-
const int UNROLL_FACTOR = 1;
3099+
// Partition work based on num_work_rows
3100+
uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize;
30953101
uint32_t max_grid_size =
30963102
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
30973103
uint32_t grid_size = std::min(
3098-
cuda_calc_xblock_count(num_groups * num_output_rows, 1), max_grid_size);
3099-
uint32_t block_size_x =
3100-
std::min(div_round_up(num_cols, UNROLL_FACTOR), kMaxThreads);
3101-
uint32_t block_size_y =
3102-
std::max((num_groups * num_output_rows) / grid_size, (uint32_t)1);
3103-
dim3 block_size(
3104-
block_size_x,
3105-
std::min(block_size_y, (uint32_t)(kMaxThreads / block_size_x)),
3106-
1);
3104+
cuda_calc_xblock_count(total_num_warps, num_warps_per_threadblock),
3105+
max_grid_size);
3106+
dim3 block_size(kWarpSize, num_warps_per_threadblock, 1);
3107+
3108+
#define INVOKE_GROUP_INDEX_SELECT_OR_ADD(USE_INDEX_SELECT, USE_VAR_COLS) \
3109+
group_index_select_or_add_2d_kernel< \
3110+
index_t, \
3111+
scalar_t, \
3112+
USE_INDEX_SELECT, \
3113+
USE_VAR_COLS, \
3114+
GROUP_INDEX_SELECT_UNROLL_FACTOR, \
3115+
GROUP_INDEX_SELECT_COLS_PER_WARP, \
3116+
GROUP_INDEX_SELECT_LOG_COLS_PER_WARP> \
3117+
<<<grid_size, block_size, 0, at::cuda::getCurrentCUDAStream()>>>( \
3118+
input_ptrs, \
3119+
output_ptrs, \
3120+
indices_ptrs, \
3121+
warp_offsets_group, \
3122+
num_cols_group, \
3123+
max_indices, \
3124+
num_work_rows, \
3125+
group_size)
31073126
31083127
AT_DISPATCH_INDEX_TYPES(
31093128
indices_scalar_type, "group_index_select_2d_wrapper_1", [&] {
31103129
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
31113130
input_scalar_type, "group_index_select_2d_wrapper_2", [&] {
3112-
group_index_select_2d_kernel<index_t, scalar_t, UNROLL_FACTOR>
3113-
<<<grid_size,
3114-
block_size,
3115-
0,
3116-
at::cuda::getCurrentCUDAStream()>>>(
3117-
input_ptrs,
3118-
indices_ptrs,
3119-
output.data_ptr<scalar_t>(),
3120-
num_input_rows,
3121-
num_output_rows,
3122-
num_cols,
3123-
num_groups);
3131+
if (use_index_select) {
3132+
if (use_var_cols) {
3133+
INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, true);
3134+
} else {
3135+
INVOKE_GROUP_INDEX_SELECT_OR_ADD(true, false);
3136+
}
3137+
} else {
3138+
if (use_var_cols) {
3139+
INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, true);
3140+
} else {
3141+
INVOKE_GROUP_INDEX_SELECT_OR_ADD(false, false);
3142+
}
3143+
}
31243144
C10_CUDA_KERNEL_LAUNCH_CHECK();
31253145
});
31263146
});
31273147
3128-
return output.split(num_output_rows, 0);
3148+
#undef INVOKE_GROUP_INDEX_SELECT_OR_ADD
31293149
}
31303150
3131-
std::vector<Tensor> group_index_add_cuda(
3132-
const int64_t* input_ptrs,
3133-
const int64_t* indices_ptrs,
3134-
const c10::TensorOptions& input_tensor_options,
3135-
const c10::ScalarType& input_scalar_type,
3136-
const c10::ScalarType& indices_scalar_type,
3137-
const c10::DeviceIndex& device,
3138-
const std::vector<int64_t>& output_shape,
3139-
const int num_input_rows,
3140-
const int num_output_rows,
3141-
const int num_cols,
3142-
const int num_groups) {
3143-
if (num_groups == 0) {
3144-
return std::vector<Tensor>();
3145-
}
3146-
3147-
at::cuda::OptionalCUDAGuard device_guard;
3148-
device_guard.set_index(device);
3149-
3150-
Tensor output = at::zeros(output_shape, input_tensor_options);
3151-
3152-
// Partition work based on num_input_rows
3153-
const int UNROLL_FACTOR = 1;
3154-
uint32_t max_grid_size =
3155-
at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 8;
3156-
uint32_t grid_size = std::min(
3157-
cuda_calc_xblock_count(num_groups * num_input_rows, 1), max_grid_size);
3158-
uint32_t block_size_x =
3159-
std::min(div_round_up(num_cols, UNROLL_FACTOR), kMaxThreads);
3160-
uint32_t block_size_y =
3161-
std::max((num_groups * num_input_rows) / grid_size, (uint32_t)1);
3162-
dim3 block_size(
3163-
block_size_x,
3164-
std::min(block_size_y, (uint32_t)(kMaxThreads / block_size_x)),
3165-
1);
3166-
3167-
AT_DISPATCH_INDEX_TYPES(
3168-
indices_scalar_type, "group_index_add_2d_wrapper_1", [&] {
3169-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
3170-
input_scalar_type, "group_index_add_2d_wrapper_2", [&] {
3171-
group_index_add_2d_kernel<index_t, scalar_t, UNROLL_FACTOR>
3172-
<<<grid_size,
3173-
block_size,
3174-
0,
3175-
at::cuda::getCurrentCUDAStream()>>>(
3176-
input_ptrs,
3177-
indices_ptrs,
3178-
output.data_ptr<scalar_t>(),
3179-
num_input_rows,
3180-
num_output_rows,
3181-
num_cols,
3182-
num_groups);
3183-
C10_CUDA_KERNEL_LAUNCH_CHECK();
3184-
});
3185-
});
3186-
3187-
return output.split(num_output_rows, 0);
3188-
}
31893151
// Copied from cupy/random/_kernels.py v11
31903152
// (commit id 420e41fd41157d4cf526b0e94eb86a3f8eb5a231)
31913153

0 commit comments

Comments
 (0)