Skip to content

Commit de24e18

Browse files
pytorchbotswolchok
andauthored
Use shared log_softmax kernels from PyTorch (#12172)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #12098 by @swolchok ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/swolchok/485/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/swolchok/485/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/swolchok/485/orig @diff-train-skip-merge Co-authored-by: Scott Wolchok <swolchok@meta.com>
1 parent 9409774 commit de24e18

File tree

2 files changed

+32
-54
lines changed

2 files changed

+32
-54
lines changed

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 31 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
#include <cmath>
1515
#include <type_traits>
1616

17-
#include <ATen/cpu/vec/functional.h>
18-
#include <ATen/cpu/vec/vec.h>
17+
#include <ATen/native/cpu/LogSoftmaxKernelImpl.h>
1918
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
2019
#include <executorch/runtime/kernel/kernel_includes.h>
20+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
2121

2222
// `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input
2323
// Tensor rescaling them so that the elements of the n-dimensional output
@@ -51,59 +51,36 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
5151
inner_size *= input.size(i);
5252
}
5353

54-
int64_t dim_stride = inner_size;
55-
int64_t outer_stride = dim_size * dim_stride;
56-
57-
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
58-
for (size_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) {
59-
const IN_T* input_data =
60-
input_data_base + outer_idx * outer_stride + inner_idx;
61-
OUT_T* output_data =
62-
output_data_base + outer_idx * outer_stride + inner_idx;
63-
64-
// calculate max in softmax dim
65-
IN_T max_input = input_data[0];
66-
for (auto d = 0; d < dim_size; ++d) {
67-
max_input = std::max(max_input, input_data[d * dim_stride]);
68-
}
69-
// calculate sum and exponential in softmax dim
70-
OUT_T temp_sum = 0;
71-
using VecOut = at::vec::Vectorized<OUT_T>;
72-
using VecIn = at::vec::Vectorized<IN_T>;
73-
auto d = 0;
74-
static_assert(sizeof(IN_T) == sizeof(OUT_T));
75-
static_assert(
76-
std::is_same_v<OUT_T, float>,
77-
"Below loop actually only supports float.");
78-
// It is not correct to vectorize if dim is not contiguous!
79-
if (dim_stride == 1) {
80-
const VecIn max_input_vec(max_input);
81-
for (; d + VecOut::size() < dim_size; d += VecOut::size()) {
82-
auto index = d * dim_stride;
83-
auto in = VecIn::loadu(&input_data[index]);
84-
auto out_ = (in - max_input_vec).exp();
85-
out_.store(&output_data[index]);
86-
#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE)
87-
temp_sum += vaddvq_f32(out_);
88-
#else
89-
temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_);
90-
#endif
91-
}
92-
}
93-
for (; d < dim_size; ++d) {
94-
output_data[d * dim_stride] =
95-
std::exp(input_data[d * dim_stride] - max_input);
96-
temp_sum += output_data[d * dim_stride];
97-
}
98-
99-
temp_sum = std::log(temp_sum);
100-
101-
for (auto dd = 0; dd < dim_size; ++dd) {
102-
output_data[dd * dim_stride] =
103-
input_data[dd * dim_stride] - max_input - temp_sum;
104-
}
105-
}
54+
if (dim == input.dim() - 1) {
55+
at::native::serial_vec_log_softmax_lastdim_range(
56+
input_data_base,
57+
output_data_base,
58+
dim_size,
59+
at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
60+
executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size),
61+
// TODO: parallelize.
62+
0,
63+
outer_size);
64+
} else {
65+
// BLOCK_SIZE in PyTorch is intended for server CPUs; let's
66+
// halve it to try and have a better chance of fitting in mobile
67+
// chip caches.
68+
const auto [chunk_size, num_chunks] =
69+
at::native::vec_logsoftmax_chunk_size_and_num_chunks<
70+
float,
71+
/*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size);
72+
at::native::serial_vec_logsoftmax_range(
73+
input_data_base,
74+
output_data_base,
75+
inner_size,
76+
chunk_size,
77+
num_chunks,
78+
dim_size,
79+
// TODO: parallelize
80+
0,
81+
outer_size * num_chunks);
10682
}
83+
return;
10784
}
10885

10986
// OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float

shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ OPTIMIZED_ATEN_OPS = (
230230
op_target(
231231
name = "op_log_softmax",
232232
deps = [
233+
"//executorch/extension/threadpool:threadpool",
233234
"//executorch/kernels/portable/cpu/util:activation_ops_util",
234235
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
235236
],

0 commit comments

Comments
 (0)