Skip to content

Commit ae6235b

Browse files
sryapfacebook-github-bot
authored andcommitted
Add support for group size > 54 in group_index_select (#1611)
Summary: Pull Request resolved: #1611 If group size is larger than 54, internally breaks the group down into smaller groups (each subgroup size is less than or equal to 54). Reviewed By: jianyuh Differential Revision: D43585937 fbshipit-source-id: bf14eeb79881a5737dcf7660e3e0f56d21f7b326
1 parent da01a59 commit ae6235b

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

fbgemm_gpu/src/sparse_ops_gpu.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -500,12 +500,41 @@ Tensor index_select_dim0_gpu(
500500
std::vector<Tensor> group_index_select_dim0_gpu(
501501
const std::vector<Tensor>& input_group,
502502
const std::vector<Tensor>& indices_group) {
503+
const auto group_size = input_group.size();
503504
std::vector<Tensor> output_group;
504-
apply_(
505-
[&](auto&&... args) {
506-
output_group = GroupIndexSelectDim0GPUOp::apply(indices_group, args...);
507-
},
508-
input_group);
505+
// We use the APPLY_AUTOGRAD_FN macros to instantiate
506+
// GroupIndexSelectDim0GPUOp for different group sizes. We only instantiate
507+
// up to group size of 54.
508+
constexpr size_t max_group_size = 54;
509+
// Specialize this path to avoid copy
510+
if (group_size <= max_group_size) {
511+
apply_(
512+
[&](auto&&... args) {
513+
output_group =
514+
GroupIndexSelectDim0GPUOp::apply(indices_group, args...);
515+
},
516+
input_group);
517+
return output_group;
518+
}
519+
520+
const auto input_itr = input_group.begin();
521+
const auto indices_itr = indices_group.begin();
522+
523+
for (size_t start = 0; start < group_size; start += max_group_size) {
524+
const auto end = std::min(start + max_group_size, group_size);
525+
std::vector<Tensor> input_subgroup(input_itr + start, input_itr + end);
526+
std::vector<Tensor> indices_subgroup(
527+
indices_itr + start, indices_itr + end);
528+
std::vector<Tensor> output_subgroup;
529+
apply_(
530+
[&](auto&&... args) {
531+
output_subgroup =
532+
GroupIndexSelectDim0GPUOp::apply(indices_subgroup, args...);
533+
},
534+
input_subgroup);
535+
output_group.insert(
536+
output_group.end(), output_subgroup.begin(), output_subgroup.end());
537+
}
509538
return output_group;
510539
}
511540

0 commit comments

Comments
 (0)