Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions common/src/KokkosFFT_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,28 @@

namespace KokkosFFT {
namespace Impl {
template <typename ViewType, std::size_t DIM>
auto get_map_axes(const ViewType& view, axis_type<DIM> axes) {
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(view, axes),
"get_map_axes: input axes are not valid for the view");

template <typename Layout, std::size_t DIM, typename IntType,
std::size_t FFT_DIM>
auto get_map_axes(const std::array<IntType, FFT_DIM>& axes) {
static_assert(std::is_integral_v<IntType>,
"get_map_axes: IntType must be an integral type.");
static_assert(
FFT_DIM >= 1 && FFT_DIM <= DIM,
"get_map_axes: the Rank of FFT axes must be between 1 and View rank");

// Convert the input axes to be in the range of [0, rank-1]
auto non_negative_axes = convert_negative_axes(axes, ViewType::rank());
auto non_negative_axes = convert_negative_axes(axes, DIM);

// how indices are map
// For 5D View and axes are (2,3), map would be (0, 1, 4, 2, 3)
constexpr int rank = static_cast<int>(ViewType::rank());
std::vector<int> map;
constexpr IntType rank = static_cast<IntType>(DIM);
std::vector<IntType> map;
map.reserve(rank);

if (std::is_same_v<typename ViewType::array_layout, Kokkos::LayoutRight>) {
if (std::is_same_v<Layout, Kokkos::LayoutRight>) {
// Stack axes not specified by axes (0, 1, 4)
for (int i = 0; i < rank; i++) {
for (IntType i = 0; i < rank; i++) {
if (!is_found(non_negative_axes, i)) {
map.push_back(i);
}
Expand All @@ -47,23 +52,31 @@ auto get_map_axes(const ViewType& view, axis_type<DIM> axes) {
}

// Then stack remaining axes
for (int i = 0; i < rank; i++) {
for (IntType i = 0; i < rank; i++) {
if (!is_found(non_negative_axes, i)) {
map.push_back(i);
}
}
}

using full_axis_type = axis_type<rank>;
using full_axis_type = std::array<IntType, rank>;
full_axis_type array_map = {}, array_map_inv = {};
std::copy_n(map.begin(), rank, array_map.begin());

// Construct inverse map
for (int i = 0; i < rank; i++) {
for (IntType i = 0; i < rank; i++) {
array_map_inv.at(i) = get_index(array_map, i);
}

return std::tuple<full_axis_type, full_axis_type>({array_map, array_map_inv});
return std::make_tuple(array_map, array_map_inv);
}

template <typename ViewType, std::size_t FFT_DIM>
auto get_map_axes(const ViewType& view, const axis_type<FFT_DIM>& axes) {
KOKKOSFFT_THROW_IF(!KokkosFFT::Impl::are_valid_axes(view, axes),
"get_map_axes: input axes are not valid for the view");
using LayoutType = typename ViewType::array_layout;
return get_map_axes<LayoutType, ViewType::rank()>(axes);
}

template <typename ViewType>
Expand Down
Loading