Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/unified/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ set(UNIFIED_SOURCES
solver/bicgstab_kernels.cpp
solver/cg_kernels.cpp
solver/cgs_kernels.cpp
solver/chebyshev_kernels.cpp
solver/common_gmres_kernels.cpp
solver/fcg_kernels.cpp
solver/gcr_kernels.cpp
Expand Down
114 changes: 114 additions & 0 deletions common/unified/solver/chebyshev_kernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "core/solver/chebyshev_kernels.hpp"

#include <type_traits>

#include <ginkgo/core/base/std_extensions.hpp>
#include <ginkgo/core/matrix/dense.hpp>
#include <ginkgo/core/solver/chebyshev.hpp>

#include "common/unified/base/kernel_launch.hpp"


namespace gko {
namespace kernels {
namespace GKO_DEVICE_NAMESPACE {
namespace chebyshev {


#if GINKGO_DPCPP_SINGLE_MODE


// we only change type in device code to keep the interface is the same as the
// other backend.
template <typename coeff_type>
using if_single_only_type =
std::conditional_t<std::is_same_v<coeff_type, double>, float,
std::complex<float>>;


#else


template <typename coeff_type>
using if_single_only_type = xstd::type_identity_t<coeff_type>;


#endif


template <typename ValueType>
void init_update(std::shared_ptr<const DefaultExecutor> exec,
const solver::detail::coeff_type<ValueType> alpha,
const matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
using coeff_type =
if_single_only_type<solver::detail::coeff_type<ValueType>>;
// the coeff_type always be the highest precision, so we need
// to cast the others from ValueType to this precision.
using arithmetic_type = device_type<coeff_type>;

auto alpha_val = static_cast<coeff_type>(alpha);

run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto inner_sol,
auto update_sol, auto output) {
const auto inner_val =
static_cast<arithmetic_type>(inner_sol(row, col));
update_sol(row, col) =
static_cast<device_type<ValueType>>(inner_val);
output(row, col) = static_cast<device_type<ValueType>>(
static_cast<arithmetic_type>(output(row, col)) +
alpha * inner_val);
},
output->get_size(), alpha_val, inner_sol, update_sol, output);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);


template <typename ValueType>
void update(std::shared_ptr<const DefaultExecutor> exec,
const solver::detail::coeff_type<ValueType> alpha,
const solver::detail::coeff_type<ValueType> beta,
matrix::Dense<ValueType>* inner_sol,
matrix::Dense<ValueType>* update_sol,
matrix::Dense<ValueType>* output)
{
using coeff_type =
if_single_only_type<solver::detail::coeff_type<ValueType>>;
// the coeff_type always be the highest precision, so we need
// to cast the others from ValueType to this precision.
using arithmetic_type = device_type<coeff_type>;

auto alpha_val = static_cast<coeff_type>(alpha);
auto beta_val = static_cast<coeff_type>(beta);

run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto beta, auto inner_sol,
auto update_sol, auto output) {
const auto val =
static_cast<arithmetic_type>(inner_sol(row, col)) +
beta * static_cast<arithmetic_type>(update_sol(row, col));
inner_sol(row, col) = static_cast<device_type<ValueType>>(val);
update_sol(row, col) = static_cast<device_type<ValueType>>(val);
output(row, col) = static_cast<device_type<ValueType>>(
static_cast<arithmetic_type>(output(row, col)) + alpha * val);
},
output->get_size(), alpha_val, beta_val, inner_sol, update_sol, output);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);


} // namespace chebyshev
} // namespace GKO_DEVICE_NAMESPACE
} // namespace kernels
} // namespace gko
3 changes: 2 additions & 1 deletion core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ target_sources(
matrix/scaled_permutation.cpp
matrix/sellp.cpp
matrix/sparsity_csr.cpp
multigrid/pgm.cpp
multigrid/fixed_coarsening.cpp
multigrid/pgm.cpp
preconditioner/batch_jacobi.cpp
preconditioner/gauss_seidel.cpp
preconditioner/sor.cpp
Expand All @@ -113,6 +113,7 @@ target_sources(
solver/cb_gmres.cpp
solver/cg.cpp
solver/cgs.cpp
solver/chebyshev.cpp
solver/direct.cpp
solver/fcg.cpp
solver/gcr.cpp
Expand Down
8 changes: 4 additions & 4 deletions core/config/config_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ namespace gko {
namespace config {


#define GKO_INVALID_CONFIG_VALUE(_entry, _value) \
GKO_INVALID_STATE(std::string("The value >" + _value + \
"< is invalid for the entry >" + _entry + \
"<"))
#define GKO_INVALID_CONFIG_VALUE(_entry, _value) \
GKO_INVALID_STATE(std::string("The value >") + _value + \
"< is invalid for the entry >" + _entry + "<")


#define GKO_MISSING_CONFIG_ENTRY(_entry) \
Expand All @@ -53,6 +52,7 @@ enum class LinOpFactoryType : int {
Direct,
LowerTrs,
UpperTrs,
Chebyshev,
Factorization_Ic,
Factorization_Ilu,
Cholesky,
Expand Down
1 change: 1 addition & 0 deletions core/config/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ configuration_map generate_config_map()
{"solver::Direct", parse<LinOpFactoryType::Direct>},
{"solver::LowerTrs", parse<LinOpFactoryType::LowerTrs>},
{"solver::UpperTrs", parse<LinOpFactoryType::UpperTrs>},
{"solver::Chebyshev", parse<LinOpFactoryType::Chebyshev>},
{"factorization::Ic", parse<LinOpFactoryType::Factorization_Ic>},
{"factorization::Ilu", parse<LinOpFactoryType::Factorization_Ilu>},
{"factorization::Cholesky", parse<LinOpFactoryType::Cholesky>},
Expand Down
2 changes: 2 additions & 0 deletions core/config/solver_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ginkgo/core/solver/cb_gmres.hpp>
#include <ginkgo/core/solver/cg.hpp>
#include <ginkgo/core/solver/cgs.hpp>
#include <ginkgo/core/solver/chebyshev.hpp>
#include <ginkgo/core/solver/direct.hpp>
#include <ginkgo/core/solver/fcg.hpp>
#include <ginkgo/core/solver/gcr.hpp>
Expand Down Expand Up @@ -45,6 +46,7 @@ GKO_PARSE_VALUE_TYPE(Minres, gko::solver::Minres);
GKO_PARSE_VALUE_AND_INDEX_TYPE(Direct, gko::experimental::solver::Direct);
GKO_PARSE_VALUE_AND_INDEX_TYPE(LowerTrs, gko::solver::LowerTrs);
GKO_PARSE_VALUE_AND_INDEX_TYPE(UpperTrs, gko::solver::UpperTrs);
GKO_PARSE_VALUE_TYPE(Chebyshev, gko::solver::Chebyshev);


template <>
Expand Down
11 changes: 11 additions & 0 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#include "core/solver/cb_gmres_kernels.hpp"
#include "core/solver/cg_kernels.hpp"
#include "core/solver/cgs_kernels.hpp"
#include "core/solver/chebyshev_kernels.hpp"
#include "core/solver/common_gmres_kernels.hpp"
#include "core/solver/fcg_kernels.hpp"
#include "core/solver/gcr_kernels.hpp"
Expand Down Expand Up @@ -677,6 +678,16 @@ GKO_STUB_CB_GMRES_CONST(GKO_DECLARE_CB_GMRES_SOLVE_KRYLOV_KERNEL);
} // namespace cb_gmres


namespace chebyshev {


GKO_STUB_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_INIT_UPDATE_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_CHEBYSHEV_UPDATE_KERNEL);


} // namespace chebyshev


namespace ir {


Expand Down
Loading