Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions core/solver/cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,108 @@ typename Cg<ValueType>::parameters_type Cg<ValueType>::parse(
return params;
}

template <typename ValueType>
void Cg<ValueType>::apply_mv(ptr_param<const matrix::MultiVector> b,
ptr_param<matrix::MultiVector> x) const
{
// @todo: need precision dispatch
auto dense_b = b->temporary_precision<ValueType>();
auto dense_x = x->temporary_precision<ValueType>();

using std::swap;
constexpr uint8 RelativeStoppingId{1};

auto exec = this->get_executor();
this->setup_workspace();

GKO_SOLVER_VECTOR(r, dense_b.get());
GKO_SOLVER_VECTOR(z, dense_b.get());
GKO_SOLVER_VECTOR(p, dense_b.get());
GKO_SOLVER_VECTOR(q, dense_b.get());

GKO_SOLVER_SCALAR(beta, dense_b.get());
GKO_SOLVER_SCALAR(prev_rho, dense_b.get());
GKO_SOLVER_SCALAR(rho, dense_b.get());

GKO_SOLVER_ONE_MINUS_ONE();

bool one_changed{};
GKO_SOLVER_STOP_REDUCTION_ARRAYS();

// r = dense_b
// rho = 0.0
// prev_rho = 1.0
// z = p = q = 0
// @todo: I think the template keyword is necessary because some of these
// variables are defined via auto.
exec->run(cg::make_initialize(
dense_b->template create_local_view<ValueType>().get(),
r->template create_local_view<ValueType>().get(),
z->template create_local_view<ValueType>().get(),
p->template create_local_view<ValueType>().get(),
q->template create_local_view<ValueType>().get(), prev_rho, rho,
&stop_status));

this->get_system_matrix()->apply(neg_one_op, dense_x, one_op, r);
auto stop_criterion = this->get_stop_criterion_factory()->generate(
this->get_system_matrix(),
std::shared_ptr<const LinOp>(dense_b.get(), [](const LinOp*) {}),
dense_x.get(), r);

int iter = -1;
/* Memory movement summary:
* 18n * values + matrix/preconditioner storage
* 1x SpMV: 2n * values + storage
* 1x Preconditioner: 2n * values + storage
* 2x dot 4n
* 1x step 1 (axpy) 3n
* 1x step 2 (axpys) 6n
* 1x norm2 residual n
*/
while (true) {
// z = preconditioner * r
this->get_preconditioner()->apply(r, z);
// rho = dot(r, z)
r->compute_conj_dot(z, rho, reduction_tmp);

++iter;
bool all_stopped =
stop_criterion->update()
.num_iterations(iter)
.residual(r)
.implicit_sq_residual_norm(rho)
.solution(dense_x.get())
.check(RelativeStoppingId, true, &stop_status, &one_changed);
this->template log<log::Logger::iteration_complete>(
this, dense_b.get(), dense_x.get(), iter, r, nullptr, rho,
&stop_status, all_stopped);
if (all_stopped) {
break;
}

// tmp = rho / prev_rho
// p = z + tmp * p
exec->run(
cg::make_step_1(p->template create_local_view<ValueType>().get(),
z->template create_local_view<ValueType>().get(),
rho, prev_rho, &stop_status));
// q = A * p
this->get_system_matrix()->apply(p, q);
// beta = dot(p, q)
p->compute_conj_dot(q, beta, reduction_tmp);
// tmp = rho / beta
// x = x + tmp * p
// r = r - tmp * q
exec->run(cg::make_step_2(
dense_x->template create_local_view<ValueType>().get(),
r->template create_local_view<ValueType>().get(),
p->template create_local_view<ValueType>().get(),
q->template create_local_view<ValueType>().get(), beta, rho,
&stop_status));
swap(prev_rho, rho);
}
}


template <typename ValueType>
std::unique_ptr<LinOp> Cg<ValueType>::transpose() const
Expand Down
6 changes: 5 additions & 1 deletion include/ginkgo/core/solver/cg.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -54,6 +54,7 @@ class Cg : public EnableLinOp<Cg<ValueType>>,
public:
using value_type = ValueType;
using transposed_type = Cg<ValueType>;
using EnableLinOp<Cg>::apply;

std::unique_ptr<LinOp> transpose() const override;

Expand Down Expand Up @@ -93,6 +94,9 @@ class Cg : public EnableLinOp<Cg<ValueType>>,
const config::type_descriptor& td_for_child =
config::make_type_descriptor<ValueType>());

void apply_mv(ptr_param<const matrix::MultiVector> b,
ptr_param<matrix::MultiVector> x) const;

protected:
void apply_impl(const LinOp* b, LinOp* x) const override;

Expand Down
44 changes: 43 additions & 1 deletion reference/test/solver/cg_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -226,6 +226,48 @@ TYPED_TEST(Cg, SolvesStencilSystem)
}


TYPED_TEST(Cg, SolvesStencilSystemMultiVector)
{
using Mtx = typename TestFixture::Mtx;
using value_type = typename TestFixture::value_type;
auto solver =
gko::solver::Cg<value_type>::build()
.with_criteria(gko::stop::Iteration::build().with_max_iters(3u))
.on(this->exec)
->generate(this->mtx);
std::unique_ptr<gko::matrix::MultiVector> b =
gko::initialize<Mtx>({-1.0, 3.0, 1.0}, this->exec);
std::unique_ptr<gko::matrix::MultiVector> x =
gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);

solver->apply_mv(b, x);

GKO_ASSERT_MTX_NEAR(gko::as<Mtx>(x.get()), l({1.0, 3.0, 2.0}),
r<value_type>::value);
}

TYPED_TEST(Cg, SolvesStencilSystemMultiVectorMixed)
{
using value_type = typename TestFixture::value_type;
using snd_value_type = gko::next_precision<value_type>;
using Mtx = gko::matrix::Dense<snd_value_type>;
auto solver =
gko::solver::Cg<value_type>::build()
.with_criteria(gko::stop::Iteration::build().with_max_iters(3u))
.on(this->exec)
->generate(this->mtx);
std::unique_ptr<gko::matrix::MultiVector> b =
gko::initialize<Mtx>({-1.0, 3.0, 1.0}, this->exec);
std::unique_ptr<gko::matrix::MultiVector> x =
gko::initialize<Mtx>({0.0, 0.0, 0.0}, this->exec);

solver->apply_mv(b, x);

GKO_ASSERT_MTX_NEAR(gko::as<Mtx>(x.get()), l({1.0, 3.0, 2.0}),
(r_mixed<value_type, TypeParam>()));
}


TYPED_TEST(Cg, SolvesStencilSystemMixed)
{
using value_type = gko::next_precision<typename TestFixture::value_type>;
Expand Down
Loading