|
14 | 14 | #include <cmath>
|
15 | 15 | #include <type_traits>
|
16 | 16 |
|
17 |
| -#include <ATen/cpu/vec/functional.h> |
18 |
| -#include <ATen/cpu/vec/vec.h> |
| 17 | +#include <ATen/native/cpu/LogSoftmaxKernelImpl.h> |
19 | 18 | #include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
|
20 | 19 | #include <executorch/runtime/kernel/kernel_includes.h>
|
| 20 | +#include <executorch/runtime/kernel/thread_parallel_interface.h> |
21 | 21 |
|
22 | 22 | // `_log_softmax_out` Applies the Log_Softmax function to an n-dimensional input
|
23 | 23 | // Tensor rescaling them so that the elements of the n-dimensional output
|
@@ -51,59 +51,36 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
|
51 | 51 | inner_size *= input.size(i);
|
52 | 52 | }
|
53 | 53 |
|
54 |
| - int64_t dim_stride = inner_size; |
55 |
| - int64_t outer_stride = dim_size * dim_stride; |
56 |
| - |
57 |
| - for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) { |
58 |
| - for (size_t inner_idx = 0; inner_idx < inner_size; ++inner_idx) { |
59 |
| - const IN_T* input_data = |
60 |
| - input_data_base + outer_idx * outer_stride + inner_idx; |
61 |
| - OUT_T* output_data = |
62 |
| - output_data_base + outer_idx * outer_stride + inner_idx; |
63 |
| - |
64 |
| - // calculate max in softmax dim |
65 |
| - IN_T max_input = input_data[0]; |
66 |
| - for (auto d = 0; d < dim_size; ++d) { |
67 |
| - max_input = std::max(max_input, input_data[d * dim_stride]); |
68 |
| - } |
69 |
| - // calculate sum and exponential in softmax dim |
70 |
| - OUT_T temp_sum = 0; |
71 |
| - using VecOut = at::vec::Vectorized<OUT_T>; |
72 |
| - using VecIn = at::vec::Vectorized<IN_T>; |
73 |
| - auto d = 0; |
74 |
| - static_assert(sizeof(IN_T) == sizeof(OUT_T)); |
75 |
| - static_assert( |
76 |
| - std::is_same_v<OUT_T, float>, |
77 |
| - "Below loop actually only supports float."); |
78 |
| - // It is not correct to vectorize if dim is not contiguous! |
79 |
| - if (dim_stride == 1) { |
80 |
| - const VecIn max_input_vec(max_input); |
81 |
| - for (; d + VecOut::size() < dim_size; d += VecOut::size()) { |
82 |
| - auto index = d * dim_stride; |
83 |
| - auto in = VecIn::loadu(&input_data[index]); |
84 |
| - auto out_ = (in - max_input_vec).exp(); |
85 |
| - out_.store(&output_data[index]); |
86 |
| -#if defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) |
87 |
| - temp_sum += vaddvq_f32(out_); |
88 |
| -#else |
89 |
| - temp_sum += at::vec::vec_reduce_all<float>(std::plus<VecOut>(), out_); |
90 |
| -#endif |
91 |
| - } |
92 |
| - } |
93 |
| - for (; d < dim_size; ++d) { |
94 |
| - output_data[d * dim_stride] = |
95 |
| - std::exp(input_data[d * dim_stride] - max_input); |
96 |
| - temp_sum += output_data[d * dim_stride]; |
97 |
| - } |
98 |
| - |
99 |
| - temp_sum = std::log(temp_sum); |
100 |
| - |
101 |
| - for (auto dd = 0; dd < dim_size; ++dd) { |
102 |
| - output_data[dd * dim_stride] = |
103 |
| - input_data[dd * dim_stride] - max_input - temp_sum; |
104 |
| - } |
105 |
| - } |
| 54 | + if (dim == input.dim() - 1) { |
| 55 | + at::native::serial_vec_log_softmax_lastdim_range( |
| 56 | + input_data_base, |
| 57 | + output_data_base, |
| 58 | + dim_size, |
| 59 | + at::native::vec_log_softmax_lastdim_chunk_size<IN_T>( |
| 60 | + executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size), |
| 61 | + // TODO: parallelize. |
| 62 | + 0, |
| 63 | + outer_size); |
| 64 | + } else { |
| 65 | + // BLOCK_SIZE in PyTorch is intended for server CPUs; let's |
| 66 | + // halve it to try and have a better chance of fitting in mobile |
| 67 | + // chip caches. |
| 68 | + const auto [chunk_size, num_chunks] = |
| 69 | + at::native::vec_logsoftmax_chunk_size_and_num_chunks< |
| 70 | + float, |
| 71 | + /*BLOCK_SIZE=*/64 * 1024>(inner_size, dim_size); |
| 72 | + at::native::serial_vec_logsoftmax_range( |
| 73 | + input_data_base, |
| 74 | + output_data_base, |
| 75 | + inner_size, |
| 76 | + chunk_size, |
| 77 | + num_chunks, |
| 78 | + dim_size, |
| 79 | + // TODO: parallelize |
| 80 | + 0, |
| 81 | + outer_size * num_chunks); |
106 | 82 | }
|
| 83 | + return; |
107 | 84 | }
|
108 | 85 |
|
109 | 86 | // OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float
|
|
0 commit comments