Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions common/cuda_hip/base/batch_multi_vector_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL);


Expand All @@ -81,7 +81,7 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
}
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL);


Expand All @@ -101,7 +101,7 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
x_ub, y_ub, res_ub, [] __device__(auto val) { return val; });
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL);


Expand All @@ -121,7 +121,7 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
x_ub, y_ub, res_ub, [] __device__(auto val) { return conj(val); });
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL);


Expand All @@ -139,7 +139,7 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
x_ub, res_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL);


Expand All @@ -156,7 +156,8 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
x_ub, result_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);


} // namespace batch_multi_vector
Expand Down
8 changes: 4 additions & 4 deletions common/cuda_hip/matrix/batch_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void simple_apply(std::shared_ptr<const DefaultExecutor> exec,
}


GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL);


Expand All @@ -72,7 +72,7 @@ void advanced_apply(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, mat_ub, b_ub, beta_ub, x_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL);


Expand All @@ -91,7 +91,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_SCALE_KERNEL);


Expand All @@ -110,7 +110,7 @@ void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, beta_ub, mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL);


Expand Down
12 changes: 7 additions & 5 deletions common/cuda_hip/matrix/batch_dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void simple_apply(std::shared_ptr<const DefaultExecutor> exec,
mat_ub, b_ub, x_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL);


Expand All @@ -71,7 +71,7 @@ void advanced_apply(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, mat_ub, b_ub, beta_ub, x_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL);


Expand All @@ -90,7 +90,8 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL);


template <typename ValueType>
Expand All @@ -108,7 +109,8 @@ void scale_add(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, mat_ub, in_out_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL);


template <typename ValueType>
Expand All @@ -126,7 +128,7 @@ void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, beta_ub, mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL);


Expand Down
8 changes: 4 additions & 4 deletions common/cuda_hip/matrix/batch_ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void simple_apply(std::shared_ptr<const DefaultExecutor> exec,
}


GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL);


Expand All @@ -72,7 +72,7 @@ void advanced_apply(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, mat_ub, b_ub, beta_ub, x_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);


Expand All @@ -91,7 +91,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);


Expand All @@ -110,7 +110,7 @@ void add_scaled_identity(std::shared_ptr<const DefaultExecutor> exec,
alpha_ub, beta_ub, mat_ub);
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL);


Expand Down
4 changes: 2 additions & 2 deletions common/cuda_hip/solver/batch_bicgstab_launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ void launch_apply_kernel(

#define GKO_DECLARE_BATCH_BICGSTAB_LAUNCH(_vtype, _n_shared, _prec_shared, \
mat_t, log_t, pre_t, stop_t) \
void launch_apply_kernel<device_type<_vtype>, _n_shared, _prec_shared, \
void launch_apply_kernel<_vtype, _n_shared, _prec_shared, \
stop_t<device_type<_vtype>>>( \
std::shared_ptr<const DefaultExecutor> exec, \
const gko::kernels::batch_bicgstab::storage_config& sconf, \
const settings<remove_complex<device_type<_vtype>>>& settings, \
const settings<remove_complex<_vtype>>& settings, \
log_t<gko::remove_complex<device_type<_vtype>>>& logger, \
pre_t<device_type<_vtype>>& prec, \
const mat_t<const device_type<_vtype>>& mat, \
Expand Down
26 changes: 13 additions & 13 deletions common/cuda_hip/solver/batch_cg_launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,19 @@ void launch_apply_kernel(
device_type<ValueType>* const __restrict__ workspace_data,
const int& block_size, const size_t& shared_size);

#define GKO_DECLARE_BATCH_CG_LAUNCH(_vtype, _n_shared, _prec_shared, mat_t, \
log_t, pre_t, stop_t) \
void launch_apply_kernel<device_type<_vtype>, _n_shared, _prec_shared, \
stop_t<device_type<_vtype>>>( \
std::shared_ptr<const DefaultExecutor> exec, \
const gko::kernels::batch_cg::storage_config& sconf, \
const settings<remove_complex<_vtype>>& settings, \
log_t<device_type<gko::remove_complex<device_type<_vtype>>>>& logger, \
pre_t<device_type<_vtype>>& prec, \
const mat_t<const device_type<_vtype>>& mat, \
const device_type<_vtype>* const __restrict__ b_values, \
device_type<_vtype>* const __restrict__ x_values, \
device_type<_vtype>* const __restrict__ workspace_data, \
#define GKO_DECLARE_BATCH_CG_LAUNCH(_vtype, _n_shared, _prec_shared, mat_t, \
log_t, pre_t, stop_t) \
void launch_apply_kernel<_vtype, _n_shared, _prec_shared, \
stop_t<device_type<_vtype>>>( \
std::shared_ptr<const DefaultExecutor> exec, \
const gko::kernels::batch_cg::storage_config& sconf, \
const settings<remove_complex<_vtype>>& settings, \
log_t<gko::remove_complex<device_type<_vtype>>>& logger, \
pre_t<device_type<_vtype>>& prec, \
const mat_t<const device_type<_vtype>>& mat, \
const device_type<_vtype>* const __restrict__ b_values, \
device_type<_vtype>* const __restrict__ x_values, \
device_type<_vtype>* const __restrict__ workspace_data, \
const int& block_size, const size_t& shared_size)

#define GKO_INSTANTIATE_BATCH_CG_LAUNCH_0_FALSE \
Expand Down
2 changes: 1 addition & 1 deletion core/base/batch_instantiation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace batch {
#define GKO_INSTANTIATE_FOR_BATCH_VALUE_MATRIX_PRECONDITIONER(...) \
GKO_CALL(GKO_BATCH_INSTANTIATE_MATRIX, \
GKO_BATCH_INSTANTIATE_PRECONDITIONER, \
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS, __VA_ARGS__)
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_VARGS_WITH_HALF, __VA_ARGS__)


} // namespace batch
Expand Down
27 changes: 24 additions & 3 deletions core/base/batch_multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ void MultiVector<ValueType>::compute_norm2(

template <typename ValueType>
void MultiVector<ValueType>::convert_to(
MultiVector<next_precision<ValueType>>* result) const
MultiVector<next_precision_with_half<ValueType>>* result) const
{
result->values_ = this->values_;
result->set_size(this->get_size());
Expand All @@ -290,14 +290,35 @@ void MultiVector<ValueType>::convert_to(

template <typename ValueType>
void MultiVector<ValueType>::move_to(
MultiVector<next_precision<ValueType>>* result)
MultiVector<next_precision_with_half<ValueType>>* result)
{
this->convert_to(result);
}


#if GINKGO_ENABLE_HALF
template <typename ValueType>
void MultiVector<ValueType>::convert_to(
MultiVector<next_precision_with_half<next_precision_with_half<ValueType>>>*
result) const
{
result->values_ = this->values_;
result->set_size(this->get_size());
}


template <typename ValueType>
void MultiVector<ValueType>::move_to(
MultiVector<next_precision_with_half<next_precision_with_half<ValueType>>>*
result)
{
this->convert_to(result);
}
#endif


#define GKO_DECLARE_BATCH_MULTI_VECTOR(_type) class MultiVector<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR);


} // namespace batch
Expand Down
53 changes: 32 additions & 21 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,15 @@ GKO_STUB_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(GKO_DECLARE_SEPARATE_LOCAL_NONLOCAL);
namespace batch_multi_vector {


GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR_SCALE_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR_ADD_SCALED_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_DOT_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_CONJ_DOT_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_MULTI_VECTOR_COMPUTE_NORM2_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);


} // namespace batch_multi_vector
Expand All @@ -376,10 +379,13 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_MULTI_VECTOR_COPY_KERNEL);
namespace batch_csr {


GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_SCALE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CSR_SCALE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL);


} // namespace batch_csr
Expand All @@ -388,11 +394,12 @@ GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_CSR_ADD_SCALED_IDENTITY_KERNEL);
namespace batch_dense {


GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL);
GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_SCALE_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_DENSE_SCALE_ADD_KERNEL);
GKO_STUB_VALUE_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL);


} // namespace batch_dense
Expand All @@ -401,10 +408,13 @@ GKO_STUB_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_ADD_SCALED_IDENTITY_KERNEL);
namespace batch_ell {


GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_SIMPLE_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_ADVANCED_APPLY_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(GKO_DECLARE_BATCH_ELL_SCALE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_ELL_ADD_SCALED_IDENTITY_KERNEL);


} // namespace batch_ell
Expand Down Expand Up @@ -941,9 +951,10 @@ namespace batch_jacobi {
GKO_STUB_INDEX_TYPE(
GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_CUMULATIVE_BLOCK_STORAGE);
GKO_STUB_INDEX_TYPE(GKO_DECLARE_BATCH_BLOCK_JACOBI_FIND_ROW_BLOCK_MAP);
GKO_STUB_VALUE_AND_INT32_TYPE(
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_BLOCK_JACOBI_EXTRACT_PATTERN_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE(GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL);
GKO_STUB_VALUE_AND_INT32_TYPE_WITH_HALF(
GKO_DECLARE_BATCH_BLOCK_JACOBI_COMPUTE_KERNEL);


} // namespace batch_jacobi
Expand Down
4 changes: 2 additions & 2 deletions core/log/batch_logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ log_data<ValueType>::log_data(std::shared_ptr<const Executor> exec,

#define GKO_DECLARE_LOG_DATA(_type) struct log_data<_type>

GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE(GKO_DECLARE_LOG_DATA);
GKO_INSTANTIATE_FOR_EACH_NON_COMPLEX_VALUE_TYPE_WITH_HALF(GKO_DECLARE_LOG_DATA);

#undef GKO_DECLARE_LOG_DATA

Expand All @@ -92,7 +92,7 @@ void BatchConvergence<ValueType>::on_batch_solver_completed(


#define GKO_DECLARE_BATCH_CONVERGENCE(_type) class BatchConvergence<_type>
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CONVERGENCE);
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE_WITH_HALF(GKO_DECLARE_BATCH_CONVERGENCE);


} // namespace log
Expand Down
Loading
Loading