Skip to content

Commit f1c5e6c

Browse files
committed
Improve numerical stability of fwd gradients for std_normal_log_qf
1 parent 3a196d4 commit f1c5e6c

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

stan/math/fwd/prob/std_normal_log_qf.hpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,9 @@ namespace math {
1616
template <typename T>
1717
inline fvar<T> std_normal_log_qf(const fvar<T>& p) {
1818
const T xv = std_normal_log_qf(p.val_);
19-
int p_sign = 1;
20-
auto p_d = p.d_;
21-
if (p.d_ < 0) {
22-
p_sign = -1;
23-
p_d *= -1;
24-
}
2519
return fvar<T>(
2620
xv,
27-
p_sign * exp(p.val_ + log(p_d) - NEG_LOG_SQRT_TWO_PI + 0.5 * square(xv)));
21+
p.d_ * exp(p.val_ - std_normal_lpdf(xv)));
2822
}
2923
} // namespace math
3024
} // namespace stan

test/unit/math/mix/prob/std_normal_log_qf_test.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <test/unit/math/test_ad.hpp>
55
#include <stan/math/fwd/prob/std_normal_log_qf.hpp>
66

7+
78
TEST_F(AgradRev, mathMixLogFun_stdNormalLogQf) {
89
auto f = [](const auto& x1) { return stan::math::std_normal_log_qf(x1); };
910
stan::test::expect_ad(f, -100.25);
@@ -46,3 +47,14 @@ TEST_F(AgradRev, mathMixMatFunLog_stdNormalLogQfVarmat) {
4647
}
4748
expect_ad_vector_matvar(f, A);
4849
}
50+
51+
TEST_F(AgradRev, GradientStabilityStdNormalLogQf) {
52+
auto f = [](const auto& y) {
53+
return stan::math::sum(stan::math::std_normal_log_qf(y));
54+
};
55+
56+
Eigen::VectorXd y1(2);
57+
y1 << -10, -2;
58+
59+
stan::test::expect_ad(f, y1);
60+
}

0 commit comments

Comments
 (0)