@@ -52,33 +52,49 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
52
52
}
53
53
54
54
if (dim == input.dim () - 1 ) {
55
- at::native::serial_vec_log_softmax_lastdim_range (
56
- input_data_base,
57
- output_data_base,
58
- dim_size,
59
- at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
60
- executorch::extension::internal::GRAIN_SIZE, outer_size, dim_size),
61
- // TODO: parallelize.
55
+ ::executorch::extension::parallel_for (
62
56
0 ,
63
- outer_size);
57
+ outer_size,
58
+ ::executorch::extension::internal::GRAIN_SIZE,
59
+ [&](const auto begin, const auto end) {
60
+ at::native::serial_vec_log_softmax_lastdim_range (
61
+ input_data_base,
62
+ output_data_base,
63
+ dim_size,
64
+ at::native::vec_log_softmax_lastdim_chunk_size<IN_T>(
65
+ executorch::extension::internal::GRAIN_SIZE,
66
+ outer_size,
67
+ dim_size),
68
+ begin,
69
+ end);
70
+ });
64
71
} else {
65
72
// BLOCK_SIZE in PyTorch is intended for server CPUs; let's
66
73
// halve it to try and have a better chance of fitting in mobile
67
74
// chip caches.
68
- const auto [chunk_size, num_chunks ] =
75
+ const auto [chunk_size_binding, num_chunks_binding ] =
69
76
at::native::vec_logsoftmax_chunk_size_and_num_chunks<
70
77
float ,
71
78
/* BLOCK_SIZE=*/ 64 * 1024 >(inner_size, dim_size);
72
- at::native::serial_vec_logsoftmax_range (
73
- input_data_base,
74
- output_data_base,
75
- inner_size,
76
- chunk_size,
77
- num_chunks,
78
- dim_size,
79
- // TODO: parallelize
79
+ // Work around "capturing a structured binding is not yet supported in
80
+ // OpenMP".
81
+ const auto chunk_size = chunk_size_binding;
82
+ const auto num_chunks = num_chunks_binding;
83
+ ::executorch::extension::parallel_for (
80
84
0 ,
81
- outer_size * num_chunks);
85
+ outer_size * num_chunks,
86
+ ::executorch::extension::internal::GRAIN_SIZE,
87
+ [&](const auto begin, const auto end) {
88
+ at::native::serial_vec_logsoftmax_range (
89
+ input_data_base,
90
+ output_data_base,
91
+ inner_size,
92
+ chunk_size,
93
+ num_chunks,
94
+ dim_size,
95
+ begin,
96
+ end);
97
+ });
82
98
}
83
99
return ;
84
100
}
0 commit comments