diff --git a/common/unified/components/format_conversion_kernels.cpp b/common/unified/components/format_conversion_kernels.cpp index 0f54cb04879..f886d434915 100644 --- a/common/unified/components/format_conversion_kernels.cpp +++ b/common/unified/components/format_conversion_kernels.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -7,6 +7,9 @@ #include #include "common/unified/base/kernel_launch.hpp" +#include "common/unified/components/bitvector.hpp" +#include "core/base/index_range.hpp" +#include "core/base/iterator_factory.hpp" #include "core/components/fill_array_kernels.hpp" @@ -21,16 +24,23 @@ void convert_ptrs_to_idxs(std::shared_ptr exec, const RowPtrType* ptrs, size_type num_blocks, IndexType* idxs) { + const auto num_elements = exec->copy_val_to_host(ptrs + num_blocks); + // transform the ptrs to a bitvector in unary delta encoding, i.e. + // every row with n elements is encoded as 1 0 ... n times ... 0 + auto it = gko::detail::make_transform_iterator( + index_iterator{0}, + [ptrs] GKO_KERNEL(IndexType i) -> RowPtrType { return ptrs[i] + i; }); + auto bv = bitvector::from_sorted_indices(exec, it, num_blocks, + num_blocks + num_elements); run_kernel( exec, - [] GKO_KERNEL(auto block, auto ptrs, auto idxs) { - auto begin = ptrs[block]; - auto end = ptrs[block + 1]; - for (auto i = begin; i < end; i++) { - idxs[i] = block; + [] GKO_KERNEL(RowPtrType i, auto bv, auto idxs) { + if (!bv[i]) { + auto rank = bv.get_rank(i); + idxs[i - rank] = rank - 1; } }, - num_blocks, ptrs, idxs); + num_blocks + num_elements, bv.device_view(), idxs); } GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS32); diff --git a/test/components/format_conversion_kernels.cpp b/test/components/format_conversion_kernels.cpp index 217ecd22600..88c1930c82d 100644 --- a/test/components/format_conversion_kernels.cpp +++ b/test/components/format_conversion_kernels.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors // // SPDX-License-Identifier: BSD-3-Clause @@ -10,6 +10,7 @@ #include +#include "core/base/index_range.hpp" #include "core/test/utils.hpp" #include "test/utils/common_fixture.hpp" @@ -67,6 +68,30 @@ TYPED_TEST(FormatConversion, ConvertsEmptyPtrsToIdxs) } +TYPED_TEST(FormatConversion, ConvertPtrsToIdxsImbalanced) +{ + using index_type = typename TestFixture::index_type; + std::vector ptrs{0}; + std::vector idxs; + std::geometric_distribution size_dist{0.01}; + for (auto i : gko::irange{10000}) { + auto count = size_dist(this->rand); + ptrs.push_back(ptrs.back() + count); + idxs.insert(idxs.end(), count, i); + } + gko::array ptr_array{this->exec, ptrs.begin(), ptrs.end()}; + gko::array idx_array{this->exec, idxs.begin(), idxs.end()}; + auto ref_idx_array = idx_array; + idx_array.fill(-1); + + gko::kernels::GKO_DEVICE_NAMESPACE::components::convert_ptrs_to_idxs( + this->exec, ptr_array.get_const_data(), ptrs.size() - 1, + idx_array.get_data()); + + GKO_ASSERT_ARRAY_EQ(idx_array, ref_idx_array); +} + + TYPED_TEST(FormatConversion, ConvertPtrsToIdxs) { auto ref_idxs = this->idxs;