Skip to content

Commit c452cc5

Browse files
committed
Add extension with custom pow rule
1 parent 98a235b commit c452cc5

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@ RoundingEmulator = "5eaf0fd0-dfba-4ccb-bf02-d820a40db705"
99

1010
[weakdeps]
1111
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
12+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1213
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1314

1415
[extensions]
1516
IntervalArithmeticDiffRulesExt = "DiffRules"
17+
IntervalArithmeticForwardDiffExt = "ForwardDiff"
1618
IntervalArithmeticRecipesBaseExt = "RecipesBase"
1719

1820
[compat]
1921
CRlibm_jll = "1"
2022
DiffRules = "1"
23+
ForwardDiff = "0.10"
2124
RecipesBase = "1"
2225
RoundingEmulator = "0.2"
2326
julia = "1.9"
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
module IntervalArithmeticForwardDiffExt
2+
3+
using IntervalArithmetic, ForwardDiff
4+
using ForwardDiff: Dual, , value, partials
5+
6+
function Base.:(^)(x::Dual{Txy, <:Interval}, y::Dual{Txy, <:Interval}) where Txy
7+
vx, vy = value(x), value(y)
8+
primal = vx^vy
9+
powval = vy * vx^(vy - interval(1))
10+
logval = primal * log(vx)
11+
new_partials = _mul_partials(partials(x), partials(y), powval, logval)
12+
return Dual{Txy}(primal, new_partials)
13+
end
14+
15+
function Base.:(^)(x::Dual{Tx, <:Interval}, y::Dual{Ty, <:Interval}) where {Tx, Ty}
16+
if Ty Tx
17+
return x^value(y)
18+
else
19+
return value(x)^y
20+
end
21+
end
22+
23+
function Base.:(^)(x::Dual{Tx, <:Interval}, y::Interval) where Tx
24+
v = value(x)
25+
new_partials = partials(x) * y * v^(y - interval(1))
26+
return Dual{Tx}(v^y, new_partials)
27+
end
28+
29+
function Base.:(^)(x::Interval, y::Dual{Ty, <:Interval}) where Ty
30+
v = value(y)
31+
primal = x^v
32+
deriv = primal*log(x)
33+
return Dual{Ty}(primal, deriv * partials(y))
34+
end
35+
36+
end

0 commit comments

Comments
 (0)