From e5e2f1355ada767015c49511f9e026948d519558 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Wed, 25 Jun 2025 22:14:50 +0800 Subject: [PATCH 1/5] Use boost beta, simplify overloads --- stan/math/fwd/fun/beta.hpp | 38 ++--- stan/math/prim/fun/beta.hpp | 20 ++- stan/math/rev/fun/beta.hpp | 220 ++++---------------------- test/unit/math/prim/fun/beta_test.cpp | 4 +- 4 files changed, 64 insertions(+), 218 deletions(-) diff --git a/stan/math/fwd/fun/beta.hpp b/stan/math/fwd/fun/beta.hpp index a4c8b782052..8bce055b0c6 100644 --- a/stan/math/fwd/fun/beta.hpp +++ b/stan/math/fwd/fun/beta.hpp @@ -47,27 +47,23 @@ namespace math { * @param x2 Second value * @return Fvar with result beta function of arguments and gradients. */ -template -inline fvar beta(const fvar& x1, const fvar& x2) { - const T beta_ab = beta(x1.val_, x2.val_); - return fvar(beta_ab, - beta_ab - * (x1.d_ * digamma(x1.val_) + x2.d_ * digamma(x2.val_) - - (x1.d_ + x2.d_) * digamma(x1.val_ + x2.val_))); -} - -template -inline fvar beta(double x1, const fvar& x2) { - const T beta_ab = beta(x1, x2.val_); - return fvar(beta_ab, - x2.d_ * (digamma(x2.val_) - digamma(x1 + x2.val_)) * beta_ab); -} - -template -inline fvar beta(const fvar& x1, double x2) { - const T beta_ab = beta(x1.val_, x2); - return fvar(beta_ab, - x1.d_ * (digamma(x1.val_) - digamma(x1.val_ + x2)) * beta_ab); +template , + require_return_type_t* = nullptr, + require_all_stan_scalar_t* = nullptr> +inline fvar beta(const Ta& a, const Tb& b) { + const auto& a_val = value_of(a); + const auto& b_val = value_of(b); + const FvarInnerT beta_val = beta(a_val, b_val); + const FvarInnerT digamma_ab = digamma(a_val + b_val); + FvarInnerT beta_d(0); + if constexpr (!is_constant::value) { + beta_d += (digamma(a_val) - digamma_ab) * beta_val * a.d_; + } + if constexpr (!is_constant::value) { + beta_d += (digamma(b_val) - digamma_ab) * beta_val * b.d_; + } + return fvar(beta_val, beta_d); } } // namespace math diff --git a/stan/math/prim/fun/beta.hpp b/stan/math/prim/fun/beta.hpp index 969ce5e0efc..85c0e4634ee 100644 --- a/stan/math/prim/fun/beta.hpp +++ b/stan/math/prim/fun/beta.hpp @@ -2,10 +2,9 @@ #define STAN_MATH_PRIM_FUN_BETA_HPP #include -#include -#include +#include #include -#include +#include namespace stan { namespace math { @@ -51,8 +50,7 @@ namespace math { */ template * = nullptr> inline return_type_t beta(const T1 a, const T2 b) { - using std::exp; - return exp(lgamma(a) + lgamma(b) - lgamma(a + b)); + return boost::math::beta(a, b, boost_policy_t<>()); } /** @@ -65,8 +63,16 @@ inline return_type_t beta(const T1 a, const T2 b) { * @param b Second input * @return Beta function applied to the two inputs. */ -template * = nullptr, - require_all_not_var_matrix_t* = nullptr> +template * = nullptr, + require_t< + math::disjunction< + is_arithmetic>, + is_fvar>, + is_std_vector, + is_std_vector + > + >* = nullptr> inline auto beta(const T1& a, const T2& b) { return apply_scalar_binary( [](const auto& c, const auto& d) { return beta(c, d); }, a, b); diff --git a/stan/math/rev/fun/beta.hpp b/stan/math/rev/fun/beta.hpp index 26303e923c5..c25d5bb1986 100644 --- a/stan/math/rev/fun/beta.hpp +++ b/stan/math/rev/fun/beta.hpp @@ -34,197 +34,41 @@ namespace math { * @param b var Argument * @return Result of beta function */ -inline var beta(const var& a, const var& b) { - double digamma_ab = digamma(a.val() + b.val()); - double digamma_a = digamma(a.val()) - digamma_ab; - double digamma_b = digamma(b.val()) - digamma_ab; - return make_callback_var(beta(a.val(), b.val()), - [a, b, digamma_a, digamma_b](auto& vi) mutable { - const double adj_val = vi.adj() * vi.val(); - a.adj() += adj_val * digamma_a; - b.adj() += adj_val * digamma_b; - }); -} +template * = nullptr, + require_return_type_t* = nullptr> +inline auto beta(const T1& a, const T2& b) { + using inner_return_t = decltype(beta(value_of(a), value_of(b))); + using return_t = return_var_matrix_t; + arena_t> arena_a = a; + arena_t> arena_b = b; -/** - * Returns the beta function and gradient for first var input. - * - \f[ - \mathrm{beta}(a,b) = \left(B\left(a,b\right)\right) - \f] + return_t res = beta(value_of(arena_a), value_of(arena_b)); + reverse_pass_callback([arena_a, arena_b, res]() mutable { + auto&& a_array = as_array_or_scalar(arena_a); + auto&& b_array = as_array_or_scalar(arena_b); + const auto& res_array = as_array_or_scalar(res); + const auto& digamma_ab = digamma(value_of(a_array) + value_of(b_array)); + const auto& adj_val = res_array.adj() * res_array.val(); - \f[ - \frac{\partial }{\partial a} = \left(\psi^{\left(0\right)}\left(a\right) - - \psi^{\left(0\right)} - \left(a + b\right)\right) - * \mathrm{beta}(a,b) - \f] - * - * @param a var Argument - * @param b double Argument - * @return Result of beta function - */ -inline var beta(const var& a, double b) { - auto digamma_ab = digamma(a.val()) - digamma(a.val() + b); - return make_callback_var(beta(a.val(), b), [a, digamma_ab](auto& vi) mutable { - a.adj() += vi.adj() * digamma_ab * vi.val(); + if constexpr (!is_constant::value) { + const auto& a_adj = adj_val * (digamma(a_array.val()) - digamma_ab); + if constexpr (is_stan_scalar::value) { + a_array.adj() += sum(a_adj); + } else { + a_array.adj() += a_adj; + } + } + if constexpr (!is_constant::value) { + const auto& b_adj = adj_val * (digamma(b_array.val()) - digamma_ab); + if constexpr (is_stan_scalar::value) { + b_array.adj() += sum(b_adj); + } else { + b_array.adj() += b_adj; + } + } }); -} - -/** - * Returns the beta function and gradient for second var input. - * - \f[ - \mathrm{beta}(a,b) = \left(B\left(a,b\right)\right) - \f] - - \f[ - \frac{\partial }{\partial b} = \left(\psi^{\left(0\right)}\left(b\right) - - \psi^{\left(0\right)} - \left(a + b\right)\right) - * \mathrm{beta}(a,b) - \f] - * - * @param a double Argument - * @param b var Argument - * @return Result of beta function - */ -inline var beta(double a, const var& b) { - auto beta_val = beta(a, b.val()); - auto digamma_ab = (digamma(b.val()) - digamma(a + b.val())) * beta_val; - return make_callback_var(beta_val, [b, digamma_ab](auto& vi) mutable { - b.adj() += vi.adj() * digamma_ab; - }); -} - -template * = nullptr, - require_all_matrix_t* = nullptr> -inline auto beta(const Mat1& a, const Mat2& b) { - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = b; - auto beta_val = beta(arena_a.val(), arena_b.val()); - auto digamma_ab - = to_arena(digamma(arena_a.val().array() + arena_b.val().array())); - return make_callback_var( - beta(arena_a.val(), arena_b.val()), - [arena_a, arena_b, digamma_ab](auto& vi) mutable { - const auto adj_val = (vi.adj().array() * vi.val().array()).eval(); - arena_a.adj().array() - += adj_val * (digamma(arena_a.val().array()) - digamma_ab); - arena_b.adj().array() - += adj_val * (digamma(arena_b.val().array()) - digamma_ab); - }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - arena_t> arena_b = value_of(b); - auto digamma_ab - = to_arena(digamma(arena_a.val()).array() - - digamma(arena_a.val().array() + arena_b.array())); - return make_callback_var(beta(arena_a.val(), arena_b), - [arena_a, arena_b, digamma_ab](auto& vi) mutable { - arena_a.adj().array() += vi.adj().array() - * digamma_ab - * vi.val().array(); - }); - } else if (!is_constant::value) { - arena_t> arena_a = value_of(a); - arena_t> arena_b = b; - auto beta_val = beta(arena_a, arena_b.val()); - auto digamma_ab - = to_arena((digamma(arena_b.val()).array() - - digamma(arena_a.array() + arena_b.val().array())) - * beta_val.array()); - return make_callback_var( - beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable { - arena_b.adj().array() += vi.adj().array() * digamma_ab.array(); - }); - } -} - -template * = nullptr, - require_stan_scalar_t* = nullptr> -inline auto beta(const Scalar& a, const VarMat& b) { - if (!is_constant::value && !is_constant::value) { - var arena_a = a; - arena_t> arena_b = b; - auto beta_val = beta(arena_a.val(), arena_b.val()); - auto digamma_ab = to_arena(digamma(arena_a.val() + arena_b.val().array())); - return make_callback_var( - beta(arena_a.val(), arena_b.val()), - [arena_a, arena_b, digamma_ab](auto& vi) mutable { - const auto adj_val = (vi.adj().array() * vi.val().array()).eval(); - arena_a.adj() - += (adj_val * (digamma(arena_a.val()) - digamma_ab)).sum(); - arena_b.adj().array() - += adj_val * (digamma(arena_b.val().array()) - digamma_ab); - }); - } else if (!is_constant::value) { - var arena_a = a; - arena_t> arena_b = value_of(b); - auto digamma_ab = to_arena(digamma(arena_a.val()) - - digamma(arena_a.val() + arena_b.array())); - return make_callback_var( - beta(arena_a.val(), arena_b), - [arena_a, arena_b, digamma_ab](auto& vi) mutable { - arena_a.adj() - += (vi.adj().array() * digamma_ab * vi.val().array()).sum(); - }); - } else if (!is_constant::value) { - double arena_a = value_of(a); - arena_t> arena_b = b; - auto beta_val = beta(arena_a, arena_b.val()); - auto digamma_ab = to_arena((digamma(arena_b.val()).array() - - digamma(arena_a + arena_b.val().array())) - * beta_val.array()); - return make_callback_var(beta_val, [arena_b, digamma_ab](auto& vi) mutable { - arena_b.adj().array() += vi.adj().array() * digamma_ab.array(); - }); - } -} - -template * = nullptr, - require_stan_scalar_t* = nullptr> -inline auto beta(const VarMat& a, const Scalar& b) { - if (!is_constant::value && !is_constant::value) { - arena_t> arena_a = a; - var arena_b = b; - auto beta_val = beta(arena_a.val(), arena_b.val()); - auto digamma_ab = to_arena(digamma(arena_a.val().array() + arena_b.val())); - return make_callback_var( - beta(arena_a.val(), arena_b.val()), - [arena_a, arena_b, digamma_ab](auto& vi) mutable { - const auto adj_val = (vi.adj().array() * vi.val().array()).eval(); - arena_a.adj().array() - += adj_val * (digamma(arena_a.val().array()) - digamma_ab); - arena_b.adj() - += (adj_val * (digamma(arena_b.val()) - digamma_ab)).sum(); - }); - } else if (!is_constant::value) { - arena_t> arena_a = a; - double arena_b = value_of(b); - auto digamma_ab = to_arena(digamma(arena_a.val()).array() - - digamma(arena_a.val().array() + arena_b)); - return make_callback_var( - beta(arena_a.val(), arena_b), [arena_a, digamma_ab](auto& vi) mutable { - arena_a.adj().array() - += vi.adj().array() * digamma_ab * vi.val().array(); - }); - } else if (!is_constant::value) { - arena_t> arena_a = value_of(a); - var arena_b = b; - auto beta_val = beta(arena_a, arena_b.val()); - auto digamma_ab = to_arena( - (digamma(arena_b.val()) - digamma(arena_a.array() + arena_b.val())) - * beta_val.array()); - return make_callback_var( - beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable { - arena_b.adj() += (vi.adj().array() * digamma_ab.array()).sum(); - }); - } + return res; } } // namespace math diff --git a/test/unit/math/prim/fun/beta_test.cpp b/test/unit/math/prim/fun/beta_test.cpp index 371bec37bd5..dc480662258 100644 --- a/test/unit/math/prim/fun/beta_test.cpp +++ b/test/unit/math/prim/fun/beta_test.cpp @@ -26,8 +26,8 @@ TEST(MathFunctions, beta_vec) { auto f = [](const auto& x1, const auto& x2) { return stan::math::beta(x1, x2); }; - Eigen::VectorXd in1 = Eigen::VectorXd::Random(6); - Eigen::VectorXd in2 = Eigen::VectorXd::Random(6); + Eigen::VectorXd in1 = Eigen::VectorXd::Random(6).cwiseAbs(); + Eigen::VectorXd in2 = Eigen::VectorXd::Random(6).cwiseAbs(); stan::test::binary_scalar_tester(f, in1, in2); } From 544a5e8de75ec482f7c730f3069c9efbfb95926b Mon Sep 17 00:00:00 2001 From: Stan Jenkins Date: Wed, 25 Jun 2025 10:24:48 -0400 Subject: [PATCH 2/5] [Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1 --- stan/math/prim/fun/beta.hpp | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/stan/math/prim/fun/beta.hpp b/stan/math/prim/fun/beta.hpp index 85c0e4634ee..d232bb8daa9 100644 --- a/stan/math/prim/fun/beta.hpp +++ b/stan/math/prim/fun/beta.hpp @@ -63,16 +63,11 @@ inline return_type_t beta(const T1 a, const T2 b) { * @param b Second input * @return Beta function applied to the two inputs. */ -template * = nullptr, - require_t< - math::disjunction< - is_arithmetic>, - is_fvar>, - is_std_vector, - is_std_vector - > - >* = nullptr> +template < + typename T1, typename T2, require_any_container_t* = nullptr, + require_t>, is_fvar>, + is_std_vector, is_std_vector>>* = nullptr> inline auto beta(const T1& a, const T2& b) { return apply_scalar_binary( [](const auto& c, const auto& d) { return beta(c, d); }, a, b); From 2c8120640ca730f0a3acc81951a6a4d88d00ab8b Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Wed, 25 Jun 2025 22:35:28 +0800 Subject: [PATCH 3/5] Update doxygen --- stan/math/fwd/fun/beta.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/stan/math/fwd/fun/beta.hpp b/stan/math/fwd/fun/beta.hpp index 8bce055b0c6..5d09cb9e826 100644 --- a/stan/math/fwd/fun/beta.hpp +++ b/stan/math/fwd/fun/beta.hpp @@ -42,9 +42,10 @@ namespace math { \end{cases} \f] * - * @tparam T inner type of the fvar - * @param x1 First value - * @param x2 Second value + * @tparam Ta Type of first scalar argument + * @tparam Tb Type of second scalar argument + * @param a First value + * @param b Second value * @return Fvar with result beta function of arguments and gradients. */ template Date: Wed, 25 Jun 2025 22:57:48 +0800 Subject: [PATCH 4/5] Expression fixes --- stan/math/rev/fun/beta.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stan/math/rev/fun/beta.hpp b/stan/math/rev/fun/beta.hpp index c25d5bb1986..44b4a5a99de 100644 --- a/stan/math/rev/fun/beta.hpp +++ b/stan/math/rev/fun/beta.hpp @@ -38,12 +38,12 @@ template * = nullptr, require_return_type_t* = nullptr> inline auto beta(const T1& a, const T2& b) { - using inner_return_t = decltype(beta(value_of(a), value_of(b))); - using return_t = return_var_matrix_t; arena_t> arena_a = a; arena_t> arena_b = b; - return_t res = beta(value_of(arena_a), value_of(arena_b)); + const auto& beta_val = beta(value_of(arena_a), value_of(arena_b)); + return_var_matrix_t res(beta_val); + reverse_pass_callback([arena_a, arena_b, res]() mutable { auto&& a_array = as_array_or_scalar(arena_a); auto&& b_array = as_array_or_scalar(arena_b); From f691622074d955324950ea1555b697df6812b301 Mon Sep 17 00:00:00 2001 From: Andrew Johnson Date: Wed, 25 Jun 2025 16:39:02 +0000 Subject: [PATCH 5/5] arena for ret --- stan/math/rev/fun/beta.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/stan/math/rev/fun/beta.hpp b/stan/math/rev/fun/beta.hpp index 44b4a5a99de..78e17c5e252 100644 --- a/stan/math/rev/fun/beta.hpp +++ b/stan/math/rev/fun/beta.hpp @@ -42,7 +42,8 @@ inline auto beta(const T1& a, const T2& b) { arena_t> arena_b = b; const auto& beta_val = beta(value_of(arena_a), value_of(arena_b)); - return_var_matrix_t res(beta_val); + using return_type_t = return_var_matrix_t; + arena_t res(beta_val); reverse_pass_callback([arena_a, arena_b, res]() mutable { auto&& a_array = as_array_or_scalar(arena_a); @@ -68,7 +69,7 @@ inline auto beta(const T1& a, const T2& b) { } } }); - return res; + return return_type_t(res); } } // namespace math