Skip to content

Commit cc1cb37

Browse files
committed
Only apply changes to log inv_Phi
1 parent 7accce6 commit cc1cb37

File tree

2 files changed

+183
-99
lines changed

2 files changed

+183
-99
lines changed

stan/math/prim/fun/inv_Phi.hpp

Lines changed: 97 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
#include <stan/math/prim/fun/log1m.hpp>
88
#include <stan/math/prim/fun/Phi.hpp>
99
#include <stan/math/prim/fun/square.hpp>
10-
#include <stan/math/prim/fun/log_diff_exp.hpp>
11-
#include <stan/math/prim/fun/log_sum_exp.hpp>
1210
#include <stan/math/prim/functor/apply_scalar_unary.hpp>
1311
#include <cmath>
1412

@@ -27,100 +25,110 @@ const int BIGINT = 2000000000;
2725
/**
2826
* The inverse of the unit normal cumulative distribution function.
2927
*
30-
* @tparam LogP Whether the input probability is already on the log scale.
3128
* @param p argument between 0 and 1 inclusive
3229
* @return Real value of the inverse cdf for the standard normal distribution.
3330
*/
34-
template <bool LogP = false>
35-
inline double inv_Phi_impl(double p) {
36-
static constexpr double log_a[8]
37-
= {1.2199838032983212, 4.8914137334471356, 7.5865960847956080,
38-
9.5274618535358388, 10.734698580862359, 11.116406781896242,
39-
10.417226196842595, 7.8276718012189362};
40-
static constexpr double log_b[8] = {0.,
41-
3.7451021830139207,
42-
6.5326064640478618,
43-
8.5930788436817044,
44-
9.9624069236663077,
45-
10.579180688621286,
46-
10.265665328832871,
47-
8.5614962136628454};
48-
static constexpr double log_c[8]
49-
= {0.3530744474482423, 1.5326298343683388, 1.7525849400614634,
50-
1.2941374937060454, 0.2393776640901312, -1.419724057885092,
51-
-3.784340465764968, -7.163234779359426};
52-
static constexpr double log_d[8] = {0.0,
53-
0.71939547349472054982,
54-
0.51663958798453168964,
55-
-0.37140093392784434556,
56-
-1.9098407084572139869,
57-
-4.186547581055928724,
58-
-7.5099767712254150709,
59-
-20.673761573859248841};
60-
static constexpr double log_e[8]
61-
= {1.8958048169567149, 1.6981417567726154, 0.5793212339927351,
62-
-1.215503791936417, -3.629396584023968, -6.690500273261249,
63-
-10.51540298415323, -15.41979457491781};
64-
static constexpr double log_f[8] = {0.,
65-
-0.511105318617135,
66-
-1.988286302259815,
67-
-4.208049039384857,
68-
-7.147448611626374,
69-
-10.89973190740069,
70-
-15.76637472711685,
71-
-33.82373901099482};
72-
73-
double log_p = LogP ? p : log(p);
74-
75-
double log_q = log_p <= LOG_HALF ? log_diff_exp(LOG_HALF, log_p)
76-
: log_diff_exp(log_p, LOG_HALF);
77-
int log_q_sign = log_p <= LOG_HALF ? -1 : 1;
78-
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);
79-
80-
if (stan::math::is_inf(log_r)) {
81-
return 0;
31+
inline double inv_Phi_lambda(double p) {
32+
check_bounded("inv_Phi", "Probability variable", p, 0, 1);
33+
34+
if (p < 8e-311) {
35+
return NEGATIVE_INFTY;
36+
}
37+
if (p == 1) {
38+
return INFTY;
8239
}
8340

84-
double log_inner_r;
85-
double log_pre_mult;
86-
const double* num_ptr;
87-
const double* den_ptr;
88-
89-
static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO;
90-
static constexpr double LOG_16 = LOG_TWO * 4;
91-
static constexpr double LOG_425 = 6.0520891689244171729;
92-
static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3;
93-
94-
if (log_q <= LOG_425_OVER_1000) {
95-
log_inner_r = log_diff_exp(LOG_425_OVER_1000 * 2, log_q * 2);
96-
log_pre_mult = log_q;
97-
num_ptr = &log_a[0];
98-
den_ptr = &log_b[0];
41+
static constexpr double a[8]
42+
= {3.3871328727963666080e+00, 1.3314166789178437745e+02,
43+
1.9715909503065514427e+03, 1.3731693765509461125e+04,
44+
4.5921953931549871457e+04, 6.7265770927008700853e+04,
45+
3.3430575583588128105e+04, 2.5090809287301226727e+03};
46+
static constexpr double b[7]
47+
= {4.2313330701600911252e+01, 6.8718700749205790830e+02,
48+
5.3941960214247511077e+03, 2.1213794301586595867e+04,
49+
3.9307895800092710610e+04, 2.8729085735721942674e+04,
50+
5.2264952788528545610e+03};
51+
static constexpr double c[8]
52+
= {1.42343711074968357734e+00, 4.63033784615654529590e+00,
53+
5.76949722146069140550e+00, 3.64784832476320460504e+00,
54+
1.27045825245236838258e+00, 2.41780725177450611770e-01,
55+
2.27238449892691845833e-02, 7.74545014278341407640e-04};
56+
static constexpr double d[7]
57+
= {2.05319162663775882187e+00, 1.67638483018380384940e+00,
58+
6.89767334985100004550e-01, 1.48103976427480074590e-01,
59+
1.51986665636164571966e-02, 5.47593808499534494600e-04,
60+
1.05075007164441684324e-09};
61+
static constexpr double e[8]
62+
= {6.65790464350110377720e+00, 5.46378491116411436990e+00,
63+
1.78482653991729133580e+00, 2.96560571828504891230e-01,
64+
2.65321895265761230930e-02, 1.24266094738807843860e-03,
65+
2.71155556874348757815e-05, 2.01033439929228813265e-07};
66+
static constexpr double f[7]
67+
= {5.99832206555887937690e-01, 1.36929880922735805310e-01,
68+
1.48753612908506148525e-02, 7.86869131145613259100e-04,
69+
1.84631831751005468180e-05, 1.42151175831644588870e-07,
70+
2.04426310338993978564e-15};
71+
72+
double q = p - 0.5;
73+
double r;
74+
double val;
75+
76+
if (std::fabs(q) <= .425) {
77+
r = .180625 - square(q);
78+
return q
79+
* (((((((a[7] * r + a[6]) * r + a[5]) * r + a[4]) * r + a[3]) * r
80+
+ a[2])
81+
* r
82+
+ a[1])
83+
* r
84+
+ a[0])
85+
/ (((((((b[6] * r + b[5]) * r + b[4]) * r + b[3]) * r + b[2]) * r
86+
+ b[1])
87+
* r
88+
+ b[0])
89+
* r
90+
+ 1.0);
9991
} else {
100-
double log_temp_r = log(-log_r) / 2.0;
101-
if (log_temp_r <= LOG_FIVE) {
102-
log_inner_r = log_diff_exp(log_temp_r, LOG_16 - LOG_TEN);
103-
num_ptr = &log_c[0];
104-
den_ptr = &log_d[0];
92+
r = q < 0 ? p : 1 - p;
93+
94+
if (r <= 0)
95+
return 0;
96+
97+
r = std::sqrt(-std::log(r));
98+
99+
if (r <= 5.0) {
100+
r += -1.6;
101+
val = (((((((c[7] * r + c[6]) * r + c[5]) * r + c[4]) * r + c[3]) * r
102+
+ c[2])
103+
* r
104+
+ c[1])
105+
* r
106+
+ c[0])
107+
/ (((((((d[6] * r + d[5]) * r + d[4]) * r + d[3]) * r + d[2]) * r
108+
+ d[1])
109+
* r
110+
+ d[0])
111+
* r
112+
+ 1.0);
105113
} else {
106-
log_inner_r = log_diff_exp(log_temp_r, LOG_FIVE);
107-
num_ptr = &log_e[0];
108-
den_ptr = &log_f[0];
114+
r -= 5.0;
115+
val = (((((((e[7] * r + e[6]) * r + e[5]) * r + e[4]) * r + e[3]) * r
116+
+ e[2])
117+
* r
118+
+ e[1])
119+
* r
120+
+ e[0])
121+
/ (((((((f[6] * r + f[5]) * r + f[4]) * r + f[3]) * r + f[2]) * r
122+
+ f[1])
123+
* r
124+
+ f[0])
125+
* r
126+
+ 1.0);
109127
}
110-
log_pre_mult = 0.0;
128+
if (q < 0.0)
129+
return -val;
111130
}
112-
113-
// As computation requires evaluating r^8, this causes a loss of precision,
114-
// even when on the log space. We can mitigate this by scaling the
115-
// exponentiated result (dividing by 10), since the same scaling is applied
116-
// to the numerator and denominator.
117-
Eigen::VectorXd log_r_pow
118-
= Eigen::ArrayXd::LinSpaced(8, 0, 7) * log_inner_r - LOG_TEN;
119-
Eigen::Map<const Eigen::VectorXd> num_map(num_ptr, 8);
120-
Eigen::Map<const Eigen::VectorXd> den_map(den_ptr, 8);
121-
double log_result
122-
= log_sum_exp(log_r_pow + num_map) - log_sum_exp(log_r_pow + den_map);
123-
return log_q_sign * exp(log_pre_mult + log_result);
131+
return val;
124132
}
125133
} // namespace internal
126134

@@ -137,17 +145,9 @@ inline double inv_Phi_impl(double p) {
137145
* @return real value of the inverse cdf for the standard normal distribution
138146
*/
139147
inline double inv_Phi(double p) {
140-
check_bounded("inv_Phi", "Probability variable", p, 0, 1);
141-
142-
if (p < 8e-311) {
143-
return NEGATIVE_INFTY;
144-
}
145-
if (p == 1) {
146-
return INFTY;
147-
}
148-
return p >= 0.9999 ? -internal::inv_Phi_impl(
148+
return p >= 0.9999 ? -internal::inv_Phi_lambda(
149149
(internal::BIGINT - internal::BIGINT * p) / internal::BIGINT)
150-
: internal::inv_Phi_impl(p);
150+
: internal::inv_Phi_lambda(p);
151151
}
152152

153153
/**

stan/math/prim/prob/std_normal_log_qf.hpp

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include <stan/math/prim/meta.hpp>
55
#include <stan/math/prim/err.hpp>
6-
#include <stan/math/prim/fun/inv_Phi.hpp>
76
#include <stan/math/prim/fun/constants.hpp>
87
#include <stan/math/prim/fun/log1m.hpp>
98
#include <stan/math/prim/fun/log.hpp>
@@ -35,7 +34,92 @@ inline double std_normal_log_qf(double log_p) {
3534
return INFTY;
3635
}
3736

38-
return internal::inv_Phi_impl<true>(log_p);
37+
static constexpr double log_a[8]
38+
= {1.2199838032983212, 4.8914137334471356, 7.5865960847956080,
39+
9.5274618535358388, 10.734698580862359, 11.116406781896242,
40+
10.417226196842595, 7.8276718012189362};
41+
static constexpr double log_b[8] = {0.,
42+
3.7451021830139207,
43+
6.5326064640478618,
44+
8.5930788436817044,
45+
9.9624069236663077,
46+
10.579180688621286,
47+
10.265665328832871,
48+
8.5614962136628454};
49+
static constexpr double log_c[8]
50+
= {0.3530744474482423, 1.5326298343683388, 1.7525849400614634,
51+
1.2941374937060454, 0.2393776640901312, -1.419724057885092,
52+
-3.784340465764968, -7.163234779359426};
53+
static constexpr double log_d[8] = {0.,
54+
0.7193954734947205,
55+
0.5166395879845317,
56+
-0.371400933927844,
57+
-1.909840708457214,
58+
-4.186547581055928,
59+
-7.509976771225415,
60+
-20.67376157385924};
61+
static constexpr double log_e[8]
62+
= {1.8958048169567149, 1.6981417567726154, 0.5793212339927351,
63+
-1.215503791936417, -3.629396584023968, -6.690500273261249,
64+
-10.51540298415323, -15.41979457491781};
65+
static constexpr double log_f[8] = {0.,
66+
-0.511105318617135,
67+
-1.988286302259815,
68+
-4.208049039384857,
69+
-7.147448611626374,
70+
-10.89973190740069,
71+
-15.76637472711685,
72+
-33.82373901099482};
73+
74+
double log_q = log_p <= LOG_HALF ? log_diff_exp(LOG_HALF, log_p)
75+
: log_diff_exp(log_p, LOG_HALF);
76+
int log_q_sign = log_p <= LOG_HALF ? -1 : 1;
77+
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);
78+
79+
if (stan::math::is_inf(log_r)) {
80+
return 0;
81+
}
82+
83+
double log_inner_r;
84+
double log_pre_mult;
85+
const double* num_ptr;
86+
const double* den_ptr;
87+
88+
static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO;
89+
static constexpr double LOG_16 = LOG_TWO * 4;
90+
static constexpr double LOG_425 = 6.0520891689244171729;
91+
static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3;
92+
93+
if (log_q <= LOG_425_OVER_1000) {
94+
log_inner_r = log_diff_exp(LOG_425_OVER_1000 * 2, log_q * 2);
95+
log_pre_mult = log_q;
96+
num_ptr = &log_a[0];
97+
den_ptr = &log_b[0];
98+
} else {
99+
double log_temp_r = log(-log_r) / 2.0;
100+
if (log_temp_r <= LOG_FIVE) {
101+
log_inner_r = log_diff_exp(log_temp_r, LOG_16 - LOG_TEN);
102+
num_ptr = &log_c[0];
103+
den_ptr = &log_d[0];
104+
} else {
105+
log_inner_r = log_diff_exp(log_temp_r, LOG_FIVE);
106+
num_ptr = &log_e[0];
107+
den_ptr = &log_f[0];
108+
}
109+
log_pre_mult = 0.0;
110+
}
111+
112+
// As computation requires evaluating r^8, this causes a loss of precision,
113+
// even when on the log space. We can mitigate this by scaling the
114+
// exponentiated result (dividing by 10), since the same scaling is applied
115+
// to the numerator and denominator.
116+
Eigen::VectorXd log_r_pow
117+
= Eigen::ArrayXd::LinSpaced(8, 0, 7) * log_inner_r - LOG_TEN;
118+
Eigen::Map<const Eigen::VectorXd> num_map(num_ptr, 8);
119+
Eigen::Map<const Eigen::VectorXd> den_map(den_ptr, 8);
120+
double log_result
121+
= log_sum_exp(log_r_pow + num_map) - log_sum_exp(log_r_pow + den_map);
122+
return log_q_sign * exp(log_pre_mult + log_result);
39123
}
40124

41125
/**

0 commit comments

Comments
 (0)