1717#include < gtest/gtest.h>
1818#include < Kokkos_Core.hpp>
1919#include < Kokkos_Random.hpp>
20-
20+ # include < KokkosBlas2_gemv.hpp >
2121#include " KokkosBatched_Util.hpp"
2222#include " KokkosBatched_Pbtrf.hpp"
2323#include " KokkosBatched_Pbtrs.hpp"
@@ -54,7 +54,7 @@ struct Functor_BatchedSerialPbtrf {
5454 const std::string name_value_type = Test::value_type_name<value_type>();
5555 std::string name = name_region + name_value_type;
5656 Kokkos::RangePolicy<execution_space, ParamTagType> policy (0 , _ab.extent (0 ));
57- Kokkos::parallel_for (name.c_str (), policy, *this , info_sum );
57+ Kokkos::parallel_for (name.c_str (), policy, *this );
5858 }
5959};
6060
@@ -89,35 +89,34 @@ struct Functor_BatchedSerialPbtrs {
8989 }
9090};
9191
92- template <typename DeviceType, typename ScalarType, typename AViewType, typename BViewType, typename CViewType,
93- typename ArgTransA, typename ArgTransB>
94- struct Functor_BatchedSerialGemm {
92+ template <typename DeviceType, typename ScalarType, typename AViewType, typename xViewType, typename yViewType>
93+ struct Functor_BatchedSerialGemv {
9594 using execution_space = typename DeviceType::execution_space;
9695 AViewType _a;
97- BViewType _b ;
98- CViewType _c ;
96+ xViewType _x ;
97+ yViewType _y ;
9998 ScalarType _alpha, _beta;
10099
101100 KOKKOS_INLINE_FUNCTION
102- Functor_BatchedSerialGemm (const ScalarType alpha, const AViewType &a, const BViewType &b , const ScalarType beta,
103- const CViewType &c )
104- : _a(a), _b(b ), _c(c ), _alpha(alpha), _beta(beta) {}
101+ Functor_BatchedSerialGemv (const ScalarType alpha, const AViewType &a, const xViewType &x , const ScalarType beta,
102+ const yViewType &y )
103+ : _a(a), _x(x ), _y(y ), _alpha(alpha), _beta(beta) {}
105104
106105 KOKKOS_INLINE_FUNCTION
107106 void operator ()(const int k) const {
108107 auto aa = Kokkos::subview (_a, k, Kokkos::ALL (), Kokkos::ALL ());
109- auto bb = Kokkos::subview (_b , k, Kokkos::ALL () , Kokkos::ALL ());
110- auto cc = Kokkos::subview (_c , k, Kokkos::ALL () , Kokkos::ALL ());
108+ auto xx = Kokkos::subview (_x , k, Kokkos::ALL ());
109+ auto yy = Kokkos::subview (_y , k, Kokkos::ALL ());
111110
112- KokkosBatched::SerialGemm<ArgTransA, ArgTransB, Algo::Gemm ::Unblocked>::invoke (_alpha, aa, bb , _beta, cc );
111+ KokkosBlas::SerialGemv<Trans::NoTranspose, Algo::Gemv ::Unblocked>::invoke (_alpha, aa, xx , _beta, yy );
113112 }
114113
115114 inline void run () {
116115 using value_type = typename AViewType::non_const_value_type;
117- std::string name_region (" KokkosBatched::Test::SerialPbtrf " );
116+ std::string name_region (" KokkosBatched::Test::SerialPbtrs " );
118117 const std::string name_value_type = Test::value_type_name<value_type>();
119118 std::string name = name_region + name_value_type;
120- Kokkos::RangePolicy<execution_space> policy (0 , _a .extent (0 ));
119+ Kokkos::RangePolicy<execution_space> policy (0 , _x .extent (0 ));
121120 Kokkos::parallel_for (name.c_str (), policy, *this );
122121 }
123122};
@@ -190,9 +189,7 @@ void impl_test_batched_pbtrs_analytical(const int N) {
190189 // Check x0 = x1
191190 for (int ib = 0 ; ib < N; ib++) {
192191 for (int i = 0 ; i < BlkSize; i++) {
193- for (int j = 0 ; j < BlkSize; j++) {
194- EXPECT_NEAR_KK (h_x0 (ib, i, j), h_x_ref (ib, i, j), eps);
195- }
192+ EXPECT_NEAR_KK (h_x0 (ib, i), h_x_ref (ib, i), eps);
196193 }
197194 }
198195}
@@ -259,9 +256,7 @@ void impl_test_batched_pbtrs(const int N, const int k, const int BlkSize) {
259256 // Check A * x0 = x_ref
260257 for (int ib = 0 ; ib < N; ib++) {
261258 for (int i = 0 ; i < BlkSize; i++) {
262- for (int j = 0 ; j < BlkSize; j++) {
263- EXPECT_NEAR_KK (h_y0 (ib, i, j), h_x_ref (ib, i, j), eps);
264- }
259+ EXPECT_NEAR_KK (h_y0 (ib, i), h_x_ref (ib, i), eps);
265260 }
266261 }
267262}
@@ -274,12 +269,12 @@ int test_batched_pbtrs() {
274269#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT)
275270 {
276271 using LayoutType = Kokkos::LayoutLeft;
277- Test::pbtrs ::impl_test_batched_pbtrs_analytical<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1 );
272+ Test::Pbtrs ::impl_test_batched_pbtrs_analytical<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1 );
278273 Test::Pbtrs::impl_test_batched_pbtrs_analytical<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(2 );
279274 for (int i = 0 ; i < 10 ; i++) {
280275 int k = 1 ;
281- Test::pbtrs ::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1 , k, i);
282- Test::pbtrs ::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(2 , k, i);
276+ Test::Pbtrs ::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1 , k, i);
277+ Test::Pbtrs ::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(2 , k, i);
283278 }
284279 }
285280#endif
0 commit comments