Skip to content

Commit 9202f1f

Browse files
authored
Merge pull request #3046 from stan-dev/inv_phi-opt
Vectorise inv_Phi calculations and optimise log-scale implementation
2 parents e73651b + cc1cb37 commit 9202f1f

File tree

1 file changed

+40
-41
lines changed

1 file changed

+40
-41
lines changed

stan/math/prim/prob/std_normal_log_qf.hpp

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -71,56 +71,55 @@ inline double std_normal_log_qf(double log_p) {
7171
-15.76637472711685,
7272
-33.82373901099482};
7373

74-
double val;
7574
double log_q = log_p <= LOG_HALF ? log_diff_exp(LOG_HALF, log_p)
7675
: log_diff_exp(log_p, LOG_HALF);
7776
int log_q_sign = log_p <= LOG_HALF ? -1 : 1;
77+
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);
7878

79-
if (log_q <= -0.85566611005772) {
80-
double log_r = log_diff_exp(-1.71133222011544, 2 * log_q);
81-
double log_agg_a = log_sum_exp(log_a[7] + log_r, log_a[6]);
82-
double log_agg_b = log_sum_exp(log_b[7] + log_r, log_b[6]);
83-
84-
for (int i = 0; i < 6; i++) {
85-
log_agg_a = log_sum_exp(log_agg_a + log_r, log_a[5 - i]);
86-
log_agg_b = log_sum_exp(log_agg_b + log_r, log_b[5 - i]);
87-
}
79+
if (stan::math::is_inf(log_r)) {
80+
return 0;
81+
}
8882

89-
return log_q_sign * exp(log_q + log_agg_a - log_agg_b);
83+
double log_inner_r;
84+
double log_pre_mult;
85+
const double* num_ptr;
86+
const double* den_ptr;
87+
88+
static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO;
89+
static constexpr double LOG_16 = LOG_TWO * 4;
90+
static constexpr double LOG_425 = 6.0520891689244171729;
91+
static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3;
92+
93+
if (log_q <= LOG_425_OVER_1000) {
94+
log_inner_r = log_diff_exp(LOG_425_OVER_1000 * 2, log_q * 2);
95+
log_pre_mult = log_q;
96+
num_ptr = &log_a[0];
97+
den_ptr = &log_b[0];
9098
} else {
91-
double log_r = log_q_sign == -1 ? log_p : log1m_exp(log_p);
92-
93-
if (stan::math::is_inf(log_r)) {
94-
return 0;
95-
}
96-
97-
log_r = log(sqrt(-log_r));
98-
99-
if (log_r <= 1.60943791243410) {
100-
log_r = log_diff_exp(log_r, 0.47000362924573);
101-
double log_agg_c = log_sum_exp(log_c[7] + log_r, log_c[6]);
102-
double log_agg_d = log_sum_exp(log_d[7] + log_r, log_d[6]);
103-
104-
for (int i = 0; i < 6; i++) {
105-
log_agg_c = log_sum_exp(log_agg_c + log_r, log_c[5 - i]);
106-
log_agg_d = log_sum_exp(log_agg_d + log_r, log_d[5 - i]);
107-
}
108-
val = exp(log_agg_c - log_agg_d);
99+
double log_temp_r = log(-log_r) / 2.0;
100+
if (log_temp_r <= LOG_FIVE) {
101+
log_inner_r = log_diff_exp(log_temp_r, LOG_16 - LOG_TEN);
102+
num_ptr = &log_c[0];
103+
den_ptr = &log_d[0];
109104
} else {
110-
log_r = log_diff_exp(log_r, 1.60943791243410);
111-
double log_agg_e = log_sum_exp(log_e[7] + log_r, log_e[6]);
112-
double log_agg_f = log_sum_exp(log_f[7] + log_r, log_f[6]);
113-
114-
for (int i = 0; i < 6; i++) {
115-
log_agg_e = log_sum_exp(log_agg_e + log_r, log_e[5 - i]);
116-
log_agg_f = log_sum_exp(log_agg_f + log_r, log_f[5 - i]);
117-
}
118-
val = exp(log_agg_e - log_agg_f);
105+
log_inner_r = log_diff_exp(log_temp_r, LOG_FIVE);
106+
num_ptr = &log_e[0];
107+
den_ptr = &log_f[0];
119108
}
120-
if (log_q_sign == -1)
121-
return -val;
109+
log_pre_mult = 0.0;
122110
}
123-
return val;
111+
112+
// As computation requires evaluating r^8, this causes a loss of precision,
113+
// even when on the log space. We can mitigate this by scaling the
114+
// exponentiated result (dividing by 10), since the same scaling is applied
115+
// to the numerator and denominator.
116+
Eigen::VectorXd log_r_pow
117+
= Eigen::ArrayXd::LinSpaced(8, 0, 7) * log_inner_r - LOG_TEN;
118+
Eigen::Map<const Eigen::VectorXd> num_map(num_ptr, 8);
119+
Eigen::Map<const Eigen::VectorXd> den_map(den_ptr, 8);
120+
double log_result
121+
= log_sum_exp(log_r_pow + num_map) - log_sum_exp(log_r_pow + den_map);
122+
return log_q_sign * exp(log_pre_mult + log_result);
124123
}
125124

126125
/**

0 commit comments

Comments
 (0)