@@ -2994,198 +2994,160 @@ Tensor index_add_with_unique_indices_cuda(
2994
2994
return input_grad.reshape(input_shape);
2995
2995
}
2996
2996
2997
- template <typename index_t , typename scalar_t , int UNROLL_FACTOR>
2998
- __global__ __launch_bounds__ (kMaxThreads ) void group_index_select_2d_kernel(
2999
- const int64_t * input_ptrs,
3000
- const int64_t * indices_ptrs,
3001
- scalar_t * output,
3002
- const int64_t num_input_rows,
3003
- const int64_t num_output_rows,
3004
- const int64_t num_cols,
3005
- const int64_t num_groups) {
3006
- for (int64_t bid = threadIdx .y * gridDim .x + blockIdx .x ;
3007
- bid < num_groups * num_output_rows;
3008
- bid += gridDim .x * blockDim .y ) {
3009
- const int64_t group_id = bid / num_output_rows;
3010
- const int64_t row = bid % num_output_rows;
3011
- scalar_t * input = (scalar_t *)input_ptrs[group_id];
3012
- index_t * indices = (index_t *)indices_ptrs[group_id];
3013
- const index_t idx = indices[row];
3014
- CUDA_KERNEL_ASSERT (idx < num_input_rows)
3015
- int col;
3016
- scalar_t * output_ = output + (num_output_rows * num_cols * group_id);
3017
- for (col = threadIdx .x * UNROLL_FACTOR;
3018
- col < num_cols / UNROLL_FACTOR * UNROLL_FACTOR;
3019
- col += blockDim .x * UNROLL_FACTOR) {
3020
- #pragma unroll
3021
- for (int i = 0 ; i < UNROLL_FACTOR; i++) {
3022
- output_[row * num_cols + col + i] =
3023
- LDG (&input[idx * num_cols + col + i]);
3024
- }
3025
- }
3026
- for (; col < num_cols; ++col) {
3027
- output_[row * num_cols + col] = LDG (&input[idx * num_cols + col]);
3028
- }
3029
- }
2997
+ // TODO: Update UNROLL_FACTOR
2998
+ constexpr int GROUP_INDEX_SELECT_UNROLL_FACTOR = 1 ;
2999
+ constexpr int GROUP_INDEX_SELECT_COLS_PER_WARP =
3000
+ GROUP_INDEX_SELECT_UNROLL_FACTOR * kWarpSize ;
3001
+ // GROUP_INDEX_SELECT_COLS_PER_WARP must be power of two
3002
+ constexpr int GROUP_INDEX_SELECT_LOG_COLS_PER_WARP =
3003
+ log2_calc<GROUP_INDEX_SELECT_COLS_PER_WARP>::value;
3004
+
3005
+ int get_group_index_select_cols_per_warp () {
3006
+ return GROUP_INDEX_SELECT_COLS_PER_WARP;
3030
3007
}
3031
3008
3032
- template <typename index_t , typename scalar_t , int UNROLL_FACTOR>
3033
- __global__ __launch_bounds__ (kMaxThreads ) void group_index_add_2d_kernel(
3009
+ template <
3010
+ typename index_t ,
3011
+ typename scalar_t ,
3012
+ bool USE_INDEX_SELECT,
3013
+ bool USE_VAR_COLS,
3014
+ int UNROLL_FACTOR,
3015
+ int COLS_PER_WARP,
3016
+ int LOG_COLS_PER_WARP>
3017
+ __global__
3018
+ __launch_bounds__ (kMaxThreads ) void group_index_select_or_add_2d_kernel(
3034
3019
const int64_t * input_ptrs,
3020
+ const int64_t * output_ptrs,
3035
3021
const int64_t * indices_ptrs,
3036
- scalar_t * output,
3037
- const int64_t num_input_rows,
3038
- const int64_t num_output_rows,
3039
- const int64_t num_cols,
3040
- const int64_t num_groups) {
3041
- for (int64_t bid = threadIdx .y * gridDim .x + blockIdx .x ;
3042
- bid < num_groups * num_input_rows;
3043
- bid += gridDim .x * blockDim .y ) {
3044
- const int64_t group_id = bid / num_input_rows;
3045
- const int64_t row = bid % num_input_rows;
3046
- scalar_t * input = (scalar_t *)input_ptrs[group_id];
3047
- index_t * indices = (index_t *)indices_ptrs[group_id];
3022
+ const int64_t * warp_offsets_group,
3023
+ const int32_t * num_cols_group,
3024
+ const int64_t max_indices,
3025
+ const int64_t num_work_rows, // number of rows to work on per member
3026
+ const int64_t group_size) {
3027
+ const auto total_num_warps = warp_offsets_group[group_size];
3028
+ for (int64_t warp_id = threadIdx .y * gridDim .x + blockIdx .x ;
3029
+ warp_id < total_num_warps;
3030
+ warp_id += gridDim .x * blockDim .y ) {
3031
+ int32_t member_id, member_warp_id, num_cols, warps_per_row;
3032
+ if (USE_VAR_COLS) {
3033
+ __shared__ int member_ids[kMaxThreads / kWarpSize ];
3034
+ if (threadIdx .x == 0 ) {
3035
+ binary_search_range (
3036
+ &member_ids[threadIdx .y ],
3037
+ warp_offsets_group + 1 ,
3038
+ warp_id,
3039
+ group_size);
3040
+ }
3041
+ syncwarp ();
3042
+ member_id = member_ids[threadIdx .y ];
3043
+ num_cols = num_cols_group[member_id];
3044
+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
3045
+ member_warp_id = warp_id - warp_offsets_group[member_id];
3046
+ } else {
3047
+ // All columns are the same
3048
+ num_cols = num_cols_group[0 ];
3049
+ warps_per_row = (num_cols + COLS_PER_WARP - 1 ) >> LOG_COLS_PER_WARP;
3050
+ member_id = warp_id / (warps_per_row * num_work_rows);
3051
+ member_warp_id = warp_id - (member_id * warps_per_row * num_work_rows);
3052
+ }
3053
+ const auto row = member_warp_id / warps_per_row;
3054
+ const auto col_offset =
3055
+ ((member_warp_id % warps_per_row) << LOG_COLS_PER_WARP) +
3056
+ (threadIdx .x * UNROLL_FACTOR);
3057
+ scalar_t * input =
3058
+ reinterpret_cast <scalar_t *>(input_ptrs[member_id]) + col_offset;
3059
+ scalar_t * output =
3060
+ reinterpret_cast <scalar_t *>(output_ptrs[member_id]) + col_offset;
3061
+ index_t * indices = reinterpret_cast <index_t *>(indices_ptrs[member_id]);
3048
3062
const index_t idx = indices[row];
3049
- CUDA_KERNEL_ASSERT (idx < num_output_rows)
3050
- int col;
3051
- scalar_t * output_ = output + (num_output_rows * num_cols * group_id);
3052
- for (col = threadIdx .x * UNROLL_FACTOR;
3053
- col < num_cols / UNROLL_FACTOR * UNROLL_FACTOR;
3054
- col += blockDim .x * UNROLL_FACTOR) {
3063
+ CUDA_KERNEL_ASSERT (idx < max_indices)
3055
3064
#pragma unroll
3056
- for (int i = 0 ; i < UNROLL_FACTOR; i++) {
3057
- // PyTorch also uses atomicAdd. It does not require sorting and
3058
- // provides better parallelism. But this can lead to numerical
3059
- // indeterminisim.
3065
+ for (int i = 0 ; i < UNROLL_FACTOR && col_offset + i < num_cols; i++) {
3066
+ // Compile time conditional
3067
+ if (USE_INDEX_SELECT) {
3068
+ output[row * num_cols + i] = LDG (&input[idx * num_cols + i]);
3069
+ } else {
3060
3070
gpuAtomicAddNoReturn (
3061
- &output_[idx * num_cols + col + i],
3062
- input[row * num_cols + col + i]);
3071
+ &output[idx * num_cols + i], input[row * num_cols + i]);
3063
3072
}
3064
3073
}
3065
- for (; col < num_cols; ++col) {
3066
- gpuAtomicAddNoReturn (
3067
- &output[idx * num_cols + col], input[row * num_cols + col]);
3068
- }
3069
3074
}
3070
3075
}
3071
3076
3072
- std::vector<Tensor> group_index_select_cuda (
3077
+ void group_index_select_or_add_cuda (
3073
3078
const int64_t * input_ptrs,
3079
+ const int64_t * output_ptrs,
3074
3080
const int64_t * indices_ptrs,
3075
- const c10::TensorOptions& input_tensor_options,
3081
+ const int64_t * warp_offsets_group,
3082
+ const int32_t * num_cols_group,
3076
3083
const c10::ScalarType& input_scalar_type,
3077
3084
const c10::ScalarType& indices_scalar_type,
3078
3085
const c10::DeviceIndex& device,
3079
- const std::vector<int64_t >& output_shape,
3080
- const int num_input_rows,
3081
- const int num_output_rows,
3082
- const int num_cols,
3083
- const int num_groups) {
3084
- if (num_groups == 0 ) {
3085
- return std::vector<Tensor>();
3086
+ const int max_indices,
3087
+ const int num_work_rows,
3088
+ const int64_t total_num_warps,
3089
+ const int group_size,
3090
+ const bool use_index_select,
3091
+ const bool use_var_cols) {
3092
+ if (group_size == 0 ) {
3093
+ return ;
3086
3094
}
3087
3095
3088
3096
at::cuda::OptionalCUDAGuard device_guard;
3089
3097
device_guard.set_index (device);
3090
3098
3091
- Tensor output = at::empty (output_shape, input_tensor_options);
3092
-
3093
- // Partition work based on num_output_rows
3094
- const int UNROLL_FACTOR = 1 ;
3099
+ // Partition work based on num_work_rows
3100
+ uint32_t num_warps_per_threadblock = kMaxThreads / kWarpSize ;
3095
3101
uint32_t max_grid_size =
3096
3102
at::cuda::getCurrentDeviceProperties ()->multiProcessorCount * 8 ;
3097
3103
uint32_t grid_size = std::min (
3098
- cuda_calc_xblock_count (num_groups * num_output_rows, 1 ), max_grid_size);
3099
- uint32_t block_size_x =
3100
- std::min (div_round_up (num_cols, UNROLL_FACTOR), kMaxThreads );
3101
- uint32_t block_size_y =
3102
- std::max ((num_groups * num_output_rows) / grid_size, (uint32_t )1 );
3103
- dim3 block_size (
3104
- block_size_x,
3105
- std::min (block_size_y, (uint32_t )(kMaxThreads / block_size_x)),
3106
- 1 );
3104
+ cuda_calc_xblock_count (total_num_warps, num_warps_per_threadblock),
3105
+ max_grid_size);
3106
+ dim3 block_size (kWarpSize , num_warps_per_threadblock, 1 );
3107
+
3108
+ #define INVOKE_GROUP_INDEX_SELECT_OR_ADD (USE_INDEX_SELECT, USE_VAR_COLS ) \
3109
+ group_index_select_or_add_2d_kernel< \
3110
+ index_t , \
3111
+ scalar_t , \
3112
+ USE_INDEX_SELECT, \
3113
+ USE_VAR_COLS, \
3114
+ GROUP_INDEX_SELECT_UNROLL_FACTOR, \
3115
+ GROUP_INDEX_SELECT_COLS_PER_WARP, \
3116
+ GROUP_INDEX_SELECT_LOG_COLS_PER_WARP> \
3117
+ <<<grid_size, block_size, 0 , at::cuda::getCurrentCUDAStream()>>> ( \
3118
+ input_ptrs, \
3119
+ output_ptrs, \
3120
+ indices_ptrs, \
3121
+ warp_offsets_group, \
3122
+ num_cols_group, \
3123
+ max_indices, \
3124
+ num_work_rows, \
3125
+ group_size)
3107
3126
3108
3127
AT_DISPATCH_INDEX_TYPES (
3109
3128
indices_scalar_type, " group_index_select_2d_wrapper_1" , [&] {
3110
3129
AT_DISPATCH_FLOATING_TYPES_AND_HALF (
3111
3130
input_scalar_type, " group_index_select_2d_wrapper_2" , [&] {
3112
- group_index_select_2d_kernel<index_t , scalar_t , UNROLL_FACTOR>
3113
- <<<grid_size,
3114
- block_size,
3115
- 0 ,
3116
- at::cuda::getCurrentCUDAStream ()>>>(
3117
- input_ptrs,
3118
- indices_ptrs,
3119
- output.data_ptr<scalar_t >(),
3120
- num_input_rows,
3121
- num_output_rows,
3122
- num_cols,
3123
- num_groups);
3131
+ if (use_index_select) {
3132
+ if (use_var_cols) {
3133
+ INVOKE_GROUP_INDEX_SELECT_OR_ADD (true , true );
3134
+ } else {
3135
+ INVOKE_GROUP_INDEX_SELECT_OR_ADD (true , false );
3136
+ }
3137
+ } else {
3138
+ if (use_var_cols) {
3139
+ INVOKE_GROUP_INDEX_SELECT_OR_ADD (false , true );
3140
+ } else {
3141
+ INVOKE_GROUP_INDEX_SELECT_OR_ADD (false , false );
3142
+ }
3143
+ }
3124
3144
C10_CUDA_KERNEL_LAUNCH_CHECK ();
3125
3145
});
3126
3146
});
3127
3147
3128
- return output.split(num_output_rows, 0 );
3148
+ # undef INVOKE_GROUP_INDEX_SELECT_OR_ADD
3129
3149
}
3130
3150
3131
- std::vector<Tensor> group_index_add_cuda (
3132
- const int64_t * input_ptrs,
3133
- const int64_t * indices_ptrs,
3134
- const c10::TensorOptions& input_tensor_options,
3135
- const c10::ScalarType& input_scalar_type,
3136
- const c10::ScalarType& indices_scalar_type,
3137
- const c10::DeviceIndex& device,
3138
- const std::vector<int64_t >& output_shape,
3139
- const int num_input_rows,
3140
- const int num_output_rows,
3141
- const int num_cols,
3142
- const int num_groups) {
3143
- if (num_groups == 0 ) {
3144
- return std::vector<Tensor>();
3145
- }
3146
-
3147
- at::cuda::OptionalCUDAGuard device_guard;
3148
- device_guard.set_index (device);
3149
-
3150
- Tensor output = at::zeros (output_shape, input_tensor_options);
3151
-
3152
- // Partition work based on num_input_rows
3153
- const int UNROLL_FACTOR = 1 ;
3154
- uint32_t max_grid_size =
3155
- at::cuda::getCurrentDeviceProperties ()->multiProcessorCount * 8 ;
3156
- uint32_t grid_size = std::min (
3157
- cuda_calc_xblock_count (num_groups * num_input_rows, 1 ), max_grid_size);
3158
- uint32_t block_size_x =
3159
- std::min (div_round_up (num_cols, UNROLL_FACTOR), kMaxThreads );
3160
- uint32_t block_size_y =
3161
- std::max ((num_groups * num_input_rows) / grid_size, (uint32_t )1 );
3162
- dim3 block_size (
3163
- block_size_x,
3164
- std::min (block_size_y, (uint32_t )(kMaxThreads / block_size_x)),
3165
- 1 );
3166
-
3167
- AT_DISPATCH_INDEX_TYPES (
3168
- indices_scalar_type, " group_index_add_2d_wrapper_1" , [&] {
3169
- AT_DISPATCH_FLOATING_TYPES_AND_HALF (
3170
- input_scalar_type, " group_index_add_2d_wrapper_2" , [&] {
3171
- group_index_add_2d_kernel<index_t , scalar_t , UNROLL_FACTOR>
3172
- <<<grid_size,
3173
- block_size,
3174
- 0 ,
3175
- at::cuda::getCurrentCUDAStream ()>>>(
3176
- input_ptrs,
3177
- indices_ptrs,
3178
- output.data_ptr<scalar_t >(),
3179
- num_input_rows,
3180
- num_output_rows,
3181
- num_cols,
3182
- num_groups);
3183
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
3184
- });
3185
- });
3186
-
3187
- return output.split(num_output_rows, 0 );
3188
- }
3189
3151
// Copied from cupy/random/_kernels.py v11
3190
3152
// (commit id 420e41fd41157d4cf526b0e94eb86a3f8eb5a231)
3191
3153
0 commit comments