Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions common/unified/components/format_conversion_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -7,6 +7,9 @@
#include <ginkgo/core/base/types.hpp>

#include "common/unified/base/kernel_launch.hpp"
#include "common/unified/components/bitvector.hpp"
#include "core/base/index_range.hpp"
#include "core/base/iterator_factory.hpp"
#include "core/components/fill_array_kernels.hpp"


Expand All @@ -21,16 +24,23 @@ void convert_ptrs_to_idxs(std::shared_ptr<const DefaultExecutor> exec,
const RowPtrType* ptrs, size_type num_blocks,
IndexType* idxs)
{
const auto num_elements = exec->copy_val_to_host(ptrs + num_blocks);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we just paste the idxs size from outside. We should have it already when allocating the idxs array

// transform the ptrs to a bitvector in unary delta encoding, i.e.
// every row with n elements is encoded as 1 0 ... n times ... 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// every row with n elements is encoded as 1 0 ... n times ... 0
// every row with n elements is encoded as 1 0 ... n times ... 0
// we only process the value when bv is zero, the prefix-sum of bv - 1, which is get_rank() - 1, is the row index.

auto it = gko::detail::make_transform_iterator(
index_iterator<IndexType>{0},
[ptrs] GKO_KERNEL(IndexType i) -> RowPtrType { return ptrs[i] + i; });
auto bv = bitvector::from_sorted_indices(exec, it, num_blocks,
num_blocks + num_elements);
run_kernel(
exec,
[] GKO_KERNEL(auto block, auto ptrs, auto idxs) {
auto begin = ptrs[block];
auto end = ptrs[block + 1];
for (auto i = begin; i < end; i++) {
idxs[i] = block;
[] GKO_KERNEL(RowPtrType i, auto bv, auto idxs) {
if (!bv[i]) {
auto rank = bv.get_rank(i);
idxs[i - rank] = rank - 1;
}
},
num_blocks, ptrs, idxs);
num_blocks + num_elements, bv.device_view(), idxs);
}

GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_CONVERT_PTRS_TO_IDXS32);
Expand Down
27 changes: 26 additions & 1 deletion test/components/format_conversion_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -10,6 +10,7 @@

#include <gtest/gtest.h>

#include "core/base/index_range.hpp"
#include "core/test/utils.hpp"
#include "test/utils/common_fixture.hpp"

Expand Down Expand Up @@ -67,6 +68,30 @@ TYPED_TEST(FormatConversion, ConvertsEmptyPtrsToIdxs)
}


TYPED_TEST(FormatConversion, ConvertPtrsToIdxsImbalanced)
{
using index_type = typename TestFixture::index_type;
std::vector<index_type> ptrs{0};
std::vector<index_type> idxs;
std::geometric_distribution<int> size_dist{0.01};
for (auto i : gko::irange{10000}) {
auto count = size_dist(this->rand);
ptrs.push_back(ptrs.back() + count);
idxs.insert(idxs.end(), count, i);
}
gko::array<index_type> ptr_array{this->exec, ptrs.begin(), ptrs.end()};
gko::array<index_type> idx_array{this->exec, idxs.begin(), idxs.end()};
auto ref_idx_array = idx_array;
idx_array.fill(-1);

gko::kernels::GKO_DEVICE_NAMESPACE::components::convert_ptrs_to_idxs(
this->exec, ptr_array.get_const_data(), ptrs.size() - 1,
idx_array.get_data());

GKO_ASSERT_ARRAY_EQ(idx_array, ref_idx_array);
}


TYPED_TEST(FormatConversion, ConvertPtrsToIdxs)
{
auto ref_idxs = this->idxs;
Expand Down
Loading