Skip to content

Commit 439f482

Browse files
authored
factorizations.jl autotangent (#375)
* factorizations.jl autotangent * Finish changing to new way of tangents and push eigen's test primal away from zero * remove extra end * push further from zero * Update test/rulesets/LinearAlgebra/factorization.jl * Update test/rulesets/LinearAlgebra/factorization.jl * trying to find a stable point
1 parent f0fead0 commit 439f482

File tree

1 file changed

+33
-57
lines changed

1 file changed

+33
-57
lines changed

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ end
2929
pivot in (Val(true), Val(false)),
3030
m in (7, 10, 13)
3131

32-
A = randn(T, m, n)
33-
ΔA = rand_tangent(A)
34-
frule_test(lu!, (A, ΔA), (pivot, nothing))
32+
test_frule(lu!, randn(T, m, n), pivot nothing)
3533
end
3634
@testset "check=false passed to primal function" begin
3735
Asingular = zeros(n, n)
@@ -40,6 +38,7 @@ end
4038
(Zero(), copy(ΔAsingular)), lu!, copy(Asingular), Val(true)
4139
)
4240
frule((Zero(), ΔAsingular), lu!, Asingular, Val(true); check=false)
41+
@test true # above line would have errored if this was not working right
4342
end
4443
end
4544
@testset "lu rrule" begin
@@ -48,12 +47,7 @@ end
4847
pivot in (Val(true), Val(false)),
4948
m in (7, 10, 13)
5049

51-
A = randn(T, m, n)
52-
ΔA = rand_tangent(A)
53-
F = lu(A, pivot)
54-
Δfactors = rand_tangent(F.factors)
55-
ΔF = Composite{typeof(F)}(; factors=Δfactors)
56-
rrule_test(lu, ΔF, (A, ΔA), (pivot, nothing))
50+
test_rrule(lu, randn(T, m, n), pivot nothing)
5751
end
5852
@testset "check=false passed to primal function" begin
5953
Asingular = zeros(n, n)
@@ -62,6 +56,7 @@ end
6256
@test_throws SingularException rrule(lu, Asingular, Val(true))
6357
_, back = rrule(lu, Asingular, Val(true); check=false)
6458
back(ΔF)
59+
@test true # above line would have errored if this was not working right
6560
end
6661
end
6762
@testset "LU" begin
@@ -71,32 +66,14 @@ end
7166
k in (:U, :L, :factors),
7267
m in (7, 10, 13)
7368

74-
A = randn(m, n)
75-
F = lu(A)
76-
X = getproperty(F, k)
77-
ΔF = Composite{typeof(F)}(; factors=rand_tangent(F.factors))
78-
ΔX = rand_tangent(X)
79-
rrule_test(getproperty, ΔX, (F, ΔF), (k, nothing); check_inferred=false)
69+
F = lu(randn(m, n))
70+
test_rrule(getproperty, F, k nothing ; check_inferred=false)
8071
end
8172
end
8273
@testset "matrix inverse using LU" begin
83-
@testset "LinearAlgebra.inv!(::LU) frule" begin
84-
@testset "inv!(lu(::LU{$T,<:StridedMatrix}))" for T in (Float64,ComplexF64)
85-
A = randn(T, n, n)
86-
F = lu(A, Val(true))
87-
ΔF = Composite{typeof(F)}(; factors=rand_tangent(F.factors))
88-
frule_test(LinearAlgebra.inv!, (F, ΔF))
89-
end
90-
end
91-
@testset "inv(::LU) rrule" begin
92-
@testset "inv(::LU{$T,<:StridedMatrix})" for T in (Float64,ComplexF64)
93-
A = randn(T, n, n)
94-
F = lu(A, Val(true))
95-
Y = inv(A)
96-
ΔF = Composite{typeof(F)}(; factors=rand_tangent(F.factors))
97-
ΔY = rand_tangent(Y)
98-
rrule_test(inv, ΔY, (F, ΔF))
99-
end
74+
@testset "inv!(lu(::LU{$T,<:StridedMatrix}))" for T in (Float64,ComplexF64)
75+
test_frule(LinearAlgebra.inv!, lu(randn(T, n, n), Val(true)))
76+
test_rrule(inv, lu(randn(T, n, n), Val(true)))
10077
end
10178
end
10279
end
@@ -188,7 +165,9 @@ end
188165
n = 10
189166

190167
@testset "eigen!(::Matrix{$T}) frule" for T in (Float64,ComplexF64)
191-
X = randn(T, n, n)
168+
# get a bit away from zero so don't have finite differencing woes
169+
# TODO: this better https://github.com/JuliaDiff/ChainRules.jl/issues/379
170+
X = 10 .* (rand(T, n, n) .+ 5.0)
192171
= rand_tangent(X)
193172
F = eigen!(copy(X))
194173
F_fwd, Ḟ_ad = frule((Zero(), copy(Ẋ)), eigen!, copy(X))
@@ -209,24 +188,28 @@ end
209188
end
210189
end
211190

212-
@testset "eigen(::Matrix{$T}) rrule" for T in (Float64,ComplexF64)
213-
# NOTE: eigen is not type-stable, so neither are is its rrule
214-
X = randn(T, n, n)
191+
@testset "eigen(::Matrix{$T}) rrule" for T in (Float64, ComplexF64)
192+
# get a bit away from zero so don't have finite differencing woes
193+
# TODO: this better https://github.com/JuliaDiff/ChainRules.jl/issues/379
194+
Random.seed!(1)
195+
X = 10 .* (rand(T, n, n) .+ 5.0)
196+
215197
F = eigen(X)
216198
= rand_tangent(F.vectors)
217199
λ̄ = rand_tangent(F.values)
218200
CT = Composite{typeof(F)}
219201
F_rev, back = rrule(eigen, X)
220202
@test F_rev == F
203+
# NOTE: eigen is not type-stable, so neither are is its rrule
221204
_, X̄_values_ad = @inferred back(CT(values = λ̄))
222205
@test X̄_values_ad j′vp(_fdm, x -> eigen(x).values, λ̄, X)[1]
223206
_, X̄_vectors_ad = @inferred back(CT(vectors = V̄))
224-
@test X̄_vectors_ad j′vp(_fdm, x -> eigen(x).vectors, V̄, X)[1]
207+
@test X̄_vectors_ad j′vp(_fdm, x -> eigen(x).vectors, V̄, X)[1] rtol=1e-4
225208
= CT(values = λ̄, vectors = V̄)
226209
s̄elf, X̄_ad = @inferred back(F̄)
227210
@test s̄elf === NO_FIELDS
228211
X̄_fd = j′vp(_fdm, asnt eigen, F̄, X)[1]
229-
@test X̄_ad X̄_fd
212+
@test X̄_ad X̄_fd rtol=1e-4
230213
@test @inferred(back(Zero())) === (NO_FIELDS, Zero())
231214
F̄zero = CT(values = Zero(), vectors = Zero())
232215
@test @inferred(back(F̄zero)) === (NO_FIELDS, Zero())
@@ -337,10 +320,8 @@ end
337320
@testset "eigvals!(::Matrix{$T}) frule" for T in (Float64,ComplexF64)
338321
n = 10
339322
X = randn(T, n, n)
340-
λ = eigvals!(copy(X))
341-
= rand_tangent(X)
342-
frule_test(eigvals!, (X, Ẋ))
343-
@test frule((Zero(), Zero()), eigvals!, copy(X)) == (λ, Zero())
323+
test_frule(eigvals!, X)
324+
@test frule((Zero(), Zero()), eigvals!, copy(X))[2] == Zero()
344325

345326
@testset "tangents are real when outputs are" begin
346327
# hermitian matrices have real eigenvalues
@@ -353,19 +334,13 @@ end
353334

354335
@testset "eigvals(::Matrix{$T}) rrule" for T in (Float64,ComplexF64)
355336
n = 10
356-
X = randn(T, n, n)
357-
= rand_tangent(X)
358-
λ̄ = rand_tangent(eigvals(X))
359-
rrule_test(eigvals, λ̄, (X, X̄))
360-
back = rrule(eigvals, X)[2]
361-
@inferred back(λ̄)
337+
test_rrule(eigvals, randn(T, n, n))
338+
339+
λ, back = rrule(eigvals, randn(T, n, n))
340+
_, X̄ = @inferred back(rand_tangent(λ))
362341
@test @inferred(back(Zero())) === (NO_FIELDS, Zero())
363342

364343
T <: Real && @testset "cotangent is real when input is" begin
365-
X = randn(T, n, n)
366-
λ = eigvals(X)
367-
λ̄ = rand_tangent(λ)
368-
= rrule(eigvals, X)[2](λ̄)[2]
369344
@test eltype(X̄) <: Real
370345
end
371346
end
@@ -399,17 +374,18 @@ end
399374

400375
# These tests are generally a bit tricky to write because FiniteDifferences doesn't
401376
# have fantastic support for this stuff at the minute.
377+
# also we might be missing some overloads for different tangent-types in the rules
402378
@testset "cholesky" begin
403379
@testset "Real" begin
404-
C = cholesky(rand() + 0.1)
405-
ΔC = Composite{typeof(C)}((factors=rand_tangent(C.factors)))
406-
rrule_test(cholesky, ΔC, (rand() + 0.1, randn()))
380+
test_rrule(cholesky, 0.8)
407381
end
408382
@testset "Diagonal{<:Real}" begin
409383
D = Diagonal(rand(5) .+ 0.1)
410384
C = cholesky(D)
411-
ΔC = Composite{typeof(C)}((factors=Diagonal(randn(5))))
412-
rrule_test(cholesky, ΔC, (D, Diagonal(randn(5))), (Val(false), nothing))
385+
test_rrule(
386+
cholesky, D Diagonal(randn(5)), Val(false) nothing;
387+
output_tangent=Composite{typeof(C)}(factors=Diagonal(randn(5)))
388+
)
413389
end
414390

415391
X = generate_well_conditioned_matrix(10)

0 commit comments

Comments
 (0)