diff --git a/stan/math/mix.hpp b/stan/math/mix.hpp index 38b6a5e9c0e..f28846b29dd 100644 --- a/stan/math/mix.hpp +++ b/stan/math/mix.hpp @@ -5,15 +5,15 @@ #include #include -#ifdef STAN_OPENCL -#include -#endif - #include #include #include #include +#ifdef STAN_OPENCL +#include +#endif + #include #include #include diff --git a/stan/math/prim/fun/acos.hpp b/stan/math/prim/fun/acos.hpp index d529ad04d48..93b77a217ee 100644 --- a/stan/math/prim/fun/acos.hpp +++ b/stan/math/prim/fun/acos.hpp @@ -47,7 +47,7 @@ template * = nullptr> inline auto acos(const Container& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } /** diff --git a/stan/math/prim/fun/cos.hpp b/stan/math/prim/fun/cos.hpp index 442c633c7b3..cea527271c3 100644 --- a/stan/math/prim/fun/cos.hpp +++ b/stan/math/prim/fun/cos.hpp @@ -42,7 +42,7 @@ template * = nullptr> inline auto cos(const Container& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } /** diff --git a/stan/math/prim/fun/digamma.hpp b/stan/math/prim/fun/digamma.hpp index e4f8199202c..5186a5d0f35 100644 --- a/stan/math/prim/fun/digamma.hpp +++ b/stan/math/prim/fun/digamma.hpp @@ -75,7 +75,7 @@ template * = nullptr, require_not_var_matrix_t* = nullptr> inline auto digamma(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/erfc.hpp b/stan/math/prim/fun/erfc.hpp index b5196cb273d..029c1808e3c 100644 --- a/stan/math/prim/fun/erfc.hpp +++ b/stan/math/prim/fun/erfc.hpp @@ -37,7 +37,7 @@ template < require_all_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr, require_not_var_matrix_t* = nullptr> inline auto erfc(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/exp.hpp b/stan/math/prim/fun/exp.hpp index 384d8d3c7d4..6307b7bbe0c 100644 --- a/stan/math/prim/fun/exp.hpp +++ b/stan/math/prim/fun/exp.hpp @@ -45,7 +45,7 @@ template < require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr, require_not_var_matrix_t* = nullptr> inline auto exp(const Container& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } /** diff --git a/stan/math/prim/fun/exp2.hpp b/stan/math/prim/fun/exp2.hpp index 23b2e803131..06e8ca13c1a 100644 --- a/stan/math/prim/fun/exp2.hpp +++ b/stan/math/prim/fun/exp2.hpp @@ -40,7 +40,7 @@ template < require_all_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr, require_not_var_matrix_t* = nullptr> inline auto exp2(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/inv.hpp b/stan/math/prim/fun/inv.hpp index 92ec18be095..ab958e658b3 100644 --- a/stan/math/prim/fun/inv.hpp +++ b/stan/math/prim/fun/inv.hpp @@ -35,7 +35,7 @@ template < typename T, require_not_container_st* = nullptr, require_all_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto inv(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } /** diff --git a/stan/math/prim/fun/inv_Phi.hpp b/stan/math/prim/fun/inv_Phi.hpp index 85cf0f0e379..a7026b7397c 100644 --- a/stan/math/prim/fun/inv_Phi.hpp +++ b/stan/math/prim/fun/inv_Phi.hpp @@ -178,7 +178,7 @@ template < require_all_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr, require_not_var_matrix_t* = nullptr> inline auto inv_Phi(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/inv_cloglog.hpp b/stan/math/prim/fun/inv_cloglog.hpp index daaec48c278..a2edfe48e49 100644 --- a/stan/math/prim/fun/inv_cloglog.hpp +++ b/stan/math/prim/fun/inv_cloglog.hpp @@ -77,7 +77,7 @@ template * = nullptr> inline auto inv_cloglog(const Container& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } /** diff --git a/stan/math/prim/fun/inv_logit.hpp b/stan/math/prim/fun/inv_logit.hpp index 7c8355f9ffe..a931df57833 100644 --- a/stan/math/prim/fun/inv_logit.hpp +++ b/stan/math/prim/fun/inv_logit.hpp @@ -85,7 +85,7 @@ template < typename T, require_not_var_matrix_t* = nullptr, require_all_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto inv_logit(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } // TODO(Tadej): Eigen is introducing their implementation logistic() of this diff --git a/stan/math/prim/fun/inv_sqrt.hpp b/stan/math/prim/fun/inv_sqrt.hpp index d2ec2b01885..a77218f32e7 100644 --- a/stan/math/prim/fun/inv_sqrt.hpp +++ b/stan/math/prim/fun/inv_sqrt.hpp @@ -46,7 +46,7 @@ template * = nullptr> inline auto inv_sqrt(const Container& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } /** diff --git a/stan/math/prim/fun/lgamma.hpp b/stan/math/prim/fun/lgamma.hpp index eeaa3006322..aae1e87c646 100644 --- a/stan/math/prim/fun/lgamma.hpp +++ b/stan/math/prim/fun/lgamma.hpp @@ -117,7 +117,7 @@ struct lgamma_fun { template * = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto lgamma(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/log.hpp b/stan/math/prim/fun/log.hpp index ba256709b95..c80138bbbfb 100644 --- a/stan/math/prim/fun/log.hpp +++ b/stan/math/prim/fun/log.hpp @@ -48,7 +48,7 @@ template < require_not_var_matrix_t* = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto log(const Container& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } /** diff --git a/stan/math/prim/fun/log10.hpp b/stan/math/prim/fun/log10.hpp index a5702086bfd..222e3f4c1d2 100644 --- a/stan/math/prim/fun/log10.hpp +++ b/stan/math/prim/fun/log10.hpp @@ -39,7 +39,7 @@ template < require_not_container_st* = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto log10(const Container& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } /** diff --git a/stan/math/prim/fun/log1m.hpp b/stan/math/prim/fun/log1m.hpp index 2b81dbd1edb..9dc1e019c48 100644 --- a/stan/math/prim/fun/log1m.hpp +++ b/stan/math/prim/fun/log1m.hpp @@ -71,7 +71,7 @@ template < typename T, require_not_var_matrix_t* = nullptr, require_all_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto log1m(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/log1m_exp.hpp b/stan/math/prim/fun/log1m_exp.hpp index 5243088649d..a8ba63aa997 100644 --- a/stan/math/prim/fun/log1m_exp.hpp +++ b/stan/math/prim/fun/log1m_exp.hpp @@ -81,7 +81,7 @@ template < typename T, require_not_var_matrix_t* = nullptr, require_all_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto log1m_exp(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/log1m_inv_logit.hpp b/stan/math/prim/fun/log1m_inv_logit.hpp index 4755da4f21e..c41bc301ae4 100644 --- a/stan/math/prim/fun/log1m_inv_logit.hpp +++ b/stan/math/prim/fun/log1m_inv_logit.hpp @@ -82,9 +82,8 @@ struct log1m_inv_logit_fun { */ template * = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> -inline typename apply_scalar_unary::return_t -log1m_inv_logit(const T& x) { - return apply_scalar_unary::apply(x); +inline auto log1m_inv_logit(const T& x) { + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/log1p.hpp b/stan/math/prim/fun/log1p.hpp index 36f1c4b05e4..81d659639eb 100644 --- a/stan/math/prim/fun/log1p.hpp +++ b/stan/math/prim/fun/log1p.hpp @@ -80,7 +80,7 @@ template * = nullptr, require_not_var_matrix_t* = nullptr> inline auto log1p(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/log1p_exp.hpp b/stan/math/prim/fun/log1p_exp.hpp index 821e08889e3..cd91d4cffed 100644 --- a/stan/math/prim/fun/log1p_exp.hpp +++ b/stan/math/prim/fun/log1p_exp.hpp @@ -76,7 +76,7 @@ template * = nullptr, require_not_var_matrix_t* = nullptr> inline auto log1p_exp(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/log2.hpp b/stan/math/prim/fun/log2.hpp index ef0dd4dd844..eaf04724056 100644 --- a/stan/math/prim/fun/log2.hpp +++ b/stan/math/prim/fun/log2.hpp @@ -47,7 +47,7 @@ struct log2_fun { template * = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto log2(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/log_inv_logit.hpp b/stan/math/prim/fun/log_inv_logit.hpp index 3436f87d951..be329f01732 100644 --- a/stan/math/prim/fun/log_inv_logit.hpp +++ b/stan/math/prim/fun/log_inv_logit.hpp @@ -81,7 +81,7 @@ struct log_inv_logit_fun { template * = nullptr> inline auto log_inv_logit(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/fun/logit.hpp b/stan/math/prim/fun/logit.hpp index affdae68b73..3535f5ebcfc 100644 --- a/stan/math/prim/fun/logit.hpp +++ b/stan/math/prim/fun/logit.hpp @@ -89,7 +89,7 @@ template < require_not_var_matrix_t* = nullptr, require_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr> inline auto logit(const Container& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } /** diff --git a/stan/math/prim/fun/tgamma.hpp b/stan/math/prim/fun/tgamma.hpp index 89ae4439289..ef514f90d17 100644 --- a/stan/math/prim/fun/tgamma.hpp +++ b/stan/math/prim/fun/tgamma.hpp @@ -51,7 +51,7 @@ template < require_all_not_nonscalar_prim_or_rev_kernel_expression_t* = nullptr, require_not_var_matrix_t* = nullptr> inline auto tgamma(const T& x) { - return apply_scalar_unary::apply(x); + return apply_scalar_unary::apply(x); } } // namespace math diff --git a/stan/math/prim/functor/apply_scalar_unary.hpp b/stan/math/prim/functor/apply_scalar_unary.hpp index 1dbc8b6f27d..f651965d19d 100644 --- a/stan/math/prim/functor/apply_scalar_unary.hpp +++ b/stan/math/prim/functor/apply_scalar_unary.hpp @@ -2,12 +2,7 @@ #define STAN_MATH_PRIM_FUNCTOR_APPLY_SCALAR_UNARY_HPP #include -#include -#include -#include -#include -#include -#include +#include #include #include @@ -36,8 +31,12 @@ namespace math { * * @tparam F Type of function to apply. * @tparam T Type of argument to which function is applied. + * @tparam ApplyZero If true, the function applied is assumed to return zero for + * inputs of zero and so sparse matrices will return sparse matrices. A value + * of false will return a dense matrix for sparse matrices. */ -template +template struct apply_scalar_unary; /** @@ -48,8 +47,8 @@ struct apply_scalar_unary; * @tparam F Type of function to apply. * @tparam T Type of argument to which function is applied. */ -template -struct apply_scalar_unary> { +template +struct apply_scalar_unary> { /** * Type of underlying scalar for the matrix type T. */ @@ -59,15 +58,82 @@ struct apply_scalar_unary> { * Return the result of applying the function defined by the * template parameter F to the specified matrix argument. * + * @tparam DenseMat A type derived from `Eigen::DenseBase`. * @param x Matrix to which operation is applied. * @return Componentwise application of the function specified * by F to the specified matrix. */ - static inline auto apply(const T& x) { + template * = nullptr> + static inline auto apply(const DenseMat& x) { return x.unaryExpr( [](scalar_t x) { return apply_scalar_unary::apply(x); }); } + /** + * Special case for `ApplyZero` set to true, returning a full sparse matrix. + * Return the result of applying the function defined by the template + * parameter F to the specified matrix argument. + * + * @param SparseMat A type derived from `Eigen::SparseMatrixBase` + * @tparam NonZeroZero Shortcut trick for using class template for deduction, + * should not be set manually. + * @param x Matrix to which operation is applied. + * @return Componentwise application of the function specified + * by F to the specified matrix. + */ + template >* = nullptr, + require_eigen_sparse_base_t* = nullptr> + static inline auto apply(const SparseMat& x) { + using val_t = value_type_t; + using triplet_t = Eigen::Triplet; + auto zeroed_val = apply_scalar_unary::apply(val_t(0.0)); + const auto x_size = x.size(); + std::vector triplet_list(x_size, triplet_t(0, 0, zeroed_val)); + for (Eigen::Index i = 0; i < x.rows(); ++i) { + for (Eigen::Index j = 0; j < x.cols(); ++j) { + // Column major order + triplet_list[i * x.cols() + j] = triplet_t(i, j, zeroed_val); + } + } + for (Eigen::Index k = 0; k < x.outerSize(); ++k) { + for (typename SparseMat::InnerIterator it(x, k); it; ++it) { + triplet_list[it.row() * x.cols() + it.col()] + = triplet_t(it.row(), it.col(), + apply_scalar_unary::apply(it.value())); + } + } + plain_type_t ret(x.rows(), x.cols()); + ret.setFromTriplets(triplet_list.begin(), triplet_list.end()); + return ret; + } + + /** + * Special case for `ApplyZero` set to false, returning a sparse matrix. + * Return the result of applying the function defined by the template + * parameter F to the specified matrix argument. + * + * @tparam SparseMat A type derived from `Eigen::SparseMatrixBase` + * @tparam NonZeroZero Shortcut trick for using class template for deduction, + * should not be set manually. + * @param x Matrix to which operation is applied. + * @return Componentwise application of the function specified + * by F to the specified matrix. + */ + template >* = nullptr, + require_eigen_sparse_base_t* = nullptr> + static inline auto apply(const SparseMat& x) { + auto ret = x.eval(); + for (Eigen::Index k = 0; k < x.outerSize(); ++k) { + for (typename SparseMat::InnerIterator it(x, k), ret_it(ret, k); it; + ++it, ++ret_it) { + ret_it.valueRef() = apply_scalar_unary::apply(it.value()); + } + } + return ret; + } + /** * Return type for applying the function elementwise to a matrix * expression template of type T. @@ -82,8 +148,8 @@ struct apply_scalar_unary> { * * @tparam F Type of function defining static apply function. */ -template -struct apply_scalar_unary> { +template +struct apply_scalar_unary> { /** * The return type, double. */ @@ -107,8 +173,8 @@ struct apply_scalar_unary> { * * @tparam F Type of function defining static apply function. */ -template -struct apply_scalar_unary> { +template +struct apply_scalar_unary> { /** * The return type, double. */ @@ -134,8 +200,8 @@ struct apply_scalar_unary> { * * @tparam F Type of function defining static apply function. */ -template -struct apply_scalar_unary> { +template +struct apply_scalar_unary> { /** * The return type, double. */ @@ -162,8 +228,8 @@ struct apply_scalar_unary> { * @tparam F Class defining a static apply function. * @tparam T Type of element contained in standard vector. */ -template -struct apply_scalar_unary> { +template +struct apply_scalar_unary, ApplyZero, void> { /** * Return type, which is calculated recursively as a standard * vector of the return type of the contained type T. diff --git a/stan/math/prim/meta/is_container.hpp b/stan/math/prim/meta/is_container.hpp index b24f0f3d558..311f0e8bb91 100644 --- a/stan/math/prim/meta/is_container.hpp +++ b/stan/math/prim/meta/is_container.hpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include #include @@ -15,12 +15,13 @@ namespace stan { /** - * Deduces whether type is eigen matrix or standard vector. + * Deduces whether type is a dense eigen type or standard vector. * @tparam Container type to check */ template -using is_container = bool_constant< - math::disjunction, is_std_vector>::value>; +using is_container + = bool_constant, + is_std_vector>::value>; STAN_ADD_REQUIRE_UNARY(container, is_container, general_types); STAN_ADD_REQUIRE_CONTAINER(container, is_container, general_types); diff --git a/stan/math/rev/functor/apply_scalar_unary.hpp b/stan/math/rev/functor/apply_scalar_unary.hpp index d68bd344335..ce1411e4802 100644 --- a/stan/math/rev/functor/apply_scalar_unary.hpp +++ b/stan/math/rev/functor/apply_scalar_unary.hpp @@ -15,8 +15,8 @@ namespace math { * * @tparam F Type of function to apply. */ -template -struct apply_scalar_unary { +template +struct apply_scalar_unary { /** * Function return type, which is var. */ @@ -31,8 +31,8 @@ struct apply_scalar_unary { static inline return_t apply(const var& x) { return F::fun(x); } }; -template -struct apply_scalar_unary> { +template +struct apply_scalar_unary> { /** * Function return type, which is a `var_value` with plain value type. */ diff --git a/test/unit/math/mix/fun/acos_test.cpp b/test/unit/math/mix/fun/acos_test.cpp index 40219636ed7..9262d456a7e 100644 --- a/test/unit/math/mix/fun/acos_test.cpp +++ b/test/unit/math/mix/fun/acos_test.cpp @@ -40,4 +40,5 @@ TEST(mathMixMatFun, acos_varmat) { A(i) = all_args[i]; } expect_ad_vector_matvar(f, A); + stan::test::expect_ad(stan::test::make_sparse_mat_func(f), A); } diff --git a/test/unit/math/mix/fun/atan_test.cpp b/test/unit/math/mix/fun/atan_test.cpp index 01ff681fa15..4c3f7e90c40 100644 --- a/test/unit/math/mix/fun/atan_test.cpp +++ b/test/unit/math/mix/fun/atan_test.cpp @@ -34,4 +34,5 @@ TEST(mathMixMatFun, atan_varmat) { A(i) = all_args[i]; } expect_ad_vector_matvar(f, A); + stan::test::expect_ad(stan::test::make_sparse_mat_func(f), A); } diff --git a/test/unit/math/test_ad.hpp b/test/unit/math/test_ad.hpp index 5a8fff9ec1d..0ac6fa6abc6 100644 --- a/test/unit/math/test_ad.hpp +++ b/test/unit/math/test_ad.hpp @@ -2150,6 +2150,33 @@ std::vector square_test_matrices(int low, int high) { return xs; } +template +auto gen_sparse_diag_mat(T1&& x) { + using triplet_t = Eigen::Triplet>; + std::vector tripletList; + tripletList.reserve(x.size()); + for (int i = 0; i < x.size(); i++) { + tripletList.emplace_back(i, i, x(i)); + } + Eigen::SparseMatrix> x_sparse(x.size(), x.size()); + x_sparse.setFromTriplets(tripletList.begin(), tripletList.end()); + x_sparse.makeCompressed(); + return x_sparse; +} + +/** + * Takes a unary lambda f taking a vector and makes a diagonal sparse matrix. + * @tparam ExpectDenseReturn if true, the return is checked for whether it's + * type is derived from `Eigen::DenseBase`, otherwise the return type is checked + * for whether it's type is derived from `Eigen::SparseMatrixBase`. + * @tparam A lambda + * @param An unary lambda + */ +template +auto make_sparse_mat_func(F&& f) { + return [&f](auto&& x) { return f(stan::test::gen_sparse_diag_mat(x)); }; +} + } // namespace test } // namespace stan #endif