From 18af903ae1f90b88bcdb5969cbba3ef3be46e705 Mon Sep 17 00:00:00 2001 From: stevebronder Date: Thu, 13 Feb 2025 13:44:51 -0500 Subject: [PATCH 01/19] Fix issue 3146 by using if constexpr for path deduction. Also remove most of the template function for offset_multiplier to simplify --- .../offset_multiplier_constrain.hpp | 240 ++++++------------ stan/math/prim/meta/is_constant.hpp | 3 + stan/math/prim/meta/is_matrix.hpp | 3 + stan/math/prim/meta/is_vector.hpp | 3 + stan/math/rev/fun/multiply_log.hpp | 105 +++----- test/unit/math/rev/fun/multiply_log_test.cpp | 12 + 6 files changed, 137 insertions(+), 229 deletions(-) create mode 100644 test/unit/math/rev/fun/multiply_log_test.cpp diff --git a/stan/math/prim/constraint/offset_multiplier_constrain.hpp b/stan/math/prim/constraint/offset_multiplier_constrain.hpp index a226ae2e15f..d9d448cce55 100644 --- a/stan/math/prim/constraint/offset_multiplier_constrain.hpp +++ b/stan/math/prim/constraint/offset_multiplier_constrain.hpp @@ -28,9 +28,9 @@ namespace math { *

If the offset is zero and the multiplier is one this * reduces to identity_constrain(x). * - * @tparam T type of scalar - * @tparam M type of offset - * @tparam S type of multiplier + * @tparam T A scalar type or type inheriting from `Eigen::DenseBase` + * @tparam M A scalar type or type inheriting from `Eigen::DenseBase` + * @tparam S A scalar type or type inheriting from `Eigen::DenseBase` * @param[in] x Unconstrained scalar input * @param[in] mu offset of constrained output * @param[in] sigma multiplier of constrained output @@ -40,25 +40,25 @@ namespace math { */ template * = nullptr> -inline auto offset_multiplier_constrain(const T& x, const M& mu, - const S& sigma) { - const auto& mu_ref = to_ref(mu); - const auto& sigma_ref = to_ref(sigma); - if (is_matrix::value && is_matrix::value) { + T, M, S>* = nullptr, + require_all_not_std_vector_t* = nullptr> +inline auto offset_multiplier_constrain(T&& x, M&& mu, S&& sigma) { + if (is_matrix_v && is_matrix::value) { check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); } - if (is_matrix::value && is_matrix::value) { + if (is_matrix_v && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - } else if (is_matrix::value && is_matrix::value) { + } else if (is_matrix::value && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "mu", mu, "sigma", sigma); } - + auto&& mu_ref = to_ref(std::forward(mu)); + auto&& sigma_ref = to_ref(std::forward(sigma)); check_finite("offset_multiplier_constrain", "offset", value_of_rec(mu_ref)); check_positive_finite("offset_multiplier_constrain", "multiplier", value_of_rec(sigma_ref)); - return fma(sigma_ref, x, mu_ref); + return fma(std::forward(sigma_ref), std::forward(x), + std::forward(mu_ref)); } /** @@ -76,9 +76,9 @@ inline auto offset_multiplier_constrain(const T& x, const M& mu, * If the offset is zero and multiplier is one, this function * reduces to identity_constraint(x, lp). * - * @tparam T type of scalar - * @tparam M type of offset - * @tparam S type of multiplier + * @tparam T A scalar type or type inheriting from `Eigen::DenseBase` + * @tparam M A scalar type or type inheriting from `Eigen::DenseBase` + * @tparam S A scalar type or type inheriting from `Eigen::DenseBase` * @param[in] x Unconstrained scalar input * @param[in] mu offset of constrained output * @param[in] sigma multiplier of constrained output @@ -89,186 +89,104 @@ inline auto offset_multiplier_constrain(const T& x, const M& mu, */ template * = nullptr> -inline auto offset_multiplier_constrain(const T& x, const M& mu, const S& sigma, + T, M, S>* = nullptr, + require_all_not_std_vector_t* = nullptr> +inline auto offset_multiplier_constrain(T&& x, M&& mu, S&& sigma, return_type_t& lp) { - const auto& mu_ref = to_ref(mu); - const auto& sigma_ref = to_ref(sigma); - if (is_matrix::value && is_matrix::value) { + if (is_matrix_v && is_matrix::value) { check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); } - if (is_matrix::value && is_matrix::value) { + if (is_matrix_v && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - } else if (is_matrix::value && is_matrix::value) { + } else if (is_matrix::value && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "mu", mu, "sigma", sigma); } - + auto&& mu_ref = to_ref(std::forward(mu)); + auto&& sigma_ref = to_ref(std::forward(sigma)); check_finite("offset_multiplier_constrain", "offset", value_of_rec(mu_ref)); check_positive_finite("offset_multiplier_constrain", "multiplier", value_of_rec(sigma_ref)); - if (math::size(sigma_ref) == 1) { - lp += sum(multiply_log(math::size(x), sigma_ref)); + if (stan::math::size(sigma_ref) == 1) { + lp += sum(multiply_log(static_cast(math::size(x)), sigma_ref)); } else { lp += sum(log(sigma_ref)); } - return fma(sigma_ref, x, mu_ref); + return fma(std::forward(sigma_ref), std::forward(x), + std::forward(mu_ref)); } /** - * Overload for array of x and non-array mu and sigma + * Overload for when x and mu or sigma are `std::vectors` */ template * = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, const M& mu, - const S& sigma) { - std::vector< - plain_type_t> - ret; - ret.reserve(x.size()); - const auto& mu_ref = to_ref(mu); - const auto& sigma_ref = to_ref(sigma); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu_ref, sigma_ref)); - } - return ret; -} - -/** - * Overload for array of x and non-array mu and sigma with lp - */ -template * = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, const M& mu, - const S& sigma, - return_type_t& lp) { - std::vector< - plain_type_t> - ret; - ret.reserve(x.size()); - const auto& mu_ref = to_ref(mu); - const auto& sigma_ref = to_ref(sigma); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu_ref, sigma_ref, lp)); + require_any_std_vector_t* = nullptr> +inline auto offset_multiplier_constrain(const T& x, M&& mu, S&& sigma) { + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); } - return ret; -} - -/** - * Overload for array of x and sigma and non-array mu - */ -template * = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, const M& mu, - const std::vector& sigma) { - check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - std::vector< - plain_type_t> - ret; - ret.reserve(x.size()); - const auto& mu_ref = to_ref(mu); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu_ref, sigma[i])); + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); } - return ret; -} - -/** - * Overload for array of x and sigma and non-array mu with lp - */ -template * = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, const M& mu, - const std::vector& sigma, - return_type_t& lp) { - check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - std::vector> - ret; - ret.reserve(x.size()); - const auto& mu_ref = to_ref(mu); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu_ref, sigma[i], lp)); + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "mu", mu, "sigma", + sigma); } - return ret; -} - -/** - * Overload for array of x and mu and non-array sigma - */ -template * = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, - const std::vector& mu, - const S& sigma) { - check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); - std::vector< - plain_type_t> - ret; + auto iter = [](auto&& it, std::size_t i) -> decltype(auto) { + if constexpr (is_std_vector_v) { + return it[i]; + } else { + return it; + } + }; + auto&& mu_ref = to_ref(std::forward(mu)); + auto&& sigma_ref = to_ref(std::forward(sigma)); + using inner_ret_t = decltype(offset_multiplier_constrain( + iter(x, 0), iter(mu_ref, 0), iter(sigma_ref, 0))); + std::vector> ret; + // In the language, if mu or sigma is a vector, x must also be a vector ret.reserve(x.size()); - const auto& sigma_ref = to_ref(sigma); for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu[i], sigma_ref)); + ret.emplace_back( + offset_multiplier_constrain(x[i], iter(mu_ref, i), iter(sigma_ref, i))); } return ret; } /** - * Overload for array of x and mu and non-array sigma with lp + * Overload for when x and mu or sigma are `std::vectors` */ template * = nullptr> -inline auto offset_multiplier_constrain(const std::vector& x, - const std::vector& mu, - const S& sigma, + require_any_std_vector_t* = nullptr> +inline auto offset_multiplier_constrain(const T& x, M&& mu, S&& sigma, return_type_t& lp) { - check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); - std::vector> - ret; - ret.reserve(x.size()); - const auto& sigma_ref = to_ref(sigma); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu[i], sigma_ref, lp)); + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); } - return ret; -} - -/** - * Overload for array of x, mu, and sigma - */ -template -inline auto offset_multiplier_constrain(const std::vector& x, - const std::vector& mu, - const std::vector& sigma) { - check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); - check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - std::vector> - ret; - ret.reserve(x.size()); - for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu[i], sigma[i])); + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); } - return ret; -} - -/** - * Overload for array of x, mu, and sigma with lp - */ -template -inline auto offset_multiplier_constrain(const std::vector& x, - const std::vector& mu, - const std::vector& sigma, - return_type_t& lp) { - check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); - check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - std::vector> - ret; + if constexpr (is_std_vector_v && is_std_vector_v) { + check_matching_dims("offset_multiplier_constrain", "mu", mu, "sigma", + sigma); + } + auto iter = [](auto&& it, std::size_t i) -> decltype(auto) { + if constexpr (is_std_vector_v) { + return it[i]; + } else { + return it; + } + }; + auto&& mu_ref = to_ref(std::forward(mu)); + auto&& sigma_ref = to_ref(std::forward(sigma)); + using inner_ret_t = decltype(offset_multiplier_constrain( + iter(x, 0), iter(mu_ref, 0), iter(sigma_ref, 0), lp)); + std::vector> ret; + // In the language, if mu or sigma is a vector, x must also be a vector ret.reserve(x.size()); for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back(offset_multiplier_constrain(x[i], mu[i], sigma[i], lp)); + ret.emplace_back(offset_multiplier_constrain(x[i], iter(mu_ref, i), + iter(sigma_ref, i), lp)); } return ret; } diff --git a/stan/math/prim/meta/is_constant.hpp b/stan/math/prim/meta/is_constant.hpp index b3fce314539..6bb0b0bb715 100644 --- a/stan/math/prim/meta/is_constant.hpp +++ b/stan/math/prim/meta/is_constant.hpp @@ -62,5 +62,8 @@ template struct is_constant> : bool_constant::Scalar>::value> {}; +template +inline constexpr bool is_not_constant_v = !is_constant>::value; + } // namespace stan #endif diff --git a/stan/math/prim/meta/is_matrix.hpp b/stan/math/prim/meta/is_matrix.hpp index 58cb26712ad..056ba326207 100644 --- a/stan/math/prim/meta/is_matrix.hpp +++ b/stan/math/prim/meta/is_matrix.hpp @@ -17,6 +17,9 @@ template struct is_matrix : bool_constant, is_eigen>::value> {}; +template +inline constexpr bool is_matrix_v = is_matrix>::value; + /*! \ingroup require_eigens_types */ /*! \defgroup matrix_types matrix */ /*! \addtogroup matrix_types */ diff --git a/stan/math/prim/meta/is_vector.hpp b/stan/math/prim/meta/is_vector.hpp index b0c62d255f2..2375ac5785f 100644 --- a/stan/math/prim/meta/is_vector.hpp +++ b/stan/math/prim/meta/is_vector.hpp @@ -597,6 +597,9 @@ struct is_std_vector< T, std::enable_if_t>::value>> : std::true_type {}; +template +inline constexpr bool is_std_vector_v = is_std_vector>::value; + /** \ingroup type_trait * Specialization of scalar_type for vector to recursively return the inner * scalar type. diff --git a/stan/math/rev/fun/multiply_log.hpp b/stan/math/rev/fun/multiply_log.hpp index 293cb1856cb..0c9ae82a5c8 100644 --- a/stan/math/rev/fun/multiply_log.hpp +++ b/stan/math/rev/fun/multiply_log.hpp @@ -100,12 +100,11 @@ inline var multiply_log(double a, const var& b) { */ template * = nullptr, require_any_var_matrix_t* = nullptr> -inline auto multiply_log(const T1& a, const T2& b) { +inline auto multiply_log(T1&& a, T2&& b) { check_matching_dims("multiply_log", "a", a, "b", b); - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = b; - + arena_t arena_a = std::forward(a); + arena_t arena_b = std::forward(b); + if constexpr (is_not_constant_v && is_not_constant_v) { return make_callback_var( multiply_log(arena_a.val(), arena_b.val()), [arena_a, arena_b](const auto& res) mutable { @@ -114,24 +113,17 @@ inline auto multiply_log(const T1& a, const T2& b) { arena_b.adj().array() += res.adj().array() * arena_a.val().array() / arena_b.val().array(); }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = value_of(b); - + } else if constexpr (is_not_constant_v) { return make_callback_var(multiply_log(arena_a.val(), arena_b), [arena_a, arena_b](const auto& res) mutable { arena_a.adj().array() - += res.adj().array() - * arena_b.val().array().log(); + += res.adj().array() * arena_b.array().log(); }); } else { - arena_t> arena_a = value_of(a); - arena_t> arena_b = b; - return make_callback_var(multiply_log(arena_a, arena_b.val()), [arena_a, arena_b](const auto& res) mutable { arena_b.adj().array() += res.adj().array() - * arena_a.val().array() + * arena_a.array() / arena_b.val().array(); }); } @@ -148,37 +140,25 @@ inline auto multiply_log(const T1& a, const T2& b) { */ template * = nullptr, require_stan_scalar_t* = nullptr> -inline auto multiply_log(const T1& a, const T2& b) { - using std::log; - - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - var arena_b = b; - +inline auto multiply_log(T1&& a, T2&& b) { + arena_t arena_a = a; + if constexpr (is_not_constant_v && is_not_constant_v) { return make_callback_var( - multiply_log(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() += res.adj().array() * log(arena_b.val()); - arena_b.adj() += (res.adj().array() * arena_a.val().array()).sum() - / arena_b.val(); + multiply_log(arena_a.val(), b.val()), + [arena_a, b](const auto& res) mutable { + arena_a.adj().array() += res.adj().array() * log(b.val()); + b.adj() + += (res.adj().array() * arena_a.val().array()).sum() / b.val(); + }); + } else if constexpr (is_not_constant_v) { + return make_callback_var( + multiply_log(arena_a.val(), b), [arena_a, b](const auto& res) mutable { + arena_a.adj().array() += res.adj().array() * log(b); }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - - return make_callback_var(multiply_log(arena_a.val(), value_of(b)), - [arena_a, b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() * log(value_of(b)); - }); } else { - arena_t> arena_a = value_of(a); - var arena_b = b; - return make_callback_var( - multiply_log(arena_a, arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_b.adj() - += (res.adj().array() * arena_a.array()).sum() / arena_b.val(); + multiply_log(arena_a, b.val()), [arena_a, b](const auto& res) mutable { + b.adj() += (res.adj().array() * arena_a.array()).sum() / b.val(); }); } } @@ -194,38 +174,27 @@ inline auto multiply_log(const T1& a, const T2& b) { */ template * = nullptr, require_var_matrix_t* = nullptr> -inline auto multiply_log(const T1& a, const T2& b) { - if (!is_constant::value && !is_constant::value) { - var arena_a = a; - arena_t> arena_b = b; - +inline auto multiply_log(T1&& a, T2&& b) { + arena_t arena_b = std::forward(b); + if constexpr (is_not_constant_v && is_not_constant_v) { return make_callback_var( - multiply_log(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj() - += (res.adj().array() * arena_b.val().array().log()).sum(); + multiply_log(a.val(), arena_b.val()), + [a, arena_b](const auto& res) mutable { + a.adj() += (res.adj().array() * arena_b.val().array().log()).sum(); arena_b.adj().array() - += arena_a.val() * res.adj().array() / arena_b.val().array(); + += a.val() * res.adj().array() / arena_b.val().array(); }); - } else if (!is_constant::value) { - var arena_a = a; - arena_t> arena_b = value_of(b); - + } else if constexpr (is_not_constant_v) { return make_callback_var( - multiply_log(arena_a.val(), arena_b), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj() - += (res.adj().array() * arena_b.val().array().log()).sum(); + multiply_log(a.val(), arena_b), [a, arena_b](const auto& res) mutable { + a.adj() += (res.adj().array() * arena_b.array().log()).sum(); }); } else { - arena_t> arena_b = b; - - return make_callback_var(multiply_log(value_of(a), arena_b.val()), - [a, arena_b](const auto& res) mutable { - arena_b.adj().array() += value_of(a) - * res.adj().array() - / arena_b.val().array(); - }); + return make_callback_var( + multiply_log(a, arena_b.val()), [a, arena_b](const auto& res) mutable { + arena_b.adj().array() + += a * res.adj().array() / arena_b.val().array(); + }); } } diff --git a/test/unit/math/rev/fun/multiply_log_test.cpp b/test/unit/math/rev/fun/multiply_log_test.cpp new file mode 100644 index 00000000000..23f5c324e36 --- /dev/null +++ b/test/unit/math/rev/fun/multiply_log_test.cpp @@ -0,0 +1,12 @@ +#include +#include +#include +#include + +TEST(RevTest, multiply_log_issue3146) { + long int x = 1; + using stan::math::var_value; + using mat_t = Eigen::Matrix; + var_value y = Eigen::MatrixXd::Random(10, 10); + auto z = stan::math::multiply_log(x, y); +} \ No newline at end of file From 7afd1a97982d99d4404408ab727cc307aabafe0e Mon Sep 17 00:00:00 2001 From: stevebronder Date: Thu, 13 Feb 2025 15:17:22 -0500 Subject: [PATCH 02/19] fix cpplint --- test/unit/math/rev/fun/multiply_log_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unit/math/rev/fun/multiply_log_test.cpp b/test/unit/math/rev/fun/multiply_log_test.cpp index 23f5c324e36..fa4aa78f5a5 100644 --- a/test/unit/math/rev/fun/multiply_log_test.cpp +++ b/test/unit/math/rev/fun/multiply_log_test.cpp @@ -4,9 +4,9 @@ #include TEST(RevTest, multiply_log_issue3146) { - long int x = 1; + std::int32_t x = 1; using stan::math::var_value; using mat_t = Eigen::Matrix; var_value y = Eigen::MatrixXd::Random(10, 10); auto z = stan::math::multiply_log(x, y); -} \ No newline at end of file +} From dbf99e8b087687c5e6254d76d2b4020987732422 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 18 Feb 2025 12:04:22 -0500 Subject: [PATCH 03/19] update multiply log to fix #2494 --- stan/math/rev/fun/multiply_log.hpp | 83 +++++++++--------------------- 1 file changed, 24 insertions(+), 59 deletions(-) diff --git a/stan/math/rev/fun/multiply_log.hpp b/stan/math/rev/fun/multiply_log.hpp index 0c9ae82a5c8..93f8b1e4a8f 100644 --- a/stan/math/rev/fun/multiply_log.hpp +++ b/stan/math/rev/fun/multiply_log.hpp @@ -14,38 +14,11 @@ namespace stan { namespace math { -namespace internal { -class multiply_log_vv_vari : public op_vv_vari { - public: - multiply_log_vv_vari(vari* avi, vari* bvi) - : op_vv_vari(multiply_log(avi->val_, bvi->val_), avi, bvi) {} - void chain() { - using std::log; - avi_->adj_ += adj_ * log(bvi_->val_); - bvi_->adj_ += adj_ * avi_->val_ / bvi_->val_; - } -}; -class multiply_log_vd_vari : public op_vd_vari { - public: - multiply_log_vd_vari(vari* avi, double b) - : op_vd_vari(multiply_log(avi->val_, b), avi, b) {} - void chain() { - using std::log; - avi_->adj_ += adj_ * log(bd_); - } -}; -class multiply_log_dv_vari : public op_dv_vari { - public: - multiply_log_dv_vari(double a, vari* bvi) - : op_dv_vari(multiply_log(a, bvi->val_), a, bvi) {} - void chain() { bvi_->adj_ += adj_ * ad_ / bvi_->val_; } -}; -} // namespace internal - /** * Return the value of a*log(b). * - * When both a and b are 0, the value returned is 0. + * When both a and b are 0, the value returned is 0 + * and no gradients are accumulated. * The partial derivative with respect to a is log(b). * The partial derivative with respect to b is a/b. * @@ -53,37 +26,29 @@ class multiply_log_dv_vari : public op_dv_vari { * @param b Second variable. * @return Value of a*log(b) */ -inline var multiply_log(const var& a, const var& b) { - return var(new internal::multiply_log_vv_vari(a.vi_, b.vi_)); -} -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to a is log(b). - * - * @param a First variable. - * @param b Second scalar. - * @return Value of a*log(b) - */ -inline var multiply_log(const var& a, double b) { - return var(new internal::multiply_log_vd_vari(a.vi_, b)); -} -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to b is a/b. - * - * @param a First scalar. - * @param b Second variable. - * @return Value of a*log(b) - */ -inline var multiply_log(double a, const var& b) { - if (a == 1.0) { - return log(b); +template * = nullptr, + require_any_var_t* = nullptr> +inline var multiply_log(const T1& a, const T2& b) { + if (value_of(a) == 0.0 && value_of(b) == 0.0) { + return var(0.0); + } + if constexpr (!is_constant::value && !is_constant::value) { + return make_callback_var(multiply_log(a.val(), b.val()), + [a, b](const auto& res) mutable { + a.adj() += res.adj() * log(b.val()); + b.adj() += res.adj() * a.val() / b.val(); + }); + } else if constexpr (!is_constant::value) { + return make_callback_var(multiply_log(a.val(), b), + [a, b](const auto& res) mutable { + a.adj() += res.adj() * log(b); + }); + } else { + return make_callback_var(multiply_log(a, b.val()), + [a, b](const auto& res) mutable { + b.adj() += res.adj() * a / b.val(); + }); } - return var(new internal::multiply_log_dv_vari(a, b.vi_)); } /** From 00c3f8d006e9f5d802c0f4e19f4507def126e96a Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Tue, 18 Feb 2025 12:05:42 -0500 Subject: [PATCH 04/19] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/rev/fun/multiply_log.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/stan/math/rev/fun/multiply_log.hpp b/stan/math/rev/fun/multiply_log.hpp index 93f8b1e4a8f..6d1fe290c79 100644 --- a/stan/math/rev/fun/multiply_log.hpp +++ b/stan/math/rev/fun/multiply_log.hpp @@ -26,8 +26,9 @@ namespace math { * @param b Second variable. * @return Value of a*log(b) */ -template * = nullptr, - require_any_var_t* = nullptr> +template * = nullptr, + require_any_var_t* = nullptr> inline var multiply_log(const T1& a, const T2& b) { if (value_of(a) == 0.0 && value_of(b) == 0.0) { return var(0.0); @@ -39,10 +40,9 @@ inline var multiply_log(const T1& a, const T2& b) { b.adj() += res.adj() * a.val() / b.val(); }); } else if constexpr (!is_constant::value) { - return make_callback_var(multiply_log(a.val(), b), - [a, b](const auto& res) mutable { - a.adj() += res.adj() * log(b); - }); + return make_callback_var( + multiply_log(a.val(), b), + [a, b](const auto& res) mutable { a.adj() += res.adj() * log(b); }); } else { return make_callback_var(multiply_log(a, b.val()), [a, b](const auto& res) mutable { From 8e0e20773d7c831e326613c53f98421c9ddef5f2 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 18 Feb 2025 16:42:59 -0500 Subject: [PATCH 05/19] update multiply_log --- stan/math/fwd/fun/multiply_log.hpp | 11 ++ stan/math/prim/fun/elt_multiply.hpp | 2 +- stan/math/prim/fun/multiply_log.hpp | 8 +- stan/math/prim/meta/is_rev_matrix.hpp | 6 + stan/math/rev/fun/multiply_log.hpp | 135 ++++++++++-------- test/unit/math/mix/fun/multiply_log1_test.cpp | 9 +- test/unit/math/mix/fun/multiply_log2_test.cpp | 80 +++++++++-- 7 files changed, 170 insertions(+), 81 deletions(-) diff --git a/stan/math/fwd/fun/multiply_log.hpp b/stan/math/fwd/fun/multiply_log.hpp index 3df40b33268..bad9b845610 100644 --- a/stan/math/fwd/fun/multiply_log.hpp +++ b/stan/math/fwd/fun/multiply_log.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -12,19 +13,29 @@ namespace math { template inline fvar multiply_log(const fvar& x1, const fvar& x2) { + if (value_of_rec(x1) == 0.0 && value_of_rec(x2) == 0.0) { + return fvar(0.0); + } return fvar(multiply_log(x1.val_, x2.val_), x1.d_ * log(x2.val_) + x1.val_ * x2.d_ / x2.val_); } template inline fvar multiply_log(double x1, const fvar& x2) { + if (x1 == 0.0 && value_of_rec(x2) == 0.0) { + return fvar(0.0); + } return fvar(multiply_log(x1, x2.val_), x1 * x2.d_ / x2.val_); } template inline fvar multiply_log(const fvar& x1, double x2) { + if (value_of_rec(x1) == 0.0 && x2 == 0.0) { + return fvar(0.0); + } return fvar(multiply_log(x1.val_, x2), x1.d_ * log(x2)); } + } // namespace math } // namespace stan #endif diff --git a/stan/math/prim/fun/elt_multiply.hpp b/stan/math/prim/fun/elt_multiply.hpp index bad90c7a26d..4e2d3a649d4 100644 --- a/stan/math/prim/fun/elt_multiply.hpp +++ b/stan/math/prim/fun/elt_multiply.hpp @@ -57,7 +57,7 @@ auto elt_multiply(const Scalar1& a, const Scalar2& b) { * @param B second argument * @return product of matrix and scalar */ -template * = nullptr, +template * = nullptr, require_any_stan_scalar_t* = nullptr> inline auto elt_multiply(const T1& A, const T2& B) { return multiply(A, B); diff --git a/stan/math/prim/fun/multiply_log.hpp b/stan/math/prim/fun/multiply_log.hpp index a4f67af5311..f28d53d1d3f 100644 --- a/stan/math/prim/fun/multiply_log.hpp +++ b/stan/math/prim/fun/multiply_log.hpp @@ -47,12 +47,10 @@ namespace math { template * = nullptr> inline return_type_t multiply_log(const T_a a, const T_b b) { - using std::log; - if (b == 0.0 && a == 0.0) { + if (a == 0.0 && b == 0.0) { return 0.0; } - - return a * log(b); + return a * std::log(b); } /** @@ -66,7 +64,7 @@ inline return_type_t multiply_log(const T_a a, const T_b b) { * @return multiply_log function applied to the two inputs. */ template * = nullptr, - require_all_not_var_matrix_t* = nullptr> + require_all_not_rev_matrix_t* = nullptr> inline auto multiply_log(const T1& a, const T2& b) { return apply_scalar_binary( a, b, [&](const auto& c, const auto& d) { return multiply_log(c, d); }); diff --git a/stan/math/prim/meta/is_rev_matrix.hpp b/stan/math/prim/meta/is_rev_matrix.hpp index 40ef60bb9e9..fa941a6f8c0 100644 --- a/stan/math/prim/meta/is_rev_matrix.hpp +++ b/stan/math/prim/meta/is_rev_matrix.hpp @@ -46,6 +46,12 @@ using require_any_rev_matrix_t template using require_all_not_rev_matrix_t = require_all_not_t>...>; + +/*! \brief Require at least one of the types do not satisfy @ref is_rev_matrix */ +/*! @tparam Types The types that are checked */ +template +using require_any_not_rev_matrix_t + = require_any_not_t>...>; /*! @} */ /** \ingroup type_trait diff --git a/stan/math/rev/fun/multiply_log.hpp b/stan/math/rev/fun/multiply_log.hpp index 93f8b1e4a8f..fb98e755028 100644 --- a/stan/math/rev/fun/multiply_log.hpp +++ b/stan/math/rev/fun/multiply_log.hpp @@ -29,26 +29,20 @@ namespace math { template * = nullptr, require_any_var_t* = nullptr> inline var multiply_log(const T1& a, const T2& b) { - if (value_of(a) == 0.0 && value_of(b) == 0.0) { + if (value_of(a) == 0.0 && value_of(b) == 0.0){ return var(0.0); } - if constexpr (!is_constant::value && !is_constant::value) { - return make_callback_var(multiply_log(a.val(), b.val()), - [a, b](const auto& res) mutable { - a.adj() += res.adj() * log(b.val()); - b.adj() += res.adj() * a.val() / b.val(); - }); - } else if constexpr (!is_constant::value) { - return make_callback_var(multiply_log(a.val(), b), - [a, b](const auto& res) mutable { - a.adj() += res.adj() * log(b); - }); - } else { - return make_callback_var(multiply_log(a, b.val()), - [a, b](const auto& res) mutable { - b.adj() += res.adj() * a / b.val(); - }); - } + return make_callback_var(multiply_log(value_of(a), value_of(b)), + [a, b](const auto& res) mutable { + if constexpr (!is_constant::value && !is_constant::value) { + a.adj() += res.adj() * log(b.val()); + b.adj() += res.adj() * a.val() / b.val(); + } else if constexpr (!is_constant::value) { + a.adj() += res.adj() * log(b); + } else { + b.adj() += res.adj() * a / b.val(); + } + }); } /** @@ -64,38 +58,50 @@ inline var multiply_log(const T1& a, const T2& b) { * @return elementwise product of `a` and `log(b)` */ template * = nullptr, - require_any_var_matrix_t* = nullptr> + require_any_rev_matrix_t* = nullptr> inline auto multiply_log(T1&& a, T2&& b) { check_matching_dims("multiply_log", "a", a, "b", b); arena_t arena_a = std::forward(a); arena_t arena_b = std::forward(b); + using return_t + = return_var_matrix_t; + arena_t res = multiply_log(value_of(arena_a), value_of(arena_b)); + if constexpr (is_not_constant_v && is_not_constant_v) { - return make_callback_var( - multiply_log(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { + reverse_pass_callback( + [res, arena_a, arena_b]() mutable { + auto is_zero = (arena_a.val().array() == 0.0 && arena_b.val().array() == 0.0); arena_a.adj().array() - += res.adj().array() * arena_b.val().array().log(); - arena_b.adj().array() += res.adj().array() * arena_a.val().array() - / arena_b.val().array(); + += is_zero.select(0.0, res.adj().array() * arena_b.val().array().log()); + arena_b.adj().array() += is_zero.select(0.0, res.adj().array() * arena_a.val().array() + / arena_b.val().array()); }); } else if constexpr (is_not_constant_v) { - return make_callback_var(multiply_log(arena_a.val(), arena_b), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() * arena_b.array().log(); + reverse_pass_callback( + [res, arena_a, arena_b]() mutable { + auto is_zero = (arena_a.val().array() == 0.0 && arena_b.array() == 0.0); + arena_a.adj().array() + += is_zero.select(0.0, res.adj().array() * arena_b.array().log()); }); } else { - return make_callback_var(multiply_log(arena_a, arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_b.adj().array() += res.adj().array() - * arena_a.array() - / arena_b.val().array(); + reverse_pass_callback( + [res, arena_a, arena_b]() mutable { + auto is_zero = (arena_a.array() == 0.0 && arena_b.val().array() == 0.0); + arena_b.adj().array() += is_zero.select(0.0, res.adj().array() + * arena_a.array() + / arena_b.val().array()); }); } + return res; } /** * Return the product `a * log(b)`. + * In the case where b is a scalar and and element of `a` and `b` are zero + * the value returned is 0 and no gradients are accumulated. + * For `b`'s adjoint, this function can still return NaN as the adjoint + * of `b` if `a` is nonzero anywhere, but `b` is zero. Likewise, + * `a`'s adjoint can have undefined values if `a` is nonzero but `b` is zero. * * @tparam T1 Type of matrix argument * @tparam T2 Type of scalar argument @@ -103,29 +109,34 @@ inline auto multiply_log(T1&& a, T2&& b) { * @param b Scalar argument * @return Product of `a` and `log(b)` */ -template * = nullptr, +template * = nullptr, require_stan_scalar_t* = nullptr> inline auto multiply_log(T1&& a, T2&& b) { arena_t arena_a = a; + using return_t + = return_var_matrix_t; + arena_t res = multiply_log(value_of(arena_a), value_of(b)); if constexpr (is_not_constant_v && is_not_constant_v) { - return make_callback_var( - multiply_log(arena_a.val(), b.val()), - [arena_a, b](const auto& res) mutable { - arena_a.adj().array() += res.adj().array() * log(b.val()); - b.adj() - += (res.adj().array() * arena_a.val().array()).sum() / b.val(); + reverse_pass_callback( + [res, arena_a, b]() mutable { + auto is_zero = ((arena_a.val().array() == 0.0) + (b.val() == 0.0) > 1); + arena_a.adj().array() += is_zero.select(0.0, res.adj().array() * log(b.val())); + b.adj() += is_zero.select(0.0, (res.adj().array() * arena_a.val().array()) / b.val()).sum(); }); } else if constexpr (is_not_constant_v) { - return make_callback_var( - multiply_log(arena_a.val(), b), [arena_a, b](const auto& res) mutable { - arena_a.adj().array() += res.adj().array() * log(b); + reverse_pass_callback( + [res, arena_a, b]() mutable { + auto is_zero = ((arena_a.val().array() == 0.0) + (b == 0.0) > 1); + arena_a.adj().array() += is_zero.select(0.0, res.adj().array() * log(b)); }); } else { - return make_callback_var( - multiply_log(arena_a, b.val()), [arena_a, b](const auto& res) mutable { - b.adj() += (res.adj().array() * arena_a.array()).sum() / b.val(); + reverse_pass_callback( + [res, arena_a, b]() mutable { + auto is_zero = ((arena_a.array() == 0.0) + (b.val() == 0.0) > 1); + b.adj() += is_zero.select(0.0, (res.adj().array() * arena_a.val().array()) / b.val()).sum(); }); } + return res; } /** @@ -138,29 +149,35 @@ inline auto multiply_log(T1&& a, T2&& b) { * @return Product of `a` and `log(b)` */ template * = nullptr, - require_var_matrix_t* = nullptr> + require_rev_matrix_t* = nullptr> inline auto multiply_log(T1&& a, T2&& b) { arena_t arena_b = std::forward(b); + using return_t + = return_var_matrix_t; + arena_t res = multiply_log(value_of(a), value_of(arena_b)); if constexpr (is_not_constant_v && is_not_constant_v) { - return make_callback_var( - multiply_log(a.val(), arena_b.val()), - [a, arena_b](const auto& res) mutable { - a.adj() += (res.adj().array() * arena_b.val().array().log()).sum(); + reverse_pass_callback( + [res, a, arena_b]() mutable { + auto is_zero = ((a.val() == 0.0) + (arena_b.val().array() == 0.0) > 1); + a.adj() += is_zero.select(0.0, res.adj().array() * arena_b.val().array().log()).sum(); arena_b.adj().array() - += a.val() * res.adj().array() / arena_b.val().array(); + += is_zero.select(0.0, a.val() * res.adj().array() / arena_b.val().array()); }); } else if constexpr (is_not_constant_v) { - return make_callback_var( - multiply_log(a.val(), arena_b), [a, arena_b](const auto& res) mutable { - a.adj() += (res.adj().array() * arena_b.array().log()).sum(); + reverse_pass_callback( + [res, a, arena_b]() mutable { + auto is_zero = ((a.val() == 0.0) + (arena_b.array() == 0.0) > 1); + a.adj() += is_zero.select(0.0, res.adj().array() * arena_b.array().log()).sum(); }); } else { - return make_callback_var( - multiply_log(a, arena_b.val()), [a, arena_b](const auto& res) mutable { + reverse_pass_callback( + [res, a, arena_b]() mutable { + auto is_zero = ((a == 0.0) + (arena_b.val().array() == 0.0) > 1); arena_b.adj().array() - += a * res.adj().array() / arena_b.val().array(); + += is_zero.select(0.0, a * res.adj().array() / arena_b.val().array()); }); } + return res; } } // namespace math diff --git a/test/unit/math/mix/fun/multiply_log1_test.cpp b/test/unit/math/mix/fun/multiply_log1_test.cpp index ba8e1336f01..7c5f6cbdc9e 100644 --- a/test/unit/math/mix/fun/multiply_log1_test.cpp +++ b/test/unit/math/mix/fun/multiply_log1_test.cpp @@ -5,14 +5,19 @@ TEST(mathMixScalFun, multiplyLog) { auto f = [](const auto& x1, const auto& x2) { return stan::math::multiply_log(x1, x2); }; - stan::test::expect_ad(f, 0.5, -0.4); // error stan::test::expect_ad(f, 0.5, 1.2); stan::test::expect_ad(f, 1.5, 1.8); stan::test::expect_ad(f, 2.2, 3.3); stan::test::expect_ad(f, 19.7, 1299.1); +} +TEST(mathMixScalFun, multiplyLog_errors) { + auto f = [](const auto& x1, const auto& x2) { + return stan::math::multiply_log(x1, x2); + }; + stan::test::expect_ad(f, 0.5, -0.4); // error - double nan = std::numeric_limits::quiet_NaN(); + constexpr double nan = std::numeric_limits::quiet_NaN(); stan::test::expect_ad(f, 1.0, nan); stan::test::expect_ad(f, nan, 1.0); stan::test::expect_ad(f, nan, nan); diff --git a/test/unit/math/mix/fun/multiply_log2_test.cpp b/test/unit/math/mix/fun/multiply_log2_test.cpp index a124be21bdb..26e31cc33c4 100644 --- a/test/unit/math/mix/fun/multiply_log2_test.cpp +++ b/test/unit/math/mix/fun/multiply_log2_test.cpp @@ -1,12 +1,26 @@ #include #include +/** + * NOTE: We do not test values of (0.0, 0.0) as inputs. + * This is because the testing framework uses + * finite difference as the comparison. The small finite + * purtubations lead to cases where the value of the + * function inputs is either slightly negative or slightly positive, + * which can lead to the function returning NaN when we would expect a + * value of 0. + */ TEST(mathMixScalFun, multiplyLog2_vec) { auto f = [](const auto& x1, const auto& x2) { using stan::math::multiply_log; return multiply_log(x1, x2); }; - + auto expect_test = [](auto&& f, auto&& x1, auto&& x2) { + stan::test::expect_ad(f, x1, x2); + stan::test::expect_ad(f, x2, x1); + stan::test::expect_ad_matvar(f, x1, x2); + stan::test::expect_ad_matvar(f, x2, x1); + }; Eigen::VectorXd in1(2); in1 << 3, 1; Eigen::VectorXd in2(2); @@ -18,28 +32,66 @@ TEST(mathMixScalFun, multiplyLog2_vec) { x2 << 1.0, 2.0, 3.0; Eigen::MatrixXd x3(2, 3); x3 << 1.0, 2.0, 3.0, 4.0, 5.0, 6.0; - + stan::test::expect_ad(f, x1, x1); + stan::test::expect_ad(f, x2, x2); + stan::test::expect_ad(f, x3, x3); stan::test::expect_ad_matvar(f, x1, x1); - stan::test::expect_ad_matvar(f, x1, 2.0); - stan::test::expect_ad_matvar(f, 3.0, x1); stan::test::expect_ad_matvar(f, x2, x2); - stan::test::expect_ad_matvar(f, x2, 2.5); - stan::test::expect_ad_matvar(f, 3.5, x2); stan::test::expect_ad_matvar(f, x3, x3); - stan::test::expect_ad_matvar(f, x3, 4.0); - stan::test::expect_ad_matvar(f, 5.0, x3); + expect_test(f, x1, 2.0); + expect_test(f, x2, 2.5); + expect_test(f, x3, 5.5); Eigen::VectorXd x4(0); Eigen::RowVectorXd x5(0); Eigen::MatrixXd x6(0, 0); + stan::test::expect_ad(f, x4, x4); + stan::test::expect_ad(f, x5, x5); + stan::test::expect_ad(f, x6, x6); stan::test::expect_ad_matvar(f, x4, x4); - stan::test::expect_ad_matvar(f, x4, 2.0); - stan::test::expect_ad_matvar(f, 3.0, x4); stan::test::expect_ad_matvar(f, x5, x5); - stan::test::expect_ad_matvar(f, x5, 2.5); - stan::test::expect_ad_matvar(f, 3.5, x5); stan::test::expect_ad_matvar(f, x6, x6); - stan::test::expect_ad_matvar(f, x6, 4.0); - stan::test::expect_ad_matvar(f, 5.0, x6); + expect_test(f, x4, 2.0); + expect_test(f, x5, 3.1); + expect_test(f, x6, 5.5); +} + + + +TEST(mathMixScalFun, multiplyLog2_zero_vec_vec) { + auto f = [](const auto& x1, const auto& x2) { + using stan::math::multiply_log; + return multiply_log(x1, x2); + }; + auto expect_test = [](auto&& f, auto&& x1, auto&& x2) { + stan::test::expect_ad(f, x1, x2); + stan::test::expect_ad(f, x2, x1); + stan::test::expect_ad_matvar(f, x1, x2); + stan::test::expect_ad_matvar(f, x2, x1); + }; + + Eigen::VectorXd zero_vec = Eigen::VectorXd::Zero(3); + Eigen::VectorXd x1(3); + x1 << 1.0, 2.0, 3.0; + expect_test(f, zero_vec, x1); } +TEST(mathMixScalFun, multiplyLog2_zero_vec_scalar) { + auto f = [](const auto& x1, const auto& x2) { + using stan::math::multiply_log; + return multiply_log(x1, x2); + }; + + Eigen::VectorXd zero_vec = Eigen::VectorXd::Zero(3); + Eigen::VectorXd x1(3); + x1 << 1.0, 2.0, 3.0; + auto expect_test = [](auto&& f, auto&& x1, auto&& x2) { + stan::test::expect_ad(f, x1, x2); + stan::test::expect_ad(f, x2, x1); + stan::test::expect_ad_matvar(f, x1, x2); + stan::test::expect_ad_matvar(f, x2, x1); + }; + expect_test(f, x1, 0.0); +} + + From df6a972852892d2432ca98ec936c39a2b1ba0e1f Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Tue, 18 Feb 2025 16:44:26 -0500 Subject: [PATCH 06/19] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/prim/fun/elt_multiply.hpp | 3 +- stan/math/prim/meta/is_rev_matrix.hpp | 3 +- stan/math/rev/fun/multiply_log.hpp | 149 +++++++++--------- test/unit/math/mix/fun/multiply_log2_test.cpp | 4 - 4 files changed, 80 insertions(+), 79 deletions(-) diff --git a/stan/math/prim/fun/elt_multiply.hpp b/stan/math/prim/fun/elt_multiply.hpp index 4e2d3a649d4..cb60e861397 100644 --- a/stan/math/prim/fun/elt_multiply.hpp +++ b/stan/math/prim/fun/elt_multiply.hpp @@ -57,7 +57,8 @@ auto elt_multiply(const Scalar1& a, const Scalar2& b) { * @param B second argument * @return product of matrix and scalar */ -template * = nullptr, +template * = nullptr, require_any_stan_scalar_t* = nullptr> inline auto elt_multiply(const T1& A, const T2& B) { return multiply(A, B); diff --git a/stan/math/prim/meta/is_rev_matrix.hpp b/stan/math/prim/meta/is_rev_matrix.hpp index fa941a6f8c0..fc7a3ba9234 100644 --- a/stan/math/prim/meta/is_rev_matrix.hpp +++ b/stan/math/prim/meta/is_rev_matrix.hpp @@ -47,7 +47,8 @@ template using require_all_not_rev_matrix_t = require_all_not_t>...>; -/*! \brief Require at least one of the types do not satisfy @ref is_rev_matrix */ +/*! \brief Require at least one of the types do not satisfy @ref is_rev_matrix + */ /*! @tparam Types The types that are checked */ template using require_any_not_rev_matrix_t diff --git a/stan/math/rev/fun/multiply_log.hpp b/stan/math/rev/fun/multiply_log.hpp index c9d22068896..bdb4ba9a896 100644 --- a/stan/math/rev/fun/multiply_log.hpp +++ b/stan/math/rev/fun/multiply_log.hpp @@ -30,20 +30,20 @@ template * = nullptr, require_any_var_t* = nullptr> inline var multiply_log(const T1& a, const T2& b) { - if (value_of(a) == 0.0 && value_of(b) == 0.0){ + if (value_of(a) == 0.0 && value_of(b) == 0.0) { return var(0.0); } - return make_callback_var(multiply_log(value_of(a), value_of(b)), - [a, b](const auto& res) mutable { - if constexpr (!is_constant::value && !is_constant::value) { - a.adj() += res.adj() * log(b.val()); - b.adj() += res.adj() * a.val() / b.val(); - } else if constexpr (!is_constant::value) { - a.adj() += res.adj() * log(b); - } else { - b.adj() += res.adj() * a / b.val(); - } - }); + return make_callback_var( + multiply_log(value_of(a), value_of(b)), [a, b](const auto& res) mutable { + if constexpr (!is_constant::value && !is_constant::value) { + a.adj() += res.adj() * log(b.val()); + b.adj() += res.adj() * a.val() / b.val(); + } else if constexpr (!is_constant::value) { + a.adj() += res.adj() * log(b); + } else { + b.adj() += res.adj() * a / b.val(); + } + }); } /** @@ -64,34 +64,32 @@ inline auto multiply_log(T1&& a, T2&& b) { check_matching_dims("multiply_log", "a", a, "b", b); arena_t arena_a = std::forward(a); arena_t arena_b = std::forward(b); - using return_t - = return_var_matrix_t; + using return_t = return_var_matrix_t< + decltype(multiply_log(value_of(arena_a), value_of(arena_b))), T1, T2>; arena_t res = multiply_log(value_of(arena_a), value_of(arena_b)); if constexpr (is_not_constant_v && is_not_constant_v) { - reverse_pass_callback( - [res, arena_a, arena_b]() mutable { - auto is_zero = (arena_a.val().array() == 0.0 && arena_b.val().array() == 0.0); - arena_a.adj().array() - += is_zero.select(0.0, res.adj().array() * arena_b.val().array().log()); - arena_b.adj().array() += is_zero.select(0.0, res.adj().array() * arena_a.val().array() - / arena_b.val().array()); - }); + reverse_pass_callback([res, arena_a, arena_b]() mutable { + auto is_zero + = (arena_a.val().array() == 0.0 && arena_b.val().array() == 0.0); + arena_a.adj().array() += is_zero.select( + 0.0, res.adj().array() * arena_b.val().array().log()); + arena_b.adj().array() + += is_zero.select(0.0, res.adj().array() * arena_a.val().array() + / arena_b.val().array()); + }); } else if constexpr (is_not_constant_v) { - reverse_pass_callback( - [res, arena_a, arena_b]() mutable { - auto is_zero = (arena_a.val().array() == 0.0 && arena_b.array() == 0.0); - arena_a.adj().array() - += is_zero.select(0.0, res.adj().array() * arena_b.array().log()); - }); + reverse_pass_callback([res, arena_a, arena_b]() mutable { + auto is_zero = (arena_a.val().array() == 0.0 && arena_b.array() == 0.0); + arena_a.adj().array() + += is_zero.select(0.0, res.adj().array() * arena_b.array().log()); + }); } else { - reverse_pass_callback( - [res, arena_a, arena_b]() mutable { - auto is_zero = (arena_a.array() == 0.0 && arena_b.val().array() == 0.0); - arena_b.adj().array() += is_zero.select(0.0, res.adj().array() - * arena_a.array() - / arena_b.val().array()); - }); + reverse_pass_callback([res, arena_a, arena_b]() mutable { + auto is_zero = (arena_a.array() == 0.0 && arena_b.val().array() == 0.0); + arena_b.adj().array() += is_zero.select( + 0.0, res.adj().array() * arena_a.array() / arena_b.val().array()); + }); } return res; } @@ -114,28 +112,32 @@ template * = nullptr, require_stan_scalar_t* = nullptr> inline auto multiply_log(T1&& a, T2&& b) { arena_t arena_a = a; - using return_t - = return_var_matrix_t; + using return_t = return_var_matrix_t< + decltype(multiply_log(value_of(arena_a), value_of(b))), T1, T2>; arena_t res = multiply_log(value_of(arena_a), value_of(b)); if constexpr (is_not_constant_v && is_not_constant_v) { - reverse_pass_callback( - [res, arena_a, b]() mutable { - auto is_zero = ((arena_a.val().array() == 0.0) + (b.val() == 0.0) > 1); - arena_a.adj().array() += is_zero.select(0.0, res.adj().array() * log(b.val())); - b.adj() += is_zero.select(0.0, (res.adj().array() * arena_a.val().array()) / b.val()).sum(); - }); + reverse_pass_callback([res, arena_a, b]() mutable { + auto is_zero = ((arena_a.val().array() == 0.0) + (b.val() == 0.0) > 1); + arena_a.adj().array() + += is_zero.select(0.0, res.adj().array() * log(b.val())); + b.adj() += is_zero + .select(0.0, (res.adj().array() * arena_a.val().array()) + / b.val()) + .sum(); + }); } else if constexpr (is_not_constant_v) { - reverse_pass_callback( - [res, arena_a, b]() mutable { - auto is_zero = ((arena_a.val().array() == 0.0) + (b == 0.0) > 1); - arena_a.adj().array() += is_zero.select(0.0, res.adj().array() * log(b)); - }); + reverse_pass_callback([res, arena_a, b]() mutable { + auto is_zero = ((arena_a.val().array() == 0.0) + (b == 0.0) > 1); + arena_a.adj().array() += is_zero.select(0.0, res.adj().array() * log(b)); + }); } else { - reverse_pass_callback( - [res, arena_a, b]() mutable { - auto is_zero = ((arena_a.array() == 0.0) + (b.val() == 0.0) > 1); - b.adj() += is_zero.select(0.0, (res.adj().array() * arena_a.val().array()) / b.val()).sum(); - }); + reverse_pass_callback([res, arena_a, b]() mutable { + auto is_zero = ((arena_a.array() == 0.0) + (b.val() == 0.0) > 1); + b.adj() += is_zero + .select(0.0, (res.adj().array() * arena_a.val().array()) + / b.val()) + .sum(); + }); } return res; } @@ -153,30 +155,31 @@ template * = nullptr, require_rev_matrix_t* = nullptr> inline auto multiply_log(T1&& a, T2&& b) { arena_t arena_b = std::forward(b); - using return_t - = return_var_matrix_t; + using return_t = return_var_matrix_t< + decltype(multiply_log(value_of(a), value_of(arena_b))), T1, T2>; arena_t res = multiply_log(value_of(a), value_of(arena_b)); if constexpr (is_not_constant_v && is_not_constant_v) { - reverse_pass_callback( - [res, a, arena_b]() mutable { - auto is_zero = ((a.val() == 0.0) + (arena_b.val().array() == 0.0) > 1); - a.adj() += is_zero.select(0.0, res.adj().array() * arena_b.val().array().log()).sum(); - arena_b.adj().array() - += is_zero.select(0.0, a.val() * res.adj().array() / arena_b.val().array()); - }); + reverse_pass_callback([res, a, arena_b]() mutable { + auto is_zero = ((a.val() == 0.0) + (arena_b.val().array() == 0.0) > 1); + a.adj() + += is_zero + .select(0.0, res.adj().array() * arena_b.val().array().log()) + .sum(); + arena_b.adj().array() += is_zero.select( + 0.0, a.val() * res.adj().array() / arena_b.val().array()); + }); } else if constexpr (is_not_constant_v) { - reverse_pass_callback( - [res, a, arena_b]() mutable { - auto is_zero = ((a.val() == 0.0) + (arena_b.array() == 0.0) > 1); - a.adj() += is_zero.select(0.0, res.adj().array() * arena_b.array().log()).sum(); - }); + reverse_pass_callback([res, a, arena_b]() mutable { + auto is_zero = ((a.val() == 0.0) + (arena_b.array() == 0.0) > 1); + a.adj() += is_zero.select(0.0, res.adj().array() * arena_b.array().log()) + .sum(); + }); } else { - reverse_pass_callback( - [res, a, arena_b]() mutable { - auto is_zero = ((a == 0.0) + (arena_b.val().array() == 0.0) > 1); - arena_b.adj().array() - += is_zero.select(0.0, a * res.adj().array() / arena_b.val().array()); - }); + reverse_pass_callback([res, a, arena_b]() mutable { + auto is_zero = ((a == 0.0) + (arena_b.val().array() == 0.0) > 1); + arena_b.adj().array() + += is_zero.select(0.0, a * res.adj().array() / arena_b.val().array()); + }); } return res; } diff --git a/test/unit/math/mix/fun/multiply_log2_test.cpp b/test/unit/math/mix/fun/multiply_log2_test.cpp index 26e31cc33c4..1d5694deea6 100644 --- a/test/unit/math/mix/fun/multiply_log2_test.cpp +++ b/test/unit/math/mix/fun/multiply_log2_test.cpp @@ -57,8 +57,6 @@ TEST(mathMixScalFun, multiplyLog2_vec) { expect_test(f, x6, 5.5); } - - TEST(mathMixScalFun, multiplyLog2_zero_vec_vec) { auto f = [](const auto& x1, const auto& x2) { using stan::math::multiply_log; @@ -93,5 +91,3 @@ TEST(mathMixScalFun, multiplyLog2_zero_vec_scalar) { }; expect_test(f, x1, 0.0); } - - From 4444ee9165f686f9d0eef3e0a92c7644aa1d5144 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Tue, 18 Feb 2025 17:30:12 -0500 Subject: [PATCH 07/19] fix accidental change to elt_multiply prim --- stan/math/prim/fun/elt_multiply.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stan/math/prim/fun/elt_multiply.hpp b/stan/math/prim/fun/elt_multiply.hpp index cb60e861397..bad90c7a26d 100644 --- a/stan/math/prim/fun/elt_multiply.hpp +++ b/stan/math/prim/fun/elt_multiply.hpp @@ -57,8 +57,7 @@ auto elt_multiply(const Scalar1& a, const Scalar2& b) { * @param B second argument * @return product of matrix and scalar */ -template * = nullptr, +template * = nullptr, require_any_stan_scalar_t* = nullptr> inline auto elt_multiply(const T1& A, const T2& B) { return multiply(A, B); From f37df6f4f4e74eba217f8ff81b18d78766b44051 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 6 Mar 2025 14:42:28 -0500 Subject: [PATCH 08/19] update lmultiply to just call multiply_log --- stan/math/fwd/fun/lmultiply.hpp | 21 +-- stan/math/prim/fun/lmultiply.hpp | 32 +---- stan/math/rev/fun/lmultiply.hpp | 220 ------------------------------- 3 files changed, 7 insertions(+), 266 deletions(-) diff --git a/stan/math/fwd/fun/lmultiply.hpp b/stan/math/fwd/fun/lmultiply.hpp index 4af4aaf27c4..a6e7f0ac74d 100644 --- a/stan/math/fwd/fun/lmultiply.hpp +++ b/stan/math/fwd/fun/lmultiply.hpp @@ -4,27 +4,8 @@ #include #include #include +#include #include #include -namespace stan { -namespace math { - -template -inline fvar lmultiply(const fvar& x1, const fvar& x2) { - return fvar(lmultiply(x1.val_, x2.val_), - x1.d_ * log(x2.val_) + x1.val_ * x2.d_ / x2.val_); -} - -template -inline fvar lmultiply(double x1, const fvar& x2) { - return fvar(lmultiply(x1, x2.val_), x1 * x2.d_ / x2.val_); -} - -template -inline fvar lmultiply(const fvar& x1, double x2) { - return fvar(lmultiply(x1.val_, x2), x1.d_ * log(x2)); -} -} // namespace math -} // namespace stan #endif diff --git a/stan/math/prim/fun/lmultiply.hpp b/stan/math/prim/fun/lmultiply.hpp index e8747dcacea..14204af0971 100644 --- a/stan/math/prim/fun/lmultiply.hpp +++ b/stan/math/prim/fun/lmultiply.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -21,33 +22,12 @@ namespace math { * @param b second argument * @return the first argument times the log of the second argument */ -template * = nullptr> -inline return_type_t lmultiply(const T1 a, const T2 b) { - using std::log; - if (a == 0 && b == 0) { - return 0; - } - return a * log(b); +template +inline auto lmultiply(T1&& a, T2&& b) { + return make_holder([](auto&& a, auto&& b) { + return multiply_log(std::forward(a), std::forward(b)); + }, std::forward(a), std::forward(b)); } - -/** - * Return the result of applying `lmultiply` to the arguments - * elementwise, with broadcasting if one of the arguments is a scalar. - * At least one of the arguments must be a container. - * - * @tparam T1 type of the first argument - * @tparam T2 type of the second argument - * @param a first argument - * @param b second argument - * @return result of applying `lmultiply` to the arguments - */ -template * = nullptr, - require_all_not_var_matrix_t* = nullptr> -inline auto lmultiply(const T1& a, const T2& b) { - return apply_scalar_binary( - a, b, [&](const auto& c, const auto& d) { return lmultiply(c, d); }); -} - } // namespace math } // namespace stan diff --git a/stan/math/rev/fun/lmultiply.hpp b/stan/math/rev/fun/lmultiply.hpp index 348fdb2fd4d..16783d94481 100644 --- a/stan/math/rev/fun/lmultiply.hpp +++ b/stan/math/rev/fun/lmultiply.hpp @@ -12,224 +12,4 @@ #include #include -namespace stan { -namespace math { - -namespace internal { -class lmultiply_vv_vari : public op_vv_vari { - public: - lmultiply_vv_vari(vari* avi, vari* bvi) - : op_vv_vari(lmultiply(avi->val_, bvi->val_), avi, bvi) {} - void chain() { - using std::log; - avi_->adj_ += adj_ * log(bvi_->val_); - bvi_->adj_ += adj_ * avi_->val_ / bvi_->val_; - } -}; -class lmultiply_vd_vari : public op_vd_vari { - public: - lmultiply_vd_vari(vari* avi, double b) - : op_vd_vari(lmultiply(avi->val_, b), avi, b) {} - void chain() { - using std::log; - avi_->adj_ += adj_ * log(bd_); - } -}; -class lmultiply_dv_vari : public op_dv_vari { - public: - lmultiply_dv_vari(double a, vari* bvi) - : op_dv_vari(lmultiply(a, bvi->val_), a, bvi) {} - void chain() { bvi_->adj_ += adj_ * ad_ / bvi_->val_; } -}; -} // namespace internal - -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to a is log(b). - * The partial derivative with respect to b is a/b. - * - * @param a First variable. - * @param b Second variable. - * @return Value of a*log(b) - */ -inline var lmultiply(const var& a, const var& b) { - return var(new internal::lmultiply_vv_vari(a.vi_, b.vi_)); -} -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to a is log(b). - * - * @param a First variable. - * @param b Second scalar. - * @return Value of a*log(b) - */ -inline var lmultiply(const var& a, double b) { - return var(new internal::lmultiply_vd_vari(a.vi_, b)); -} -/** - * Return the value of a*log(b). - * - * When both a and b are 0, the value returned is 0. - * The partial derivative with respect to b is a/b. - * - * @param a First scalar. - * @param b Second variable. - * @return Value of a*log(b) - */ -inline var lmultiply(double a, const var& b) { - if (a == 1.0) { - return log(b); - } - return var(new internal::lmultiply_dv_vari(a, b.vi_)); -} - -/** - * Return the elementwise product `a * log(b)`. - * - * Both `T1` and `T2` are matrices, and one of `T1` or `T2` must be a - * `var_value` - * - * @tparam T1 Type of first argument - * @tparam T2 Type of second argument - * @param a First argument - * @param b Second argument - * @return elementwise product of `a` and `log(b)` - */ -template * = nullptr, - require_any_var_matrix_t* = nullptr> -inline auto lmultiply(const T1& a, const T2& b) { - check_matching_dims("lmultiply", "a", a, "b", b); - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = b; - - return make_callback_var( - lmultiply(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() * arena_b.val().array().log(); - arena_b.adj().array() += res.adj().array() * arena_a.val().array() - / arena_b.val().array(); - }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = value_of(b); - - return make_callback_var(lmultiply(arena_a.val(), arena_b), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() - * arena_b.val().array().log(); - }); - } else { - arena_t> arena_a = value_of(a); - arena_t> arena_b = b; - - return make_callback_var(lmultiply(arena_a, arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_b.adj().array() += res.adj().array() - * arena_a.val().array() - / arena_b.val().array(); - }); - } -} - -/** - * Return the product `a * log(b)`. - * - * @tparam T1 Type of matrix argument - * @tparam T2 Type of scalar argument - * @param a Matrix argument - * @param b Scalar argument - * @return Product of `a` and `log(b)` - */ -template * = nullptr, - require_stan_scalar_t* = nullptr> -inline auto lmultiply(const T1& a, const T2& b) { - using std::log; - - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - var arena_b = b; - - return make_callback_var( - lmultiply(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj().array() += res.adj().array() * log(arena_b.val()); - arena_b.adj() += (res.adj().array() * arena_a.val().array()).sum() - / arena_b.val(); - }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - - return make_callback_var(lmultiply(arena_a.val(), value_of(b)), - [arena_a, b](const auto& res) mutable { - arena_a.adj().array() - += res.adj().array() * log(value_of(b)); - }); - } else { - arena_t> arena_a = value_of(a); - var arena_b = b; - - return make_callback_var( - lmultiply(arena_a, arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_b.adj() - += (res.adj().array() * arena_a.array()).sum() / arena_b.val(); - }); - } -} - -/** - * Return the product `a * log(b)`. - * - * @tparam T1 Type of scalar argument - * @tparam T2 Type of matrix argument - * @param a Scalar argument - * @param b Matrix argument - * @return Product of `a` and `log(b)` - */ -template * = nullptr, - require_var_matrix_t* = nullptr> -inline auto lmultiply(const T1& a, const T2& b) { - if (!is_constant::value && !is_constant::value) { - var arena_a = a; - arena_t> arena_b = b; - - return make_callback_var( - lmultiply(arena_a.val(), arena_b.val()), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj() - += (res.adj().array() * arena_b.val().array().log()).sum(); - arena_b.adj().array() - += arena_a.val() * res.adj().array() / arena_b.val().array(); - }); - } else if (!is_constant::value) { - var arena_a = a; - arena_t> arena_b = value_of(b); - - return make_callback_var( - lmultiply(arena_a.val(), arena_b), - [arena_a, arena_b](const auto& res) mutable { - arena_a.adj() - += (res.adj().array() * arena_b.val().array().log()).sum(); - }); - } else { - arena_t> arena_b = b; - - return make_callback_var(lmultiply(value_of(a), arena_b.val()), - [a, arena_b](const auto& res) mutable { - arena_b.adj().array() += value_of(a) - * res.adj().array() - / arena_b.val().array(); - }); - } -} - -} // namespace math -} // namespace stan #endif From 81986392b0dc211a7acd532216fa4af17dfa963d Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 6 Mar 2025 14:53:40 -0500 Subject: [PATCH 09/19] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/prim/fun/lmultiply.hpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/stan/math/prim/fun/lmultiply.hpp b/stan/math/prim/fun/lmultiply.hpp index 14204af0971..66190def577 100644 --- a/stan/math/prim/fun/lmultiply.hpp +++ b/stan/math/prim/fun/lmultiply.hpp @@ -24,9 +24,12 @@ namespace math { */ template inline auto lmultiply(T1&& a, T2&& b) { - return make_holder([](auto&& a, auto&& b) { - return multiply_log(std::forward(a), std::forward(b)); - }, std::forward(a), std::forward(b)); + return make_holder( + [](auto&& a, auto&& b) { + return multiply_log(std::forward(a), + std::forward(b)); + }, + std::forward(a), std::forward(b)); } } // namespace math } // namespace stan From b064d8a3c4ab4850851fe20f2d3490306087d608 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 6 Mar 2025 16:17:12 -0500 Subject: [PATCH 10/19] update reverse mode multiply_log to have one signature to accept combinations of matrices and scalars --- stan/math/rev/fun/multiply_log.hpp | 150 ++++++++--------------------- 1 file changed, 40 insertions(+), 110 deletions(-) diff --git a/stan/math/rev/fun/multiply_log.hpp b/stan/math/rev/fun/multiply_log.hpp index bdb4ba9a896..0adf964b6ad 100644 --- a/stan/math/rev/fun/multiply_log.hpp +++ b/stan/math/rev/fun/multiply_log.hpp @@ -5,10 +5,11 @@ #include #include #include -#include +#include #include #include #include +#include #include namespace stan { @@ -46,143 +47,72 @@ inline var multiply_log(const T1& a, const T2& b) { }); } +namespace internal { +template +inline auto conditional_sum(T&& x) { + if constexpr (Cond) { + return sum(std::forward(x)); + } else { + return std::forward(x); + } +} +} + /** * Return the elementwise product `a * log(b)`. * - * Both `T1` and `T2` are matrices, and one of `T1` or `T2` must be a - * `var_value` + * If both `T1` and `T2` are matrices, the dimensions of `a` and `b` must match. * - * @tparam T1 Type of first argument - * @tparam T2 Type of second argument + * @tparam T1 Either a scalar or a matrix + * @tparam T2 Either a scalar or a matrix * @param a First argument * @param b Second argument * @return elementwise product of `a` and `log(b)` */ -template * = nullptr, - require_any_rev_matrix_t* = nullptr> +template * = nullptr, + require_any_st_var* = nullptr> inline auto multiply_log(T1&& a, T2&& b) { - check_matching_dims("multiply_log", "a", a, "b", b); + constexpr bool is_a_scalar = !is_matrix_v; + constexpr bool is_b_scalar = !is_matrix_v; + if constexpr (!is_a_scalar && !is_b_scalar) { + check_matching_dims("multiply_log", "a", a, "b", b); + } arena_t arena_a = std::forward(a); arena_t arena_b = std::forward(b); using return_t = return_var_matrix_t< decltype(multiply_log(value_of(arena_a), value_of(arena_b))), T1, T2>; arena_t res = multiply_log(value_of(arena_a), value_of(arena_b)); - + using internal::conditional_sum; if constexpr (is_not_constant_v && is_not_constant_v) { reverse_pass_callback([res, arena_a, arena_b]() mutable { + auto arena_a_arr = as_array_or_scalar(arena_a); + auto arena_b_arr = as_array_or_scalar(arena_b); + auto res_arr = as_array_or_scalar(res); auto is_zero - = (arena_a.val().array() == 0.0 && arena_b.val().array() == 0.0); - arena_a.adj().array() += is_zero.select( - 0.0, res.adj().array() * arena_b.val().array().log()); - arena_b.adj().array() - += is_zero.select(0.0, res.adj().array() * arena_a.val().array() - / arena_b.val().array()); + = ((arena_a_arr.val() == 0.0 + arena_b_arr.val() == 0.0) > 1); + arena_a_arr.adj() += conditional_sum(select(is_zero, 0.0, res_arr.adj() * log(arena_b_arr.val()))); + arena_b_arr.adj() += conditional_sum(select(is_zero, 0.0, res_arr.adj() * arena_a_arr.val() / arena_b_arr.val())); }); } else if constexpr (is_not_constant_v) { reverse_pass_callback([res, arena_a, arena_b]() mutable { - auto is_zero = (arena_a.val().array() == 0.0 && arena_b.array() == 0.0); - arena_a.adj().array() - += is_zero.select(0.0, res.adj().array() * arena_b.array().log()); + auto arena_a_arr = as_array_or_scalar(arena_a); + auto arena_b_arr = as_array_or_scalar(arena_b); + auto res_arr = as_array_or_scalar(res); + auto is_zero = ((arena_a_arr.val() == 0.0 + arena_b_arr == 0.0) > 1); + arena_a_arr.adj() += conditional_sum(select(is_zero, 0.0, res_arr.adj() * log(arena_b_arr))); }); } else { reverse_pass_callback([res, arena_a, arena_b]() mutable { - auto is_zero = (arena_a.array() == 0.0 && arena_b.val().array() == 0.0); - arena_b.adj().array() += is_zero.select( - 0.0, res.adj().array() * arena_a.array() / arena_b.val().array()); + auto arena_a_arr = as_array_or_scalar(arena_a); + auto arena_b_arr = as_array_or_scalar(arena_b); + auto res_arr = as_array_or_scalar(res); + auto is_zero = ((arena_a_arr == 0.0 + arena_b_arr.val() == 0.0) > 1); + arena_b_arr.adj() += conditional_sum(select(is_zero, 0.0, res_arr.adj() * arena_a_arr / arena_b_arr.val())); }); } return res; } -/** - * Return the product `a * log(b)`. - * In the case where b is a scalar and and element of `a` and `b` are zero - * the value returned is 0 and no gradients are accumulated. - * For `b`'s adjoint, this function can still return NaN as the adjoint - * of `b` if `a` is nonzero anywhere, but `b` is zero. Likewise, - * `a`'s adjoint can have undefined values if `a` is nonzero but `b` is zero. - * - * @tparam T1 Type of matrix argument - * @tparam T2 Type of scalar argument - * @param a Matrix argument - * @param b Scalar argument - * @return Product of `a` and `log(b)` - */ -template * = nullptr, - require_stan_scalar_t* = nullptr> -inline auto multiply_log(T1&& a, T2&& b) { - arena_t arena_a = a; - using return_t = return_var_matrix_t< - decltype(multiply_log(value_of(arena_a), value_of(b))), T1, T2>; - arena_t res = multiply_log(value_of(arena_a), value_of(b)); - if constexpr (is_not_constant_v && is_not_constant_v) { - reverse_pass_callback([res, arena_a, b]() mutable { - auto is_zero = ((arena_a.val().array() == 0.0) + (b.val() == 0.0) > 1); - arena_a.adj().array() - += is_zero.select(0.0, res.adj().array() * log(b.val())); - b.adj() += is_zero - .select(0.0, (res.adj().array() * arena_a.val().array()) - / b.val()) - .sum(); - }); - } else if constexpr (is_not_constant_v) { - reverse_pass_callback([res, arena_a, b]() mutable { - auto is_zero = ((arena_a.val().array() == 0.0) + (b == 0.0) > 1); - arena_a.adj().array() += is_zero.select(0.0, res.adj().array() * log(b)); - }); - } else { - reverse_pass_callback([res, arena_a, b]() mutable { - auto is_zero = ((arena_a.array() == 0.0) + (b.val() == 0.0) > 1); - b.adj() += is_zero - .select(0.0, (res.adj().array() * arena_a.val().array()) - / b.val()) - .sum(); - }); - } - return res; -} - -/** - * Return the product `a * log(b)`. - * - * @tparam T1 Type of scalar argument - * @tparam T2 Type of matrix argument - * @param a Scalar argument - * @param b Matrix argument - * @return Product of `a` and `log(b)` - */ -template * = nullptr, - require_rev_matrix_t* = nullptr> -inline auto multiply_log(T1&& a, T2&& b) { - arena_t arena_b = std::forward(b); - using return_t = return_var_matrix_t< - decltype(multiply_log(value_of(a), value_of(arena_b))), T1, T2>; - arena_t res = multiply_log(value_of(a), value_of(arena_b)); - if constexpr (is_not_constant_v && is_not_constant_v) { - reverse_pass_callback([res, a, arena_b]() mutable { - auto is_zero = ((a.val() == 0.0) + (arena_b.val().array() == 0.0) > 1); - a.adj() - += is_zero - .select(0.0, res.adj().array() * arena_b.val().array().log()) - .sum(); - arena_b.adj().array() += is_zero.select( - 0.0, a.val() * res.adj().array() / arena_b.val().array()); - }); - } else if constexpr (is_not_constant_v) { - reverse_pass_callback([res, a, arena_b]() mutable { - auto is_zero = ((a.val() == 0.0) + (arena_b.array() == 0.0) > 1); - a.adj() += is_zero.select(0.0, res.adj().array() * arena_b.array().log()) - .sum(); - }); - } else { - reverse_pass_callback([res, a, arena_b]() mutable { - auto is_zero = ((a == 0.0) + (arena_b.val().array() == 0.0) > 1); - arena_b.adj().array() - += is_zero.select(0.0, a * res.adj().array() / arena_b.val().array()); - }); - } - return res; -} } // namespace math } // namespace stan From 511b06a31a4d27fa1934e4a742a51ba17f9b3eb3 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 6 Mar 2025 16:18:07 -0500 Subject: [PATCH 11/19] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/rev/fun/multiply_log.hpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/stan/math/rev/fun/multiply_log.hpp b/stan/math/rev/fun/multiply_log.hpp index 0adf964b6ad..10c3d8b5ec0 100644 --- a/stan/math/rev/fun/multiply_log.hpp +++ b/stan/math/rev/fun/multiply_log.hpp @@ -56,7 +56,7 @@ inline auto conditional_sum(T&& x) { return std::forward(x); } } -} +} // namespace internal /** * Return the elementwise product `a * log(b)`. @@ -90,8 +90,10 @@ inline auto multiply_log(T1&& a, T2&& b) { auto res_arr = as_array_or_scalar(res); auto is_zero = ((arena_a_arr.val() == 0.0 + arena_b_arr.val() == 0.0) > 1); - arena_a_arr.adj() += conditional_sum(select(is_zero, 0.0, res_arr.adj() * log(arena_b_arr.val()))); - arena_b_arr.adj() += conditional_sum(select(is_zero, 0.0, res_arr.adj() * arena_a_arr.val() / arena_b_arr.val())); + arena_a_arr.adj() += conditional_sum( + select(is_zero, 0.0, res_arr.adj() * log(arena_b_arr.val()))); + arena_b_arr.adj() += conditional_sum(select( + is_zero, 0.0, res_arr.adj() * arena_a_arr.val() / arena_b_arr.val())); }); } else if constexpr (is_not_constant_v) { reverse_pass_callback([res, arena_a, arena_b]() mutable { @@ -99,7 +101,8 @@ inline auto multiply_log(T1&& a, T2&& b) { auto arena_b_arr = as_array_or_scalar(arena_b); auto res_arr = as_array_or_scalar(res); auto is_zero = ((arena_a_arr.val() == 0.0 + arena_b_arr == 0.0) > 1); - arena_a_arr.adj() += conditional_sum(select(is_zero, 0.0, res_arr.adj() * log(arena_b_arr))); + arena_a_arr.adj() += conditional_sum( + select(is_zero, 0.0, res_arr.adj() * log(arena_b_arr))); }); } else { reverse_pass_callback([res, arena_a, arena_b]() mutable { @@ -107,13 +110,13 @@ inline auto multiply_log(T1&& a, T2&& b) { auto arena_b_arr = as_array_or_scalar(arena_b); auto res_arr = as_array_or_scalar(res); auto is_zero = ((arena_a_arr == 0.0 + arena_b_arr.val() == 0.0) > 1); - arena_b_arr.adj() += conditional_sum(select(is_zero, 0.0, res_arr.adj() * arena_a_arr / arena_b_arr.val())); + arena_b_arr.adj() += conditional_sum(select( + is_zero, 0.0, res_arr.adj() * arena_a_arr / arena_b_arr.val())); }); } return res; } - } // namespace math } // namespace stan #endif From 026583c2cbb9c3bfbcc51f26555fa7dcf8487bde Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 6 Mar 2025 16:19:02 -0500 Subject: [PATCH 12/19] update reverse mode multiply_log to have one signature to accept combinations of matrices and scalars --- stan/math/rev/fun/multiply_log.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stan/math/rev/fun/multiply_log.hpp b/stan/math/rev/fun/multiply_log.hpp index 10c3d8b5ec0..fa8bb95dad1 100644 --- a/stan/math/rev/fun/multiply_log.hpp +++ b/stan/math/rev/fun/multiply_log.hpp @@ -61,8 +61,8 @@ inline auto conditional_sum(T&& x) { /** * Return the elementwise product `a * log(b)`. * - * If both `T1` and `T2` are matrices, the dimensions of `a` and `b` must match. - * + * For each element of `a` and `b`, when `a[i]` and `b[i]` are 0, + * the value and adjoint returned are zero. * @tparam T1 Either a scalar or a matrix * @tparam T2 Either a scalar or a matrix * @param a First argument From 90c4235c65bc4dc168b8e67e57eed12feffe3634 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Fri, 7 Mar 2025 13:58:31 -0500 Subject: [PATCH 13/19] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- .../offset_multiplier_constrain.hpp | 38 ++++++++++--------- stan/math/prim/fun/lmultiply.hpp | 4 +- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/stan/math/prim/constraint/offset_multiplier_constrain.hpp b/stan/math/prim/constraint/offset_multiplier_constrain.hpp index d9203fac43e..0987ef7e2c5 100644 --- a/stan/math/prim/constraint/offset_multiplier_constrain.hpp +++ b/stan/math/prim/constraint/offset_multiplier_constrain.hpp @@ -58,11 +58,14 @@ inline auto offset_multiplier_constrain(T&& x, M&& mu, S&& sigma) { check_finite("offset_multiplier_constrain", "offset", value_of_rec(mu_ref)); check_positive_finite("offset_multiplier_constrain", "multiplier", value_of_rec(sigma_ref)); - return make_holder([](auto&& sigma_ref, auto&& x, auto&& mu_ref) { - return fma(std::forward(sigma_ref), std::forward(x), - std::forward(mu_ref)); - }, std::forward(sigma_ref), std::forward(x), - std::forward(mu_ref)); + return make_holder( + [](auto&& sigma_ref, auto&& x, auto&& mu_ref) { + return fma(std::forward(sigma_ref), + std::forward(x), + std::forward(mu_ref)); + }, + std::forward(sigma_ref), std::forward(x), + std::forward(mu_ref)); } /** @@ -96,14 +99,13 @@ template * = nullptr, require_all_not_std_vector_t* = nullptr> -inline auto offset_multiplier_constrain(T&& x, M&& mu, S&& sigma, - Lp& lp) { +inline auto offset_multiplier_constrain(T&& x, M&& mu, S&& sigma, Lp& lp) { if constexpr (is_matrix_v && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "x", x, "mu", mu); } if constexpr (is_matrix_v && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "x", x, "sigma", sigma); - } + } if constexpr (is_matrix_v && is_matrix_v) { check_matching_dims("offset_multiplier_constrain", "mu", mu, "sigma", sigma); @@ -118,12 +120,14 @@ inline auto offset_multiplier_constrain(T&& x, M&& mu, S&& sigma, } else { lp += sum(log(sigma_ref)); } - return make_holder([](auto&& sigma_ref, auto&& x, auto&& mu_ref) { - return fma(std::forward(sigma_ref), std::forward(x), - std::forward(mu_ref)); - }, std::forward(sigma_ref), std::forward(x), - std::forward(mu_ref)); - + return make_holder( + [](auto&& sigma_ref, auto&& x, auto&& mu_ref) { + return fma(std::forward(sigma_ref), + std::forward(x), + std::forward(mu_ref)); + }, + std::forward(sigma_ref), std::forward(x), + std::forward(mu_ref)); } /** @@ -167,7 +171,7 @@ inline auto offset_multiplier_constrain(const T& x, M&& mu, S&& sigma) { * Overload for when x and mu or sigma are `std::vectors` */ template , Lp>* = nullptr, + require_convertible_t, Lp>* = nullptr, require_any_std_vector_t* = nullptr> inline auto offset_multiplier_constrain(const T& x, M&& mu, S&& sigma, Lp& lp) { if constexpr (is_std_vector_v && is_std_vector_v) { @@ -195,8 +199,8 @@ inline auto offset_multiplier_constrain(const T& x, M&& mu, S&& sigma, Lp& lp) { // In the language, if mu or sigma is a vector, x must also be a vector ret.reserve(x.size()); for (size_t i = 0; i < x.size(); ++i) { - ret.emplace_back( - offset_multiplier_constrain(x[i], iter(mu_ref, i), iter(sigma_ref, i), lp)); + ret.emplace_back(offset_multiplier_constrain(x[i], iter(mu_ref, i), + iter(sigma_ref, i), lp)); } return ret; } diff --git a/stan/math/prim/fun/lmultiply.hpp b/stan/math/prim/fun/lmultiply.hpp index 9ec3308b1cd..c4c484aead9 100644 --- a/stan/math/prim/fun/lmultiply.hpp +++ b/stan/math/prim/fun/lmultiply.hpp @@ -22,7 +22,9 @@ namespace math { * @param b second argument * @return the first argument times the log of the second argument */ -template * = nullptr> +template * = nullptr> inline auto lmultiply(T1&& a, T2&& b) { return make_holder( [](auto&& a, auto&& b) { From 0b4a667a3eff0fb52eaf1a96387bbd42a5f40d0f Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Fri, 14 Mar 2025 11:31:29 -0400 Subject: [PATCH 14/19] fix opencl lmultiply --- stan/math/opencl/rev/lmultiply.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stan/math/opencl/rev/lmultiply.hpp b/stan/math/opencl/rev/lmultiply.hpp index 46295525cd5..a2d5d78c9dc 100644 --- a/stan/math/opencl/rev/lmultiply.hpp +++ b/stan/math/opencl/rev/lmultiply.hpp @@ -33,9 +33,9 @@ inline var_value> lmultiply(T_a&& a, T_b&& b) { lmultiply(value_of(a_arena), value_of(b_arena)), [a_arena, b_arena](const vari_value>& res) mutable { adjoint_results(a_arena, b_arena) += expressions( - elt_multiply(res.adj(), log(value_of(b_arena))), - elt_multiply(res.adj(), - elt_divide(value_of(a_arena), value_of(b_arena)))); + select(value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0, 0.0, elt_multiply(res.adj(), log(value_of(b_arena)))), + select(value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0, 0.0, elt_multiply(res.adj(), + elt_divide(value_of(a_arena), value_of(b_arena))))); }); } From eb7db19522ca15c00d6a004d4c49f9da5152fb3f Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Fri, 14 Mar 2025 11:32:25 -0400 Subject: [PATCH 15/19] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/opencl/rev/lmultiply.hpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/stan/math/opencl/rev/lmultiply.hpp b/stan/math/opencl/rev/lmultiply.hpp index a2d5d78c9dc..ddaf0958aed 100644 --- a/stan/math/opencl/rev/lmultiply.hpp +++ b/stan/math/opencl/rev/lmultiply.hpp @@ -33,9 +33,11 @@ inline var_value> lmultiply(T_a&& a, T_b&& b) { lmultiply(value_of(a_arena), value_of(b_arena)), [a_arena, b_arena](const vari_value>& res) mutable { adjoint_results(a_arena, b_arena) += expressions( - select(value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0, 0.0, elt_multiply(res.adj(), log(value_of(b_arena)))), - select(value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0, 0.0, elt_multiply(res.adj(), - elt_divide(value_of(a_arena), value_of(b_arena))))); + select(value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0, 0.0, + elt_multiply(res.adj(), log(value_of(b_arena)))), + select(value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0, 0.0, + elt_multiply(res.adj(), elt_divide(value_of(a_arena), + value_of(b_arena))))); }); } From d55f7a6120853e7ccad216b431225dee4a140cf9 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Wed, 19 Mar 2025 17:18:48 -0400 Subject: [PATCH 16/19] remove lmultiply from opencl and instead use opencl multiply_log --- .../kernel_generator/elt_function_cl.hpp | 4 +- stan/math/opencl/rev/lmultiply.hpp | 48 ------------------- stan/math/opencl/rev/multiply_log.hpp | 20 ++++++-- stan/math/prim/fun/lmultiply.hpp | 21 ++++---- 4 files changed, 30 insertions(+), 63 deletions(-) diff --git a/stan/math/opencl/kernel_generator/elt_function_cl.hpp b/stan/math/opencl/kernel_generator/elt_function_cl.hpp index a40301d716c..bb5d5639e12 100644 --- a/stan/math/opencl/kernel_generator/elt_function_cl.hpp +++ b/stan/math/opencl/kernel_generator/elt_function_cl.hpp @@ -365,8 +365,8 @@ ADD_BINARY_FUNCTION_WITH_INCLUDES(log_diff_exp, opencl_kernels::log_diff_exp_device_function) ADD_BINARY_FUNCTION_WITH_INCLUDES( multiply_log, stan::math::opencl_kernels::multiply_log_device_function) -ADD_BINARY_FUNCTION_WITH_INCLUDES( - lmultiply, stan::math::opencl_kernels::lmultiply_device_function) +//ADD_BINARY_FUNCTION_WITH_INCLUDES( +// lmultiply, stan::math::opencl_kernels::lmultiply_device_function) #undef ADD_BINARY_FUNCTION_WITH_INCLUDES #undef ADD_UNARY_FUNCTION_WITH_INCLUDES diff --git a/stan/math/opencl/rev/lmultiply.hpp b/stan/math/opencl/rev/lmultiply.hpp index ddaf0958aed..e69de29bb2d 100644 --- a/stan/math/opencl/rev/lmultiply.hpp +++ b/stan/math/opencl/rev/lmultiply.hpp @@ -1,48 +0,0 @@ -#ifndef STAN_MATH_OPENCL_REV_LMULTIPLY_HPP -#define STAN_MATH_OPENCL_REV_LMULTIPLY_HPP -#ifdef STAN_OPENCL - -#include -#include -#include -#include -#include - -namespace stan { -namespace math { - -/** - * Returns the elementwise `lmultiply()` of the input. - * - * @tparam T_a type of first expression - * @tparam T_b type of second expression - * @param a first expression - * @param b second expression - * - * @return Elementwise `lmultiply()` of the input. - */ -template * = nullptr, - require_any_var_t* = nullptr, - require_any_not_stan_scalar_t* = nullptr> -inline var_value> lmultiply(T_a&& a, T_b&& b) { - arena_t a_arena = std::forward(a); - arena_t b_arena = std::forward(b); - - return make_callback_var( - lmultiply(value_of(a_arena), value_of(b_arena)), - [a_arena, b_arena](const vari_value>& res) mutable { - adjoint_results(a_arena, b_arena) += expressions( - select(value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0, 0.0, - elt_multiply(res.adj(), log(value_of(b_arena)))), - select(value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0, 0.0, - elt_multiply(res.adj(), elt_divide(value_of(a_arena), - value_of(b_arena))))); - }); -} - -} // namespace math -} // namespace stan - -#endif -#endif diff --git a/stan/math/opencl/rev/multiply_log.hpp b/stan/math/opencl/rev/multiply_log.hpp index 535129b3e77..19ce898014d 100644 --- a/stan/math/opencl/rev/multiply_log.hpp +++ b/stan/math/opencl/rev/multiply_log.hpp @@ -32,10 +32,22 @@ inline var_value> multiply_log(T_a&& a, T_b&& b) { return make_callback_var( multiply_log(value_of(a_arena), value_of(b_arena)), [a_arena, b_arena](const vari_value>& res) mutable { - adjoint_results(a_arena, b_arena) += expressions( - elt_multiply(res.adj(), log(value_of(b_arena))), - elt_multiply(res.adj(), - elt_divide(value_of(a_arena), value_of(b_arena)))); + auto is_zero = value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0; + if constexpr (is_var::value && is_var::value) { + a_arena.adj() += select(is_zero, 0.0, + elt_multiply(res.adj(), log(value_of(b_arena)))); + b_arena.adj() += select(is_zero, 0.0, + elt_multiply(res.adj(), elt_divide(value_of(a_arena), + value_of(b_arena)))); + } else if constexpr (is_var::value) { + a_arena.adj() += select(is_zero, 0.0, + elt_multiply(res.adj(), log(value_of(b_arena)))); + } else if constexpr (is_var::value) { + b_arena.adj() += + select(is_zero, 0.0, + elt_multiply(res.adj(), elt_divide(value_of(a_arena), + value_of(b_arena)))); + } }); } diff --git a/stan/math/prim/fun/lmultiply.hpp b/stan/math/prim/fun/lmultiply.hpp index c4c484aead9..351b5a7e4e5 100644 --- a/stan/math/prim/fun/lmultiply.hpp +++ b/stan/math/prim/fun/lmultiply.hpp @@ -22,16 +22,19 @@ namespace math { * @param b second argument * @return the first argument times the log of the second argument */ -template * = nullptr> +template inline auto lmultiply(T1&& a, T2&& b) { - return make_holder( - [](auto&& a, auto&& b) { - return multiply_log(std::forward(a), - std::forward(b)); - }, - std::forward(a), std::forward(b)); + if constexpr (is_kernel_expression::value + || is_kernel_expression::value) { + return multiply_log(std::forward(a), std::forward(b)); + } else { + return make_holder( + [](auto&& a, auto&& b) { + return multiply_log(std::forward(a), + std::forward(b)); + }, + std::forward(a), std::forward(b)); + } } } // namespace math } // namespace stan From f1c9094c33d86aa94f63d39d51acc995b0d4eb1c Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Wed, 19 Mar 2025 17:19:44 -0400 Subject: [PATCH 17/19] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- .../kernel_generator/elt_function_cl.hpp | 2 +- stan/math/opencl/rev/multiply_log.hpp | 23 ++++++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/stan/math/opencl/kernel_generator/elt_function_cl.hpp b/stan/math/opencl/kernel_generator/elt_function_cl.hpp index bb5d5639e12..19b58d3683d 100644 --- a/stan/math/opencl/kernel_generator/elt_function_cl.hpp +++ b/stan/math/opencl/kernel_generator/elt_function_cl.hpp @@ -365,7 +365,7 @@ ADD_BINARY_FUNCTION_WITH_INCLUDES(log_diff_exp, opencl_kernels::log_diff_exp_device_function) ADD_BINARY_FUNCTION_WITH_INCLUDES( multiply_log, stan::math::opencl_kernels::multiply_log_device_function) -//ADD_BINARY_FUNCTION_WITH_INCLUDES( +// ADD_BINARY_FUNCTION_WITH_INCLUDES( // lmultiply, stan::math::opencl_kernels::lmultiply_device_function) #undef ADD_BINARY_FUNCTION_WITH_INCLUDES diff --git a/stan/math/opencl/rev/multiply_log.hpp b/stan/math/opencl/rev/multiply_log.hpp index 19ce898014d..6f16788379c 100644 --- a/stan/math/opencl/rev/multiply_log.hpp +++ b/stan/math/opencl/rev/multiply_log.hpp @@ -34,19 +34,20 @@ inline var_value> multiply_log(T_a&& a, T_b&& b) { [a_arena, b_arena](const vari_value>& res) mutable { auto is_zero = value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0; if constexpr (is_var::value && is_var::value) { - a_arena.adj() += select(is_zero, 0.0, - elt_multiply(res.adj(), log(value_of(b_arena)))); - b_arena.adj() += select(is_zero, 0.0, - elt_multiply(res.adj(), elt_divide(value_of(a_arena), - value_of(b_arena)))); + a_arena.adj() += select( + is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena)))); + b_arena.adj() += select( + is_zero, 0.0, + elt_multiply(res.adj(), + elt_divide(value_of(a_arena), value_of(b_arena)))); } else if constexpr (is_var::value) { - a_arena.adj() += select(is_zero, 0.0, - elt_multiply(res.adj(), log(value_of(b_arena)))); + a_arena.adj() += select( + is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena)))); } else if constexpr (is_var::value) { - b_arena.adj() += - select(is_zero, 0.0, - elt_multiply(res.adj(), elt_divide(value_of(a_arena), - value_of(b_arena)))); + b_arena.adj() += select( + is_zero, 0.0, + elt_multiply(res.adj(), + elt_divide(value_of(a_arena), value_of(b_arena)))); } }); } From c6ad6b0abc576268255c6e8d8d5ae88883a9a923 Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 20 Mar 2025 13:26:03 -0400 Subject: [PATCH 18/19] conditional sum for opencl multiply_log --- stan/math/opencl/rev/multiply_log.hpp | 30 ++++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/stan/math/opencl/rev/multiply_log.hpp b/stan/math/opencl/rev/multiply_log.hpp index 6f16788379c..3a8a5f041e3 100644 --- a/stan/math/opencl/rev/multiply_log.hpp +++ b/stan/math/opencl/rev/multiply_log.hpp @@ -10,6 +10,17 @@ namespace stan { namespace math { +namespace internalcl { + template + inline decltype(auto) conditional_sum(T&& x) { + if constexpr (Cond) { + return sum(std::forward(x)); + } else { + return std::forward(x); + } + } + } // namespace internal + /** * Returns the elementwise `multiply_log()` of the input. @@ -32,22 +43,25 @@ inline var_value> multiply_log(T_a&& a, T_b&& b) { return make_callback_var( multiply_log(value_of(a_arena), value_of(b_arena)), [a_arena, b_arena](const vari_value>& res) mutable { + constexpr bool is_scalar_a = !is_matrix_v; + constexpr bool is_scalar_b = !is_matrix_v; + using internalcl::conditional_sum; auto is_zero = value_of(a_arena) == 0.0 && value_of(b_arena) == 0.0; if constexpr (is_var::value && is_var::value) { - a_arena.adj() += select( - is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena)))); - b_arena.adj() += select( + a_arena.adj() += conditional_sum(select( + is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena))))); + b_arena.adj() += conditional_sum(select( is_zero, 0.0, elt_multiply(res.adj(), - elt_divide(value_of(a_arena), value_of(b_arena)))); + elt_divide(value_of(a_arena), value_of(b_arena))))); } else if constexpr (is_var::value) { - a_arena.adj() += select( - is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena)))); + a_arena.adj() += conditional_sum(select( + is_zero, 0.0, elt_multiply(res.adj(), log(value_of(b_arena))))); } else if constexpr (is_var::value) { - b_arena.adj() += select( + b_arena.adj() += conditional_sum(select( is_zero, 0.0, elt_multiply(res.adj(), - elt_divide(value_of(a_arena), value_of(b_arena)))); + elt_divide(value_of(a_arena), value_of(b_arena))))); } }); } From cbc3af505d6a8c9edc31567da6f2c14a9775a8b9 Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Thu, 20 Mar 2025 13:26:56 -0400 Subject: [PATCH 19/19] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/opencl/rev/multiply_log.hpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/stan/math/opencl/rev/multiply_log.hpp b/stan/math/opencl/rev/multiply_log.hpp index 3a8a5f041e3..a1527d8697b 100644 --- a/stan/math/opencl/rev/multiply_log.hpp +++ b/stan/math/opencl/rev/multiply_log.hpp @@ -11,16 +11,15 @@ namespace stan { namespace math { namespace internalcl { - template - inline decltype(auto) conditional_sum(T&& x) { - if constexpr (Cond) { - return sum(std::forward(x)); - } else { - return std::forward(x); - } +template +inline decltype(auto) conditional_sum(T&& x) { + if constexpr (Cond) { + return sum(std::forward(x)); + } else { + return std::forward(x); } - } // namespace internal - +} +} // namespace internalcl /** * Returns the elementwise `multiply_log()` of the input.