Skip to content

Commit 5275e5c

Browse files
authored
Merge pull request #629 from Kolaru/fowarddiff_pow
ForwardDiff extension for power
2 parents 98a235b + ff4307b commit 5275e5c

File tree

4 files changed

+97
-9
lines changed

4 files changed

+97
-9
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
name = "IntervalArithmetic"
22
uuid = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253"
33
repo = "https://github.com/JuliaIntervals/IntervalArithmetic.jl.git"
4-
version = "0.22.7"
4+
version = "0.22.8"
55

66
[deps]
77
CRlibm_jll = "4e9b3aee-d8a1-5a3d-ad8b-7d824db253f0"
88
RoundingEmulator = "5eaf0fd0-dfba-4ccb-bf02-d820a40db705"
99

1010
[weakdeps]
1111
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
12+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1213
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1314

1415
[extensions]
1516
IntervalArithmeticDiffRulesExt = "DiffRules"
17+
IntervalArithmeticForwardDiffExt = "ForwardDiff"
1618
IntervalArithmeticRecipesBaseExt = "RecipesBase"
1719

1820
[compat]
1921
CRlibm_jll = "1"
2022
DiffRules = "1"
23+
ForwardDiff = "0.10"
2124
RecipesBase = "1"
2225
RoundingEmulator = "0.2"
2326
julia = "1.9"
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
module IntervalArithmeticForwardDiffExt
2+
3+
using IntervalArithmetic, ForwardDiff
4+
using ForwardDiff: Dual, , value, partials
5+
6+
function isconstant_interval(x)
7+
all(isthinzero.(values(partials(x))))
8+
end
9+
10+
function Base.:(^)(x::Dual{Txy, <:Interval}, y::Dual{Txy, <:Interval}) where Txy
11+
vx, vy = value(x), value(y)
12+
expv = vx^vy
13+
powval = vy * vx^(vy - interval(1))
14+
if isconstant_interval(y)
15+
logval = one(expv)
16+
elseif isthinzero(vx) && inf(vy) > 0
17+
logval = zero(vx)
18+
else
19+
logval = expv * log(vx)
20+
end
21+
new_partials = ForwardDiff._mul_partials(partials(x), partials(y), powval, logval)
22+
return Dual{Txy}(expv, new_partials)
23+
end
24+
25+
function Base.:(^)(x::Dual{Tx, <:Interval}, y::Dual{Ty, <:Interval}) where {Tx, Ty}
26+
if Ty Tx
27+
return x^value(y)
28+
else
29+
return value(x)^y
30+
end
31+
end
32+
33+
function Base.:(^)(x::Dual{Tx, <:Interval}, y::Interval) where Tx
34+
v = value(x)
35+
expv = v^y
36+
if isthinzero(y) || isconstant_interval(x)
37+
new_partials = zero(partials(x))
38+
else
39+
new_partials = partials(x) * y * v^(y - interval(1))
40+
end
41+
return Dual{Tx}(expv, new_partials)
42+
end
43+
44+
function Base.:(^)(x::Interval, y::Dual{Ty, <:Interval}) where Ty
45+
v = value(y)
46+
expv = x^v
47+
deriv = (isthinzero(x) && inf(v) > 0) ? zero(expv) : expv*log(x)
48+
return Dual{Ty}(expv, deriv * partials(y))
49+
end
50+
51+
end

test/interval_tests/forwarddiff.jl

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616
@test ForwardDiff.derivative(abs, interval(-2, 2)) === interval(-1, 1, trv)
1717

1818
f(x) = abs(x)^interval(2)
19-
@test_broken ForwardDiff.derivative(f, interval(-1, 1)) === interval(-2, 2, trv)
19+
@test ForwardDiff.derivative(f, interval(-1, 1)) === interval(-2, 2, trv)
2020

2121
g(x) = abs(x)^2
2222
@test ForwardDiff.derivative(g, interval(-1, 1) ) === interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)
@@ -56,12 +56,44 @@ end
5656
end
5757

5858
@testset "Power" begin
59-
f(x) = interval(2)^x
60-
f′(x) = log(interval(2)) * f(x)
61-
df(t) = ForwardDiff.derivative(f, t)
59+
fxy(xy) = xy[1]^xy[2]
6260

63-
# g(x) = 2^x # not guaranteed
61+
for x in [0.0, 1.1, 2.2]
62+
for y in [-3.3, 0.0, 4.4]
63+
fx(xx) = xx^y
64+
fxi(xx) = xx^interval(y)
65+
fy(yy) = x^yy
66+
fyi(yy) = interval(x)^yy
6467

65-
@test f′(0) === df(0)
68+
dfdx = ForwardDiff.derivative(fxi, interval(x))
69+
dfdy = ForwardDiff.derivative(fyi, interval(y))
70+
grad = ForwardDiff.gradient(fxy, [interval(x), interval(y)])
71+
72+
@test isguaranteed(dfdx)
73+
@test isguaranteed(dfdy)
74+
@test isguaranteed(grad[1])
75+
@test isguaranteed(grad[2])
76+
77+
if iszero(x) && y < 0
78+
@test decoration(dfdx) == trv
79+
else
80+
@test in_interval(ForwardDiff.derivative(fx, x), dfdx)
81+
end
82+
83+
if iszero(x) && y <= 0
84+
@test decoration(dfdy) == trv
85+
else
86+
@test in_interval(ForwardDiff.derivative(fy, y), dfdy)
87+
end
88+
89+
if iszero(x) && iszero(y)
90+
@test decoration(grad[1]) == trv
91+
@test decoration(dfdx) == com
92+
else
93+
@test isequal_interval(dfdx, grad[1])
94+
end
95+
@test isequal_interval(dfdy, grad[2])
96+
end
97+
end
6698
end
67-
end
99+
end

test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using Test
2+
3+
using ForwardDiff
24
using IntervalArithmetic
35
using InteractiveUtils
46

@@ -22,4 +24,4 @@ for f ∈ readdir("ITF1788_tests"; join = true)
2224
@testset "$f" begin
2325
include(f)
2426
end
25-
end
27+
end

0 commit comments

Comments
 (0)