Skip to content

Commit e8ab249

Browse files
authored
Merge pull request #3033 from nhuurre/sqrt-0-gradient
sqrt(x) gradient when x=0
2 parents 33b9a82 + 7272c82 commit e8ab249

File tree

4 files changed

+28
-3
lines changed

4 files changed

+28
-3
lines changed

stan/math/fwd/fun/sqrt.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ namespace math {
1616
template <typename T>
1717
inline fvar<T> sqrt(const fvar<T>& x) {
1818
using std::sqrt;
19+
if (value_of_rec(x.val_) == 0.0) {
20+
return fvar<T>(sqrt(x.val_), 0.0 * x.d_);
21+
}
1922
return fvar<T>(sqrt(x.val_), 0.5 * x.d_ * inv_sqrt(x.val_));
2023
}
2124

stan/math/rev/fun/sqrt.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ namespace math {
4343
*/
4444
inline var sqrt(const var& a) {
4545
return make_callback_var(std::sqrt(a.val()), [a](auto& vi) mutable {
46-
a.adj() += vi.adj() / (2.0 * vi.val());
46+
if (vi.val() != 0.0) {
47+
a.adj() += vi.adj() / (2.0 * vi.val());
48+
}
4749
});
4850
}
4951

@@ -58,7 +60,9 @@ template <typename T, require_var_matrix_t<T>* = nullptr>
5860
inline auto sqrt(const T& a) {
5961
return make_callback_var(
6062
a.val().array().sqrt().matrix(), [a](auto& vi) mutable {
61-
a.adj().array() += vi.adj().array() / (2.0 * vi.val_op().array());
63+
a.adj().array()
64+
+= (vi.val_op().array() == 0.0)
65+
.select(0.0, vi.adj().array() / (2.0 * vi.val_op().array()));
6266
});
6367
}
6468

test/unit/math/mix/fun/distance_test.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
TEST(MathMixMatFun, distance) {
44
auto f
55
= [](const auto& x, const auto& y) { return stan::math::distance(x, y); };
6+
stan::test::ad_tolerances tols;
7+
tols.hessian_hessian_ = 2.0;
8+
tols.hessian_fvar_hessian_ = 2.0;
9+
tols.grad_hessian_hessian_ = 2.0;
10+
tols.grad_hessian_grad_hessian_ = 2.0;
611

712
// 0 x 0
813
Eigen::VectorXd x0(0);
@@ -13,6 +18,7 @@ TEST(MathMixMatFun, distance) {
1318
// 1 x 1
1419
Eigen::VectorXd x1(1);
1520
x1 << 1;
21+
stan::test::expect_ad(tols, f, x1, x1);
1622
Eigen::VectorXd y1(1);
1723
y1 << -2.3;
1824
stan::test::expect_ad(f, x1, y1);
@@ -21,6 +27,7 @@ TEST(MathMixMatFun, distance) {
2127
// 2 x 2
2228
Eigen::VectorXd x2(2);
2329
x2 << 2, -3;
30+
stan::test::expect_ad(tols, f, x2, x2);
2431
Eigen::VectorXd y2(2);
2532
y2 << -2.3, 1.1;
2633
stan::test::expect_ad(f, x2, y2);
@@ -29,6 +36,7 @@ TEST(MathMixMatFun, distance) {
2936
// 3 x 3
3037
Eigen::VectorXd x(3);
3138
x << 1, 3, -5;
39+
stan::test::expect_ad(tols, f, x, x);
3240
Eigen::VectorXd y(3);
3341
y << 4, -2, -1;
3442
stan::test::expect_ad(f, x, y);

test/unit/math/mix/fun/sqrt_test.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,21 @@ TEST(mathMixMatFun, sqrt) {
66
using stan::math::sqrt;
77
return sqrt(x1);
88
};
9+
auto ff = [](const auto& x1) {
10+
using stan::math::sqrt;
11+
return sqrt((x1 >= 0.0) ? x1 : -x1);
12+
};
913
stan::test::expect_common_nonzero_unary_vectorized<
1014
stan::test::ScalarSupport::Real>(f);
1115
stan::test::expect_unary_vectorized(f, -6, -5.2, 1.3, 7, 10.7, 36, 1e6);
1216

13-
// undefined with 0 in denominator
17+
stan::test::ad_tolerances tols;
18+
tols.hessian_hessian_ = 2.0;
19+
tols.hessian_fvar_hessian_ = 2.0;
20+
tols.grad_hessian_hessian_ = 2.0;
21+
tols.grad_hessian_grad_hessian_ = 2.0;
22+
stan::test::expect_ad(tols, ff, 0.0);
23+
1424
stan::test::expect_ad(f, std::complex<double>(0.9, 0.8));
1525
for (double im : std::vector<double>{-1.3, 2.3}) {
1626
for (double re : std::vector<double>{-3.6, -0.0, 0.0, 0.5}) {

0 commit comments

Comments
 (0)