Skip to content

Commit 87d04dc

Browse files
authored
Merge pull request #3037 from stan-dev/robust-weibull-cdf
Make Weibull cdf & lcdf more robust, handle y = 0 inputs
2 parents b010193 + f627b37 commit 87d04dc

File tree

4 files changed

+105
-32
lines changed

4 files changed

+105
-32
lines changed

stan/math/prim/prob/weibull_cdf.hpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ return_type_t<T_y, T_shape, T_scale> weibull_cdf(const T_y& y,
4343
using T_y_ref = ref_type_if_not_constant_t<T_y>;
4444
using T_alpha_ref = ref_type_if_not_constant_t<T_shape>;
4545
using T_sigma_ref = ref_type_if_not_constant_t<T_scale>;
46-
using std::pow;
4746
static constexpr const char* function = "weibull_cdf";
4847

4948
T_y_ref y_ref = y;
@@ -63,34 +62,36 @@ return_type_t<T_y, T_shape, T_scale> weibull_cdf(const T_y& y,
6362
}
6463

6564
auto ops_partials = make_partials_propagator(y_ref, alpha_ref, sigma_ref);
65+
if (any(value_of_rec(y_val) == 0)) {
66+
return ops_partials.build(0.0);
67+
}
6668

6769
constexpr bool any_derivs = !is_constant_all<T_y, T_shape, T_scale>::value;
68-
const auto& pow_n = to_ref_if<any_derivs>(pow(y_val / sigma_val, alpha_val));
69-
const auto& exp_n = to_ref_if<any_derivs>(exp(-pow_n));
70-
const auto& cdf_n = to_ref_if<any_derivs>(1 - exp_n);
70+
const auto& log_y = to_ref_if<any_derivs>(log(y_val));
71+
const auto& log_sigma = to_ref_if<any_derivs>(log(sigma_val));
72+
const auto& log_y_div_sigma = to_ref_if<any_derivs>(log_y - log_sigma);
73+
const auto& log_pow_n = to_ref_if<any_derivs>(alpha_val * log_y_div_sigma);
74+
const auto& pow_n = to_ref_if<any_derivs>(exp(log_pow_n));
75+
const auto& log_cdf_n = to_ref_if<any_derivs>(log1m_exp(-pow_n));
7176

72-
T_partials_return cdf = prod(cdf_n);
77+
T_partials_return log_cdf = sum(log_cdf_n);
7378

7479
if (any_derivs) {
75-
const auto& rep_deriv = to_ref_if<(!is_constant_all<T_y, T_scale>::value
76-
&& !is_constant_all<T_shape>::value)>(
77-
exp_n * pow_n * cdf / cdf_n);
80+
const auto& log_rep_deriv = to_ref(log_pow_n + log_cdf - log_cdf_n - pow_n);
7881
if (!is_constant_all<T_y, T_scale>::value) {
79-
const auto& deriv_y_sigma = to_ref_if<(
80-
!is_constant_all<T_y>::value && !is_constant_all<T_scale>::value)>(
81-
rep_deriv * alpha_val);
82+
const auto& log_deriv_y_sigma = to_ref(log_rep_deriv + log(alpha_val));
8283
if (!is_constant_all<T_y>::value) {
83-
partials<0>(ops_partials) = deriv_y_sigma / y_val;
84+
partials<0>(ops_partials) = exp(log_deriv_y_sigma - log_y);
8485
}
8586
if (!is_constant_all<T_scale>::value) {
86-
partials<2>(ops_partials) = -deriv_y_sigma / sigma_val;
87+
partials<2>(ops_partials) = -exp(log_deriv_y_sigma - log_sigma);
8788
}
8889
}
8990
if (!is_constant_all<T_shape>::value) {
90-
partials<1>(ops_partials) = rep_deriv * log(y_val / sigma_val);
91+
partials<1>(ops_partials) = exp(log_rep_deriv) * log_y_div_sigma;
9192
}
9293
}
93-
return ops_partials.build(cdf);
94+
return ops_partials.build(exp(log_cdf));
9495
}
9596

9697
} // namespace math

stan/math/prim/prob/weibull_lcdf.hpp

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,9 @@ template <typename T_y, typename T_shape, typename T_scale,
3939
return_type_t<T_y, T_shape, T_scale> weibull_lcdf(const T_y& y,
4040
const T_shape& alpha,
4141
const T_scale& sigma) {
42-
using T_partials_return = partials_return_t<T_y, T_shape, T_scale>;
4342
using T_y_ref = ref_type_if_not_constant_t<T_y>;
4443
using T_alpha_ref = ref_type_if_not_constant_t<T_shape>;
4544
using T_sigma_ref = ref_type_if_not_constant_t<T_scale>;
46-
using std::pow;
4745
static constexpr const char* function = "weibull_lcdf";
4846

4947
T_y_ref y_ref = y;
@@ -63,34 +61,34 @@ return_type_t<T_y, T_shape, T_scale> weibull_lcdf(const T_y& y,
6361
}
6462

6563
auto ops_partials = make_partials_propagator(y_ref, alpha_ref, sigma_ref);
64+
if (any(value_of_rec(y_val) == 0)) {
65+
return ops_partials.build(stan::math::NEGATIVE_INFTY);
66+
}
6667

6768
constexpr bool any_derivs = !is_constant_all<T_y, T_shape, T_scale>::value;
68-
const auto& pow_n = to_ref_if<any_derivs>(pow(y_val / sigma_val, alpha_val));
69-
const auto& exp_n = to_ref_if<any_derivs>(exp(-pow_n));
69+
const auto& log_y = to_ref_if<any_derivs>(log(y_val));
70+
const auto& log_sigma = to_ref_if<any_derivs>(log(sigma_val));
71+
const auto& log_y_div_sigma = to_ref_if<any_derivs>(log_y - log_sigma);
72+
const auto& log_pow_n = to_ref_if<any_derivs>(alpha_val * log_y_div_sigma);
73+
const auto& pow_n = to_ref_if<any_derivs>(exp(log_pow_n));
7074

71-
// TODO(Andrew) Further simplify derivatives and log1m_exp below
72-
T_partials_return cdf_log = sum(log1m(exp_n));
75+
if (any_derivs) {
76+
const auto& log_rep_deriv = to_ref(log_pow_n - log_diff_exp(pow_n, 0.0));
7377

74-
if (!is_constant_all<T_y, T_scale, T_shape>::value) {
75-
const auto& rep_deriv = to_ref_if<(!is_constant_all<T_y, T_scale>::value
76-
&& !is_constant_all<T_shape>::value)>(
77-
pow_n / (1.0 / exp_n - 1.0));
7878
if (!is_constant_all<T_y, T_scale>::value) {
79-
const auto& deriv_y_sigma = to_ref_if<(
80-
!is_constant_all<T_y>::value && !is_constant_all<T_scale>::value)>(
81-
rep_deriv * alpha_val);
79+
const auto& log_deriv_y_sigma = to_ref(log_rep_deriv + log(alpha_val));
8280
if (!is_constant_all<T_y>::value) {
83-
partials<0>(ops_partials) = deriv_y_sigma / y_val;
81+
partials<0>(ops_partials) = exp(log_deriv_y_sigma - log_y);
8482
}
8583
if (!is_constant_all<T_scale>::value) {
86-
partials<2>(ops_partials) = -deriv_y_sigma / sigma_val;
84+
partials<2>(ops_partials) = -exp(log_deriv_y_sigma - log_sigma);
8785
}
8886
}
8987
if (!is_constant_all<T_shape>::value) {
90-
partials<1>(ops_partials) = rep_deriv * log(y_val / sigma_val);
88+
partials<1>(ops_partials) = exp(log_rep_deriv) * log_y_div_sigma;
9189
}
9290
}
93-
return ops_partials.build(cdf_log);
91+
return ops_partials.build(sum(log1m_exp(-pow_n)));
9492
}
9593

9694
} // namespace math
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include <stan/math/mix.hpp>
2+
#include <test/unit/math/test_ad.hpp>
3+
4+
TEST(mathMixScalFun, weibull_cdf) {
5+
// Inputs are tested on the log (i.e., unconstrained) scale so that the
6+
// finite-diffs don't result in invalid inputs.
7+
auto f = [](const auto& y, const auto& alpha, const auto& sigma) {
8+
using stan::math::lb_constrain;
9+
using stan::math::positive_constrain;
10+
11+
return stan::math::weibull_cdf(lb_constrain(y, 0.0),
12+
positive_constrain(alpha),
13+
positive_constrain(sigma));
14+
};
15+
16+
using stan::math::log;
17+
18+
Eigen::VectorXd y(3);
19+
y << stan::math::NEGATIVE_INFTY, 1.2, 0.0; // lb_constrain(y[0], 0.0) = 0.0
20+
21+
Eigen::VectorXd alpha(3);
22+
alpha << 2.0, 3.0, 4.0;
23+
24+
Eigen::VectorXd sigma(3);
25+
sigma << 5.0, 6.0, 7.0;
26+
27+
stan::test::expect_ad(f, y, alpha, sigma);
28+
stan::test::expect_ad(f, y[0], alpha, sigma);
29+
stan::test::expect_ad(f, y, alpha[0], sigma);
30+
stan::test::expect_ad(f, y, alpha, sigma[0]);
31+
32+
stan::test::expect_ad(f, y[0], alpha[0], sigma);
33+
stan::test::expect_ad(f, y[0], alpha, sigma[0]);
34+
stan::test::expect_ad(f, y, alpha[0], sigma[0]);
35+
36+
stan::test::expect_ad(f, y[0], alpha[0], sigma[0]);
37+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include <stan/math/mix.hpp>
2+
#include <test/unit/math/test_ad.hpp>
3+
4+
TEST(mathMixScalFun, weibull_lcdf) {
5+
// Inputs are tested on the log (i.e., unconstrained) scale so that the
6+
// finite-diffs don't result in invalid inputs.
7+
auto f = [](const auto& y, const auto& alpha, const auto& sigma) {
8+
using stan::math::lb_constrain;
9+
using stan::math::positive_constrain;
10+
11+
return stan::math::weibull_lcdf(lb_constrain(y, 0.0),
12+
positive_constrain(alpha),
13+
positive_constrain(sigma));
14+
};
15+
16+
using stan::math::log;
17+
18+
Eigen::VectorXd y(3);
19+
y << stan::math::NEGATIVE_INFTY, 1.2, 0.0; // lb_constrain(y[0], 0.0) = 0.0
20+
21+
Eigen::VectorXd alpha(3);
22+
alpha << 2.0, 3.0, 4.0;
23+
24+
Eigen::VectorXd sigma(3);
25+
sigma << 5.0, 6.0, 7.0;
26+
27+
stan::test::expect_ad(f, y, alpha, sigma);
28+
stan::test::expect_ad(f, y[0], alpha, sigma);
29+
stan::test::expect_ad(f, y, alpha[0], sigma);
30+
stan::test::expect_ad(f, y, alpha, sigma[0]);
31+
32+
stan::test::expect_ad(f, y[0], alpha[0], sigma);
33+
stan::test::expect_ad(f, y[0], alpha, sigma[0]);
34+
stan::test::expect_ad(f, y, alpha[0], sigma[0]);
35+
36+
stan::test::expect_ad(f, y[0], alpha[0], sigma[0]);
37+
}

0 commit comments

Comments
 (0)