Skip to content

Commit 815a6c9

Browse files
committed
First pass vectorise
1 parent 35d6d53 commit 815a6c9

File tree

1 file changed

+35
-58
lines changed

1 file changed

+35
-58
lines changed

stan/math/prim/fun/inv_Phi.hpp

Lines changed: 35 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ inline double inv_Phi_lambda(double p) {
4343
1.9715909503065514427e+03, 1.3731693765509461125e+04,
4444
4.5921953931549871457e+04, 6.7265770927008700853e+04,
4545
3.3430575583588128105e+04, 2.5090809287301226727e+03};
46-
static constexpr double b[7]
47-
= {4.2313330701600911252e+01, 6.8718700749205790830e+02,
46+
static constexpr double b[8]
47+
= {1.0, 4.2313330701600911252e+01, 6.8718700749205790830e+02,
4848
5.3941960214247511077e+03, 2.1213794301586595867e+04,
4949
3.9307895800092710610e+04, 2.8729085735721942674e+04,
5050
5.2264952788528545610e+03};
@@ -53,8 +53,8 @@ inline double inv_Phi_lambda(double p) {
5353
5.76949722146069140550e+00, 3.64784832476320460504e+00,
5454
1.27045825245236838258e+00, 2.41780725177450611770e-01,
5555
2.27238449892691845833e-02, 7.74545014278341407640e-04};
56-
static constexpr double d[7]
57-
= {2.05319162663775882187e+00, 1.67638483018380384940e+00,
56+
static constexpr double d[8]
57+
= {1.0, 2.05319162663775882187e+00, 1.67638483018380384940e+00,
5858
6.89767334985100004550e-01, 1.48103976427480074590e-01,
5959
1.51986665636164571966e-02, 5.47593808499534494600e-04,
6060
1.05075007164441684324e-09};
@@ -63,72 +63,49 @@ inline double inv_Phi_lambda(double p) {
6363
1.78482653991729133580e+00, 2.96560571828504891230e-01,
6464
2.65321895265761230930e-02, 1.24266094738807843860e-03,
6565
2.71155556874348757815e-05, 2.01033439929228813265e-07};
66-
static constexpr double f[7]
67-
= {5.99832206555887937690e-01, 1.36929880922735805310e-01,
66+
static constexpr double f[8]
67+
= {1.0, 5.99832206555887937690e-01, 1.36929880922735805310e-01,
6868
1.48753612908506148525e-02, 7.86869131145613259100e-04,
6969
1.84631831751005468180e-05, 1.42151175831644588870e-07,
7070
2.04426310338993978564e-15};
7171

7272
double q = p - 0.5;
73-
double r;
73+
double r = q < 0 ? p : 1 - p;
74+
75+
if (r <= 0) {
76+
return 0;
77+
}
78+
7479
double val;
80+
double inner_r;
81+
double pre_mult;
82+
using Vector8d = Eigen::Matrix<double, 8, 1>;
83+
Eigen::Map<const Vector8d> numerator_map(NULL);
84+
Eigen::Map<const Vector8d> denonimator_map(NULL);
7585

7686
if (std::fabs(q) <= .425) {
77-
r = .180625 - square(q);
78-
return q
79-
* (((((((a[7] * r + a[6]) * r + a[5]) * r + a[4]) * r + a[3]) * r
80-
+ a[2])
81-
* r
82-
+ a[1])
83-
* r
84-
+ a[0])
85-
/ (((((((b[6] * r + b[5]) * r + b[4]) * r + b[3]) * r + b[2]) * r
86-
+ b[1])
87-
* r
88-
+ b[0])
89-
* r
90-
+ 1.0);
87+
inner_r = .180625 - square(q);
88+
pre_mult = q;
89+
90+
new (&numerator_map) Eigen::Map<const Vector8d>(a, 8);
91+
new (&denonimator_map) Eigen::Map<const Vector8d>(b, 8);
9192
} else {
92-
r = q < 0 ? p : 1 - p;
93-
94-
if (r <= 0)
95-
return 0;
96-
97-
r = std::sqrt(-std::log(r));
98-
99-
if (r <= 5.0) {
100-
r += -1.6;
101-
val = (((((((c[7] * r + c[6]) * r + c[5]) * r + c[4]) * r + c[3]) * r
102-
+ c[2])
103-
* r
104-
+ c[1])
105-
* r
106-
+ c[0])
107-
/ (((((((d[6] * r + d[5]) * r + d[4]) * r + d[3]) * r + d[2]) * r
108-
+ d[1])
109-
* r
110-
+ d[0])
111-
* r
112-
+ 1.0);
93+
94+
double temp_r = std::sqrt(-std::log(r));
95+
if (temp_r <= 5.0) {
96+
inner_r = temp_r - 1.6;
97+
new (&numerator_map) Eigen::Map<const Vector8d>(c, 8);
98+
new (&denonimator_map) Eigen::Map<const Vector8d>(d, 8);
11399
} else {
114-
r -= 5.0;
115-
val = (((((((e[7] * r + e[6]) * r + e[5]) * r + e[4]) * r + e[3]) * r
116-
+ e[2])
117-
* r
118-
+ e[1])
119-
* r
120-
+ e[0])
121-
/ (((((((f[6] * r + f[5]) * r + f[4]) * r + f[3]) * r + f[2]) * r
122-
+ f[1])
123-
* r
124-
+ f[0])
125-
* r
126-
+ 1.0);
100+
inner_r = temp_r - 5.0;
101+
new (&numerator_map) Eigen::Map<const Vector8d>(e, 8);
102+
new (&denonimator_map) Eigen::Map<const Vector8d>(f, 8);
127103
}
128-
if (q < 0.0)
129-
return -val;
104+
pre_mult = q < 0 ? -1 : 1;
130105
}
131-
return val;
106+
107+
Eigen::VectorXd r_pow = pow(inner_r, Eigen::ArrayXd::LinSpaced(8, 0, 7)) / 10.0;
108+
return pre_mult * (numerator_map.dot(r_pow) * 10.0) / (denonimator_map.dot(r_pow) * 10.0);
132109
}
133110
} // namespace internal
134111

0 commit comments

Comments
 (0)