Skip to content

Use boost beta, simplify overloads #3212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 21 additions & 24 deletions stan/math/fwd/fun/beta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,29 @@ namespace math {
\end{cases}
\f]
*
* @tparam T inner type of the fvar
* @param x1 First value
* @param x2 Second value
* @tparam Ta Type of first scalar argument
* @tparam Tb Type of second scalar argument
* @param a First value
* @param b Second value
* @return Fvar with result beta function of arguments and gradients.
*/
template <typename T>
inline fvar<T> beta(const fvar<T>& x1, const fvar<T>& x2) {
const T beta_ab = beta(x1.val_, x2.val_);
return fvar<T>(beta_ab,
beta_ab
* (x1.d_ * digamma(x1.val_) + x2.d_ * digamma(x2.val_)
- (x1.d_ + x2.d_) * digamma(x1.val_ + x2.val_)));
}

template <typename T>
inline fvar<T> beta(double x1, const fvar<T>& x2) {
const T beta_ab = beta(x1, x2.val_);
return fvar<T>(beta_ab,
x2.d_ * (digamma(x2.val_) - digamma(x1 + x2.val_)) * beta_ab);
}

template <typename T>
inline fvar<T> beta(const fvar<T>& x1, double x2) {
const T beta_ab = beta(x1.val_, x2);
return fvar<T>(beta_ab,
x1.d_ * (digamma(x1.val_) - digamma(x1.val_ + x2)) * beta_ab);
template <typename Ta, typename Tb,
typename FvarInnerT = partials_return_t<Ta, Tb>,
require_return_type_t<is_fvar, Ta, Tb>* = nullptr,
require_all_stan_scalar_t<Ta, Tb>* = nullptr>
inline fvar<FvarInnerT> beta(const Ta& a, const Tb& b) {
const auto& a_val = value_of(a);
const auto& b_val = value_of(b);
const FvarInnerT beta_val = beta(a_val, b_val);
const FvarInnerT digamma_ab = digamma(a_val + b_val);
FvarInnerT beta_d(0);
if constexpr (!is_constant<Ta>::value) {
beta_d += (digamma(a_val) - digamma_ab) * beta_val * a.d_;
}
if constexpr (!is_constant<Tb>::value) {
beta_d += (digamma(b_val) - digamma_ab) * beta_val * b.d_;
}
return fvar<FvarInnerT>(beta_val, beta_d);
}

} // namespace math
Expand Down
15 changes: 8 additions & 7 deletions stan/math/prim/fun/beta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
#define STAN_MATH_PRIM_FUN_BETA_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/lgamma.hpp>
#include <stan/math/prim/fun/boost_policy.hpp>
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <cmath>
#include <boost/math/special_functions/beta.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -51,8 +50,7 @@ namespace math {
*/
template <typename T1, typename T2, require_all_arithmetic_t<T1, T2>* = nullptr>
inline return_type_t<T1, T2> beta(const T1 a, const T2 b) {
using std::exp;
return exp(lgamma(a) + lgamma(b) - lgamma(a + b));
return boost::math::beta(a, b, boost_policy_t<>());
}

/**
Expand All @@ -65,8 +63,11 @@ inline return_type_t<T1, T2> beta(const T1 a, const T2 b) {
* @param b Second input
* @return Beta function applied to the two inputs.
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
require_all_not_var_matrix_t<T1, T2>* = nullptr>
template <
typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr,
require_t<math::disjunction<
is_arithmetic<return_type_t<T1, T2>>, is_fvar<return_type_t<T1, T2>>,
is_std_vector<T1>, is_std_vector<T2>>>* = nullptr>
inline auto beta(const T1& a, const T2& b) {
return apply_scalar_binary(
[](const auto& c, const auto& d) { return beta(c, d); }, a, b);
Expand Down
219 changes: 32 additions & 187 deletions stan/math/rev/fun/beta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,197 +34,42 @@ namespace math {
* @param b var Argument
* @return Result of beta function
*/
inline var beta(const var& a, const var& b) {
double digamma_ab = digamma(a.val() + b.val());
double digamma_a = digamma(a.val()) - digamma_ab;
double digamma_b = digamma(b.val()) - digamma_ab;
return make_callback_var(beta(a.val(), b.val()),
[a, b, digamma_a, digamma_b](auto& vi) mutable {
const double adj_val = vi.adj() * vi.val();
a.adj() += adj_val * digamma_a;
b.adj() += adj_val * digamma_b;
});
}

/**
* Returns the beta function and gradient for first var input.
*
\f[
\mathrm{beta}(a,b) = \left(B\left(a,b\right)\right)
\f]
template <typename T1, typename T2,
require_all_not_std_vector_t<T1, T2>* = nullptr,
require_return_type_t<is_var, T1, T2>* = nullptr>
inline auto beta(const T1& a, const T2& b) {
arena_t<ref_type_t<T1>> arena_a = a;
arena_t<ref_type_t<T2>> arena_b = b;

\f[
\frac{\partial }{\partial a} = \left(\psi^{\left(0\right)}\left(a\right)
- \psi^{\left(0\right)}
\left(a + b\right)\right)
* \mathrm{beta}(a,b)
\f]
*
* @param a var Argument
* @param b double Argument
* @return Result of beta function
*/
inline var beta(const var& a, double b) {
auto digamma_ab = digamma(a.val()) - digamma(a.val() + b);
return make_callback_var(beta(a.val(), b), [a, digamma_ab](auto& vi) mutable {
a.adj() += vi.adj() * digamma_ab * vi.val();
});
}
const auto& beta_val = beta(value_of(arena_a), value_of(arena_b));
using return_type_t = return_var_matrix_t<decltype(beta_val), T1, T2>;
arena_t<return_type_t> res(beta_val);

/**
* Returns the beta function and gradient for second var input.
*
\f[
\mathrm{beta}(a,b) = \left(B\left(a,b\right)\right)
\f]
reverse_pass_callback([arena_a, arena_b, res]() mutable {
auto&& a_array = as_array_or_scalar(arena_a);
auto&& b_array = as_array_or_scalar(arena_b);
const auto& res_array = as_array_or_scalar(res);
const auto& digamma_ab = digamma(value_of(a_array) + value_of(b_array));
const auto& adj_val = res_array.adj() * res_array.val();

\f[
\frac{\partial }{\partial b} = \left(\psi^{\left(0\right)}\left(b\right)
- \psi^{\left(0\right)}
\left(a + b\right)\right)
* \mathrm{beta}(a,b)
\f]
*
* @param a double Argument
* @param b var Argument
* @return Result of beta function
*/
inline var beta(double a, const var& b) {
auto beta_val = beta(a, b.val());
auto digamma_ab = (digamma(b.val()) - digamma(a + b.val())) * beta_val;
return make_callback_var(beta_val, [b, digamma_ab](auto& vi) mutable {
b.adj() += vi.adj() * digamma_ab;
if constexpr (!is_constant<T1>::value) {
const auto& a_adj = adj_val * (digamma(a_array.val()) - digamma_ab);
if constexpr (is_stan_scalar<T1>::value) {
a_array.adj() += sum(a_adj);
} else {
a_array.adj() += a_adj;
}
}
if constexpr (!is_constant<T2>::value) {
const auto& b_adj = adj_val * (digamma(b_array.val()) - digamma_ab);
if constexpr (is_stan_scalar<T2>::value) {
b_array.adj() += sum(b_adj);
} else {
b_array.adj() += b_adj;
}
}
});
}

template <typename Mat1, typename Mat2,
require_any_var_matrix_t<Mat1, Mat2>* = nullptr,
require_all_matrix_t<Mat1, Mat2>* = nullptr>
inline auto beta(const Mat1& a, const Mat2& b) {
if (!is_constant<Mat1>::value && !is_constant<Mat2>::value) {
arena_t<promote_scalar_t<var, Mat1>> arena_a = a;
arena_t<promote_scalar_t<var, Mat2>> arena_b = b;
auto beta_val = beta(arena_a.val(), arena_b.val());
auto digamma_ab
= to_arena(digamma(arena_a.val().array() + arena_b.val().array()));
return make_callback_var(
beta(arena_a.val(), arena_b.val()),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
arena_a.adj().array()
+= adj_val * (digamma(arena_a.val().array()) - digamma_ab);
arena_b.adj().array()
+= adj_val * (digamma(arena_b.val().array()) - digamma_ab);
});
} else if (!is_constant<Mat1>::value) {
arena_t<promote_scalar_t<var, Mat1>> arena_a = a;
arena_t<promote_scalar_t<double, Mat2>> arena_b = value_of(b);
auto digamma_ab
= to_arena(digamma(arena_a.val()).array()
- digamma(arena_a.val().array() + arena_b.array()));
return make_callback_var(beta(arena_a.val(), arena_b),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_a.adj().array() += vi.adj().array()
* digamma_ab
* vi.val().array();
});
} else if (!is_constant<Mat2>::value) {
arena_t<promote_scalar_t<double, Mat1>> arena_a = value_of(a);
arena_t<promote_scalar_t<var, Mat2>> arena_b = b;
auto beta_val = beta(arena_a, arena_b.val());
auto digamma_ab
= to_arena((digamma(arena_b.val()).array()
- digamma(arena_a.array() + arena_b.val().array()))
* beta_val.array());
return make_callback_var(
beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_b.adj().array() += vi.adj().array() * digamma_ab.array();
});
}
}

template <typename Scalar, typename VarMat,
require_var_matrix_t<VarMat>* = nullptr,
require_stan_scalar_t<Scalar>* = nullptr>
inline auto beta(const Scalar& a, const VarMat& b) {
if (!is_constant<Scalar>::value && !is_constant<VarMat>::value) {
var arena_a = a;
arena_t<promote_scalar_t<var, VarMat>> arena_b = b;
auto beta_val = beta(arena_a.val(), arena_b.val());
auto digamma_ab = to_arena(digamma(arena_a.val() + arena_b.val().array()));
return make_callback_var(
beta(arena_a.val(), arena_b.val()),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
arena_a.adj()
+= (adj_val * (digamma(arena_a.val()) - digamma_ab)).sum();
arena_b.adj().array()
+= adj_val * (digamma(arena_b.val().array()) - digamma_ab);
});
} else if (!is_constant<Scalar>::value) {
var arena_a = a;
arena_t<promote_scalar_t<double, VarMat>> arena_b = value_of(b);
auto digamma_ab = to_arena(digamma(arena_a.val())
- digamma(arena_a.val() + arena_b.array()));
return make_callback_var(
beta(arena_a.val(), arena_b),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_a.adj()
+= (vi.adj().array() * digamma_ab * vi.val().array()).sum();
});
} else if (!is_constant<VarMat>::value) {
double arena_a = value_of(a);
arena_t<promote_scalar_t<var, VarMat>> arena_b = b;
auto beta_val = beta(arena_a, arena_b.val());
auto digamma_ab = to_arena((digamma(arena_b.val()).array()
- digamma(arena_a + arena_b.val().array()))
* beta_val.array());
return make_callback_var(beta_val, [arena_b, digamma_ab](auto& vi) mutable {
arena_b.adj().array() += vi.adj().array() * digamma_ab.array();
});
}
}

template <typename VarMat, typename Scalar,
require_var_matrix_t<VarMat>* = nullptr,
require_stan_scalar_t<Scalar>* = nullptr>
inline auto beta(const VarMat& a, const Scalar& b) {
if (!is_constant<VarMat>::value && !is_constant<Scalar>::value) {
arena_t<promote_scalar_t<var, VarMat>> arena_a = a;
var arena_b = b;
auto beta_val = beta(arena_a.val(), arena_b.val());
auto digamma_ab = to_arena(digamma(arena_a.val().array() + arena_b.val()));
return make_callback_var(
beta(arena_a.val(), arena_b.val()),
[arena_a, arena_b, digamma_ab](auto& vi) mutable {
const auto adj_val = (vi.adj().array() * vi.val().array()).eval();
arena_a.adj().array()
+= adj_val * (digamma(arena_a.val().array()) - digamma_ab);
arena_b.adj()
+= (adj_val * (digamma(arena_b.val()) - digamma_ab)).sum();
});
} else if (!is_constant<VarMat>::value) {
arena_t<promote_scalar_t<var, VarMat>> arena_a = a;
double arena_b = value_of(b);
auto digamma_ab = to_arena(digamma(arena_a.val()).array()
- digamma(arena_a.val().array() + arena_b));
return make_callback_var(
beta(arena_a.val(), arena_b), [arena_a, digamma_ab](auto& vi) mutable {
arena_a.adj().array()
+= vi.adj().array() * digamma_ab * vi.val().array();
});
} else if (!is_constant<Scalar>::value) {
arena_t<promote_scalar_t<double, VarMat>> arena_a = value_of(a);
var arena_b = b;
auto beta_val = beta(arena_a, arena_b.val());
auto digamma_ab = to_arena(
(digamma(arena_b.val()) - digamma(arena_a.array() + arena_b.val()))
* beta_val.array());
return make_callback_var(
beta_val, [arena_a, arena_b, digamma_ab](auto& vi) mutable {
arena_b.adj() += (vi.adj().array() * digamma_ab.array()).sum();
});
}
return return_type_t(res);
}

} // namespace math
Expand Down
4 changes: 2 additions & 2 deletions test/unit/math/prim/fun/beta_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ TEST(MathFunctions, beta_vec) {
auto f
= [](const auto& x1, const auto& x2) { return stan::math::beta(x1, x2); };

Eigen::VectorXd in1 = Eigen::VectorXd::Random(6);
Eigen::VectorXd in2 = Eigen::VectorXd::Random(6);
Eigen::VectorXd in1 = Eigen::VectorXd::Random(6).cwiseAbs();
Eigen::VectorXd in2 = Eigen::VectorXd::Random(6).cwiseAbs();

stan::test::binary_scalar_tester(f, in1, in2);
}