Skip to content

Commit a3a00b9

Browse files
committed
Tests and fixes
1 parent c452cc5 commit a3a00b9

File tree

3 files changed

+66
-17
lines changed

3 files changed

+66
-17
lines changed

ext/IntervalArithmeticForwardDiffExt.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,23 @@ module IntervalArithmeticForwardDiffExt
33
using IntervalArithmetic, ForwardDiff
44
using ForwardDiff: Dual, , value, partials
55

6+
function isconstant_interval(x)
7+
all(isthinzero.(values(partials(x))))
8+
end
9+
610
function Base.:(^)(x::Dual{Txy, <:Interval}, y::Dual{Txy, <:Interval}) where Txy
711
vx, vy = value(x), value(y)
8-
primal = vx^vy
12+
expv = vx^vy
913
powval = vy * vx^(vy - interval(1))
10-
logval = primal * log(vx)
11-
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
12-
return Dual{Txy}(primal, new_partials)
14+
if isconstant_interval(y)
15+
logval = one(expv)
16+
elseif isthinzero(vx) && inf(vy) > 0
17+
logval = zero(vx)
18+
else
19+
logval = expv * log(vx)
20+
end
21+
new_partials = ForwardDiff._mul_partials(partials(x), partials(y), powval, logval)
22+
return Dual{Txy}(expv, new_partials)
1323
end
1424

1525
function Base.:(^)(x::Dual{Tx, <:Interval}, y::Dual{Ty, <:Interval}) where {Tx, Ty}
@@ -22,15 +32,20 @@ end
2232

2333
function Base.:(^)(x::Dual{Tx, <:Interval}, y::Interval) where Tx
2434
v = value(x)
25-
new_partials = partials(x) * y * v^(y - interval(1))
26-
return Dual{Tx}(v^y, new_partials)
35+
expv = v^y
36+
if isthinzero(y) || isconstant_interval(x)
37+
new_partials = zero(partials(x))
38+
else
39+
new_partials = partials(x) * y * v^(y - interval(1))
40+
end
41+
return Dual{Tx}(expv, new_partials)
2742
end
2843

2944
function Base.:(^)(x::Interval, y::Dual{Ty, <:Interval}) where Ty
3045
v = value(y)
31-
primal = x^v
32-
deriv = primal*log(x)
33-
return Dual{Ty}(primal, deriv * partials(y))
46+
expv = x^v
47+
deriv = (isthinzero(x) && inf(v) > 0) ? zero(expv) : expv*log(x)
48+
return Dual{Ty}(expv, deriv * partials(y))
3449
end
3550

3651
end

test/interval_tests/forwarddiff.jl

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616
@test ForwardDiff.derivative(abs, interval(-2, 2)) === interval(-1, 1, trv)
1717

1818
f(x) = abs(x)^interval(2)
19-
@test_broken ForwardDiff.derivative(f, interval(-1, 1)) === interval(-2, 2, trv)
19+
@test ForwardDiff.derivative(f, interval(-1, 1)) === interval(-2, 2, trv)
2020

2121
g(x) = abs(x)^2
2222
@test ForwardDiff.derivative(g, interval(-1, 1) ) === interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)
@@ -56,12 +56,44 @@ end
5656
end
5757

5858
@testset "Power" begin
59-
f(x) = interval(2)^x
60-
f′(x) = log(interval(2)) * f(x)
61-
df(t) = ForwardDiff.derivative(f, t)
59+
fxy(xy) = xy[1]^xy[2]
6260

63-
# g(x) = 2^x # not guaranteed
61+
for x in [0.0, 1.1, 2.2]
62+
for y in [-3.3, 0.0, 4.4]
63+
fx(xx) = xx^y
64+
fxi(xx) = xx^interval(y)
65+
fy(yy) = x^yy
66+
fyi(yy) = interval(x)^yy
6467

65-
@test f′(0) === df(0)
68+
dfdx = ForwardDiff.derivative(fxi, interval(x))
69+
dfdy = ForwardDiff.derivative(fyi, interval(y))
70+
grad = ForwardDiff.gradient(fxy, [interval(x), interval(y)])
71+
72+
@test isguaranteed(dfdx)
73+
@test isguaranteed(dfdy)
74+
@test isguaranteed(grad[1])
75+
@test isguaranteed(grad[2])
76+
77+
if iszero(x) && y < 0
78+
@test decoration(dfdx) == trv
79+
else
80+
@test in_interval(ForwardDiff.derivative(fx, x), dfdx)
81+
end
82+
83+
if iszero(x) && y <= 0
84+
@test decoration(dfdy) == trv
85+
else
86+
@test in_interval(ForwardDiff.derivative(fy, y), dfdy)
87+
end
88+
89+
if iszero(x) && iszero(y)
90+
@test decoration(grad[1]) == trv
91+
@test decoration(dfdx) == com
92+
else
93+
@test isequal_interval(dfdx, grad[1])
94+
end
95+
@test isequal_interval(dfdy, grad[2])
96+
end
97+
end
6698
end
67-
end
99+
end

test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using Test
2+
3+
using ForwardDiff
24
using IntervalArithmetic
35
using InteractiveUtils
46

@@ -22,4 +24,4 @@ for f ∈ readdir("ITF1788_tests"; join = true)
2224
@testset "$f" begin
2325
include(f)
2426
end
25-
end
27+
end

0 commit comments

Comments
 (0)