@@ -27,41 +27,38 @@ Random.seed!(63);
27
27
# # Utils
28
28
29
29
change_shape (x:: AbstractArray{T,3} ) where {T} = x[:, :, 1 ]
30
+ change_shape (x:: AbstractSparseArray ) = x
30
31
31
32
function mysqrt (x:: AbstractArray )
32
- return identity_break_autodiff (sqrt .(abs .(change_shape (x))))
33
- end
34
-
35
- function mypower (x:: AbstractArray , p)
36
- return identity_break_autodiff (abs .(change_shape (x)) .^ p)
33
+ return identity_break_autodiff (sqrt .(abs .(x)))
37
34
end
38
35
39
36
# # Various signatures
40
37
41
38
function make_implicit_sqrt (; kwargs... )
42
- forward (x) = mysqrt (x )
43
- conditions (x, y) = y .^ 2 .- abs .(change_shape (x))
39
+ forward (x) = mysqrt (change_shape (x) )
40
+ conditions (x, y) = abs2 .(y) .- abs .(change_shape (x))
44
41
implicit = ImplicitFunction (forward, conditions; kwargs... )
45
42
return implicit
46
43
end
47
44
48
45
function make_implicit_sqrt_byproduct (; kwargs... )
49
- forward (x) = mysqrt (x), 2
50
- conditions (x, y, z:: Integer ) = y .^ z .- abs .(change_shape (x))
46
+ forward (x) = 1 * mysqrt (change_shape (x)), 1
47
+ conditions (x, y, z:: Integer ) = abs2 .( y ./ z) .- abs .(change_shape (x))
51
48
implicit = ImplicitFunction (forward, conditions; kwargs... )
52
49
return implicit
53
50
end
54
51
55
- function make_implicit_power_args (; kwargs... )
56
- forward (x, p:: Integer ) = mypower (x, one ( eltype (x)) / p )
57
- conditions (x, y, p:: Integer ) = y .^ p .- abs .(change_shape (x))
52
+ function make_implicit_sqrt_args (; kwargs... )
53
+ forward (x, p:: Integer ) = p * mysqrt ( change_shape (x))
54
+ conditions (x, y, p:: Integer ) = abs2 .( y ./ p) .- abs .(change_shape (x))
58
55
implicit = ImplicitFunction (forward, conditions; kwargs... )
59
56
return implicit
60
57
end
61
58
62
- function make_implicit_power_kwargs (; kwargs... )
63
- forward (x; p:: Integer ) = mypower (x, one ( eltype (x)) / p )
64
- conditions (x, y; p:: Integer ) = y .^ p .- abs .(change_shape (x))
59
+ function make_implicit_sqrt_kwargs (; kwargs... )
60
+ forward (x; p:: Integer ) = p .* mysqrt ( change_shape (x))
61
+ conditions (x, y; p:: Integer ) = abs2 .( y ./ p) .- abs .(change_shape (x))
65
62
implicit = ImplicitFunction (forward, conditions; kwargs... )
66
63
return implicit
67
64
end
85
82
function test_implicit_call (x:: AbstractArray{T} ; kwargs... ) where {T}
86
83
imf1 = make_implicit_sqrt (; kwargs... )
87
84
imf2 = make_implicit_sqrt_byproduct (; kwargs... )
88
- imf3 = make_implicit_power_args (; kwargs... )
89
- imf4 = make_implicit_power_kwargs (; kwargs... )
85
+ imf3 = make_implicit_sqrt_args (; kwargs... )
86
+ imf4 = make_implicit_sqrt_kwargs (; kwargs... )
90
87
91
- y_true = mysqrt (x )
88
+ y_true = mysqrt (change_shape (x) )
92
89
y1 = @inferred imf1 (x)
93
90
y2, z2 = @inferred imf2 (x)
94
- y3 = @inferred imf3 (x, 2 )
95
- y4 = @inferred imf4 (x; p= 2 )
91
+ y3 = @inferred imf3 (x, 1 )
92
+ y4 = @inferred imf4 (x; p= 1 )
96
93
97
94
@testset " Exact value" begin
98
95
@test y1 ≈ y_true
99
96
@test y2 ≈ y_true
100
97
@test y3 ≈ y_true
101
98
@test y4 ≈ y_true
102
- @test z2 ≈ 2
99
+ @test z2 ≈ 1
103
100
end
104
101
105
102
@testset " Array type" begin
@@ -112,38 +109,38 @@ function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T}
112
109
@testset " JET" begin
113
110
@test_opt target_modules = (ID,) imf1 (x)
114
111
@test_opt target_modules = (ID,) imf2 (x)
115
- @test_opt target_modules = (ID,) imf3 (x, 2 )
116
- @test_opt target_modules = (ID,) imf4 (x; p= 2 )
112
+ @test_opt target_modules = (ID,) imf3 (x, 1 )
113
+ @test_opt target_modules = (ID,) imf4 (x; p= 1 )
117
114
118
115
@test_call target_modules = (ID,) imf1 (x)
119
116
@test_call target_modules = (ID,) imf2 (x)
120
- @test_call target_modules = (ID,) imf3 (x, 2 )
121
- @test_call target_modules = (ID,) imf4 (x; p= 2 )
117
+ @test_call target_modules = (ID,) imf3 (x, 1 )
118
+ @test_call target_modules = (ID,) imf4 (x; p= 1 )
122
119
end
123
120
end
124
121
125
122
function test_implicit_duals (x:: AbstractArray{T} ; kwargs... ) where {T}
126
123
imf1 = make_implicit_sqrt (; kwargs... )
127
124
imf2 = make_implicit_sqrt_byproduct (; kwargs... )
128
- imf3 = make_implicit_power_args (; kwargs... )
129
- imf4 = make_implicit_power_kwargs (; kwargs... )
125
+ imf3 = make_implicit_sqrt_args (; kwargs... )
126
+ imf4 = make_implicit_sqrt_kwargs (; kwargs... )
130
127
131
- y_true = mysqrt (x )
128
+ y_true = mysqrt (change_shape (x) )
132
129
dx = similar (x)
133
130
dx .= one (T)
134
131
x_and_dx = ForwardDiff. Dual .(x, dx)
135
132
136
133
y_and_dy1 = @inferred imf1 (x_and_dx)
137
134
y_and_dy2, z2 = @inferred imf2 (x_and_dx)
138
- y_and_dy3 = @inferred imf3 (x_and_dx, 2 )
139
- y_and_dy4 = @inferred imf4 (x_and_dx; p= 2 )
135
+ y_and_dy3 = @inferred imf3 (x_and_dx, 1 )
136
+ y_and_dy4 = @inferred imf4 (x_and_dx; p= 1 )
140
137
141
138
@testset " Dual numbers" begin
142
139
@test ForwardDiff. value .(y_and_dy1) ≈ y_true
143
140
@test ForwardDiff. value .(y_and_dy2) ≈ y_true
144
141
@test ForwardDiff. value .(y_and_dy3) ≈ y_true
145
142
@test ForwardDiff. value .(y_and_dy4) ≈ y_true
146
- @test z2 ≈ 2
143
+ @test z2 ≈ 1
147
144
end
148
145
149
146
@testset " Static arrays" begin
@@ -156,31 +153,31 @@ function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T}
156
153
@testset " JET" begin
157
154
@test_opt target_modules = (ID,) imf1 (x_and_dx)
158
155
@test_opt target_modules = (ID,) imf2 (x_and_dx)
159
- @test_opt target_modules = (ID,) imf3 (x_and_dx, 2 )
160
- @test_opt target_modules = (ID,) imf4 (x_and_dx; p= 2 )
156
+ @test_opt target_modules = (ID,) imf3 (x_and_dx, 1 )
157
+ @test_opt target_modules = (ID,) imf4 (x_and_dx; p= 1 )
161
158
162
159
@test_call target_modules = (ID,) imf1 (x_and_dx)
163
160
@test_call target_modules = (ID,) imf2 (x_and_dx)
164
- @test_call target_modules = (ID,) imf3 (x_and_dx, 2 )
165
- @test_call target_modules = (ID,) imf4 (x_and_dx; p= 2 )
161
+ @test_call target_modules = (ID,) imf3 (x_and_dx, 1 )
162
+ @test_call target_modules = (ID,) imf4 (x_and_dx; p= 1 )
166
163
end
167
164
end
168
165
169
166
function test_implicit_rrule (rc, x:: AbstractArray{T} ; kwargs... ) where {T}
170
167
imf1 = make_implicit_sqrt (; kwargs... )
171
168
imf2 = make_implicit_sqrt_byproduct (; kwargs... )
172
- imf3 = make_implicit_power_args (; kwargs... )
173
- imf4 = make_implicit_power_kwargs (; kwargs... )
169
+ imf3 = make_implicit_sqrt_args (; kwargs... )
170
+ imf4 = make_implicit_sqrt_kwargs (; kwargs... )
174
171
175
- y_true = mysqrt (x )
172
+ y_true = mysqrt (change_shape (x) )
176
173
dy = similar (y_true)
177
174
dy .= one (eltype (y_true))
178
175
dz = nothing
179
176
180
177
y1, pb1 = @inferred rrule (rc, imf1, x)
181
178
(y2, z2), pb2 = @inferred rrule (rc, imf2, x)
182
- y3, pb3 = @inferred rrule (rc, imf3, x, 2 )
183
- y4, pb4 = @inferred rrule (rc, imf4, x; p= 2 )
179
+ y3, pb3 = @inferred rrule (rc, imf3, x, 1 )
180
+ y4, pb4 = @inferred rrule (rc, imf4, x; p= 1 )
184
181
185
182
dimf1, dx1 = @inferred pb1 (dy)
186
183
dimf2, dx2 = @inferred pb2 ((dy, dz))
@@ -192,7 +189,7 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
192
189
@test y2 ≈ y_true
193
190
@test y3 ≈ y_true
194
191
@test y4 ≈ y_true
195
- @test z2 ≈ 2
192
+ @test z2 ≈ 1
196
193
197
194
@test dimf1 isa NoTangent
198
195
@test dimf2 isa NoTangent
@@ -222,8 +219,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
222
219
@testset " JET" begin
223
220
@test_skip @test_opt target_modules = (ID,) rrule (rc, imf1, x)
224
221
@test_skip @test_opt target_modules = (ID,) rrule (rc, imf2, x)
225
- @test_skip @test_opt target_modules = (ID,) rrule (rc, imf3, x, 2 )
226
- @test_skip @test_opt target_modules = (ID,) rrule (rc, imf4, x; p= 2 )
222
+ @test_skip @test_opt target_modules = (ID,) rrule (rc, imf3, x, 1 )
223
+ @test_skip @test_opt target_modules = (ID,) rrule (rc, imf4, x; p= 1 )
227
224
228
225
@test_skip @test_opt target_modules = (ID,) pb1 (dy)
229
226
@test_skip @test_opt target_modules = (ID,) pb2 ((dy, dz))
@@ -232,8 +229,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
232
229
233
230
@test_call target_modules = (ID,) rrule (rc, imf1, x)
234
231
@test_call target_modules = (ID,) rrule (rc, imf2, x)
235
- @test_call target_modules = (ID,) rrule (rc, imf3, x, 2 )
236
- @test_call target_modules = (ID,) rrule (rc, imf4, x; p= 2 )
232
+ @test_call target_modules = (ID,) rrule (rc, imf3, x, 1 )
233
+ @test_call target_modules = (ID,) rrule (rc, imf4, x; p= 1 )
237
234
238
235
@test_call target_modules = (ID,) pb1 (dy)
239
236
@test_call target_modules = (ID,) pb2 ((dy, dz))
@@ -244,8 +241,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
244
241
@testset " ChainRulesTestUtils" begin
245
242
test_rrule (rc, imf1, x; atol= 1e-2 )
246
243
test_rrule (rc, imf2, x; atol= 5e-2 , output_tangent= (dy, 0 )) # see issue https://github.com/gdalle/ImplicitDifferentiation.jl/issues/112
247
- test_rrule (rc, imf3, x, 2 ; atol= 1e-2 )
248
- test_rrule (rc, imf4, x; atol= 1e-2 , fkwargs= (p= 2 ,))
244
+ test_rrule (rc, imf3, x, 1 ; atol= 1e-2 )
245
+ test_rrule (rc, imf4, x; atol= 1e-2 , fkwargs= (p= 1 ,))
249
246
end
250
247
end
251
248
@@ -254,13 +251,13 @@ end
254
251
function test_implicit_forwarddiff (x:: AbstractArray{T} ; kwargs... ) where {T}
255
252
imf1 = make_implicit_sqrt (; kwargs... )
256
253
imf2 = make_implicit_sqrt_byproduct (; kwargs... )
257
- imf3 = make_implicit_power_args (; kwargs... )
258
- imf4 = make_implicit_power_kwargs (; kwargs... )
254
+ imf3 = make_implicit_sqrt_args (; kwargs... )
255
+ imf4 = make_implicit_sqrt_kwargs (; kwargs... )
259
256
260
257
J1 = ForwardDiff. jacobian (imf1, x)
261
258
J2 = ForwardDiff. jacobian (first ∘ imf2, x)
262
- J3 = ForwardDiff. jacobian (_x -> imf3 (_x, 2 ), x)
263
- J4 = ForwardDiff. jacobian (_x -> imf4 (_x; p= 2 ), x)
259
+ J3 = ForwardDiff. jacobian (_x -> imf3 (_x, 1 ), x)
260
+ J4 = ForwardDiff. jacobian (_x -> imf4 (_x; p= 1 ), x)
264
261
J_true = ForwardDiff. jacobian (_x -> sqrt .(change_shape (_x)), x)
265
262
266
263
@testset " Exact Jacobian" begin
@@ -280,13 +277,13 @@ end
280
277
function test_implicit_zygote (x:: AbstractArray{T} ; kwargs... ) where {T}
281
278
imf1 = make_implicit_sqrt (; kwargs... )
282
279
imf2 = make_implicit_sqrt_byproduct (; kwargs... )
283
- imf3 = make_implicit_power_args (; kwargs... )
284
- imf4 = make_implicit_power_kwargs (; kwargs... )
280
+ imf3 = make_implicit_sqrt_args (; kwargs... )
281
+ imf4 = make_implicit_sqrt_kwargs (; kwargs... )
285
282
286
283
J1 = Zygote. jacobian (imf1, x)[1 ]
287
284
J2 = Zygote. jacobian (first ∘ imf2, x)[1 ]
288
- J3 = Zygote. jacobian (imf3, x, 2 )[1 ]
289
- J4 = Zygote. jacobian (_x -> imf4 (_x; p= 2 ), x)[1 ]
285
+ J3 = Zygote. jacobian (imf3, x, 1 )[1 ]
286
+ J4 = Zygote. jacobian (_x -> imf4 (_x; p= 1 ), x)[1 ]
290
287
J_true = Zygote. jacobian (_x -> sqrt .(change_shape (_x)), x)[1 ]
291
288
292
289
@testset " Exact Jacobian" begin
@@ -308,8 +305,10 @@ function test_implicit(x; kwargs...)
308
305
test_implicit_call (x; kwargs... )
309
306
end
310
307
@testset verbose = true " ForwardDiff.jl" begin
311
- test_implicit_forwarddiff (x; kwargs... )
312
- test_implicit_duals (x; kwargs... )
308
+ if ! (x isa AbstractSparseArray)
309
+ test_implicit_forwarddiff (x; kwargs... )
310
+ test_implicit_duals (x; kwargs... )
311
+ end
313
312
end
314
313
@testset verbose = true " Zygote.jl" begin
315
314
rc = Zygote. ZygoteRuleConfig ()
@@ -337,6 +336,8 @@ conditions_backend_candidates = (
337
336
x_candidates = (
338
337
rand (Float32, 2 , 3 , 2 ), #
339
338
SArray {Tuple{2,3,2}} (rand (Float32, 2 , 3 , 2 )), #
339
+ sparse (rand (Float32, 2 )), #
340
+ sparse (rand (Float32, 2 , 3 )), #
340
341
);
341
342
342
343
params_candidates = []
366
367
367
368
for (linear_solver, conditions_backend, x) in params_candidates
368
369
testsetname = " $(typeof (linear_solver)) - $(typeof (conditions_backend)) - $(typeof (x)) "
370
+ if (
371
+ linear_solver isa DirectLinearSolver &&
372
+ x isa AbstractSparseArray &&
373
+ VERSION < v " 1.9"
374
+ ) # missing linalg function for sparse arrays in 1.6
375
+ continue
376
+ end
369
377
@info " $testsetname "
370
- @testset " $testsetname " begin
378
+ @testset verbose = true " $testsetname " begin
371
379
test_implicit (x; linear_solver, conditions_backend)
372
380
end
373
381
end
0 commit comments