Skip to content

Commit c778063

Browse files
authored
Manual cherry-pick: Parallelize optimized op_log_softmax (pytorch#12246)
This landed internally and PR pytorch#12099 closed, but the bot couldn't pick it. Here's a manual pick. Differential Revision: D76831122
1 parent 71522c4 commit c778063

File tree

1 file changed

+34
-18
lines changed

1 file changed

+34
-18
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,33 +52,49 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
5252
}
5353

5454
if (dim == input.dim() - 1) {
55-
at::native::serial_vec_log_softmax_lastdim_range(
56-
input_data_base,
57-
output_data_base,
58-
dim_size,
59-
at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
60-
executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size),
61-
// TODO: parallelize.
55+
::executorch::extension::parallel_for(
6256
0,
63-
outer_size);
57+
outer_size,
58+
::executorch::extension::internal::GRAIN_SIZE,
59+
[&](const auto begin, const auto end) {
60+
at::native::serial_vec_log_softmax_lastdim_range(
61+
input_data_base,
62+
output_data_base,
63+
dim_size,
64+
at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
65+
executorch::extension::internal::GRAIN_SIZE,
66+
outer_size,
67+
dim_size),
68+
begin,
69+
end);
70+
});
6471
} else {
6572
// BLOCK_SIZE in PyTorch is intended for server CPUs; let's
6673
// halve it to try and have a better chance of fitting in mobile
6774
// chip caches.
68-
const auto [chunk_size, num_chunks] =
75+
const auto [chunk_size_binding, num_chunks_binding] =
6976
at::native::vec_logsoftmax_chunk_size_and_num_chunks<
7077
float,
7178
/*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size);
72-
at::native::serial_vec_logsoftmax_range(
73-
input_data_base,
74-
output_data_base,
75-
inner_size,
76-
chunk_size,
77-
num_chunks,
78-
dim_size,
79-
// TODO: parallelize
79+
// Work around "capturing a structured binding is not yet supported in
80+
// OpenMP".
81+
const auto chunk_size = chunk_size_binding;
82+
const auto num_chunks = num_chunks_binding;
83+
::executorch::extension::parallel_for(
8084
0,
81-
outer_size * num_chunks);
85+
outer_size * num_chunks,
86+
::executorch::extension::internal::GRAIN_SIZE,
87+
[&](const auto begin, const auto end) {
88+
at::native::serial_vec_logsoftmax_range(
89+
input_data_base,
90+
output_data_base,
91+
inner_size,
92+
chunk_size,
93+
num_chunks,
94+
dim_size,
95+
begin,
96+
end);
97+
});
8298
}
8399
return;
84100
}

0 commit comments

Comments
 (0)