Skip to content

Commit 8416328

Browse files
committed
Add an extension for SparseArrays
1 parent 3fd55b7 commit 8416328

File tree

4 files changed

+16
-4
lines changed

4 files changed

+16
-4
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1616
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1717
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
1818
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
19+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1920

2021
[extensions]
2122
IntervalArithmeticDiffRulesExt = "DiffRules"
2223
IntervalArithmeticForwardDiffExt = "ForwardDiff"
2324
IntervalArithmeticIntervalSetsExt = "IntervalSets"
2425
IntervalArithmeticLinearAlgebraExt = "LinearAlgebra"
2526
IntervalArithmeticRecipesBaseExt = "RecipesBase"
27+
IntervalArithmeticSparseArraysExt = "SparseArrays"
2628

2729
[compat]
2830
CRlibm = "1.0.2"
@@ -35,4 +37,5 @@ OpenBLASConsistentFPCSR_jll = "0.3.29"
3537
Printf = "1.10"
3638
RecipesBase = "1"
3739
RoundingEmulator = "0.2"
40+
SparseArrays = "1.10.0"
3841
julia = "1.10"

ext/IntervalArithmeticForwardDiffExt.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function (constant::Constant)(::Dual{T, Interval{S}}) where {T, S}
9898
return Dual{T}(interval(S, constant.value), interval(S, 0.0))
9999
end
100100

101-
function (piecewise::Piecewise)(dual::Dual{T, <:Interval}) where {T}
101+
function (piecewise::Piecewise)(dual::Dual{T,<:Interval}) where {T}
102102
X = value(dual)
103103
input_domain = Domain(X)
104104
if !overlap_domain(input_domain, piecewise)
@@ -135,5 +135,4 @@ function (piecewise::Piecewise)(dual::Dual{T, <:Interval}) where {T}
135135
return Dual{T}(primal, tuple(partial...))
136136
end
137137

138-
139138
end
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module IntervalArithmeticSparseArraysExt
2+
3+
using IntervalArithmetic
4+
import SparseArrays
5+
6+
SparseArrays._iszero(x::Interval) = isthinzero(x)
7+
8+
SparseArrays._isnotzero(x::Interval) = !isthinzero(x)
9+
10+
end

test/interval_tests/forwarddiff.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ end
1919
@test ForwardDiff.derivative(f, interval(-1, 1)) === interval(-2, 2, trv)
2020

2121
g(x) = abs(x)^2
22-
@test ForwardDiff.derivative(g, interval(-1, 1) ) === interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)
22+
@test ForwardDiff.derivative(g, interval(-1, 1)) === interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)
2323
@test all(ForwardDiff.gradient( v -> g(v[1]), [interval(-1, 1)]) .=== [interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)])
24-
@test_broken all(ForwardDiff.hessian( v -> g(v[1]), [interval( 0 )]) .=== [interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)])
24+
@test all(ForwardDiff.hessian( v -> g(v[1]), [interval( 0 )]) .=== [interval(convert(Interval{Float64}, -2), convert(Interval{Float64}, 2), trv)])
2525
end
2626

2727
@testset "sin" begin

0 commit comments

Comments
 (0)