Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 21 additions & 81 deletions core/config/preconditioner_ic_config.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include <ginkgo/core/base/exception_helpers.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/preconditioner/ic.hpp>
#include <ginkgo/core/preconditioner/ilu.hpp>
#include <ginkgo/core/preconditioner/isai.hpp>
#include <ginkgo/core/solver/gmres.hpp>
#include <ginkgo/core/solver/ir.hpp>
#include <ginkgo/core/solver/triangular.hpp>

#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"
Expand All @@ -22,85 +17,30 @@ namespace gko {
namespace config {


// For Ic and Ilu, we use additional ValueType to help Solver type decision
template <typename Solver>
class IcSolverHelper {
public:
template <typename ValueType, typename IndexType>
class Configurator {
public:
static
typename gko::preconditioner::Ic<Solver, IndexType>::parameters_type
parse(const pnode& config, const registry& context,
const type_descriptor& td_for_child)
{
return gko::preconditioner::Ic<Solver, IndexType>::parse(
config, context, td_for_child);
}
};
};


// Do not use the partial specialization for SolverBase<V> and SolverBase<V, I>
// because the default template arguments are allowed for a template template
// argument (detail: CWG 150 after c++17
// https://en.cppreference.com/w/cpp/language/template_parameters#Template_template_arguments)
template <template <typename V> class SolverBase>
class IcHelper1 {
public:
template <typename ValueType, typename IndexType>
class Configurator
: public IcSolverHelper<SolverBase<ValueType>>::template Configurator<
ValueType, IndexType> {};
};


template <template <typename V, typename I> class SolverBase>
class IcHelper2 {
public:
template <typename ValueType, typename IndexType>
class Configurator
: public IcSolverHelper<SolverBase<ValueType, IndexType>>::
template Configurator<ValueType, IndexType> {};
};


template <>
deferred_factory_parameter<gko::LinOpFactory> parse<LinOpFactoryType::Ic>(
const pnode& config, const registry& context, const type_descriptor& td)
deferred_factory_parameter<gko::LinOpFactory>
parse<gko::config::LinOpFactoryType::Ic>(const gko::config::pnode& config,
const gko::config::registry& context,
const gko::config::type_descriptor& td)
{
auto updated = update_type(config, td);
std::string str("solver::LowerTrs");
if (auto& obj = config.get("l_solver_type")) {
str = obj.get_string();
}
if (str == "solver::LowerTrs") {
return dispatch<gko::LinOpFactory,
IcHelper2<solver::LowerTrs>::Configurator>(
config, context, updated,
make_type_selector(updated.get_value_typestr(), value_type_list()),
make_type_selector(updated.get_index_typestr(), index_type_list()));
} else if (str == "solver::Ir") {
return dispatch<gko::LinOpFactory, IcHelper1<solver::Ir>::Configurator>(
config, context, updated,
make_type_selector(updated.get_value_typestr(), value_type_list()),
make_type_selector(updated.get_index_typestr(), index_type_list()));
} else if (str == "preconditioner::LowerIsai") {
return dispatch<gko::LinOpFactory,
IcHelper2<preconditioner::LowerIsai>::Configurator>(
config, context, updated,
make_type_selector(updated.get_value_typestr(), value_type_list()),
make_type_selector(updated.get_index_typestr(), index_type_list()));
} else if (str == "solver::Gmres") {
return dispatch<gko::LinOpFactory,
IcHelper1<solver::Gmres>::Configurator>(
config, context, updated,
make_type_selector(updated.get_value_typestr(), value_type_list()),
make_type_selector(updated.get_index_typestr(), index_type_list()));
} else {
GKO_INVALID_CONFIG_VALUE("l_solver_type", str);
auto updated = gko::config::update_type(config, td);
if (config.get("l_solver_type_or_value_type")) {
GKO_INVALID_STATE(
"preconditioner::Ic only allows value_type from "
"l_solver_type_or_value_type. To avoid type confusion between "
"these types and value_type, l_solver_type_or_value_type uses "
"the value_type directly.");
}
return gko::config::dispatch<gko::LinOpFactory, gko::preconditioner::Ic>(
config, context, updated,
gko::config::make_type_selector(updated.get_value_typestr(),
gko::config::value_type_list()),
gko::config::make_type_selector(updated.get_index_typestr(),
gko::config::index_type_list()));
}
static_assert(true,
"This assert is used to counter the false positive extra "
"semi-colon warnings");


} // namespace config
Expand Down
62 changes: 24 additions & 38 deletions core/preconditioner/ic.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include "ginkgo/core/preconditioner/ic.hpp"

#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/base/utils_helper.hpp>
#include <ginkgo/core/config/config.hpp>
#include <ginkgo/core/config/registry.hpp>
#include <ginkgo/core/preconditioner/isai.hpp>
#include <ginkgo/core/preconditioner/utils.hpp>
#include <ginkgo/core/solver/gmres.hpp>
#include <ginkgo/core/solver/ir.hpp>

#include "core/config/config_helper.hpp"
#include "core/config/dispatch.hpp"
Expand All @@ -21,18 +18,19 @@ namespace preconditioner {
namespace detail {


template <typename Ic,
std::enable_if_t<support_ic_parse<typename Ic::l_solver_type>>*>
template <typename Ic, std::enable_if_t<support_ic_parse<Ic>>*>
typename Ic::parameters_type ic_parse(
const config::pnode& config, const config::registry& context,
const config::type_descriptor& td_for_child)
{
auto params = Ic::build();

using l_solver_type = typename Ic::l_solver_type;
static_assert(std::is_same_v<l_solver_type, LinOp>,
"only support IC parse when l_solver_type is LinOp.");
if (auto& obj = config.get("l_solver")) {
params.with_l_solver(
gko::config::parse_or_get_specific_factory<
const typename Ic::l_solver_type>(obj, context, td_for_child));
gko::config::parse_or_get_factory<const LinOpFactory>(
obj, context, td_for_child));
}
if (auto& obj = config.get("factorization")) {
params.with_factorization(
Expand All @@ -44,35 +42,23 @@ typename Ic::parameters_type ic_parse(
}


#define GKO_DECLARE_LOWERTRS_IC_PARSE(ValueType, IndexType) \
typename Ic<solver::LowerTrs<ValueType, IndexType>, \
IndexType>::parameters_type \
ic_parse<Ic<solver::LowerTrs<ValueType, IndexType>, IndexType>>( \
const config::pnode&, const config::registry&, \
const config::type_descriptor&)
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_LOWERTRS_IC_PARSE);

#define GKO_DECLARE_IR_IC_PARSE(ValueType, IndexType) \
typename Ic<solver::Ir<ValueType>, IndexType>::parameters_type \
ic_parse<Ic<solver::Ir<ValueType>, IndexType>>( \
const config::pnode&, const config::registry&, \
const config::type_descriptor&)
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_IR_IC_PARSE);

#define GKO_DECLARE_GMRES_IC_PARSE(ValueType, IndexType) \
typename Ic<solver::Gmres<ValueType>, IndexType>::parameters_type \
ic_parse<Ic<solver::Gmres<ValueType>, IndexType>>( \
const config::pnode&, const config::registry&, \
const config::type_descriptor&)
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_GMRES_IC_PARSE);

#define GKO_DECLARE_LOWERISAI_IC_PARSE(ValueType, IndexType) \
typename Ic<LowerIsai<ValueType, IndexType>, IndexType>::parameters_type \
ic_parse<Ic<LowerIsai<ValueType, IndexType>, IndexType>>( \
const config::pnode&, const config::registry&, \
const config::type_descriptor&)
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_LOWERISAI_IC_PARSE);
#define GKO_DECLARE_IC_PARSE(ValueType, IndexType) \
typename Ic<ValueType, IndexType>::parameters_type \
ic_parse<Ic<ValueType, IndexType>>(const config::pnode&, \
const config::registry&, \
const config::type_descriptor&)

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_IC_PARSE);


} // namespace detail


// only instantiate the value type variants of IC, whose solver is LinOp.
#define GKO_DECLARE_IC(ValueType, IndexType) class Ic<ValueType, IndexType>

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_IC);


} // namespace preconditioner
} // namespace gko
14 changes: 6 additions & 8 deletions core/test/config/preconditioner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ struct PreconditionerConfigTest {
};


struct Ic : PreconditionerConfigTest<
::gko::preconditioner::Ic<DummyIr, int>,
::gko::preconditioner::Ic<gko::solver::LowerTrs<>, int>> {
struct Ic : PreconditionerConfigTest<::gko::preconditioner::Ic<float, int>,
::gko::preconditioner::Ic<double, int>> {
static pnode::map_type setup_base()
{
return {{"type", pnode{"preconditioner::Ic"}}};
Expand All @@ -57,7 +56,6 @@ struct Ic : PreconditionerConfigTest<
static void change_template(pnode::map_type& config_map)
{
config_map["value_type"] = pnode{"float32"};
config_map["l_solver_type"] = pnode{"solver::Ir"};
}

template <bool from_reg, typename ParamType>
Expand All @@ -66,17 +64,17 @@ struct Ic : PreconditionerConfigTest<
{
if (from_reg) {
config_map["l_solver"] = pnode{"l_solver"};
param.with_l_solver(detail::registry_accessor::get_data<
typename changed_type::l_solver_type::Factory>(
reg, "l_solver"));
param.with_l_solver(
detail::registry_accessor::get_data<gko::LinOpFactory>(
reg, "l_solver"));
config_map["factorization"] = pnode{"factorization"};
param.with_factorization(
detail::registry_accessor::get_data<gko::LinOpFactory>(
reg, "factorization"));
} else {
config_map["l_solver"] = pnode{{{"type", pnode{"solver::Ir"}},
{"value_type", pnode{"float32"}}}};
param.with_l_solver(changed_type::l_solver_type::build().on(exec));
param.with_l_solver(DummyIr::build().on(exec));
config_map["factorization"] =
pnode{{{"type", pnode{"solver::Ir"}},
{"value_type", pnode{"float32"}}}};
Expand Down
2 changes: 1 addition & 1 deletion include/ginkgo/core/base/composition.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down
75 changes: 75 additions & 0 deletions include/ginkgo/core/base/type_traits.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#ifndef GKO_PUBLIC_CORE_BASE_TYPE_TRAITS_HPP_
#define GKO_PUBLIC_CORE_BASE_TYPE_TRAITS_HPP_

#include <type_traits>

#include <ginkgo/core/base/lin_op.hpp>

namespace gko {
namespace detail {


template <typename Type>
constexpr bool is_ginkgo_linop = std::is_base_of_v<LinOp, Type>;


// helper to get factory type of concrete type or LinOp
template <typename Type>
struct factory_type_impl {
using type = typename Type::Factory;
};

// It requires LinOp to be complete type
template <>
struct factory_type_impl<LinOp> {
using type = LinOpFactory;
};


template <typename Type>
using factory_type = typename factory_type_impl<Type>::type;


// helper for handle the transposed type of concrete type and LinOp
template <typename Type>
struct transposed_type_impl {
using type = typename Type::transposed_type;
};

// It requires LinOp to be complete type
template <>
struct transposed_type_impl<LinOp> {
using type = LinOp;
};


template <typename Type>
using transposed_type = typename transposed_type_impl<Type>::type;


// helper to get value_type of concrete type or void for LinOp
template <typename Type, typename = void>
struct get_value_type_impl {
using type = typename Type::value_type;
};

// We need to use SFINAE not conditional_t because both type needs to be
// valid in conditional_t
template <typename Type>
struct get_value_type_impl<Type, std::enable_if_t<!is_ginkgo_linop<Type>>> {
using type = Type;
};


template <typename Type>
using get_value_type = typename get_value_type_impl<Type>::type;


} // namespace detail
} // namespace gko

#endif // GKO_PUBLIC_CORE_BASE_TYPE_TRAITS_HPP_
2 changes: 1 addition & 1 deletion include/ginkgo/core/base/utils_helper.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down
Loading