Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions batched/KokkosBatched_Util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -682,13 +682,13 @@ KOKKOS_INLINE_FUNCTION int get_extent_int(const ViewType &v, const int r) {
static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2.");

if (r == 0) {
int V_extent_0 = V_rank == 0 ? 0 : v.extent_int(0);
int V_extent_0 = V_rank < 1 ? 1 : v.extent_int(0);
return V_extent_0;
} else if (r == 1) {
int V_extent_1 = V_rank == 0 ? 0 : V_rank == 1 ? 1 : v.extent_int(1);
int V_extent_1 = V_rank < 2 ? 1 : v.extent_int(1);
return V_extent_1;
} else {
return -1;
return 1;
}
}

Expand All @@ -699,13 +699,13 @@ KOKKOS_INLINE_FUNCTION std::size_t get_stride(const ViewType &v, const int r) {
static_assert(V_rank <= 2, "KokkosBatched: ViewType must have rank 0, 1 or 2.");

if (r == 0) {
std::size_t V_stride_0 = V_rank == 0 ? 0 : v.stride(0);
std::size_t V_stride_0 = V_rank < 1 ? 1 : v.stride(0);
return V_stride_0;
} else if (r == 1) {
std::size_t V_stride_1 = V_rank == 0 ? 0 : V_rank == 1 ? 1 : v.stride(1);
std::size_t V_stride_1 = V_rank < 2 ? 1 : v.stride(1);
return V_stride_1;
} else {
return 0;
return 1;
}
}
} // namespace Impl
Expand Down
79 changes: 0 additions & 79 deletions batched/dense/impl/KokkosBatched_Gemm_Common_Impl.hpp

This file was deleted.

56 changes: 55 additions & 1 deletion batched/dense/impl/KokkosBatched_Gemm_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,64 @@

#include "KokkosBlas_util.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Gemm_Common_Impl.hpp"
#include "KokkosBatched_Gemm_Serial_Internal.hpp"

namespace KokkosBatched {
namespace Impl {
template <typename ArgTransA, typename ArgTransB, typename AViewType, typename BViewType, typename CViewType>
KOKKOS_INLINE_FUNCTION static int checkGemmInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const BViewType &B,
[[maybe_unused]] const CViewType &C) {
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::gemm: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<BViewType>, "KokkosBatched::gemm: BViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<CViewType>, "KokkosBatched::gemm: CViewType is not a Kokkos::View.");

static_assert(AViewType::rank <= 2, "KokkosBatched::gemm: AViewType must have rank 0, 1 or 2.");
static_assert(BViewType::rank <= 2, "KokkosBatched::gemm: BViewType must have rank 0, 1 or 2.");
static_assert(CViewType::rank <= 2, "KokkosBatched::gemm: CViewType must have rank 0, 1 or 2.");

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int m = C.extent(0), n = C.extent(1);
const int lda = A.extent(0);
const int ldb = B.extent(0);

const int ka = std::is_same_v<ArgTransA, Trans::NoTranspose> ? A.extent(1) : A.extent(0);
const int kb = std::is_same_v<ArgTransB, Trans::NoTranspose> ? B.extent(0) : B.extent(1);

if (ka != kb) {
Kokkos::printf(
"KokkosBatched::gemm: Dimensions of A and B do not match: A: %d x %d, "
"B: %d x %d\n",
A.extent(0), A.extent(1), B.extent(0), B.extent(1));
return 1;
}

const int nrowa = std::is_same_v<ArgTransA, Trans::NoTranspose> ? m : ka;
const int nrowb = std::is_same_v<ArgTransB, Trans::NoTranspose> ? kb : n;

if (lda < Kokkos::max(1, nrowa)) {
Kokkos::printf(
"KokkosBatched::gemm: leading dimension of A must not be smaller than "
"max(1, nrowa): "
"lda = %d, nrowa = %d\n",
lda, nrowa);
return 1;
}
if (ldb < Kokkos::max(1, nrowb)) {
Kokkos::printf(
"KokkosBatched::gemm: leading dimension of B must not be smaller than "
"max(1, nrowb): "
"ldb = %d, nrowb = %d\n",
ldb, nrowb);
return 1;
}

#endif

return 0;
}
} // namespace Impl

///
/// Serial Impl
/// ===========
Expand Down
Loading
Loading