@@ -71,56 +71,55 @@ inline double std_normal_log_qf(double log_p) {
71
71
-15.76637472711685 ,
72
72
-33.82373901099482 };
73
73
74
- double val;
75
74
double log_q = log_p <= LOG_HALF ? log_diff_exp (LOG_HALF, log_p)
76
75
: log_diff_exp (log_p, LOG_HALF);
77
76
int log_q_sign = log_p <= LOG_HALF ? -1 : 1 ;
77
+ double log_r = log_q_sign == -1 ? log_p : log1m_exp (log_p);
78
78
79
- if (log_q <= -0.85566611005772 ) {
80
- double log_r = log_diff_exp (-1.71133222011544 , 2 * log_q);
81
- double log_agg_a = log_sum_exp (log_a[7 ] + log_r, log_a[6 ]);
82
- double log_agg_b = log_sum_exp (log_b[7 ] + log_r, log_b[6 ]);
83
-
84
- for (int i = 0 ; i < 6 ; i++) {
85
- log_agg_a = log_sum_exp (log_agg_a + log_r, log_a[5 - i]);
86
- log_agg_b = log_sum_exp (log_agg_b + log_r, log_b[5 - i]);
87
- }
79
+ if (stan::math::is_inf (log_r)) {
80
+ return 0 ;
81
+ }
88
82
89
- return log_q_sign * exp (log_q + log_agg_a - log_agg_b);
83
+ double log_inner_r;
84
+ double log_pre_mult;
85
+ const double * num_ptr;
86
+ const double * den_ptr;
87
+
88
+ static constexpr double LOG_FIVE = LOG_TEN - LOG_TWO;
89
+ static constexpr double LOG_16 = LOG_TWO * 4 ;
90
+ static constexpr double LOG_425 = 6.0520891689244171729 ;
91
+ static constexpr double LOG_425_OVER_1000 = LOG_425 - LOG_TEN * 3 ;
92
+
93
+ if (log_q <= LOG_425_OVER_1000) {
94
+ log_inner_r = log_diff_exp (LOG_425_OVER_1000 * 2 , log_q * 2 );
95
+ log_pre_mult = log_q;
96
+ num_ptr = &log_a[0 ];
97
+ den_ptr = &log_b[0 ];
90
98
} else {
91
- double log_r = log_q_sign == -1 ? log_p : log1m_exp (log_p);
92
-
93
- if (stan::math::is_inf (log_r)) {
94
- return 0 ;
95
- }
96
-
97
- log_r = log (sqrt (-log_r));
98
-
99
- if (log_r <= 1.60943791243410 ) {
100
- log_r = log_diff_exp (log_r, 0.47000362924573 );
101
- double log_agg_c = log_sum_exp (log_c[7 ] + log_r, log_c[6 ]);
102
- double log_agg_d = log_sum_exp (log_d[7 ] + log_r, log_d[6 ]);
103
-
104
- for (int i = 0 ; i < 6 ; i++) {
105
- log_agg_c = log_sum_exp (log_agg_c + log_r, log_c[5 - i]);
106
- log_agg_d = log_sum_exp (log_agg_d + log_r, log_d[5 - i]);
107
- }
108
- val = exp (log_agg_c - log_agg_d);
99
+ double log_temp_r = log (-log_r) / 2.0 ;
100
+ if (log_temp_r <= LOG_FIVE) {
101
+ log_inner_r = log_diff_exp (log_temp_r, LOG_16 - LOG_TEN);
102
+ num_ptr = &log_c[0 ];
103
+ den_ptr = &log_d[0 ];
109
104
} else {
110
- log_r = log_diff_exp (log_r, 1.60943791243410 );
111
- double log_agg_e = log_sum_exp (log_e[7 ] + log_r, log_e[6 ]);
112
- double log_agg_f = log_sum_exp (log_f[7 ] + log_r, log_f[6 ]);
113
-
114
- for (int i = 0 ; i < 6 ; i++) {
115
- log_agg_e = log_sum_exp (log_agg_e + log_r, log_e[5 - i]);
116
- log_agg_f = log_sum_exp (log_agg_f + log_r, log_f[5 - i]);
117
- }
118
- val = exp (log_agg_e - log_agg_f);
105
+ log_inner_r = log_diff_exp (log_temp_r, LOG_FIVE);
106
+ num_ptr = &log_e[0 ];
107
+ den_ptr = &log_f[0 ];
119
108
}
120
- if (log_q_sign == -1 )
121
- return -val;
109
+ log_pre_mult = 0.0 ;
122
110
}
123
- return val;
111
+
112
+ // As computation requires evaluating r^8, this causes a loss of precision,
113
+ // even when on the log space. We can mitigate this by scaling the
114
+ // exponentiated result (dividing by 10), since the same scaling is applied
115
+ // to the numerator and denominator.
116
+ Eigen::VectorXd log_r_pow
117
+ = Eigen::ArrayXd::LinSpaced (8 , 0 , 7 ) * log_inner_r - LOG_TEN;
118
+ Eigen::Map<const Eigen::VectorXd> num_map (num_ptr, 8 );
119
+ Eigen::Map<const Eigen::VectorXd> den_map (den_ptr, 8 );
120
+ double log_result
121
+ = log_sum_exp (log_r_pow + num_map) - log_sum_exp (log_r_pow + den_map);
122
+ return log_q_sign * exp (log_pre_mult + log_result);
124
123
}
125
124
126
125
/* *
0 commit comments