Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 93 additions & 4 deletions core/config/config_helper.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -12,6 +12,7 @@

#include "core/config/registry_accessor.hpp"
#include "core/config/stop_config.hpp"
#include "type_descriptor_helper.hpp"

namespace gko {
namespace config {
Expand Down Expand Up @@ -43,7 +44,9 @@ parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
if (config.get_tag() == pnode::tag_t::string) {
return detail::registry_accessor::get_data<stop::CriterionFactory>(
context, config.get_string());
} else if (config.get_tag() == pnode::tag_t::map) {
}

if (config.get_tag() == pnode::tag_t::map) {
static std::map<std::string,
std::function<deferred_factory_parameter<
gko::stop::CriterionFactory>(
Expand All @@ -55,9 +58,95 @@ parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
{"ImplicitResidualNorm", configure_implicit_residual}}};
return criterion_map.at(config.get("type").get_string())(config,
context, td);
} else {
GKO_INVALID_STATE("The type of config is not valid.");
}

GKO_INVALID_STATE(
"Criteria must either be defined as a string or an array.");
}


std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
parse_minimal_criteria(const pnode& config, const registry& context,
const type_descriptor& td)
{
auto map_time = [](const pnode& config, const registry& context,
const type_descriptor& td) {
pnode time_config{{{"time_limit", config.get("time")}}};
return configure_time(time_config, context, td);
};
auto map_iteration = [](const pnode& config, const registry& context,
const type_descriptor& td) {
pnode iter_config{{{"max_iters", config.get("iteration")}}};
return configure_iter(iter_config, context, td);
};
auto create_residual_mapping = [](const std::string& key,
const std::string& baseline,
auto configure_fn) {
return std::make_pair(
key, [=](const pnode& config, const registry& context,
const type_descriptor& td) {
pnode res_config{{{"baseline", pnode{baseline}},
{"reduction_factor", config.get(key)}}};
return configure_fn(res_config, context, td);
});
};
std::map<
std::string,
std::function<deferred_factory_parameter<gko::stop::CriterionFactory>(
const pnode&, const registry&, type_descriptor)>>
criterion_map{
{{"time", map_time},
{"iteration", map_iteration},
create_residual_mapping("relative_residual_norm", "rhs_norm",
configure_residual),
create_residual_mapping("initial_residual_norm", "initial_resnorm",
configure_residual),
create_residual_mapping("absolute_residual_norm", "absolute",
configure_residual),
create_residual_mapping("relative_implicit_residual_norm",
"rhs_norm", configure_implicit_residual),
create_residual_mapping("initial_implicit_residual_norm",
"initial_resnorm",
configure_implicit_residual),
create_residual_mapping("absolute_implicit_residual_norm",
"absolute", configure_implicit_residual)}};

type_descriptor updated_td = update_type(config, td);

std::vector<deferred_factory_parameter<const stop::CriterionFactory>> res;
for (const auto& it : config.get_map()) {
if (it.first == "value_type") {
continue;
}
res.emplace_back(
criterion_map.at(it.first)(config, context, updated_td));
}
return res;
}


std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
parse_or_get_criteria(const pnode& config, const registry& context,
const type_descriptor& td)
{
if (config.get_tag() == pnode::tag_t::array ||
(config.get_tag() == pnode::tag_t::map && config.get("type"))) {
return parse_or_get_factory_vector<const stop::CriterionFactory>(
config, context, td);
}

if (config.get_tag() == pnode::tag_t::map) {
return parse_minimal_criteria(config, context, td);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return parse_minimal_criteria(config, context, td);
auto updated = config::update_type(td);
return parse_minimal_criteria(config, context, updated);

It is to support no valuetype available outside

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean here. I can specify the value type of the residual nom criterion, as you can see in the tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's based on the type_descriptor from the outer loop.
I mean something like

"stop": {
  "value_type": "float64",
  "residual_norm": ...
}

in case no precision information from outside or want to specify certain precision for stop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following the above comment,

"stop": [
  {
    "type": "ResidualNorm",
    "value_type": "float32",
    ...
  }
]

in multigrid or some solver without precision information.
it can be mapped to

"stop": {
  "value_type": "float32",
  "residual_norm": ...
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, we can say this rare case should use the map version not minimal version

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try a bit to make it work. If it's a reasonable effort I will adapt it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was pretty simple to add, so it's now available.

}

if (config.get_tag() == pnode::tag_t::string) {
return {detail::registry_accessor::get_data<stop::CriterionFactory>(
context, config.get_string())};
}

GKO_INVALID_STATE(
"Criteria must either be defined as a string, an array,"
"or an map.");
}

} // namespace config
Expand Down
11 changes: 10 additions & 1 deletion core/config/config_helper.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -141,6 +141,15 @@ parse_or_get_factory<const stop::CriterionFactory>(const pnode& config,
const registry& context,
const type_descriptor& td);

/**
* parse or get a std::vector of criteria.
* A stored single criterion will be converted to a std::vector.
*/
std::vector<deferred_factory_parameter<const stop::CriterionFactory>>
parse_or_get_criteria(const pnode& config, const registry& context,
const type_descriptor& td);


/**
* give a vector of factory by calling parse_or_get_factory.
*/
Expand Down
5 changes: 2 additions & 3 deletions core/config/solver_config.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -27,8 +27,7 @@ inline void common_solver_parse(SolverParam& params, const pnode& config,
}
if (auto& obj = config.get("criteria")) {
params.with_criteria(
gko::config::parse_or_get_factory_vector<
const stop::CriterionFactory>(obj, context, td_for_child));
gko::config::parse_or_get_criteria(obj, context, td_for_child));
}
if (auto& obj = config.get("preconditioner")) {
params.with_preconditioner(
Expand Down
136 changes: 135 additions & 1 deletion core/test/config/config.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -12,6 +12,7 @@
#include <ginkgo/core/stop/combined.hpp>
#include <ginkgo/core/stop/iteration.hpp>
#include <ginkgo/core/stop/residual_norm.hpp>
#include <ginkgo/core/stop/time.hpp>

#include "core/config/config_helper.hpp"
#include "core/test/utils.hpp"
Expand Down Expand Up @@ -122,6 +123,139 @@ TEST_F(Config, GenerateObjectWithCustomBuild)
}


TEST_F(Config, GenerateCriteriaFromMinimalConfig)
{
// the map is ordered, since this allows for easier comparison in the test
pnode minimal_stop{{
{"absolute_implicit_residual_norm", pnode{0.01}},
{"absolute_residual_norm", pnode{0.01}},
{"initial_implicit_residual_norm", pnode{0.01}},
{"initial_residual_norm", pnode{0.01}},
{"iteration", pnode{10}},
{"relative_implicit_residual_norm", pnode{0.01}},
{"relative_residual_norm", pnode{0.01}},
{"time", pnode{100}},
}};

pnode p{{{"criteria", minimal_stop}}};
auto obj = std::dynamic_pointer_cast<gko::solver::Cg<float>::Factory>(
parse<LinOpFactoryType::Cg>(p, registry(),
type_descriptor{"float32", "void"})
.on(this->exec));

ASSERT_NE(obj, nullptr);
auto criteria = obj->get_parameters().criteria;
ASSERT_EQ(criteria.size(), minimal_stop.get_map().size());
try {
throw std::runtime_error("Criteria does not exist");
} catch (...) {
}
{
SCOPED_TRACE("Absolute Implicit Residual Criterion");
auto res = std::dynamic_pointer_cast<
const gko::stop::ImplicitResidualNorm<float>::Factory>(criteria[0]);
ASSERT_NE(res, nullptr);
EXPECT_EQ(res->get_parameters().baseline, gko::stop::mode::absolute);
EXPECT_EQ(res->get_parameters().reduction_factor, 0.01f);
}
{
SCOPED_TRACE("Absolute Residual Criterion");
auto res = std::dynamic_pointer_cast<
const gko::stop::ResidualNorm<float>::Factory>(criteria[1]);
ASSERT_NE(res, nullptr);
EXPECT_EQ(res->get_parameters().baseline, gko::stop::mode::absolute);
EXPECT_EQ(res->get_parameters().reduction_factor, 0.01f);
}
{
SCOPED_TRACE("Initial Implicit Residual Criterion");
auto res = std::dynamic_pointer_cast<
const gko::stop::ImplicitResidualNorm<float>::Factory>(criteria[2]);
ASSERT_NE(res, nullptr);
EXPECT_EQ(res->get_parameters().baseline,
gko::stop::mode::initial_resnorm);
EXPECT_EQ(res->get_parameters().reduction_factor, 0.01f);
}
{
SCOPED_TRACE("Initial Residual Criterion");
auto res = std::dynamic_pointer_cast<
const gko::stop::ResidualNorm<float>::Factory>(criteria[3]);
ASSERT_NE(res, nullptr);
EXPECT_EQ(res->get_parameters().baseline,
gko::stop::mode::initial_resnorm);
EXPECT_EQ(res->get_parameters().reduction_factor, 0.01f);
}
{
SCOPED_TRACE("Iteration Criterion");
auto it =
std::dynamic_pointer_cast<const gko::stop::Iteration::Factory>(
criteria[4]);
ASSERT_NE(it, nullptr);
EXPECT_EQ(it->get_parameters().max_iters, 10);
}
{
SCOPED_TRACE("Relative Implicit Residual Criterion");
auto res = std::dynamic_pointer_cast<
const gko::stop::ImplicitResidualNorm<float>::Factory>(criteria[5]);
ASSERT_NE(res, nullptr);
EXPECT_EQ(res->get_parameters().baseline, gko::stop::mode::rhs_norm);
EXPECT_EQ(res->get_parameters().reduction_factor, 0.01f);
}
{
SCOPED_TRACE("Relative Residual Criterion");
auto res = std::dynamic_pointer_cast<
const gko::stop::ResidualNorm<float>::Factory>(criteria[6]);
ASSERT_NE(res, nullptr);
EXPECT_EQ(res->get_parameters().baseline, gko::stop::mode::rhs_norm);
EXPECT_EQ(res->get_parameters().reduction_factor, 0.01f);
}
{
SCOPED_TRACE("Time Criterion");
using namespace std::chrono_literals;
auto time = std::dynamic_pointer_cast<const gko::stop::Time::Factory>(
criteria[7]);
ASSERT_NE(time, nullptr);
EXPECT_EQ(time->get_parameters().time_limit, 100ns);
}
}


TEST_F(Config, GenerateCriteriaFromMinimalConfigWithValueType)
{
auto reg = registry();
reg.emplace("precond", this->mtx);
pnode minimal_stop{{
{"value_type", pnode{"float64"}},
{"relative_residual_norm", pnode{0.01}},
{"time", pnode{100}},
}};

pnode p{{{"criteria", minimal_stop}}};
auto obj = std::dynamic_pointer_cast<gko::solver::Cg<float>::Factory>(
parse<LinOpFactoryType::Cg>(p, reg, type_descriptor{"float32", "void"})
.on(this->exec));

ASSERT_NE(obj, nullptr);
auto criteria = obj->get_parameters().criteria;
ASSERT_EQ(criteria.size(), minimal_stop.get_map().size() - 1);
{
SCOPED_TRACE("Residual Criterion");
auto res = std::dynamic_pointer_cast<
const gko::stop::ResidualNorm<double>::Factory>(criteria[0]);
ASSERT_NE(res, nullptr);
EXPECT_EQ(res->get_parameters().baseline, gko::stop::mode::rhs_norm);
EXPECT_EQ(res->get_parameters().reduction_factor, 0.01);
}
{
SCOPED_TRACE("Time Criterion");
using namespace std::chrono_literals;
auto time = std::dynamic_pointer_cast<const gko::stop::Time::Factory>(
criteria[1]);
ASSERT_NE(time, nullptr);
EXPECT_EQ(time->get_parameters().time_limit, 100ns);
}
}


TEST(GetValue, IndexType)
{
long long int value = 123;
Expand Down
Loading
Loading