Skip to content

Commit c54a880

Browse files
committed
Add hessian test for ForwardDiff.jl
1 parent 3b1aff6 commit c54a880

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

ext/IntervalArithmeticForwardDiffExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ ForwardDiff.Dual{T,V}(x::ExactReal) where {T,V} = convert(Dual{T,V}, x)
1414

1515
Base.convert(::Type{Dual{T,V,N}}, x::ExactReal) where {T,V,N} = Dual{T}(V(x), zero(Partials{N,V}))
1616

17-
Base.promote_rule(::Type{Dual{T, V, N}}, ::Type{Interval{S}}) where {T, V, N, S<:Union{AbstractFloat, Rational}} =
17+
Base.promote_rule(::Type{Dual{T, V, N}}, ::Type{Interval{S}}) where {T, V, N, S<:IntervalArithmetic.NumTypes} =
1818
Dual{T,Interval{IntervalArithmetic.promote_numtype(V, S)},N}
19-
Base.promote_rule(::Type{Interval{S}}, ::Type{Dual{T, V, N}}) where {S<:Union{AbstractFloat, Rational}, T, V, N} =
19+
Base.promote_rule(::Type{Interval{S}}, ::Type{Dual{T, V, N}}) where {S<:IntervalArithmetic.NumTypes, T, V, N} =
2020
Dual{T,Interval{IntervalArithmetic.promote_numtype(V, S)},N}
2121
Base.promote_rule(::Type{ExactReal{S}}, ::Type{Dual{T, V, N}}) where {S<:Real, T, V, N} =
2222
Dual{T,ExactReal{IntervalArithmetic.promote_numtype(V, S)},N}
@@ -98,7 +98,7 @@ function (constant::Constant)(::Dual{T, Interval{S}}) where {T, S}
9898
return Dual{T}(interval(S, constant.value), interval(S, 0.0))
9999
end
100100

101-
function (piecewise::Piecewise)(dual::Dual{T, <:Interval}) where {T}
101+
function (piecewise::Piecewise)(dual::Dual{T,<:Interval}) where {T}
102102
X = value(dual)
103103
input_domain = Domain(X)
104104
if !overlap_domain(input_domain, piecewise)
@@ -135,6 +135,8 @@ function (piecewise::Piecewise)(dual::Dual{T, <:Interval}) where {T}
135135
return Dual{T}(primal, tuple(partial...))
136136
end
137137

138+
#
139+
138140
ForwardDiff.DiffRules._abs_deriv(x::Dual{T,<:Interval}) where {T} =
139141
Dual{T}(ForwardDiff.DiffRules._abs_deriv(value(x)), zero(partials(x)))
140142

test/interval_tests/forwarddiff.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ end
2222
@test ForwardDiff.derivative(g, interval(-1, 1) ) === interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)
2323
@test all(ForwardDiff.gradient( v -> g(v[1]), [interval(-1, 1)]) .=== [interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)])
2424
@test all(ForwardDiff.hessian( v -> g(v[1]), [interval( 0 )]) .=== [interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)])
25+
@test all(ForwardDiff.hessian( v -> g(v[1]), [interval(-1, 1)]) .=== [interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)])
2526
end
2627

2728
@testset "sin" begin

0 commit comments

Comments
 (0)