Skip to content

Commit 6a317e0

Browse files
authored
Simpler tests that work with SparseArrays (#114)
* Simpler tests that work with SparseArrays * Add tests for sparse vector in addition to sparse matrix * Document sparse array warning * Skip sparse direct in 1.6
1 parent e725c2f commit 6a317e0

File tree

3 files changed

+75
-64
lines changed

3 files changed

+75
-64
lines changed

docs/src/faq.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ You can specify it with the `conditions_backend` keyword argument when construct
2323

2424
Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size.
2525

26-
If the output is a small array (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
26+
If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
2727

2828
### Scalars
2929

@@ -34,7 +34,10 @@ Or better yet, wrap it in a static vector: `SVector(val)`.
3434
### Sparse arrays
3535

3636
!!! danger "Danger"
37-
Sparse arrays are not supported and might give incorrect values or `NaN`s!
37+
Sparse arrays are not officially supported and might give incorrect values or `NaN`s!
38+
39+
With ForwardDiff.jl, differentiation of sparse arrays will always give wrong results due to [sparsity pattern cancellation](https://github.com/JuliaDiff/ForwardDiff.jl/issues/658).
40+
With Zygote.jl it appears to work, but this functionality is considered experimental and might evolve.
3841

3942
## Number of inputs and outputs
4043

@@ -45,7 +48,7 @@ Well, it depends whether you want their derivatives or not.
4548
| | Derivatives needed | Derivatives not needed |
4649
| -------------------- | --------------------------------------- | --------------------------------------- |
4750
| **Multiple inputs** | Make `x` a `ComponentVector` | Supply `args` and `kwargs` to `forward` |
48-
| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct |
51+
| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct |
4952

5053
We now detail each of these options.
5154

test/errors.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ end
5959
end
6060

6161
@testset "Weird ChainRulesTestUtils behavior" begin
62-
x = rand(2, 3)
63-
forward(x) = sqrt.(abs.(x)), 2
64-
conditions(x, y, z) = abs.(y) .^ z .- abs.(x)
62+
x = rand(3)
63+
forward(x) = sqrt.(abs.(x)), 1
64+
conditions(x, y, z) = abs.(y ./ z) .- abs.(x)
6565
implicit = ImplicitFunction(forward, conditions)
6666
y, z = implicit(x)
6767
dy = similar(y)

test/systematic.jl

Lines changed: 66 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,41 +27,38 @@ Random.seed!(63);
2727
## Utils
2828

2929
change_shape(x::AbstractArray{T,3}) where {T} = x[:, :, 1]
30+
change_shape(x::AbstractSparseArray) = x
3031

3132
function mysqrt(x::AbstractArray)
32-
return identity_break_autodiff(sqrt.(abs.(change_shape(x))))
33-
end
34-
35-
function mypower(x::AbstractArray, p)
36-
return identity_break_autodiff(abs.(change_shape(x)) .^ p)
33+
return identity_break_autodiff(sqrt.(abs.(x)))
3734
end
3835

3936
## Various signatures
4037

4138
function make_implicit_sqrt(; kwargs...)
42-
forward(x) = mysqrt(x)
43-
conditions(x, y) = y .^ 2 .- abs.(change_shape(x))
39+
forward(x) = mysqrt(change_shape(x))
40+
conditions(x, y) = abs2.(y) .- abs.(change_shape(x))
4441
implicit = ImplicitFunction(forward, conditions; kwargs...)
4542
return implicit
4643
end
4744

4845
function make_implicit_sqrt_byproduct(; kwargs...)
49-
forward(x) = mysqrt(x), 2
50-
conditions(x, y, z::Integer) = y .^ z .- abs.(change_shape(x))
46+
forward(x) = 1 * mysqrt(change_shape(x)), 1
47+
conditions(x, y, z::Integer) = abs2.(y ./ z) .- abs.(change_shape(x))
5148
implicit = ImplicitFunction(forward, conditions; kwargs...)
5249
return implicit
5350
end
5451

55-
function make_implicit_power_args(; kwargs...)
56-
forward(x, p::Integer) = mypower(x, one(eltype(x)) / p)
57-
conditions(x, y, p::Integer) = y .^ p .- abs.(change_shape(x))
52+
function make_implicit_sqrt_args(; kwargs...)
53+
forward(x, p::Integer) = p * mysqrt(change_shape(x))
54+
conditions(x, y, p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x))
5855
implicit = ImplicitFunction(forward, conditions; kwargs...)
5956
return implicit
6057
end
6158

62-
function make_implicit_power_kwargs(; kwargs...)
63-
forward(x; p::Integer) = mypower(x, one(eltype(x)) / p)
64-
conditions(x, y; p::Integer) = y .^ p .- abs.(change_shape(x))
59+
function make_implicit_sqrt_kwargs(; kwargs...)
60+
forward(x; p::Integer) = p .* mysqrt(change_shape(x))
61+
conditions(x, y; p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x))
6562
implicit = ImplicitFunction(forward, conditions; kwargs...)
6663
return implicit
6764
end
@@ -85,21 +82,21 @@ end
8582
function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T}
8683
imf1 = make_implicit_sqrt(; kwargs...)
8784
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
88-
imf3 = make_implicit_power_args(; kwargs...)
89-
imf4 = make_implicit_power_kwargs(; kwargs...)
85+
imf3 = make_implicit_sqrt_args(; kwargs...)
86+
imf4 = make_implicit_sqrt_kwargs(; kwargs...)
9087

91-
y_true = mysqrt(x)
88+
y_true = mysqrt(change_shape(x))
9289
y1 = @inferred imf1(x)
9390
y2, z2 = @inferred imf2(x)
94-
y3 = @inferred imf3(x, 2)
95-
y4 = @inferred imf4(x; p=2)
91+
y3 = @inferred imf3(x, 1)
92+
y4 = @inferred imf4(x; p=1)
9693

9794
@testset "Exact value" begin
9895
@test y1 y_true
9996
@test y2 y_true
10097
@test y3 y_true
10198
@test y4 y_true
102-
@test z2 2
99+
@test z2 1
103100
end
104101

105102
@testset "Array type" begin
@@ -112,38 +109,38 @@ function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T}
112109
@testset "JET" begin
113110
@test_opt target_modules = (ID,) imf1(x)
114111
@test_opt target_modules = (ID,) imf2(x)
115-
@test_opt target_modules = (ID,) imf3(x, 2)
116-
@test_opt target_modules = (ID,) imf4(x; p=2)
112+
@test_opt target_modules = (ID,) imf3(x, 1)
113+
@test_opt target_modules = (ID,) imf4(x; p=1)
117114

118115
@test_call target_modules = (ID,) imf1(x)
119116
@test_call target_modules = (ID,) imf2(x)
120-
@test_call target_modules = (ID,) imf3(x, 2)
121-
@test_call target_modules = (ID,) imf4(x; p=2)
117+
@test_call target_modules = (ID,) imf3(x, 1)
118+
@test_call target_modules = (ID,) imf4(x; p=1)
122119
end
123120
end
124121

125122
function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T}
126123
imf1 = make_implicit_sqrt(; kwargs...)
127124
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
128-
imf3 = make_implicit_power_args(; kwargs...)
129-
imf4 = make_implicit_power_kwargs(; kwargs...)
125+
imf3 = make_implicit_sqrt_args(; kwargs...)
126+
imf4 = make_implicit_sqrt_kwargs(; kwargs...)
130127

131-
y_true = mysqrt(x)
128+
y_true = mysqrt(change_shape(x))
132129
dx = similar(x)
133130
dx .= one(T)
134131
x_and_dx = ForwardDiff.Dual.(x, dx)
135132

136133
y_and_dy1 = @inferred imf1(x_and_dx)
137134
y_and_dy2, z2 = @inferred imf2(x_and_dx)
138-
y_and_dy3 = @inferred imf3(x_and_dx, 2)
139-
y_and_dy4 = @inferred imf4(x_and_dx; p=2)
135+
y_and_dy3 = @inferred imf3(x_and_dx, 1)
136+
y_and_dy4 = @inferred imf4(x_and_dx; p=1)
140137

141138
@testset "Dual numbers" begin
142139
@test ForwardDiff.value.(y_and_dy1) y_true
143140
@test ForwardDiff.value.(y_and_dy2) y_true
144141
@test ForwardDiff.value.(y_and_dy3) y_true
145142
@test ForwardDiff.value.(y_and_dy4) y_true
146-
@test z2 2
143+
@test z2 1
147144
end
148145

149146
@testset "Static arrays" begin
@@ -156,31 +153,31 @@ function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T}
156153
@testset "JET" begin
157154
@test_opt target_modules = (ID,) imf1(x_and_dx)
158155
@test_opt target_modules = (ID,) imf2(x_and_dx)
159-
@test_opt target_modules = (ID,) imf3(x_and_dx, 2)
160-
@test_opt target_modules = (ID,) imf4(x_and_dx; p=2)
156+
@test_opt target_modules = (ID,) imf3(x_and_dx, 1)
157+
@test_opt target_modules = (ID,) imf4(x_and_dx; p=1)
161158

162159
@test_call target_modules = (ID,) imf1(x_and_dx)
163160
@test_call target_modules = (ID,) imf2(x_and_dx)
164-
@test_call target_modules = (ID,) imf3(x_and_dx, 2)
165-
@test_call target_modules = (ID,) imf4(x_and_dx; p=2)
161+
@test_call target_modules = (ID,) imf3(x_and_dx, 1)
162+
@test_call target_modules = (ID,) imf4(x_and_dx; p=1)
166163
end
167164
end
168165

169166
function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
170167
imf1 = make_implicit_sqrt(; kwargs...)
171168
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
172-
imf3 = make_implicit_power_args(; kwargs...)
173-
imf4 = make_implicit_power_kwargs(; kwargs...)
169+
imf3 = make_implicit_sqrt_args(; kwargs...)
170+
imf4 = make_implicit_sqrt_kwargs(; kwargs...)
174171

175-
y_true = mysqrt(x)
172+
y_true = mysqrt(change_shape(x))
176173
dy = similar(y_true)
177174
dy .= one(eltype(y_true))
178175
dz = nothing
179176

180177
y1, pb1 = @inferred rrule(rc, imf1, x)
181178
(y2, z2), pb2 = @inferred rrule(rc, imf2, x)
182-
y3, pb3 = @inferred rrule(rc, imf3, x, 2)
183-
y4, pb4 = @inferred rrule(rc, imf4, x; p=2)
179+
y3, pb3 = @inferred rrule(rc, imf3, x, 1)
180+
y4, pb4 = @inferred rrule(rc, imf4, x; p=1)
184181

185182
dimf1, dx1 = @inferred pb1(dy)
186183
dimf2, dx2 = @inferred pb2((dy, dz))
@@ -192,7 +189,7 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
192189
@test y2 y_true
193190
@test y3 y_true
194191
@test y4 y_true
195-
@test z2 2
192+
@test z2 1
196193

197194
@test dimf1 isa NoTangent
198195
@test dimf2 isa NoTangent
@@ -222,8 +219,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
222219
@testset "JET" begin
223220
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf1, x)
224221
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf2, x)
225-
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x, 2)
226-
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=2)
222+
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1)
223+
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=1)
227224

228225
@test_skip @test_opt target_modules = (ID,) pb1(dy)
229226
@test_skip @test_opt target_modules = (ID,) pb2((dy, dz))
@@ -232,8 +229,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
232229

233230
@test_call target_modules = (ID,) rrule(rc, imf1, x)
234231
@test_call target_modules = (ID,) rrule(rc, imf2, x)
235-
@test_call target_modules = (ID,) rrule(rc, imf3, x, 2)
236-
@test_call target_modules = (ID,) rrule(rc, imf4, x; p=2)
232+
@test_call target_modules = (ID,) rrule(rc, imf3, x, 1)
233+
@test_call target_modules = (ID,) rrule(rc, imf4, x; p=1)
237234

238235
@test_call target_modules = (ID,) pb1(dy)
239236
@test_call target_modules = (ID,) pb2((dy, dz))
@@ -244,8 +241,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
244241
@testset "ChainRulesTestUtils" begin
245242
test_rrule(rc, imf1, x; atol=1e-2)
246243
test_rrule(rc, imf2, x; atol=5e-2, output_tangent=(dy, 0)) # see issue https://github.com/gdalle/ImplicitDifferentiation.jl/issues/112
247-
test_rrule(rc, imf3, x, 2; atol=1e-2)
248-
test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=2,))
244+
test_rrule(rc, imf3, x, 1; atol=1e-2)
245+
test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=1,))
249246
end
250247
end
251248

@@ -254,13 +251,13 @@ end
254251
function test_implicit_forwarddiff(x::AbstractArray{T}; kwargs...) where {T}
255252
imf1 = make_implicit_sqrt(; kwargs...)
256253
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
257-
imf3 = make_implicit_power_args(; kwargs...)
258-
imf4 = make_implicit_power_kwargs(; kwargs...)
254+
imf3 = make_implicit_sqrt_args(; kwargs...)
255+
imf4 = make_implicit_sqrt_kwargs(; kwargs...)
259256

260257
J1 = ForwardDiff.jacobian(imf1, x)
261258
J2 = ForwardDiff.jacobian(first imf2, x)
262-
J3 = ForwardDiff.jacobian(_x -> imf3(_x, 2), x)
263-
J4 = ForwardDiff.jacobian(_x -> imf4(_x; p=2), x)
259+
J3 = ForwardDiff.jacobian(_x -> imf3(_x, 1), x)
260+
J4 = ForwardDiff.jacobian(_x -> imf4(_x; p=1), x)
264261
J_true = ForwardDiff.jacobian(_x -> sqrt.(change_shape(_x)), x)
265262

266263
@testset "Exact Jacobian" begin
@@ -280,13 +277,13 @@ end
280277
function test_implicit_zygote(x::AbstractArray{T}; kwargs...) where {T}
281278
imf1 = make_implicit_sqrt(; kwargs...)
282279
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
283-
imf3 = make_implicit_power_args(; kwargs...)
284-
imf4 = make_implicit_power_kwargs(; kwargs...)
280+
imf3 = make_implicit_sqrt_args(; kwargs...)
281+
imf4 = make_implicit_sqrt_kwargs(; kwargs...)
285282

286283
J1 = Zygote.jacobian(imf1, x)[1]
287284
J2 = Zygote.jacobian(first imf2, x)[1]
288-
J3 = Zygote.jacobian(imf3, x, 2)[1]
289-
J4 = Zygote.jacobian(_x -> imf4(_x; p=2), x)[1]
285+
J3 = Zygote.jacobian(imf3, x, 1)[1]
286+
J4 = Zygote.jacobian(_x -> imf4(_x; p=1), x)[1]
290287
J_true = Zygote.jacobian(_x -> sqrt.(change_shape(_x)), x)[1]
291288

292289
@testset "Exact Jacobian" begin
@@ -308,8 +305,10 @@ function test_implicit(x; kwargs...)
308305
test_implicit_call(x; kwargs...)
309306
end
310307
@testset verbose = true "ForwardDiff.jl" begin
311-
test_implicit_forwarddiff(x; kwargs...)
312-
test_implicit_duals(x; kwargs...)
308+
if !(x isa AbstractSparseArray)
309+
test_implicit_forwarddiff(x; kwargs...)
310+
test_implicit_duals(x; kwargs...)
311+
end
313312
end
314313
@testset verbose = true "Zygote.jl" begin
315314
rc = Zygote.ZygoteRuleConfig()
@@ -337,6 +336,8 @@ conditions_backend_candidates = (
337336
x_candidates = (
338337
rand(Float32, 2, 3, 2), #
339338
SArray{Tuple{2,3,2}}(rand(Float32, 2, 3, 2)), #
339+
sparse(rand(Float32, 2)), #
340+
sparse(rand(Float32, 2, 3)), #
340341
);
341342

342343
params_candidates = []
@@ -366,8 +367,15 @@ end
366367

367368
for (linear_solver, conditions_backend, x) in params_candidates
368369
testsetname = "$(typeof(linear_solver)) - $(typeof(conditions_backend)) - $(typeof(x))"
370+
if (
371+
linear_solver isa DirectLinearSolver &&
372+
x isa AbstractSparseArray &&
373+
VERSION < v"1.9"
374+
) # missing linalg function for sparse arrays in 1.6
375+
continue
376+
end
369377
@info "$testsetname"
370-
@testset "$testsetname" begin
378+
@testset verbose = true "$testsetname" begin
371379
test_implicit(x; linear_solver, conditions_backend)
372380
end
373381
end

0 commit comments

Comments
 (0)