7
7
#include < stan/math/prim/fun/log1m.hpp>
8
8
#include < stan/math/prim/fun/Phi.hpp>
9
9
#include < stan/math/prim/fun/square.hpp>
10
+ #include < stan/math/prim/fun/log_diff_exp.hpp>
11
+ #include < stan/math/prim/fun/log_sum_exp.hpp>
10
12
#include < stan/math/prim/functor/apply_scalar_unary.hpp>
11
13
#include < cmath>
12
14
@@ -25,10 +27,12 @@ const int BIGINT = 2000000000;
25
27
/* *
26
28
* The inverse of the unit normal cumulative distribution function.
27
29
*
30
+ * @tparam LogP Whether the input probability is already on the log scale.
28
31
* @param p argument between 0 and 1 inclusive
29
32
* @return Real value of the inverse cdf for the standard normal distribution.
30
33
*/
31
- inline double inv_Phi_impl (double p, bool log_p) {
34
+ template <bool LogP = false >
35
+ inline double inv_Phi_impl (double p) {
32
36
static constexpr double log_a[8 ]
33
37
= {1.2199838032983212 , 4.8914137334471356 , 7.5865960847956080 ,
34
38
9.5274618535358388 , 10.734698580862359 , 11.116406781896242 ,
@@ -54,48 +58,57 @@ inline double inv_Phi_impl(double p, bool log_p) {
54
58
-7.147448611626374 , -10.89973190740069 , -15.76637472711685 ,
55
59
-33.82373901099482 };
56
60
57
- double q = p - 0.5 ;
58
- double r = q < 0 ? p : 1 - p;
61
+ double log_p = LogP ? p : log (p);
59
62
60
- if (r <= 0 ) {
63
+ double log_q = log_p <= LOG_HALF ? log_diff_exp (LOG_HALF, log_p)
64
+ : log_diff_exp (log_p, LOG_HALF);
65
+ int log_q_sign = log_p <= LOG_HALF ? -1 : 1 ;
66
+ double log_r = log_q_sign == -1 ? log_p : log1m_exp (log_p);
67
+
68
+ if (stan::math::is_inf (log_r)) {
61
69
return 0 ;
62
70
}
63
71
64
- double inner_r ;
65
- double pre_mult ;
72
+ double log_inner_r ;
73
+ double log_pre_mult ;
66
74
const double * num_ptr;
67
75
const double * den_ptr;
68
76
69
- if (std::fabs (q) <= .425 ) {
70
- inner_r = .180625 - square (q);
71
- pre_mult = q;
77
+ static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO;
78
+ static constexpr double LOG_16 = LOG_TWO * 4 ;
79
+ static constexpr double LOG_425 = 6.0520891689244171729 ;
80
+ static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3 ;
81
+
82
+ if (log_q <= LOG_425_OVER_1000) {
83
+ log_inner_r = log_diff_exp (LOG_425_OVER_1000 * 2 , log_q * 2 );
84
+ log_pre_mult = log_q;
72
85
num_ptr = &log_a[0 ];
73
86
den_ptr = &log_b[0 ];
74
87
} else {
75
- double temp_r = std::sqrt (- std:: log (r)) ;
76
- if (temp_r <= 5.0 ) {
77
- inner_r = temp_r - 1.6 ;
88
+ double log_temp_r = log (-log_r) / 2.0 ;
89
+ if (log_temp_r <= LOG_FIVE ) {
90
+ log_inner_r = log_diff_exp (log_temp_r, LOG_16 - LOG_TEN) ;
78
91
num_ptr = &log_c[0 ];
79
92
den_ptr = &log_d[0 ];
80
93
} else {
81
- inner_r = temp_r - 5.0 ;
94
+ log_inner_r = log_diff_exp (log_temp_r, LOG_FIVE) ;
82
95
num_ptr = &log_e[0 ];
83
96
den_ptr = &log_f[0 ];
84
97
}
85
- pre_mult = q < 0 ? - 1 : 1 ;
98
+ log_pre_mult = 0.0 ;
86
99
}
87
100
88
101
// As computation requires evaluating r^8, this causes a loss of precision,
89
102
// even when on the log space. We can mitigate this by scaling the
90
103
// exponentiated result (dividing by 10), since the same scaling is applied
91
104
// to the numerator and denominator.
92
- Eigen::VectorXd log_r_pow = Eigen::ArrayXd::LinSpaced (8 , 0 , 7 ) * log (inner_r)
105
+ Eigen::VectorXd log_r_pow = Eigen::ArrayXd::LinSpaced (8 , 0 , 7 ) * log_inner_r
93
106
- LOG_TEN;
94
107
Eigen::Map<const Eigen::VectorXd> num_map (num_ptr, 8 );
95
108
Eigen::Map<const Eigen::VectorXd> den_map (den_ptr, 8 );
96
109
double log_result = log_sum_exp (log_r_pow + num_map)
97
110
- log_sum_exp (log_r_pow + den_map);
98
- return pre_mult * exp (log_result);
111
+ return log_q_sign * exp (log_pre_mult + log_result);
99
112
}
100
113
} // namespace internal
101
114
@@ -121,8 +134,8 @@ inline double inv_Phi(double p) {
121
134
return INFTY;
122
135
}
123
136
return p >= 0.9999 ? -internal::inv_Phi_impl (
124
- (internal::BIGINT - internal::BIGINT * p) / internal::BIGINT, false )
125
- : internal::inv_Phi_impl (p, false );
137
+ (internal::BIGINT - internal::BIGINT * p) / internal::BIGINT)
138
+ : internal::inv_Phi_impl (p);
126
139
}
127
140
128
141
/* *
0 commit comments