Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion common/cuda_hip/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -24,6 +24,7 @@
#include "common/cuda_hip/components/reduction.hpp"
#include "common/cuda_hip/components/thread_ids.hpp"
#include "common/cuda_hip/components/uninitialized_array.hpp"
#include "core/base/mixed_precision_types.hpp"
#include "core/base/utils.hpp"
#include "core/components/prefix_sum_kernels.hpp"

Expand Down
1 change: 1 addition & 0 deletions common/unified/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(UNIFIED_SOURCES
matrix/ell_kernels.cpp
matrix/hybrid_kernels.cpp
matrix/permutation_kernels.cpp
matrix/row_scatterer.cpp
matrix/scaled_permutation_kernels.cpp
matrix/sellp_kernels.cpp
matrix/sparsity_csr_kernels.cpp
Expand Down
60 changes: 60 additions & 0 deletions common/unified/matrix/row_scatterer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "common/unified/base/kernel_launch.hpp"
#include "core/base/mixed_precision_types.hpp"
#include "core/matrix/row_scatterer_kernels.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace row_scatter {


template <typename ValueType, typename OutputType, typename IndexType>
void row_scatter(std::shared_ptr<const DefaultExecutor> exec,
const array<IndexType>* row_idxs,
const matrix::Dense<ValueType>* orig,
matrix::Dense<OutputType>* target, bool& invalid_access)
{
array<bool> invalid_access_arr{exec, {false}};
run_kernel(
exec,
[num_rows = target->get_size()[0]] GKO_KERNEL(
auto row, auto col, auto orig, auto rows, auto scattered,
auto* invalid_access_ptr) {
if (rows[row] >= num_rows) {
*invalid_access_ptr = true;
return;
}
scattered(rows[row], col) = orig(row, col);
},
dim<2>{row_idxs->get_size(), orig->get_size()[1]}, orig, *row_idxs,
target, invalid_access_arr.get_data());
invalid_access = exec->copy_val_to_host(invalid_access_arr.get_data());
}

GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
GKO_DECLARE_ROW_SCATTER_SIMPLE_APPLY);


template <typename ValueType, typename OutputType, typename IndexType>
void advanced_row_scatter(std::shared_ptr<const DefaultExecutor> exec,
const array<IndexType>* row_idxs,
const matrix::Dense<ValueType>* alpha,
const matrix::Dense<ValueType>* orig,
const matrix::Dense<OutputType>* beta,
matrix::Dense<OutputType>* target,
bit_packed_span<bool, IndexType, uint32> mask,
bool& invalid_access) GKO_NOT_IMPLEMENTED;

GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE_2(
GKO_DECLARE_ROW_SCATTER_ADVANCED_APPLY);


} // namespace row_scatter
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ target_sources(
matrix/identity.cpp
matrix/permutation.cpp
matrix/row_gatherer.cpp
matrix/row_scatterer.cpp
matrix/scaled_permutation.cpp
matrix/sellp.cpp
matrix/sparsity_csr.cpp
Expand Down
12 changes: 12 additions & 0 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <type_traits>

#include <core/matrix/row_scatterer_kernels.hpp>

#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/base/types.hpp>

Expand Down Expand Up @@ -1114,6 +1116,16 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_IMPLICIT_RESIDUAL_NORM_KERNEL);


} // namespace implicit_residual_norm


namespace row_scatter {


GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2(GKO_DECLARE_ROW_SCATTER_SIMPLE_APPLY);
GKO_STUB_MIXED_VALUE_AND_INDEX_TYPE_2(GKO_DECLARE_ROW_SCATTER_ADVANCED_APPLY);


} // namespace row_scatter
} // namespace GKO_HOOK_MODULE
} // namespace kernels
} // namespace gko
52 changes: 51 additions & 1 deletion core/matrix/dense.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -22,6 +22,7 @@
#include <ginkgo/core/matrix/fbcsr.hpp>
#include <ginkgo/core/matrix/hybrid.hpp>
#include <ginkgo/core/matrix/permutation.hpp>
#include <ginkgo/core/matrix/row_scatterer.hpp>
#include <ginkgo/core/matrix/scaled_permutation.hpp>
#include <ginkgo/core/matrix/sellp.hpp>
#include <ginkgo/core/matrix/sparsity_csr.hpp>
Expand Down Expand Up @@ -1310,6 +1311,19 @@ void Dense<ValueType>::row_gather_impl(const Dense<ValueType>* alpha,
}


template <typename IndexType>
size_type get_size(const array<IndexType>* arr)
{
return arr->get_size();
}

template <typename IndexType>
size_type get_size(const index_set<IndexType>* is)
{
return is->get_num_elems();
}


template <typename ValueType>
std::unique_ptr<LinOp> Dense<ValueType>::permute(
const array<int32>* permutation_indices) const
Expand Down Expand Up @@ -1613,6 +1627,42 @@ void Dense<ValueType>::row_gather(ptr_param<const LinOp> alpha,
}


template <typename ValueType>
void Dense<ValueType>::row_scatter(
ptr_param<const RowScatterer<int32>> scatterer, ptr_param<LinOp> target)
{
scatterer->apply(this, target);
}


template <typename ValueType>
void Dense<ValueType>::row_scatter(
ptr_param<const LinOp> alpha,
ptr_param<const RowScatterer<int32>> scatterer, ptr_param<const LinOp> beta,
ptr_param<LinOp> target)
{
scatterer->apply(alpha, this, beta, target);
}


template <typename ValueType>
void Dense<ValueType>::row_scatter(
ptr_param<const RowScatterer<int64>> scatterer, ptr_param<LinOp> target)
{
scatterer->apply(this, target);
}


template <typename ValueType>
void Dense<ValueType>::row_scatter(
ptr_param<const LinOp> alpha,
ptr_param<const RowScatterer<int64>> scatterer, ptr_param<const LinOp> beta,
ptr_param<LinOp> target)
{
scatterer->apply(alpha, this, beta, target);
}


template <typename ValueType>
std::unique_ptr<LinOp> Dense<ValueType>::column_permute(
const array<int32>* permutation_indices) const
Expand Down
3 changes: 2 additions & 1 deletion core/matrix/dense_kernels.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -8,6 +8,7 @@

#include <memory>

#include <ginkgo/core/base/index_set.hpp>
#include <ginkgo/core/base/math.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/matrix/dense.hpp>
Expand Down
124 changes: 124 additions & 0 deletions core/matrix/row_scatterer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// SPDX-FileCopyrightText: 2024 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "ginkgo/core/matrix/row_scatterer.hpp"

#include <ginkgo/core/base/precision_dispatch.hpp>
#include <ginkgo/core/matrix/dense.hpp>

#include "core/base/dispatch_helper.hpp"
#include "core/components/bit_packed_storage.hpp"
#include "core/matrix/row_scatterer_kernels.hpp"

namespace gko {
namespace matrix {
namespace {


GKO_REGISTER_OPERATION(row_scatter, row_scatter::row_scatter);
GKO_REGISTER_OPERATION(advanced_row_scatter, row_scatter::advanced_row_scatter);


} // namespace


template <typename IndexType>
std::unique_ptr<RowScatterer<IndexType>> RowScatterer<IndexType>::create(
std::shared_ptr<const Executor> exec, array<IndexType> idxs,
size_type to_size)
{
return std::unique_ptr<RowScatterer>(
new RowScatterer(std::move(exec), std::move(idxs), to_size));
}


template <typename IndexType>
RowScatterer<IndexType>::RowScatterer(std::shared_ptr<const Executor> exec)
: EnableLinOp<RowScatterer<IndexType>>(std::move(exec))
{}


template <typename IndexType>
RowScatterer<IndexType>::RowScatterer(std::shared_ptr<const Executor> exec,
array<IndexType> idxs, size_type to_size)
: EnableLinOp<RowScatterer<IndexType>>(exec, {to_size, idxs.get_size()}),
idxs_(exec, std::move(idxs)),
mask_(exec,
bit_packed_span<bool, IndexType, uint32>::storage_size(to_size, 1))
{}


template <typename IndexType>
void RowScatterer<IndexType>::apply_impl(const LinOp* b, LinOp* x) const
{
auto impl = [&](const auto* orig, auto* target) {
auto exec = this->get_executor();
bool invalid_access = false;

exec->run(make_row_scatter(&idxs_, orig, target, invalid_access));

// TODO: find a uniform way to handle device-side errors
if (invalid_access) {
GKO_INVALID_STATE("Out-of-bounds scatter index detected.");
}
};

run<Dense,
#if GINKGO_ENABLE_HALF
gko::half, std::complex<gko::half>,
#endif
float, double, std::complex<float>, std::complex<double>>(
b, [&](auto* orig) {
using value_type =
typename std::decay_t<decltype(*orig)>::value_type;
mixed_precision_dispatch_real_complex<value_type>(impl, orig, x);
});
}


template <typename IndexType>
void RowScatterer<IndexType>::apply_impl(const LinOp* alpha, const LinOp* b,
const LinOp* beta, LinOp* x) const
{
auto impl = [&](const auto* orig, auto* target) {
auto exec = this->get_executor();
bool invalid_access = false;

auto dense_alpha = make_temporary_conversion<
typename std::decay_t<decltype(*orig)>::value_type>(alpha);
auto dense_beta = make_temporary_conversion<
typename std::decay_t<decltype(*target)>::value_type>(beta);

exec->run(make_advanced_row_scatter(
&idxs_, dense_alpha.get(), orig, dense_beta.get(), target,
bit_packed_span<bool, IndexType, uint32>(mask_.get_data(), 1,
this->get_size()[0]),
invalid_access));

if (invalid_access) {
GKO_INVALID_STATE("Out-of-bounds scatter index detected.");
}
};

mask_.fill(uint32{});

run<Dense,
#if GINKGO_ENABLE_HALF
gko::half, std::complex<gko::half>,
#endif
float, double, std::complex<float>, std::complex<double>>(
b, [&](auto* orig) {
using value_type =
typename std::decay_t<decltype(*orig)>::value_type;
mixed_precision_dispatch_real_complex<value_type>(impl, orig, x);
});
}


#define GKO_DECLARE_ROW_SCATTER(_type) class RowScatterer<_type>
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_ROW_SCATTER);


} // namespace matrix
} // namespace gko
47 changes: 47 additions & 0 deletions core/matrix/row_scatterer_kernels.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#pragma once


#include <ginkgo/core/matrix/dense.hpp>

#include "core/base/kernel_declaration.hpp"
#include "core/components/bit_packed_storage.hpp"


namespace gko {
namespace kernels {

#define GKO_DECLARE_ROW_SCATTER_SIMPLE_APPLY(_vtype, _otype, _itype) \
void row_scatter(std::shared_ptr<const DefaultExecutor> exec, \
const array<_itype>* gather_indices, \
const matrix::Dense<_vtype>* orig, \
matrix::Dense<_otype>* target, bool& invalid_access)

#define GKO_DECLARE_ROW_SCATTER_ADVANCED_APPLY(_vtype, _otype, _itype) \
void advanced_row_scatter( \
std::shared_ptr<const DefaultExecutor> exec, \
const array<_itype>* row_idxs, const matrix::Dense<_vtype>* alpha, \
const matrix::Dense<_vtype>* orig, const matrix::Dense<_otype>* beta, \
matrix::Dense<_otype>* target, \
bit_packed_span<bool, _itype, uint32> mask, bool& invalid_access)


#define GKO_DECLARE_ALL_AS_TEMPLATES \
template <typename ValueType, typename OutputType, typename IndexType> \
GKO_DECLARE_ROW_SCATTER_SIMPLE_APPLY(ValueType, OutputType, IndexType); \
template <typename ValueType, typename OutputType, typename IndexType> \
GKO_DECLARE_ROW_SCATTER_ADVANCED_APPLY(ValueType, OutputType, IndexType)


GKO_DECLARE_FOR_ALL_EXECUTOR_NAMESPACES(row_scatter,
GKO_DECLARE_ALL_AS_TEMPLATES);


#undef GKO_DECLARE_ALL_AS_TEMPLATES


} // namespace kernels
} // namespace gko
Loading
Loading