Skip to content

Commit d51289a

Browse files
committed
Merge commit '1f94ed312376f726feb820bea90ed8df27974c17' into HEAD
2 parents d8804d1 + 1f94ed3 commit d51289a

File tree

8 files changed

+177
-0
lines changed

8 files changed

+177
-0
lines changed

stan/math/fwd/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
#include <stan/math/fwd/fun/gamma_p.hpp>
4242
#include <stan/math/fwd/fun/gamma_q.hpp>
4343
#include <stan/math/fwd/fun/grad_inc_beta.hpp>
44+
#include <stan/math/fwd/fun/hypergeometric_1F0.hpp>
4445
#include <stan/math/fwd/fun/hypergeometric_2F1.hpp>
4546
#include <stan/math/fwd/fun/hypergeometric_pFq.hpp>
4647
#include <stan/math/fwd/fun/hypot.hpp>
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#ifndef STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP
2+
#define STAN_MATH_FWD_FUN_HYPERGEOMETRIC_1F0_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
6+
#include <stan/math/fwd/core.hpp>
7+
8+
namespace stan {
9+
namespace math {
10+
11+
/**
12+
* Returns the Hypergeometric 1F0 function applied to the
13+
* input arguments:
14+
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
15+
*
16+
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
17+
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
18+
*
19+
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
20+
* a\left(1-z\right)^{-a-1} \f$
21+
*
22+
* @tparam Ta Fvar or arithmetic type of 'a' argument
23+
* @tparam Tz Fvar or arithmetic type of 'z' argument
24+
* @param[in] a Scalar 'a' argument
25+
* @param[in] z Scalar z argument
26+
* @return Hypergeometric 1F0 function
27+
*/
28+
template <typename Ta, typename Tz, typename FvarT = return_type_t<Ta, Tz>,
29+
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
30+
require_any_fvar_t<Ta, Tz>* = nullptr>
31+
FvarT hypergeometric_1f0(const Ta& a, const Tz& z) {
32+
partials_type_t<Ta> a_val = value_of(a);
33+
partials_type_t<Tz> z_val = value_of(z);
34+
FvarT rtn = FvarT(hypergeometric_1f0(a_val, z_val), 0.0);
35+
if (!is_constant_all<Ta>::value) {
36+
rtn.d_ += forward_as<FvarT>(a).d() * -rtn.val() * log1m(z_val);
37+
}
38+
if (!is_constant_all<Tz>::value) {
39+
rtn.d_ += forward_as<FvarT>(z).d() * rtn.val() * a_val * inv(1 - z_val);
40+
}
41+
return rtn;
42+
}
43+
44+
} // namespace math
45+
} // namespace stan
46+
#endif

stan/math/prim/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@
134134
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
135135
#include <stan/math/prim/fun/grad_reg_lower_inc_gamma.hpp>
136136
#include <stan/math/prim/fun/head.hpp>
137+
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
137138
#include <stan/math/prim/fun/hypergeometric_2F1.hpp>
138139
#include <stan/math/prim/fun/hypergeometric_2F2.hpp>
139140
#include <stan/math/prim/fun/hypergeometric_3F2.hpp>
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#ifndef STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP
2+
#define STAN_MATH_PRIM_FUN_HYPERGEOMETRIC_1F0_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/err/check_less.hpp>
6+
#include <stan/math/prim/fun/boost_policy.hpp>
7+
#include <boost/math/special_functions/hypergeometric_1F0.hpp>
8+
#include <cmath>
9+
10+
namespace stan {
11+
namespace math {
12+
13+
/**
14+
* Returns the Hypergeometric 1F0 function applied to the
15+
* input arguments:
16+
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
17+
*
18+
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
19+
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
20+
*
21+
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
22+
* a\left(1-z\right)^{-a-1} \f$
23+
*
24+
* @tparam Ta Arithmetic type of 'a' argument
25+
* @tparam Tz Arithmetic type of 'z' argument
26+
* @param[in] a Scalar 'a' argument
27+
* @param[in] z Scalar z argument
28+
* @return Hypergeometric 1F0 function
29+
*/
30+
template <typename Ta, typename Tz, require_all_arithmetic_t<Ta, Tz>* = nullptr>
31+
return_type_t<Ta, Tz> hypergeometric_1f0(const Ta& a, const Tz& z) {
32+
constexpr const char* function = "hypergeometric_1f0";
33+
check_less("hypergeometric_1f0", "abs(z)", std::fabs(z), 1.0);
34+
35+
return boost::math::hypergeometric_1F0(a, z, boost_policy_t<>());
36+
}
37+
38+
} // namespace math
39+
} // namespace stan
40+
#endif

stan/math/rev/fun.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
#include <stan/math/rev/fun/gp_periodic_cov.hpp>
7777
#include <stan/math/rev/fun/grad.hpp>
7878
#include <stan/math/rev/fun/grad_inc_beta.hpp>
79+
#include <stan/math/rev/fun/hypergeometric_1F0.hpp>
7980
#include <stan/math/rev/fun/hypergeometric_2F1.hpp>
8081
#include <stan/math/rev/fun/hypergeometric_pFq.hpp>
8182
#include <stan/math/rev/fun/hypot.hpp>
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#ifndef STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP
2+
#define STAN_MATH_REV_FUN_HYPERGEOMETRIC_1F0_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
#include <stan/math/prim/fun/hypergeometric_1F0.hpp>
6+
#include <stan/math/prim/fun/value_of.hpp>
7+
#include <stan/math/prim/fun/log1m.hpp>
8+
#include <stan/math/prim/fun/inv.hpp>
9+
#include <stan/math/rev/core.hpp>
10+
11+
namespace stan {
12+
namespace math {
13+
14+
/**
15+
* Returns the Hypergeometric 1F0 function applied to the
16+
* input arguments:
17+
* \f$ _1F_0(a;;z) = \sum_{k=1}^{\infty}\frac{\left(a\right)_kz^k}{k!}\f$
18+
*
19+
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial a} =
20+
* -\left(1-z\right)^{-a}\log\left(1 - z\right) \f$
21+
*
22+
* \f$ \frac{\partial _1F_0\left(a;;z\right)}{\partial z} =
23+
* a\left(1-z\right)^{-a-1} \f$
24+
*
25+
* @tparam Ta Var or arithmetic type of 'a' argument
26+
* @tparam Tz Var or arithmetic type of 'z' argument
27+
* @param[in] a Scalar 'a' argument
28+
* @param[in] z Scalar z argument
29+
* @return Hypergeometric 1F0 function
30+
*/
31+
template <typename Ta, typename Tz,
32+
require_all_stan_scalar_t<Ta, Tz>* = nullptr,
33+
require_any_var_t<Ta, Tz>* = nullptr>
34+
var hypergeometric_1f0(const Ta& a, const Tz& z) {
35+
double a_val = value_of(a);
36+
double z_val = value_of(z);
37+
double rtn = hypergeometric_1f0(a_val, z_val);
38+
return make_callback_var(rtn, [rtn, a, z, a_val, z_val](auto& vi) mutable {
39+
if (!is_constant_all<Ta>::value) {
40+
forward_as<var>(a).adj() += vi.adj() * -rtn * log1m(z_val);
41+
}
42+
if (!is_constant_all<Tz>::value) {
43+
forward_as<var>(z).adj() += vi.adj() * rtn * a_val * inv(1 - z_val);
44+
}
45+
});
46+
}
47+
48+
} // namespace math
49+
} // namespace stan
50+
#endif
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#include <test/unit/math/test_ad.hpp>
2+
3+
TEST(mathMixScalFun, hypergeometric_1f0) {
4+
auto f = [](const auto& x1, const auto& x2) {
5+
using stan::math::hypergeometric_1f0;
6+
return hypergeometric_1f0(x1, x2);
7+
};
8+
9+
stan::test::expect_ad(f, 5, 0.3);
10+
stan::test::expect_ad(f, 3.4, 0.9);
11+
stan::test::expect_ad(f, 3.4, 0.1);
12+
stan::test::expect_ad(f, 5, -0.7);
13+
stan::test::expect_ad(f, 7, -0.1);
14+
stan::test::expect_ad(f, 2.8, 0.8);
15+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <stan/math/prim.hpp>
2+
#include <gtest/gtest.h>
3+
#include <cmath>
4+
#include <limits>
5+
6+
TEST(MathFunctions, hypergeometric_1f0Double) {
7+
using stan::math::hypergeometric_1f0;
8+
using stan::math::inv;
9+
10+
EXPECT_FLOAT_EQ(4.62962962963, hypergeometric_1f0(3, 0.4));
11+
EXPECT_FLOAT_EQ(0.510204081633, hypergeometric_1f0(2, -0.4));
12+
EXPECT_FLOAT_EQ(300.906354890, hypergeometric_1f0(16.0, 0.3));
13+
EXPECT_FLOAT_EQ(0.531441, hypergeometric_1f0(-6.0, 0.1));
14+
}
15+
16+
TEST(MathFunctions, hypergeometric_1f0_throw) {
17+
using stan::math::hypergeometric_1f0;
18+
19+
EXPECT_THROW(hypergeometric_1f0(2.1, 1.0), std::domain_error);
20+
EXPECT_THROW(hypergeometric_1f0(0.5, 1.5), std::domain_error);
21+
EXPECT_THROW(hypergeometric_1f0(0.5, -1.0), std::domain_error);
22+
EXPECT_THROW(hypergeometric_1f0(0.5, -1.5), std::domain_error);
23+
}

0 commit comments

Comments
 (0)