Skip to content

Commit d8804d1

Browse files
committed
Update log-scale handling
1 parent 76ef6a7 commit d8804d1

File tree

2 files changed

+33
-105
lines changed

2 files changed

+33
-105
lines changed

stan/math/prim/fun/inv_Phi.hpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <stan/math/prim/fun/log1m.hpp>
88
#include <stan/math/prim/fun/Phi.hpp>
99
#include <stan/math/prim/fun/square.hpp>
10+
#include <stan/math/prim/fun/log_diff_exp.hpp>
11+
#include <stan/math/prim/fun/log_sum_exp.hpp>
1012
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
1113
#include <cmath>
1214

@@ -25,10 +27,12 @@ const int BIGINT = 2000000000;
2527
/**
2628
* The inverse of the unit normal cumulative distribution function.
2729
*
30+
* @tparam LogP Whether the input probability is already on the log scale.
2831
* @param p argument between 0 and 1 inclusive
2932
* @return Real value of the inverse cdf for the standard normal distribution.
3033
*/
31-
inline double inv_Phi_impl(double p, bool log_p) {
34+
template <bool LogP = false>
35+
inline double inv_Phi_impl(double p) {
3236
static constexpr double log_a[8]
3337
= {1.2199838032983212, 4.8914137334471356, 7.5865960847956080,
3438
9.5274618535358388, 10.734698580862359, 11.116406781896242,
@@ -54,48 +58,57 @@ inline double inv_Phi_impl(double p, bool log_p) {
5458
-7.147448611626374, -10.89973190740069, -15.76637472711685,
5559
-33.82373901099482};
5660

57-
double q = p - 0.5;
58-
double r = q < 0 ? p : 1 - p;
61+
double log_p = LogP ? p : log(p);
5962

60-
if (r <= 0) {
63+
double log_q = log_p <= LOG_HALF ? log_diff_exp(LOG_HALF, log_p)
64+
: log_diff_exp(log_p, LOG_HALF);
65+
int log_q_sign = log_p <= LOG_HALF ? -1 : 1;
66+
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);
67+
68+
if (stan::math::is_inf(log_r)) {
6169
return 0;
6270
}
6371

64-
double inner_r;
65-
double pre_mult;
72+
double log_inner_r;
73+
double log_pre_mult;
6674
const double* num_ptr;
6775
const double* den_ptr;
6876

69-
if (std::fabs(q) <= .425) {
70-
inner_r = .180625 - square(q);
71-
pre_mult = q;
77+
static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO;
78+
static constexpr double LOG_16 = LOG_TWO * 4;
79+
static constexpr double LOG_425 = 6.0520891689244171729;
80+
static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3;
81+
82+
if (log_q <= LOG_425_OVER_1000) {
83+
log_inner_r = log_diff_exp(LOG_425_OVER_1000 * 2, log_q * 2);
84+
log_pre_mult = log_q;
7285
num_ptr = &log_a[0];
7386
den_ptr = &log_b[0];
7487
} else {
75-
double temp_r = std::sqrt(-std::log(r));
76-
if (temp_r <= 5.0) {
77-
inner_r = temp_r - 1.6;
88+
double log_temp_r = log(-log_r) / 2.0;
89+
if (log_temp_r <= LOG_FIVE) {
90+
log_inner_r = log_diff_exp(log_temp_r, LOG_16 - LOG_TEN);
7891
num_ptr = &log_c[0];
7992
den_ptr = &log_d[0];
8093
} else {
81-
inner_r = temp_r - 5.0;
94+
log_inner_r = log_diff_exp(log_temp_r, LOG_FIVE);
8295
num_ptr = &log_e[0];
8396
den_ptr = &log_f[0];
8497
}
85-
pre_mult = q < 0 ? -1 : 1;
98+
log_pre_mult = 0.0;
8699
}
87100

88101
// As computation requires evaluating r^8, this causes a loss of precision,
89102
// even when on the log space. We can mitigate this by scaling the
90103
// exponentiated result (dividing by 10), since the same scaling is applied
91104
// to the numerator and denominator.
92-
Eigen::VectorXd log_r_pow = Eigen::ArrayXd::LinSpaced(8, 0, 7) * log(inner_r)
105+
Eigen::VectorXd log_r_pow = Eigen::ArrayXd::LinSpaced(8, 0, 7) * log_inner_r
93106
- LOG_TEN;
94107
Eigen::Map<const Eigen::VectorXd> num_map(num_ptr, 8);
95108
Eigen::Map<const Eigen::VectorXd> den_map(den_ptr, 8);
96109
double log_result = log_sum_exp(log_r_pow + num_map)
97110
- log_sum_exp(log_r_pow + den_map);
98-
return pre_mult * exp(log_result);
111+
return log_q_sign * exp(log_pre_mult + log_result);
99112
}
100113
} // namespace internal
101114

@@ -121,8 +134,8 @@ inline double inv_Phi(double p) {
121134
return INFTY;
122135
}
123136
return p >= 0.9999 ? -internal::inv_Phi_impl(
124-
(internal::BIGINT - internal::BIGINT * p) / internal::BIGINT, false)
125-
: internal::inv_Phi_impl(p, false);
137+
(internal::BIGINT - internal::BIGINT * p) / internal::BIGINT)
138+
: internal::inv_Phi_impl(p);
126139
}
127140

128141
/**

stan/math/prim/prob/std_normal_log_qf.hpp

Lines changed: 2 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
6+
#include <stan/math/prim/fun/inv_Phi.hpp>
67
#include <stan/math/prim/fun/constants.hpp>
78
#include <stan/math/prim/fun/log1m.hpp>
89
#include <stan/math/prim/fun/log.hpp>
@@ -34,93 +35,7 @@ inline double std_normal_log_qf(double log_p) {
3435
return INFTY;
3536
}
3637

37-
static constexpr double log_a[8]
38-
= {1.2199838032983212, 4.8914137334471356, 7.5865960847956080,
39-
9.5274618535358388, 10.734698580862359, 11.116406781896242,
40-
10.417226196842595, 7.8276718012189362};
41-
static constexpr double log_b[8] = {0.,
42-
3.7451021830139207,
43-
6.5326064640478618,
44-
8.5930788436817044,
45-
9.9624069236663077,
46-
10.579180688621286,
47-
10.265665328832871,
48-
8.5614962136628454};
49-
static constexpr double log_c[8]
50-
= {0.3530744474482423, 1.5326298343683388, 1.7525849400614634,
51-
1.2941374937060454, 0.2393776640901312, -1.419724057885092,
52-
-3.784340465764968, -7.163234779359426};
53-
static constexpr double log_d[8] = {0.,
54-
0.7193954734947205,
55-
0.5166395879845317,
56-
-0.371400933927844,
57-
-1.909840708457214,
58-
-4.186547581055928,
59-
-7.509976771225415,
60-
-20.67376157385924};
61-
static constexpr double log_e[8]
62-
= {1.8958048169567149, 1.6981417567726154, 0.5793212339927351,
63-
-1.215503791936417, -3.629396584023968, -6.690500273261249,
64-
-10.51540298415323, -15.41979457491781};
65-
static constexpr double log_f[8] = {0.,
66-
-0.511105318617135,
67-
-1.988286302259815,
68-
-4.208049039384857,
69-
-7.147448611626374,
70-
-10.89973190740069,
71-
-15.76637472711685,
72-
-33.82373901099482};
73-
74-
double val;
75-
double log_q = log_p <= LOG_HALF ? log_diff_exp(LOG_HALF, log_p)
76-
: log_diff_exp(log_p, LOG_HALF);
77-
int log_q_sign = log_p <= LOG_HALF ? -1 : 1;
78-
79-
if (log_q <= -0.85566611005772) {
80-
double log_r = log_diff_exp(-1.71133222011544, 2 * log_q);
81-
double log_agg_a = log_sum_exp(log_a[7] + log_r, log_a[6]);
82-
double log_agg_b = log_sum_exp(log_b[7] + log_r, log_b[6]);
83-
84-
for (int i = 0; i < 6; i++) {
85-
log_agg_a = log_sum_exp(log_agg_a + log_r, log_a[5 - i]);
86-
log_agg_b = log_sum_exp(log_agg_b + log_r, log_b[5 - i]);
87-
}
88-
89-
return log_q_sign * exp(log_q + log_agg_a - log_agg_b);
90-
} else {
91-
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);
92-
93-
if (stan::math::is_inf(log_r)) {
94-
return 0;
95-
}
96-
97-
log_r = log(sqrt(-log_r));
98-
99-
if (log_r <= 1.60943791243410) {
100-
log_r = log_diff_exp(log_r, 0.47000362924573);
101-
double log_agg_c = log_sum_exp(log_c[7] + log_r, log_c[6]);
102-
double log_agg_d = log_sum_exp(log_d[7] + log_r, log_d[6]);
103-
104-
for (int i = 0; i < 6; i++) {
105-
log_agg_c = log_sum_exp(log_agg_c + log_r, log_c[5 - i]);
106-
log_agg_d = log_sum_exp(log_agg_d + log_r, log_d[5 - i]);
107-
}
108-
val = exp(log_agg_c - log_agg_d);
109-
} else {
110-
log_r = log_diff_exp(log_r, 1.60943791243410);
111-
double log_agg_e = log_sum_exp(log_e[7] + log_r, log_e[6]);
112-
double log_agg_f = log_sum_exp(log_f[7] + log_r, log_f[6]);
113-
114-
for (int i = 0; i < 6; i++) {
115-
log_agg_e = log_sum_exp(log_agg_e + log_r, log_e[5 - i]);
116-
log_agg_f = log_sum_exp(log_agg_f + log_r, log_f[5 - i]);
117-
}
118-
val = exp(log_agg_e - log_agg_f);
119-
}
120-
if (log_q_sign == -1)
121-
return -val;
122-
}
123-
return val;
38+
return internal::inv_Phi_impl<true>(log_p);
12439
}
12540

12641
/**

0 commit comments

Comments
 (0)