7
7
#include < stan/math/prim/fun/log1m.hpp>
8
8
#include < stan/math/prim/fun/Phi.hpp>
9
9
#include < stan/math/prim/fun/square.hpp>
10
- #include < stan/math/prim/fun/log_diff_exp.hpp>
11
- #include < stan/math/prim/fun/log_sum_exp.hpp>
12
10
#include < stan/math/prim/functor/apply_scalar_unary.hpp>
13
11
#include < cmath>
14
12
@@ -27,100 +25,110 @@ const int BIGINT = 2000000000;
27
25
/* *
28
26
* The inverse of the unit normal cumulative distribution function.
29
27
*
30
- * @tparam LogP Whether the input probability is already on the log scale.
31
28
* @param p argument between 0 and 1 inclusive
32
29
* @return Real value of the inverse cdf for the standard normal distribution.
33
30
*/
34
- template <bool LogP = false >
35
- inline double inv_Phi_impl (double p) {
36
- static constexpr double log_a[8 ]
37
- = {1.2199838032983212 , 4.8914137334471356 , 7.5865960847956080 ,
38
- 9.5274618535358388 , 10.734698580862359 , 11.116406781896242 ,
39
- 10.417226196842595 , 7.8276718012189362 };
40
- static constexpr double log_b[8 ] = {0 .,
41
- 3.7451021830139207 ,
42
- 6.5326064640478618 ,
43
- 8.5930788436817044 ,
44
- 9.9624069236663077 ,
45
- 10.579180688621286 ,
46
- 10.265665328832871 ,
47
- 8.5614962136628454 };
48
- static constexpr double log_c[8 ]
49
- = {0.3530744474482423 , 1.5326298343683388 , 1.7525849400614634 ,
50
- 1.2941374937060454 , 0.2393776640901312 , -1.419724057885092 ,
51
- -3.784340465764968 , -7.163234779359426 };
52
- static constexpr double log_d[8 ] = {0.0 ,
53
- 0.71939547349472054982 ,
54
- 0.51663958798453168964 ,
55
- -0.37140093392784434556 ,
56
- -1.9098407084572139869 ,
57
- -4.186547581055928724 ,
58
- -7.5099767712254150709 ,
59
- -20.673761573859248841 };
60
- static constexpr double log_e[8 ]
61
- = {1.8958048169567149 , 1.6981417567726154 , 0.5793212339927351 ,
62
- -1.215503791936417 , -3.629396584023968 , -6.690500273261249 ,
63
- -10.51540298415323 , -15.41979457491781 };
64
- static constexpr double log_f[8 ] = {0 .,
65
- -0.511105318617135 ,
66
- -1.988286302259815 ,
67
- -4.208049039384857 ,
68
- -7.147448611626374 ,
69
- -10.89973190740069 ,
70
- -15.76637472711685 ,
71
- -33.82373901099482 };
72
-
73
- double log_p = LogP ? p : log (p);
74
-
75
- double log_q = log_p <= LOG_HALF ? log_diff_exp (LOG_HALF, log_p)
76
- : log_diff_exp (log_p, LOG_HALF);
77
- int log_q_sign = log_p <= LOG_HALF ? -1 : 1 ;
78
- double log_r = log_q_sign == -1 ? log_p : log1m_exp (log_p);
79
-
80
- if (stan::math::is_inf (log_r)) {
81
- return 0 ;
31
+ inline double inv_Phi_lambda (double p) {
32
+ check_bounded (" inv_Phi" , " Probability variable" , p, 0 , 1 );
33
+
34
+ if (p < 8e-311 ) {
35
+ return NEGATIVE_INFTY;
36
+ }
37
+ if (p == 1 ) {
38
+ return INFTY;
82
39
}
83
40
84
- double log_inner_r;
85
- double log_pre_mult;
86
- const double * num_ptr;
87
- const double * den_ptr;
88
-
89
- static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO;
90
- static constexpr double LOG_16 = LOG_TWO * 4 ;
91
- static constexpr double LOG_425 = 6.0520891689244171729 ;
92
- static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3 ;
93
-
94
- if (log_q <= LOG_425_OVER_1000) {
95
- log_inner_r = log_diff_exp (LOG_425_OVER_1000 * 2 , log_q * 2 );
96
- log_pre_mult = log_q;
97
- num_ptr = &log_a[0 ];
98
- den_ptr = &log_b[0 ];
41
+ static constexpr double a[8 ]
42
+ = {3.3871328727963666080e+00 , 1.3314166789178437745e+02 ,
43
+ 1.9715909503065514427e+03 , 1.3731693765509461125e+04 ,
44
+ 4.5921953931549871457e+04 , 6.7265770927008700853e+04 ,
45
+ 3.3430575583588128105e+04 , 2.5090809287301226727e+03 };
46
+ static constexpr double b[7 ]
47
+ = {4.2313330701600911252e+01 , 6.8718700749205790830e+02 ,
48
+ 5.3941960214247511077e+03 , 2.1213794301586595867e+04 ,
49
+ 3.9307895800092710610e+04 , 2.8729085735721942674e+04 ,
50
+ 5.2264952788528545610e+03 };
51
+ static constexpr double c[8 ]
52
+ = {1.42343711074968357734e+00 , 4.63033784615654529590e+00 ,
53
+ 5.76949722146069140550e+00 , 3.64784832476320460504e+00 ,
54
+ 1.27045825245236838258e+00 , 2.41780725177450611770e-01 ,
55
+ 2.27238449892691845833e-02 , 7.74545014278341407640e-04 };
56
+ static constexpr double d[7 ]
57
+ = {2.05319162663775882187e+00 , 1.67638483018380384940e+00 ,
58
+ 6.89767334985100004550e-01 , 1.48103976427480074590e-01 ,
59
+ 1.51986665636164571966e-02 , 5.47593808499534494600e-04 ,
60
+ 1.05075007164441684324e-09 };
61
+ static constexpr double e[8 ]
62
+ = {6.65790464350110377720e+00 , 5.46378491116411436990e+00 ,
63
+ 1.78482653991729133580e+00 , 2.96560571828504891230e-01 ,
64
+ 2.65321895265761230930e-02 , 1.24266094738807843860e-03 ,
65
+ 2.71155556874348757815e-05 , 2.01033439929228813265e-07 };
66
+ static constexpr double f[7 ]
67
+ = {5.99832206555887937690e-01 , 1.36929880922735805310e-01 ,
68
+ 1.48753612908506148525e-02 , 7.86869131145613259100e-04 ,
69
+ 1.84631831751005468180e-05 , 1.42151175831644588870e-07 ,
70
+ 2.04426310338993978564e-15 };
71
+
72
+ double q = p - 0.5 ;
73
+ double r;
74
+ double val;
75
+
76
+ if (std::fabs (q) <= .425 ) {
77
+ r = .180625 - square (q);
78
+ return q
79
+ * (((((((a[7 ] * r + a[6 ]) * r + a[5 ]) * r + a[4 ]) * r + a[3 ]) * r
80
+ + a[2 ])
81
+ * r
82
+ + a[1 ])
83
+ * r
84
+ + a[0 ])
85
+ / (((((((b[6 ] * r + b[5 ]) * r + b[4 ]) * r + b[3 ]) * r + b[2 ]) * r
86
+ + b[1 ])
87
+ * r
88
+ + b[0 ])
89
+ * r
90
+ + 1.0 );
99
91
} else {
100
- double log_temp_r = log (-log_r) / 2.0 ;
101
- if (log_temp_r <= LOG_FIVE) {
102
- log_inner_r = log_diff_exp (log_temp_r, LOG_16 - LOG_TEN);
103
- num_ptr = &log_c[0 ];
104
- den_ptr = &log_d[0 ];
92
+ r = q < 0 ? p : 1 - p;
93
+
94
+ if (r <= 0 )
95
+ return 0 ;
96
+
97
+ r = std::sqrt (-std::log (r));
98
+
99
+ if (r <= 5.0 ) {
100
+ r += -1.6 ;
101
+ val = (((((((c[7 ] * r + c[6 ]) * r + c[5 ]) * r + c[4 ]) * r + c[3 ]) * r
102
+ + c[2 ])
103
+ * r
104
+ + c[1 ])
105
+ * r
106
+ + c[0 ])
107
+ / (((((((d[6 ] * r + d[5 ]) * r + d[4 ]) * r + d[3 ]) * r + d[2 ]) * r
108
+ + d[1 ])
109
+ * r
110
+ + d[0 ])
111
+ * r
112
+ + 1.0 );
105
113
} else {
106
- log_inner_r = log_diff_exp (log_temp_r, LOG_FIVE);
107
- num_ptr = &log_e[0 ];
108
- den_ptr = &log_f[0 ];
114
+ r -= 5.0 ;
115
+ val = (((((((e[7 ] * r + e[6 ]) * r + e[5 ]) * r + e[4 ]) * r + e[3 ]) * r
116
+ + e[2 ])
117
+ * r
118
+ + e[1 ])
119
+ * r
120
+ + e[0 ])
121
+ / (((((((f[6 ] * r + f[5 ]) * r + f[4 ]) * r + f[3 ]) * r + f[2 ]) * r
122
+ + f[1 ])
123
+ * r
124
+ + f[0 ])
125
+ * r
126
+ + 1.0 );
109
127
}
110
- log_pre_mult = 0.0 ;
128
+ if (q < 0.0 )
129
+ return -val;
111
130
}
112
-
113
- // As computation requires evaluating r^8, this causes a loss of precision,
114
- // even when on the log space. We can mitigate this by scaling the
115
- // exponentiated result (dividing by 10), since the same scaling is applied
116
- // to the numerator and denominator.
117
- Eigen::VectorXd log_r_pow
118
- = Eigen::ArrayXd::LinSpaced (8 , 0 , 7 ) * log_inner_r - LOG_TEN;
119
- Eigen::Map<const Eigen::VectorXd> num_map (num_ptr, 8 );
120
- Eigen::Map<const Eigen::VectorXd> den_map (den_ptr, 8 );
121
- double log_result
122
- = log_sum_exp (log_r_pow + num_map) - log_sum_exp (log_r_pow + den_map);
123
- return log_q_sign * exp (log_pre_mult + log_result);
131
+ return val;
124
132
}
125
133
} // namespace internal
126
134
@@ -137,17 +145,9 @@ inline double inv_Phi_impl(double p) {
137
145
* @return real value of the inverse cdf for the standard normal distribution
138
146
*/
139
147
inline double inv_Phi (double p) {
140
- check_bounded (" inv_Phi" , " Probability variable" , p, 0 , 1 );
141
-
142
- if (p < 8e-311 ) {
143
- return NEGATIVE_INFTY;
144
- }
145
- if (p == 1 ) {
146
- return INFTY;
147
- }
148
- return p >= 0.9999 ? -internal::inv_Phi_impl (
148
+ return p >= 0.9999 ? -internal::inv_Phi_lambda (
149
149
(internal::BIGINT - internal::BIGINT * p) / internal::BIGINT)
150
- : internal::inv_Phi_impl (p);
150
+ : internal::inv_Phi_lambda (p);
151
151
}
152
152
153
153
/* *
0 commit comments