29
29
pivot in (Val (true ), Val (false )),
30
30
m in (7 , 10 , 13 )
31
31
32
- A = randn (T, m, n)
33
- ΔA = rand_tangent (A)
34
- frule_test (lu!, (A, ΔA), (pivot, nothing ))
32
+ test_frule (lu!, randn (T, m, n), pivot ⊢ nothing )
35
33
end
36
34
@testset " check=false passed to primal function" begin
37
35
Asingular = zeros (n, n)
40
38
(Zero (), copy (ΔAsingular)), lu!, copy (Asingular), Val (true )
41
39
)
42
40
frule ((Zero (), ΔAsingular), lu!, Asingular, Val (true ); check= false )
41
+ @test true # above line would have errored if this was not working right
43
42
end
44
43
end
45
44
@testset " lu rrule" begin
48
47
pivot in (Val (true ), Val (false )),
49
48
m in (7 , 10 , 13 )
50
49
51
- A = randn (T, m, n)
52
- ΔA = rand_tangent (A)
53
- F = lu (A, pivot)
54
- Δfactors = rand_tangent (F. factors)
55
- ΔF = Composite {typeof(F)} (; factors= Δfactors)
56
- rrule_test (lu, ΔF, (A, ΔA), (pivot, nothing ))
50
+ test_rrule (lu, randn (T, m, n), pivot ⊢ nothing )
57
51
end
58
52
@testset " check=false passed to primal function" begin
59
53
Asingular = zeros (n, n)
62
56
@test_throws SingularException rrule (lu, Asingular, Val (true ))
63
57
_, back = rrule (lu, Asingular, Val (true ); check= false )
64
58
back (ΔF)
59
+ @test true # above line would have errored if this was not working right
65
60
end
66
61
end
67
62
@testset " LU" begin
71
66
k in (:U , :L , :factors ),
72
67
m in (7 , 10 , 13 )
73
68
74
- A = randn (m, n)
75
- F = lu (A)
76
- X = getproperty (F, k)
77
- ΔF = Composite {typeof(F)} (; factors= rand_tangent (F. factors))
78
- ΔX = rand_tangent (X)
79
- rrule_test (getproperty, ΔX, (F, ΔF), (k, nothing ); check_inferred= false )
69
+ F = lu (randn (m, n))
70
+ test_rrule (getproperty, F, k ⊢ nothing ; check_inferred= false )
80
71
end
81
72
end
82
73
@testset " matrix inverse using LU" begin
83
- @testset " LinearAlgebra.inv!(::LU) frule" begin
84
- @testset " inv!(lu(::LU{$T ,<:StridedMatrix}))" for T in (Float64,ComplexF64)
85
- A = randn (T, n, n)
86
- F = lu (A, Val (true ))
87
- ΔF = Composite {typeof(F)} (; factors= rand_tangent (F. factors))
88
- frule_test (LinearAlgebra. inv!, (F, ΔF))
89
- end
90
- end
91
- @testset " inv(::LU) rrule" begin
92
- @testset " inv(::LU{$T ,<:StridedMatrix})" for T in (Float64,ComplexF64)
93
- A = randn (T, n, n)
94
- F = lu (A, Val (true ))
95
- Y = inv (A)
96
- ΔF = Composite {typeof(F)} (; factors= rand_tangent (F. factors))
97
- ΔY = rand_tangent (Y)
98
- rrule_test (inv, ΔY, (F, ΔF))
99
- end
74
+ @testset " inv!(lu(::LU{$T ,<:StridedMatrix}))" for T in (Float64,ComplexF64)
75
+ test_frule (LinearAlgebra. inv!, lu (randn (T, n, n), Val (true )))
76
+ test_rrule (inv, lu (randn (T, n, n), Val (true )))
100
77
end
101
78
end
102
79
end
188
165
n = 10
189
166
190
167
@testset " eigen!(::Matrix{$T }) frule" for T in (Float64,ComplexF64)
191
- X = randn (T, n, n)
168
+ # get a bit away from zero so don't have finite differencing woes
169
+ # TODO : this better https://github.com/JuliaDiff/ChainRules.jl/issues/379
170
+ X = 10 .* (rand (T, n, n) .+ 5.0 )
192
171
Ẋ = rand_tangent (X)
193
172
F = eigen! (copy (X))
194
173
F_fwd, Ḟ_ad = frule ((Zero (), copy (Ẋ)), eigen!, copy (X))
@@ -209,24 +188,28 @@ end
209
188
end
210
189
end
211
190
212
- @testset " eigen(::Matrix{$T }) rrule" for T in (Float64,ComplexF64)
213
- # NOTE: eigen is not type-stable, so neither are is its rrule
214
- X = randn (T, n, n)
191
+ @testset " eigen(::Matrix{$T }) rrule" for T in (Float64, ComplexF64)
192
+ # get a bit away from zero so don't have finite differencing woes
193
+ # TODO : this better https://github.com/JuliaDiff/ChainRules.jl/issues/379
194
+ Random. seed! (1 )
195
+ X = 10 .* (rand (T, n, n) .+ 5.0 )
196
+
215
197
F = eigen (X)
216
198
V̄ = rand_tangent (F. vectors)
217
199
λ̄ = rand_tangent (F. values)
218
200
CT = Composite{typeof (F)}
219
201
F_rev, back = rrule (eigen, X)
220
202
@test F_rev == F
203
+ # NOTE: eigen is not type-stable, so neither are is its rrule
221
204
_, X̄_values_ad = @inferred back (CT (values = λ̄))
222
205
@test X̄_values_ad ≈ j′vp (_fdm, x -> eigen (x). values, λ̄, X)[1 ]
223
206
_, X̄_vectors_ad = @inferred back (CT (vectors = V̄))
224
- @test X̄_vectors_ad ≈ j′vp (_fdm, x -> eigen (x). vectors, V̄, X)[1 ]
207
+ @test X̄_vectors_ad ≈ j′vp (_fdm, x -> eigen (x). vectors, V̄, X)[1 ] rtol = 1e-4
225
208
F̄ = CT (values = λ̄, vectors = V̄)
226
209
s̄elf, X̄_ad = @inferred back (F̄)
227
210
@test s̄elf === NO_FIELDS
228
211
X̄_fd = j′vp (_fdm, asnt ∘ eigen, F̄, X)[1 ]
229
- @test X̄_ad ≈ X̄_fd
212
+ @test X̄_ad ≈ X̄_fd rtol = 1e-4
230
213
@test @inferred (back (Zero ())) === (NO_FIELDS, Zero ())
231
214
F̄zero = CT (values = Zero (), vectors = Zero ())
232
215
@test @inferred (back (F̄zero)) === (NO_FIELDS, Zero ())
337
320
@testset " eigvals!(::Matrix{$T }) frule" for T in (Float64,ComplexF64)
338
321
n = 10
339
322
X = randn (T, n, n)
340
- λ = eigvals! (copy (X))
341
- Ẋ = rand_tangent (X)
342
- frule_test (eigvals!, (X, Ẋ))
343
- @test frule ((Zero (), Zero ()), eigvals!, copy (X)) == (λ, Zero ())
323
+ test_frule (eigvals!, X)
324
+ @test frule ((Zero (), Zero ()), eigvals!, copy (X))[2 ] == Zero ()
344
325
345
326
@testset " tangents are real when outputs are" begin
346
327
# hermitian matrices have real eigenvalues
@@ -353,19 +334,13 @@ end
353
334
354
335
@testset " eigvals(::Matrix{$T }) rrule" for T in (Float64,ComplexF64)
355
336
n = 10
356
- X = randn (T, n, n)
357
- X̄ = rand_tangent (X)
358
- λ̄ = rand_tangent (eigvals (X))
359
- rrule_test (eigvals, λ̄, (X, X̄))
360
- back = rrule (eigvals, X)[2 ]
361
- @inferred back (λ̄)
337
+ test_rrule (eigvals, randn (T, n, n))
338
+
339
+ λ, back = rrule (eigvals, randn (T, n, n))
340
+ _, X̄ = @inferred back (rand_tangent (λ))
362
341
@test @inferred (back (Zero ())) === (NO_FIELDS, Zero ())
363
342
364
343
T <: Real && @testset " cotangent is real when input is" begin
365
- X = randn (T, n, n)
366
- λ = eigvals (X)
367
- λ̄ = rand_tangent (λ)
368
- X̄ = rrule (eigvals, X)[2 ](λ̄)[2 ]
369
344
@test eltype (X̄) <: Real
370
345
end
371
346
end
@@ -399,17 +374,18 @@ end
399
374
400
375
# These tests are generally a bit tricky to write because FiniteDifferences doesn't
401
376
# have fantastic support for this stuff at the minute.
377
+ # also we might be missing some overloads for different tangent-types in the rules
402
378
@testset " cholesky" begin
403
379
@testset " Real" begin
404
- C = cholesky (rand () + 0.1 )
405
- ΔC = Composite {typeof(C)} ((factors= rand_tangent (C. factors)))
406
- rrule_test (cholesky, ΔC, (rand () + 0.1 , randn ()))
380
+ test_rrule (cholesky, 0.8 )
407
381
end
408
382
@testset " Diagonal{<:Real}" begin
409
383
D = Diagonal (rand (5 ) .+ 0.1 )
410
384
C = cholesky (D)
411
- ΔC = Composite {typeof(C)} ((factors= Diagonal (randn (5 ))))
412
- rrule_test (cholesky, ΔC, (D, Diagonal (randn (5 ))), (Val (false ), nothing ))
385
+ test_rrule (
386
+ cholesky, D ⊢ Diagonal (randn (5 )), Val (false ) ⊢ nothing ;
387
+ output_tangent= Composite {typeof(C)} (factors= Diagonal (randn (5 )))
388
+ )
413
389
end
414
390
415
391
X = generate_well_conditioned_matrix (10 )
0 commit comments