Skip to content

Commit 99f223d

Browse files
author
Yuuichi Asahi
committed
fix: tests for pbtrs
1 parent 1ce6fd2 commit 99f223d

File tree

2 files changed

+22
-24
lines changed

2 files changed

+22
-24
lines changed

batched/dense/unit_test/Test_Batched_Dense.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@
5858
#include "Test_Batched_SerialPbtrf.hpp"
5959
#include "Test_Batched_SerialPbtrf_Real.hpp"
6060
#include "Test_Batched_SerialPbtrf_Complex.hpp"
61+
#include "Test_Batched_SerialPbtrs.hpp"
62+
#include "Test_Batched_SerialPbtrs_Real.hpp"
63+
#include "Test_Batched_SerialPbtrs_Complex.hpp"
6164

6265
// Team Kernels
6366
#include "Test_Batched_TeamAxpy.hpp"

batched/dense/unit_test/Test_Batched_SerialPbtrs.hpp

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include <gtest/gtest.h>
1818
#include <Kokkos_Core.hpp>
1919
#include <Kokkos_Random.hpp>
20-
20+
#include <KokkosBlas2_gemv.hpp>
2121
#include "KokkosBatched_Util.hpp"
2222
#include "KokkosBatched_Pbtrf.hpp"
2323
#include "KokkosBatched_Pbtrs.hpp"
@@ -54,7 +54,7 @@ struct Functor_BatchedSerialPbtrf {
5454
const std::string name_value_type = Test::value_type_name<value_type>();
5555
std::string name = name_region + name_value_type;
5656
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, _ab.extent(0));
57-
Kokkos::parallel_for(name.c_str(), policy, *this, info_sum);
57+
Kokkos::parallel_for(name.c_str(), policy, *this);
5858
}
5959
};
6060

@@ -89,35 +89,34 @@ struct Functor_BatchedSerialPbtrs {
8989
}
9090
};
9191

92-
template <typename DeviceType, typename ScalarType, typename AViewType, typename BViewType, typename CViewType,
93-
typename ArgTransA, typename ArgTransB>
94-
struct Functor_BatchedSerialGemm {
92+
template <typename DeviceType, typename ScalarType, typename AViewType, typename xViewType, typename yViewType>
93+
struct Functor_BatchedSerialGemv {
9594
using execution_space = typename DeviceType::execution_space;
9695
AViewType _a;
97-
BViewType _b;
98-
CViewType _c;
96+
xViewType _x;
97+
yViewType _y;
9998
ScalarType _alpha, _beta;
10099

101100
KOKKOS_INLINE_FUNCTION
102-
Functor_BatchedSerialGemm(const ScalarType alpha, const AViewType &a, const BViewType &b, const ScalarType beta,
103-
const CViewType &c)
104-
: _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {}
101+
Functor_BatchedSerialGemv(const ScalarType alpha, const AViewType &a, const xViewType &x, const ScalarType beta,
102+
const yViewType &y)
103+
: _a(a), _x(x), _y(y), _alpha(alpha), _beta(beta) {}
105104

106105
KOKKOS_INLINE_FUNCTION
107106
void operator()(const int k) const {
108107
auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL());
109-
auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL());
110-
auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL());
108+
auto xx = Kokkos::subview(_x, k, Kokkos::ALL());
109+
auto yy = Kokkos::subview(_y, k, Kokkos::ALL());
111110

112-
KokkosBatched::SerialGemm<ArgTransA, ArgTransB, Algo::Gemm::Unblocked>::invoke(_alpha, aa, bb, _beta, cc);
111+
KokkosBlas::SerialGemv<Trans::NoTranspose, Algo::Gemv::Unblocked>::invoke(_alpha, aa, xx, _beta, yy);
113112
}
114113

115114
inline void run() {
116115
using value_type = typename AViewType::non_const_value_type;
117-
std::string name_region("KokkosBatched::Test::SerialPbtrf");
116+
std::string name_region("KokkosBatched::Test::SerialPbtrs");
118117
const std::string name_value_type = Test::value_type_name<value_type>();
119118
std::string name = name_region + name_value_type;
120-
Kokkos::RangePolicy<execution_space> policy(0, _a.extent(0));
119+
Kokkos::RangePolicy<execution_space> policy(0, _x.extent(0));
121120
Kokkos::parallel_for(name.c_str(), policy, *this);
122121
}
123122
};
@@ -190,9 +189,7 @@ void impl_test_batched_pbtrs_analytical(const int N) {
190189
// Check x0 = x1
191190
for (int ib = 0; ib < N; ib++) {
192191
for (int i = 0; i < BlkSize; i++) {
193-
for (int j = 0; j < BlkSize; j++) {
194-
EXPECT_NEAR_KK(h_x0(ib, i, j), h_x_ref(ib, i, j), eps);
195-
}
192+
EXPECT_NEAR_KK(h_x0(ib, i), h_x_ref(ib, i), eps);
196193
}
197194
}
198195
}
@@ -259,9 +256,7 @@ void impl_test_batched_pbtrs(const int N, const int k, const int BlkSize) {
259256
// Check A * x0 = x_ref
260257
for (int ib = 0; ib < N; ib++) {
261258
for (int i = 0; i < BlkSize; i++) {
262-
for (int j = 0; j < BlkSize; j++) {
263-
EXPECT_NEAR_KK(h_y0(ib, i, j), h_x_ref(ib, i, j), eps);
264-
}
259+
EXPECT_NEAR_KK(h_y0(ib, i), h_x_ref(ib, i), eps);
265260
}
266261
}
267262
}
@@ -274,12 +269,12 @@ int test_batched_pbtrs() {
274269
#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT)
275270
{
276271
using LayoutType = Kokkos::LayoutLeft;
277-
Test::pbtrs::impl_test_batched_pbtrs_analytical<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1);
272+
Test::Pbtrs::impl_test_batched_pbtrs_analytical<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1);
278273
Test::Pbtrs::impl_test_batched_pbtrs_analytical<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(2);
279274
for (int i = 0; i < 10; i++) {
280275
int k = 1;
281-
Test::pbtrs::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1, k, i);
282-
Test::pbtrs::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(2, k, i);
276+
Test::Pbtrs::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1, k, i);
277+
Test::Pbtrs::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(2, k, i);
283278
}
284279
}
285280
#endif

0 commit comments

Comments
 (0)