diff --git a/stan/math/fwd/fun/lmultiply.hpp b/stan/math/fwd/fun/lmultiply.hpp index 4af4aaf27c4..a6e7f0ac74d 100644 --- a/stan/math/fwd/fun/lmultiply.hpp +++ b/stan/math/fwd/fun/lmultiply.hpp @@ -4,27 +4,8 @@ #include #include #include +#include #include #include -namespace stan { -namespace math { - -template -inline fvar lmultiply(const fvar& x1, const fvar& x2) { - return fvar(lmultiply(x1.val_, x2.val_), - x1.d_ * log(x2.val_) + x1.val_ * x2.d_ / x2.val_); -} - -template -inline fvar lmultiply(double x1, const fvar& x2) { - return fvar(lmultiply(x1, x2.val_), x1 * x2.d_ / x2.val_); -} - -template -inline fvar lmultiply(const fvar& x1, double x2) { - return fvar(lmultiply(x1.val_, x2), x1.d_ * log(x2)); -} -} // namespace math -} // namespace stan #endif diff --git a/stan/math/fwd/fun/multiply_log.hpp b/stan/math/fwd/fun/multiply_log.hpp index 3df40b33268..bad9b845610 100644 --- a/stan/math/fwd/fun/multiply_log.hpp +++ b/stan/math/fwd/fun/multiply_log.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -12,19 +13,29 @@ namespace math { template inline fvar multiply_log(const fvar& x1, const fvar& x2) { + if (value_of_rec(x1) == 0.0 && value_of_rec(x2) == 0.0) { + return fvar(0.0); + } return fvar(multiply_log(x1.val_, x2.val_), x1.d_ * log(x2.val_) + x1.val_ * x2.d_ / x2.val_); } template inline fvar multiply_log(double x1, const fvar& x2) { + if (x1 == 0.0 && value_of_rec(x2) == 0.0) { + return fvar(0.0); + } return fvar(multiply_log(x1, x2.val_), x1 * x2.d_ / x2.val_); } template inline fvar multiply_log(const fvar& x1, double x2) { + if (value_of_rec(x1) == 0.0 && x2 == 0.0) { + return fvar(0.0); + } return fvar(multiply_log(x1.val_, x2), x1.d_ * log(x2)); } + } // namespace math } // namespace stan #endif diff --git a/stan/math/opencl/kernel_generator/elt_function_cl.hpp b/stan/math/opencl/kernel_generator/elt_function_cl.hpp index a40301d716c..19b58d3683d 100644 --- a/stan/math/opencl/kernel_generator/elt_function_cl.hpp +++ b/stan/math/opencl/kernel_generator/elt_function_cl.hpp @@ -365,8 +365,8 @@ ADD_BINARY_FUNCTION_WITH_INCLUDES(log_diff_exp, opencl_kernels::log_diff_exp_device_function) ADD_BINARY_FUNCTION_WITH_INCLUDES( multiply_log, stan::math::opencl_kernels::multiply_log_device_function) -ADD_BINARY_FUNCTION_WITH_INCLUDES( - lmultiply, stan::math::opencl_kernels::lmultiply_device_function) +// ADD_BINARY_FUNCTION_WITH_INCLUDES( +// lmultiply, stan::math::opencl_kernels::lmultiply_device_function) #undef ADD_BINARY_FUNCTION_WITH_INCLUDES #undef ADD_UNARY_FUNCTION_WITH_INCLUDES diff --git a/stan/math/opencl/rev/lmultiply.hpp b/stan/math/opencl/rev/lmultiply.hpp index 46295525cd5..e69de29bb2d 100644 --- a/stan/math/opencl/rev/lmultiply.hpp +++ b/stan/math/opencl/rev/lmultiply.hpp @@ -1,46 +0,0 @@ -#ifndef STAN_MATH_OPENCL_REV_LMULTIPLY_HPP -#define STAN_MATH_OPENCL_REV_LMULTIPLY_HPP -#ifdef STAN_OPENCL - -#include -#include -#include -#include -#include - -namespace stan { -namespace math { - -/** - * Returns the elementwise `lmultiply()` of the input. - * - * @tparam T_a type of first expression - * @tparam T_b type of second expression - * @param a first expression - * @param b second expression - * - * @return Elementwise `lmultiply()` of the input. - */ -template * = nullptr, - require_any_var_t* = nullptr, - require_any_not_stan_scalar_t* = nullptr> -inline var_value> lmultiply(T_a&& a, T_b&& b) { - arena_t a_arena = std::forward(a); - arena_t b_arena = std::forward(b); - - return make_callback_var( - lmultiply(value_of(a_arena), value_of(b_arena)), - [a_arena, b_arena](const vari_value>& res) mutable { - adjoint_results(a_arena, b_arena) += expressions( - elt_multiply(res.adj(), log(value_of(b_arena))), - elt_multiply(res.adj(), - elt_divide(value_of(a_arena), value_of(b_arena)))); - }); -} - -} // namespace math -} // namespace stan - -#endif -#endif diff --git a/stan/math/opencl/rev/multiply_log.hpp b/stan/math/opencl/rev/multiply_log.hpp index 535129b3e77..a1527d8697b 100644 --- a/stan/math/opencl/rev/multiply_log.hpp +++ b/stan/math/opencl/rev/multiply_log.hpp @@ -10,6 +10,16 @@ namespace stan { namespace math { +namespace internalcl { +template +inline decltype(auto) conditional_sum(T&& x) { + if constexpr (Cond) { + return sum(std::forward(x)); + } else { + return std::forward(x); + } +} +} // namespace internalcl /** * Returns the elementwise `multiply_log()` of the input. @@ -32,10 +42,26 @@ inline var_value> multiply_log(T_a&& a, T_b&& b) { return make_callback_var( multiply_log(value_of(a_arena), value_of(b_arena)), [a_arena, b_arena](const vari_value>& res) mutable { - adjoint_results(a_arena, b_arena) += expressions( - elt_multiply(res.adj(), log(value_of(b_arena))), - elt_multiply(res.adj(), - elt_divide(value_of(a_arena), value_of(b_arena)))); + constexpr bool is_scalar_a = !is_matrix_v; + constexpr bool is_scalar_b = !is_matrix_v; + using internalcl::conditional_sum; + auto is_zero = value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0; + if constexpr (is_var::value && is_var::value) { + a_arena.adj() += conditional_sum(select( + is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena))))); + b_arena.adj() += conditional_sum(select( + is_zero, 0.0, + elt_multiply(res.adj(), + elt_divide(value_of(a_arena), value_of(b_arena))))); + } else if constexpr (is_var::value) { + a_arena.adj() += conditional_sum(select( + is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena))))); + } else if constexpr (is_var::value) { + b_arena.adj() += conditional_sum(select( + is_zero, 0.0, + elt_multiply(res.adj(), + elt_divide(value_of(a_arena), value_of(b_arena))))); + } }); } diff --git a/stan/math/prim/constraint/offset_multiplier_constrain.hpp b/stan/math/prim/constraint/offset_multiplier_constrain.hpp index 186c8f38988..0987ef7e2c5 100644 --- a/stan/math/prim/constraint/offset_multiplier_constrain.hpp +++ b/stan/math/prim/constraint/offset_multiplier_constrain.hpp @@ -29,9 +29,9 @@ namespace math { *

If the offset is zero and the multiplier is one this * reduces to identity_constrain(x). * - * @tparam T type of scalar - * @tparam M type of offset - * @tparam S type of multiplier + * @tparam T A scalar type or type inheriting from `Eigen::DenseBase` + * @tparam M A scalar type or type inheriting from `Eigen::DenseBase` + * @tparam S A scalar type or type inheriting from `Eigen::DenseBase` * @param[in] x Unconstrained scalar input * @param[in] mu offset of constrained output * @param[in] sigma multiplier of constrained output @@ -41,25 +41,31 @@ namespace math { */ template * = nullptr> -inline auto offset_multiplier_constrain(const T& x, const M& mu, - const S& sigma) { - const auto& mu_ref = to_ref(mu); - const auto& sigma_ref = to_ref(sigma); - if (is_matrix::value && is_matrix::value) { + T, M, S>* = nullptr, + require_all_not_std_vector_t* = nullptr> +inline auto offset_multiplier_constrain(T&& x, M&& mu, S&& sigma) { + if (is_matrix_v && is_matrix::value) { check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); } - if (is_matrix::value && is_matrix::value) { + if (is_matrix_v && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - } else if (is_matrix::value && is_matrix::value) { + } else if (is_matrix::value && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "mu", mu, "sigma", sigma); } - + auto&& mu_ref = to_ref(std::forward(mu)); + auto&& sigma_ref = to_ref(std::forward(sigma)); check_finite("offset_multiplier_constrain", "offset", value_of_rec(mu_ref)); check_positive_finite("offset_multiplier_constrain", "multiplier", value_of_rec(sigma_ref)); - return stan::math::eval(fma(sigma_ref, x, mu_ref)); + return make_holder( + [](auto&& sigma_ref, auto&& x, auto&& mu_ref) { + return fma(std::forward(sigma_ref), + std::forward(x), + std::forward(mu_ref)); + }, + std::forward(sigma_ref), std::forward(x), + std::forward(mu_ref)); } /** @@ -77,10 +83,9 @@ inline auto offset_multiplier_constrain(const T& x, const M& mu, * If the offset is zero and multiplier is one, this function * reduces to identity_constraint(x, lp). * - * @tparam T type of scalar - * @tparam M type of offset - * @tparam S type of multiplier - * @tparam Lp Scalar type, convertable from T, M, and S + * @tparam T A scalar type or type inheriting from `Eigen::DenseBase` + * @tparam M A scalar type or type inheriting from `Eigen::DenseBase` + * @tparam S A scalar type or type inheriting from `Eigen::DenseBase` * @param[in] x Unconstrained scalar input * @param[in] mu offset of constrained output * @param[in] sigma multiplier of constrained output @@ -92,186 +97,110 @@ inline auto offset_multiplier_constrain(const T& x, const M& mu, template , Lp>* = nullptr, require_all_not_nonscalar_prim_or_rev_kernel_expression_t< - T, M, S>* = nullptr> -inline auto offset_multiplier_constrain(const T& x, const M& mu, const S& sigma, - Lp& lp) { - const auto& mu_ref = to_ref(mu); - const auto& sigma_ref = to_ref(sigma); - if (is_matrix::value && is_matrix::value) { + T, M, S>* = nullptr, + require_all_not_std_vector_t* = nullptr> +inline auto offset_multiplier_constrain(T&& x, M&& mu, S&& sigma, Lp& lp) { + if constexpr (is_matrix_v && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); } - if (is_matrix::value && is_matrix::value) { + if constexpr (is_matrix_v && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - } else if (is_matrix::value && is_matrix::value) { + } + if constexpr (is_matrix_v && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "mu", mu, "sigma", sigma); } - + auto&& mu_ref = to_ref(std::forward(mu)); + auto&& sigma_ref = to_ref(std::forward(sigma)); check_finite("offset_multiplier_constrain", "offset", value_of_rec(mu_ref)); check_positive_finite("offset_multiplier_constrain", "multiplier", value_of_rec(sigma_ref)); - if (math::size(sigma_ref) == 1) { - lp += sum(multiply_log(math::size(x), sigma_ref)); + if (stan::math::size(sigma_ref) == 1) { + lp += sum(multiply_log(static_cast(math::size(x)), sigma_ref)); } else { lp += sum(log(sigma_ref)); } - return stan::math::eval(fma(sigma_ref, x, mu_ref)); + return make_holder( + [](auto&& sigma_ref, auto&& x, auto&& mu_ref) { + return fma(std::forward(sigma_ref), + std::forward(x), + std::forward(mu_ref)); + }, + std::forward(sigma_ref), std::forward(x), + std::forward(mu_ref)); } /** - * Overload for array of x and non-array mu and sigma + * Overload for when x and mu or sigma are `std::vectors` */ template * = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, const M& mu, - const S& sigma) { - std::vector< - plain_type_t> - ret; - ret.reserve(x.size()); - const auto& mu_ref = to_ref(mu); - const auto& sigma_ref = to_ref(sigma); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu_ref, sigma_ref)); - } - return ret; -} - -/** - * Overload for array of x and non-array mu and sigma with lp - */ -template , Lp>* = nullptr, - require_all_not_std_vector_t* = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, const M& mu, - const S& sigma, Lp& lp) { - std::vector< - plain_type_t> - ret; - ret.reserve(x.size()); - const auto& mu_ref = to_ref(mu); - const auto& sigma_ref = to_ref(sigma); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu_ref, sigma_ref, lp)); + require_any_std_vector_t* = nullptr> +inline auto offset_multiplier_constrain(const T& x, M&& mu, S&& sigma) { + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); } - return ret; -} - -/** - * Overload for array of x and sigma and non-array mu - */ -template * = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, const M& mu, - const std::vector& sigma) { - check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - std::vector< - plain_type_t> - ret; - ret.reserve(x.size()); - const auto& mu_ref = to_ref(mu); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu_ref, sigma[i])); + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); } - return ret; -} - -/** - * Overload for array of x and sigma and non-array mu with lp - */ -template , Lp>* = nullptr, - require_not_std_vector_t* = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, const M& mu, - const std::vector& sigma, Lp& lp) { - check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - std::vector> - ret; - ret.reserve(x.size()); - const auto& mu_ref = to_ref(mu); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu_ref, sigma[i], lp)); + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "mu", mu, "sigma", + sigma); } - return ret; -} - -/** - * Overload for array of x and mu and non-array sigma - */ -template * = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, - const std::vector& mu, - const S& sigma) { - check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); - std::vector< - plain_type_t> - ret; + auto iter = [](auto&& it, std::size_t i) -> decltype(auto) { + if constexpr (is_std_vector_v) { + return it[i]; + } else { + return it; + } + }; + auto&& mu_ref = to_ref(std::forward(mu)); + auto&& sigma_ref = to_ref(std::forward(sigma)); + using inner_ret_t = decltype(offset_multiplier_constrain( + iter(x, 0), iter(mu_ref, 0), iter(sigma_ref, 0))); + std::vector> ret; + // In the language, if mu or sigma is a vector, x must also be a vector ret.reserve(x.size()); - const auto& sigma_ref = to_ref(sigma); for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu[i], sigma_ref)); + ret.emplace_back( + offset_multiplier_constrain(x[i], iter(mu_ref, i), iter(sigma_ref, i))); } return ret; } /** - * Overload for array of x and mu and non-array sigma with lp + * Overload for when x and mu or sigma are `std::vectors` */ template , Lp>* = nullptr, - require_not_std_vector_t* = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, - const std::vector& mu, - const S& sigma, Lp& lp) { - check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); - std::vector> - ret; - ret.reserve(x.size()); - const auto& sigma_ref = to_ref(sigma); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu[i], sigma_ref, lp)); + require_any_std_vector_t* = nullptr> +inline auto offset_multiplier_constrain(const T& x, M&& mu, S&& sigma, Lp& lp) { + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); } - return ret; -} - -/** - * Overload for array of x, mu, and sigma - */ -template -inline auto offset_multiplier_constrain(const std::vector& x, - const std::vector& mu, - const std::vector& sigma) { - check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); - check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - std::vector> - ret; - ret.reserve(x.size()); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu[i], sigma[i])); + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); } - return ret; -} - -/** - * Overload for array of x, mu, and sigma with lp - */ -template , Lp>* = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, - const std::vector& mu, - const std::vector& sigma, Lp& lp) { - check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); - check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - std::vector> - ret; + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "mu", mu, "sigma", + sigma); + } + auto iter = [](auto&& it, std::size_t i) -> decltype(auto) { + if constexpr (is_std_vector_v) { + return it[i]; + } else { + return it; + } + }; + auto&& mu_ref = to_ref(std::forward(mu)); + auto&& sigma_ref = to_ref(std::forward(sigma)); + using inner_ret_t = decltype(offset_multiplier_constrain( + iter(x, 0), iter(mu_ref, 0), iter(sigma_ref, 0))); + std::vector> ret; + // In the language, if mu or sigma is a vector, x must also be a vector ret.reserve(x.size()); for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu[i], sigma[i], lp)); + ret.emplace_back(offset_multiplier_constrain(x[i], iter(mu_ref, i), + iter(sigma_ref, i), lp)); } return ret; } diff --git a/stan/math/prim/fun/lmultiply.hpp b/stan/math/prim/fun/lmultiply.hpp index e7de4468f81..351b5a7e4e5 100644 --- a/stan/math/prim/fun/lmultiply.hpp +++ b/stan/math/prim/fun/lmultiply.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -21,33 +22,20 @@ namespace math { * @param b second argument * @return the first argument times the log of the second argument */ -template * = nullptr> -inline return_type_t lmultiply(const T1 a, const T2 b) { - using std::log; - if (a == 0 && b == 0) { - return 0; +template +inline auto lmultiply(T1&& a, T2&& b) { + if constexpr (is_kernel_expression::value + || is_kernel_expression::value) { + return multiply_log(std::forward(a), std::forward(b)); + } else { + return make_holder( + [](auto&& a, auto&& b) { + return multiply_log(std::forward(a), + std::forward(b)); + }, + std::forward(a), std::forward(b)); } - return a * log(b); } - -/** - * Return the result of applying `lmultiply` to the arguments - * elementwise, with broadcasting if one of the arguments is a scalar. - * At least one of the arguments must be a container. - * - * @tparam T1 type of the first argument - * @tparam T2 type of the second argument - * @param a first argument - * @param b second argument - * @return result of applying `lmultiply` to the arguments - */ -template * = nullptr, - require_all_not_var_matrix_t* = nullptr> -inline auto lmultiply(const T1& a, const T2& b) { - return apply_scalar_binary( - [](const auto& c, const auto& d) { return lmultiply(c, d); }, a, b); -} - } // namespace math } // namespace stan diff --git a/stan/math/prim/fun/multiply_log.hpp b/stan/math/prim/fun/multiply_log.hpp index 6dcd8909999..1095ef7377b 100644 --- a/stan/math/prim/fun/multiply_log.hpp +++ b/stan/math/prim/fun/multiply_log.hpp @@ -47,12 +47,10 @@ namespace math { template * = nullptr> inline return_type_t multiply_log(const T_a a, const T_b b) { - using std::log; - if (b == 0.0 && a == 0.0) { + if (a == 0.0 && b == 0.0) { return 0.0; } - - return a * log(b); + return a * std::log(b); } /** @@ -66,7 +64,7 @@ inline return_type_t multiply_log(const T_a a, const T_b b) { * @return multiply_log function applied to the two inputs. */ template * = nullptr, - require_all_not_var_matrix_t* = nullptr> + require_all_not_rev_matrix_t* = nullptr> inline auto multiply_log(const T1& a, const T2& b) { return apply_scalar_binary( [](const auto& c, const auto& d) { return multiply_log(c, d); }, a, b); diff --git a/stan/math/prim/meta/is_constant.hpp b/stan/math/prim/meta/is_constant.hpp index b3fce314539..6bb0b0bb715 100644 --- a/stan/math/prim/meta/is_constant.hpp +++ b/stan/math/prim/meta/is_constant.hpp @@ -62,5 +62,8 @@ template struct is_constant> : bool_constant::Scalar>::value> {}; +template +inline constexpr bool is_not_constant_v = !is_constant>::value; + } // namespace stan #endif diff --git a/stan/math/prim/meta/is_matrix.hpp b/stan/math/prim/meta/is_matrix.hpp index 58cb26712ad..056ba326207 100644 --- a/stan/math/prim/meta/is_matrix.hpp +++ b/stan/math/prim/meta/is_matrix.hpp @@ -17,6 +17,9 @@ template struct is_matrix : bool_constant, is_eigen>::value> {}; +template +inline constexpr bool is_matrix_v = is_matrix>::value; + /*! \ingroup require_eigens_types */ /*! \defgroup matrix_types matrix */ /*! \addtogroup matrix_types */ diff --git a/stan/math/prim/meta/is_rev_matrix.hpp b/stan/math/prim/meta/is_rev_matrix.hpp index 40ef60bb9e9..fc7a3ba9234 100644 --- a/stan/math/prim/meta/is_rev_matrix.hpp +++ b/stan/math/prim/meta/is_rev_matrix.hpp @@ -46,6 +46,13 @@ using require_any_rev_matrix_t template using require_all_not_rev_matrix_t = require_all_not_t>...>; + +/*! \brief Require at least one of the types do not satisfy @ref is_rev_matrix + */ +/*! @tparam Types The types that are checked */ +template +using require_any_not_rev_matrix_t + = require_any_not_t>...>; /*! @} */ /** \ingroup type_trait diff --git a/stan/math/prim/meta/is_vector.hpp b/stan/math/prim/meta/is_vector.hpp index b0c62d255f2..2375ac5785f 100644 --- a/stan/math/prim/meta/is_vector.hpp +++ b/stan/math/prim/meta/is_vector.hpp @@ -597,6 +597,9 @@ struct is_std_vector< T, std::enable_if_t>::value>> : std::true_type {}; +template +inline constexpr bool is_std_vector_v = is_std_vector>::value; + /** \ingroup type_trait * Specialization of scalar_type for vector to recursively return the inner * scalar type. diff --git a/stan/math/rev/fun/lmultiply.hpp b/stan/math/rev/fun/lmultiply.hpp index 348fdb2fd4d..16783d94481 100644 --- a/stan/math/rev/fun/lmultiply.hpp +++ b/stan/math/rev/fun/lmultiply.hpp @@ -12,224 +12,4 @@ #include #include -namespace stan { -namespace math { - -namespace internal { -class lmultiply_vv_vari : public op_vv_vari { - public: - lmultiply_vv_vari(vari* avi, vari* bvi) - : op_vv_vari(lmultiply(avi->val_, bvi->val_), avi, bvi) {} - void chain() { - using std::log; - avi_->adj_ += adj_ * log(bvi_->val_); - bvi_->adj_ += adj_ * avi_->val_ / bvi_->val_; - } -}; -class lmultiply_vd_vari : public op_vd_vari { - public: - lmultiply_vd_vari(vari* avi, double b) - : op_vd_vari(lmultiply(avi->val_, b), avi, b) {} - void chain() { - using std::log; - avi_->adj_ += adj_ * log(bd_); - } -}; -class lmultiply_dv_vari : public op_dv_vari { - public: - lmultiply_dv_vari(double a, vari* bvi) - : op_dv_vari(lmultiply(a, bvi->val_), a, bvi) {} - void chain() { bvi_->adj_ += adj_ * ad_ / bvi_->val_; } -}; -} // namespace internal - -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to a is log(b). - * The partial derivative with respect to b is a/b. - * - * @param a First variable. - * @param b Second variable. - * @return Value of a*log(b) - */ -inline var lmultiply(const var& a, const var& b) { - return var(new internal::lmultiply_vv_vari(a.vi_, b.vi_)); -} -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to a is log(b). - * - * @param a First variable. - * @param b Second scalar. - * @return Value of a*log(b) - */ -inline var lmultiply(const var& a, double b) { - return var(new internal::lmultiply_vd_vari(a.vi_, b)); -} -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to b is a/b. - * - * @param a First scalar. - * @param b Second variable. - * @return Value of a*log(b) - */ -inline var lmultiply(double a, const var& b) { - if (a == 1.0) { - return log(b); - } - return var(new internal::lmultiply_dv_vari(a, b.vi_)); -} - -/** - * Return the elementwise product `a * log(b)`. - * - * Both `T1` and `T2` are matrices, and one of `T1` or `T2` must be a - * `var_value` - * - * @tparam T1 Type of first argument - * @tparam T2 Type of second argument - * @param a First argument - * @param b Second argument - * @return elementwise product of `a` and `log(b)` - */ -template * = nullptr, - require_any_var_matrix_t* = nullptr> -inline auto lmultiply(const T1& a, const T2& b) { - check_matching_dims("lmultiply", "a", a, "b", b); - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = b; - - return make_callback_var( - lmultiply(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() * arena_b.val().array().log(); - arena_b.adj().array() += res.adj().array() * arena_a.val().array() - / arena_b.val().array(); - }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = value_of(b); - - return make_callback_var(lmultiply(arena_a.val(), arena_b), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() - * arena_b.val().array().log(); - }); - } else { - arena_t> arena_a = value_of(a); - arena_t> arena_b = b; - - return make_callback_var(lmultiply(arena_a, arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_b.adj().array() += res.adj().array() - * arena_a.val().array() - / arena_b.val().array(); - }); - } -} - -/** - * Return the product `a * log(b)`. - * - * @tparam T1 Type of matrix argument - * @tparam T2 Type of scalar argument - * @param a Matrix argument - * @param b Scalar argument - * @return Product of `a` and `log(b)` - */ -template * = nullptr, - require_stan_scalar_t* = nullptr> -inline auto lmultiply(const T1& a, const T2& b) { - using std::log; - - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - var arena_b = b; - - return make_callback_var( - lmultiply(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() += res.adj().array() * log(arena_b.val()); - arena_b.adj() += (res.adj().array() * arena_a.val().array()).sum() - / arena_b.val(); - }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - - return make_callback_var(lmultiply(arena_a.val(), value_of(b)), - [arena_a, b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() * log(value_of(b)); - }); - } else { - arena_t> arena_a = value_of(a); - var arena_b = b; - - return make_callback_var( - lmultiply(arena_a, arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_b.adj() - += (res.adj().array() * arena_a.array()).sum() / arena_b.val(); - }); - } -} - -/** - * Return the product `a * log(b)`. - * - * @tparam T1 Type of scalar argument - * @tparam T2 Type of matrix argument - * @param a Scalar argument - * @param b Matrix argument - * @return Product of `a` and `log(b)` - */ -template * = nullptr, - require_var_matrix_t* = nullptr> -inline auto lmultiply(const T1& a, const T2& b) { - if (!is_constant::value && !is_constant::value) { - var arena_a = a; - arena_t> arena_b = b; - - return make_callback_var( - lmultiply(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj() - += (res.adj().array() * arena_b.val().array().log()).sum(); - arena_b.adj().array() - += arena_a.val() * res.adj().array() / arena_b.val().array(); - }); - } else if (!is_constant::value) { - var arena_a = a; - arena_t> arena_b = value_of(b); - - return make_callback_var( - lmultiply(arena_a.val(), arena_b), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj() - += (res.adj().array() * arena_b.val().array().log()).sum(); - }); - } else { - arena_t> arena_b = b; - - return make_callback_var(lmultiply(value_of(a), arena_b.val()), - [a, arena_b](const auto& res) mutable { - arena_b.adj().array() += value_of(a) - * res.adj().array() - / arena_b.val().array(); - }); - } -} - -} // namespace math -} // namespace stan #endif diff --git a/stan/math/rev/fun/multiply_log.hpp b/stan/math/rev/fun/multiply_log.hpp index 293cb1856cb..fa8bb95dad1 100644 --- a/stan/math/rev/fun/multiply_log.hpp +++ b/stan/math/rev/fun/multiply_log.hpp @@ -5,47 +5,21 @@ #include #include #include -#include +#include #include #include #include +#include #include namespace stan { namespace math { -namespace internal { -class multiply_log_vv_vari : public op_vv_vari { - public: - multiply_log_vv_vari(vari* avi, vari* bvi) - : op_vv_vari(multiply_log(avi->val_, bvi->val_), avi, bvi) {} - void chain() { - using std::log; - avi_->adj_ += adj_ * log(bvi_->val_); - bvi_->adj_ += adj_ * avi_->val_ / bvi_->val_; - } -}; -class multiply_log_vd_vari : public op_vd_vari { - public: - multiply_log_vd_vari(vari* avi, double b) - : op_vd_vari(multiply_log(avi->val_, b), avi, b) {} - void chain() { - using std::log; - avi_->adj_ += adj_ * log(bd_); - } -}; -class multiply_log_dv_vari : public op_dv_vari { - public: - multiply_log_dv_vari(double a, vari* bvi) - : op_dv_vari(multiply_log(a, bvi->val_), a, bvi) {} - void chain() { bvi_->adj_ += adj_ * ad_ / bvi_->val_; } -}; -} // namespace internal - /** * Return the value of a*log(b). * - * When both a and b are 0, the value returned is 0. + * When both a and b are 0, the value returned is 0 + * and no gradients are accumulated. * The partial derivative with respect to a is log(b). * The partial derivative with respect to b is a/b. * @@ -53,180 +27,94 @@ class multiply_log_dv_vari : public op_dv_vari { * @param b Second variable. * @return Value of a*log(b) */ -inline var multiply_log(const var& a, const var& b) { - return var(new internal::multiply_log_vv_vari(a.vi_, b.vi_)); -} -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to a is log(b). - * - * @param a First variable. - * @param b Second scalar. - * @return Value of a*log(b) - */ -inline var multiply_log(const var& a, double b) { - return var(new internal::multiply_log_vd_vari(a.vi_, b)); +template * = nullptr, + require_any_var_t* = nullptr> +inline var multiply_log(const T1& a, const T2& b) { + if (value_of(a) == 0.0 && value_of(b) == 0.0) { + return var(0.0); + } + return make_callback_var( + multiply_log(value_of(a), value_of(b)), [a, b](const auto& res) mutable { + if constexpr (!is_constant::value && !is_constant::value) { + a.adj() += res.adj() * log(b.val()); + b.adj() += res.adj() * a.val() / b.val(); + } else if constexpr (!is_constant::value) { + a.adj() += res.adj() * log(b); + } else { + b.adj() += res.adj() * a / b.val(); + } + }); } -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to b is a/b. - * - * @param a First scalar. - * @param b Second variable. - * @return Value of a*log(b) - */ -inline var multiply_log(double a, const var& b) { - if (a == 1.0) { - return log(b); + +namespace internal { +template +inline auto conditional_sum(T&& x) { + if constexpr (Cond) { + return sum(std::forward(x)); + } else { + return std::forward(x); } - return var(new internal::multiply_log_dv_vari(a, b.vi_)); } +} // namespace internal /** * Return the elementwise product `a * log(b)`. * - * Both `T1` and `T2` are matrices, and one of `T1` or `T2` must be a - * `var_value` - * - * @tparam T1 Type of first argument - * @tparam T2 Type of second argument + * For each element of `a` and `b`, when `a[i]` and `b[i]` are 0, + * the value and adjoint returned are zero. + * @tparam T1 Either a scalar or a matrix + * @tparam T2 Either a scalar or a matrix * @param a First argument * @param b Second argument * @return elementwise product of `a` and `log(b)` */ -template * = nullptr, - require_any_var_matrix_t* = nullptr> -inline auto multiply_log(const T1& a, const T2& b) { - check_matching_dims("multiply_log", "a", a, "b", b); - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = b; - - return make_callback_var( - multiply_log(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() * arena_b.val().array().log(); - arena_b.adj().array() += res.adj().array() * arena_a.val().array() - / arena_b.val().array(); - }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = value_of(b); - - return make_callback_var(multiply_log(arena_a.val(), arena_b), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() - * arena_b.val().array().log(); - }); - } else { - arena_t> arena_a = value_of(a); - arena_t> arena_b = b; - - return make_callback_var(multiply_log(arena_a, arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_b.adj().array() += res.adj().array() - * arena_a.val().array() - / arena_b.val().array(); - }); +template * = nullptr, + require_any_st_var* = nullptr> +inline auto multiply_log(T1&& a, T2&& b) { + constexpr bool is_a_scalar = !is_matrix_v; + constexpr bool is_b_scalar = !is_matrix_v; + if constexpr (!is_a_scalar && !is_b_scalar) { + check_matching_dims("multiply_log", "a", a, "b", b); } -} - -/** - * Return the product `a * log(b)`. - * - * @tparam T1 Type of matrix argument - * @tparam T2 Type of scalar argument - * @param a Matrix argument - * @param b Scalar argument - * @return Product of `a` and `log(b)` - */ -template * = nullptr, - require_stan_scalar_t* = nullptr> -inline auto multiply_log(const T1& a, const T2& b) { - using std::log; - - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - var arena_b = b; - - return make_callback_var( - multiply_log(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() += res.adj().array() * log(arena_b.val()); - arena_b.adj() += (res.adj().array() * arena_a.val().array()).sum() - / arena_b.val(); - }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - - return make_callback_var(multiply_log(arena_a.val(), value_of(b)), - [arena_a, b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() * log(value_of(b)); - }); - } else { - arena_t> arena_a = value_of(a); - var arena_b = b; - - return make_callback_var( - multiply_log(arena_a, arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_b.adj() - += (res.adj().array() * arena_a.array()).sum() / arena_b.val(); - }); - } -} - -/** - * Return the product `a * log(b)`. - * - * @tparam T1 Type of scalar argument - * @tparam T2 Type of matrix argument - * @param a Scalar argument - * @param b Matrix argument - * @return Product of `a` and `log(b)` - */ -template * = nullptr, - require_var_matrix_t* = nullptr> -inline auto multiply_log(const T1& a, const T2& b) { - if (!is_constant::value && !is_constant::value) { - var arena_a = a; - arena_t> arena_b = b; - - return make_callback_var( - multiply_log(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj() - += (res.adj().array() * arena_b.val().array().log()).sum(); - arena_b.adj().array() - += arena_a.val() * res.adj().array() / arena_b.val().array(); - }); - } else if (!is_constant::value) { - var arena_a = a; - arena_t> arena_b = value_of(b); - - return make_callback_var( - multiply_log(arena_a.val(), arena_b), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj() - += (res.adj().array() * arena_b.val().array().log()).sum(); - }); + arena_t arena_a = std::forward(a); + arena_t arena_b = std::forward(b); + using return_t = return_var_matrix_t< + decltype(multiply_log(value_of(arena_a), value_of(arena_b))), T1, T2>; + arena_t res = multiply_log(value_of(arena_a), value_of(arena_b)); + using internal::conditional_sum; + if constexpr (is_not_constant_v && is_not_constant_v) { + reverse_pass_callback([res, arena_a, arena_b]() mutable { + auto arena_a_arr = as_array_or_scalar(arena_a); + auto arena_b_arr = as_array_or_scalar(arena_b); + auto res_arr = as_array_or_scalar(res); + auto is_zero + = ((arena_a_arr.val() == 0.0 + arena_b_arr.val() == 0.0) > 1); + arena_a_arr.adj() += conditional_sum( + select(is_zero, 0.0, res_arr.adj() * log(arena_b_arr.val()))); + arena_b_arr.adj() += conditional_sum(select( + is_zero, 0.0, res_arr.adj() * arena_a_arr.val() / arena_b_arr.val())); + }); + } else if constexpr (is_not_constant_v) { + reverse_pass_callback([res, arena_a, arena_b]() mutable { + auto arena_a_arr = as_array_or_scalar(arena_a); + auto arena_b_arr = as_array_or_scalar(arena_b); + auto res_arr = as_array_or_scalar(res); + auto is_zero = ((arena_a_arr.val() == 0.0 + arena_b_arr == 0.0) > 1); + arena_a_arr.adj() += conditional_sum( + select(is_zero, 0.0, res_arr.adj() * log(arena_b_arr))); + }); } else { - arena_t> arena_b = b; - - return make_callback_var(multiply_log(value_of(a), arena_b.val()), - [a, arena_b](const auto& res) mutable { - arena_b.adj().array() += value_of(a) - * res.adj().array() - / arena_b.val().array(); - }); + reverse_pass_callback([res, arena_a, arena_b]() mutable { + auto arena_a_arr = as_array_or_scalar(arena_a); + auto arena_b_arr = as_array_or_scalar(arena_b); + auto res_arr = as_array_or_scalar(res); + auto is_zero = ((arena_a_arr == 0.0 + arena_b_arr.val() == 0.0) > 1); + arena_b_arr.adj() += conditional_sum(select( + is_zero, 0.0, res_arr.adj() * arena_a_arr / arena_b_arr.val())); + }); } + return res; } } // namespace math diff --git a/test/unit/math/mix/fun/multiply_log1_test.cpp b/test/unit/math/mix/fun/multiply_log1_test.cpp index ba8e1336f01..7c5f6cbdc9e 100644 --- a/test/unit/math/mix/fun/multiply_log1_test.cpp +++ b/test/unit/math/mix/fun/multiply_log1_test.cpp @@ -5,14 +5,19 @@ TEST(mathMixScalFun, multiplyLog) { auto f = [](const auto& x1, const auto& x2) { return stan::math::multiply_log(x1, x2); }; - stan::test::expect_ad(f, 0.5, -0.4); // error stan::test::expect_ad(f, 0.5, 1.2); stan::test::expect_ad(f, 1.5, 1.8); stan::test::expect_ad(f, 2.2, 3.3); stan::test::expect_ad(f, 19.7, 1299.1); +} +TEST(mathMixScalFun, multiplyLog_errors) { + auto f = [](const auto& x1, const auto& x2) { + return stan::math::multiply_log(x1, x2); + }; + stan::test::expect_ad(f, 0.5, -0.4); // error - double nan = std::numeric_limits::quiet_NaN(); + constexpr double nan = std::numeric_limits::quiet_NaN(); stan::test::expect_ad(f, 1.0, nan); stan::test::expect_ad(f, nan, 1.0); stan::test::expect_ad(f, nan, nan); diff --git a/test/unit/math/mix/fun/multiply_log2_test.cpp b/test/unit/math/mix/fun/multiply_log2_test.cpp index a124be21bdb..1d5694deea6 100644 --- a/test/unit/math/mix/fun/multiply_log2_test.cpp +++ b/test/unit/math/mix/fun/multiply_log2_test.cpp @@ -1,12 +1,26 @@ #include #include +/** + * NOTE: We do not test values of (0.0, 0.0) as inputs. + * This is because the testing framework uses + * finite difference as the comparison. The small finite + * purtubations lead to cases where the value of the + * function inputs is either slightly negative or slightly positive, + * which can lead to the function returning NaN when we would expect a + * value of 0. + */ TEST(mathMixScalFun, multiplyLog2_vec) { auto f = [](const auto& x1, const auto& x2) { using stan::math::multiply_log; return multiply_log(x1, x2); }; - + auto expect_test = [](auto&& f, auto&& x1, auto&& x2) { + stan::test::expect_ad(f, x1, x2); + stan::test::expect_ad(f, x2, x1); + stan::test::expect_ad_matvar(f, x1, x2); + stan::test::expect_ad_matvar(f, x2, x1); + }; Eigen::VectorXd in1(2); in1 << 3, 1; Eigen::VectorXd in2(2); @@ -18,28 +32,62 @@ TEST(mathMixScalFun, multiplyLog2_vec) { x2 << 1.0, 2.0, 3.0; Eigen::MatrixXd x3(2, 3); x3 << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0; - + stan::test::expect_ad(f, x1, x1); + stan::test::expect_ad(f, x2, x2); + stan::test::expect_ad(f, x3, x3); stan::test::expect_ad_matvar(f, x1, x1); - stan::test::expect_ad_matvar(f, x1, 2.0); - stan::test::expect_ad_matvar(f, 3.0, x1); stan::test::expect_ad_matvar(f, x2, x2); - stan::test::expect_ad_matvar(f, x2, 2.5); - stan::test::expect_ad_matvar(f, 3.5, x2); stan::test::expect_ad_matvar(f, x3, x3); - stan::test::expect_ad_matvar(f, x3, 4.0); - stan::test::expect_ad_matvar(f, 5.0, x3); + expect_test(f, x1, 2.0); + expect_test(f, x2, 2.5); + expect_test(f, x3, 5.5); Eigen::VectorXd x4(0); Eigen::RowVectorXd x5(0); Eigen::MatrixXd x6(0, 0); + stan::test::expect_ad(f, x4, x4); + stan::test::expect_ad(f, x5, x5); + stan::test::expect_ad(f, x6, x6); stan::test::expect_ad_matvar(f, x4, x4); - stan::test::expect_ad_matvar(f, x4, 2.0); - stan::test::expect_ad_matvar(f, 3.0, x4); stan::test::expect_ad_matvar(f, x5, x5); - stan::test::expect_ad_matvar(f, x5, 2.5); - stan::test::expect_ad_matvar(f, 3.5, x5); stan::test::expect_ad_matvar(f, x6, x6); - stan::test::expect_ad_matvar(f, x6, 4.0); - stan::test::expect_ad_matvar(f, 5.0, x6); + expect_test(f, x4, 2.0); + expect_test(f, x5, 3.1); + expect_test(f, x6, 5.5); +} + +TEST(mathMixScalFun, multiplyLog2_zero_vec_vec) { + auto f = [](const auto& x1, const auto& x2) { + using stan::math::multiply_log; + return multiply_log(x1, x2); + }; + auto expect_test = [](auto&& f, auto&& x1, auto&& x2) { + stan::test::expect_ad(f, x1, x2); + stan::test::expect_ad(f, x2, x1); + stan::test::expect_ad_matvar(f, x1, x2); + stan::test::expect_ad_matvar(f, x2, x1); + }; + + Eigen::VectorXd zero_vec = Eigen::VectorXd::Zero(3); + Eigen::VectorXd x1(3); + x1 << 1.0, 2.0, 3.0; + expect_test(f, zero_vec, x1); +} +TEST(mathMixScalFun, multiplyLog2_zero_vec_scalar) { + auto f = [](const auto& x1, const auto& x2) { + using stan::math::multiply_log; + return multiply_log(x1, x2); + }; + + Eigen::VectorXd zero_vec = Eigen::VectorXd::Zero(3); + Eigen::VectorXd x1(3); + x1 << 1.0, 2.0, 3.0; + auto expect_test = [](auto&& f, auto&& x1, auto&& x2) { + stan::test::expect_ad(f, x1, x2); + stan::test::expect_ad(f, x2, x1); + stan::test::expect_ad_matvar(f, x1, x2); + stan::test::expect_ad_matvar(f, x2, x1); + }; + expect_test(f, x1, 0.0); } diff --git a/test/unit/math/rev/fun/multiply_log_test.cpp b/test/unit/math/rev/fun/multiply_log_test.cpp new file mode 100644 index 00000000000..fa4aa78f5a5 --- /dev/null +++ b/test/unit/math/rev/fun/multiply_log_test.cpp @@ -0,0 +1,12 @@ +#include +#include +#include +#include + +TEST(RevTest, multiply_log_issue3146) { + std::int32_t x = 1; + using stan::math::var_value; + using mat_t = Eigen::Matrix; + var_value y = Eigen::MatrixXd::Random(10, 10); + auto z = stan::math::multiply_log(x, y); +}