Skip to content

Commit 22eddb4

Browse files
authored
Add eigen and eigvals rules for StridedMatrix (#321)
* Add implementation of eigen * Add implementation of eigvals * Add todo notes * Test eigen and eigvals * Choose dimension that is stable * Fix function call * Check that pullbacks are type-stable * Note why we don't check type-stability for rules * Add test for idempotence * Test that sensitivities are real when the primals are * Rearrange tests * Test sensitivities are real when primals are for eigvals * Increment version number * Don't compute eigenvectors if unused * Don't compute full matrix product * Avoid calling Matrix * Overload mutating versions for frule * Test mutating form for frule * Use fewer subscripts * Increment required patch version * Increment version number * Increment version number
1 parent 80443b1 commit 22eddb4

File tree

3 files changed

+266
-2
lines changed

3 files changed

+266
-2
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.37"
3+
version = "0.7.38"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -14,7 +14,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414

1515
[compat]
1616
ChainRulesCore = "0.9.21"
17-
ChainRulesTestUtils = "0.5.1"
17+
ChainRulesTestUtils = "0.5.5"
1818
Compat = "3"
1919
FiniteDifferences = "0.11.4"
2020
Reexport = "0.2"

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,138 @@ function svd_rev(USV::SVD, Ū, s̄, V̄)
6666
return Ā
6767
end
6868

69+
#####
70+
##### `eigen`
71+
#####
72+
73+
# TODO:
74+
# - support correct differential of phase convention when A is hermitian
75+
# - simplify when A is diagonal
76+
# - support degenerate matrices (see #144)
77+
78+
function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
79+
F = eigen!(A; kwargs...)
80+
ΔA isa AbstractZero && return F, ΔA
81+
λ, V = F.values, F.vectors
82+
tmp = V \ ΔA
83+
∂K = tmp * V
84+
∂Kdiag = @view ∂K[diagind(∂K)]
85+
∂λ = eltype(λ) <: Real ? real.(∂Kdiag) : copy(∂Kdiag)
86+
∂K ./= transpose(λ) .- λ
87+
fill!(∂Kdiag, 0)
88+
∂V = mul!(tmp, V, ∂K)
89+
_eigen_norm_phase_fwd!(∂V, A, V)
90+
∂F = Composite{typeof(F)}(values = ∂λ, vectors = ∂V)
91+
return F, ∂F
92+
end
93+
94+
function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}}
95+
F = eigen(A; kwargs...)
96+
function eigen_pullback(ΔF::Composite{<:Eigen})
97+
λ, V = F.values, F.vectors
98+
Δλ, ΔV = ΔF.values, ΔF.vectors
99+
if ΔV isa AbstractZero
100+
Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV)
101+
∂K = Diagonal(Δλ)
102+
∂A = V' \ ∂K * V'
103+
else
104+
∂V = copyto!(similar(ΔV), ΔV)
105+
_eigen_norm_phase_rev!(∂V, A, V)
106+
∂K = V' * ∂V
107+
∂K ./= λ' .- conj.(λ)
108+
∂K[diagind(∂K)] .= Δλ
109+
∂A = mul!(∂K, V' \ ∂K, V')
110+
end
111+
return NO_FIELDS, T <: Real ? real(∂A) : ∂A
112+
end
113+
eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF)
114+
return F, eigen_pullback
115+
end
116+
117+
# mutate ∂V to account for the (arbitrary but consistent) normalization and phase condition
118+
# applied to the eigenvectors.
119+
# these implementations assume the convention used by eigen in LinearAlgebra (i.e. that of
120+
# LAPACK.geevx!; eigenvectors have unit norm, and the element with the largest absolute
121+
# value is real), but they can be specialized for `A`
122+
123+
function _eigen_norm_phase_fwd!(∂V, A, V)
124+
@inbounds for i in axes(V, 2)
125+
v, ∂v = @views V[:, i], ∂V[:, i]
126+
# account for unit normalization
127+
∂c_norm = -real(dot(v, ∂v))
128+
if eltype(V) <: Real
129+
∂c = ∂c_norm
130+
else
131+
# account for rotation of largest element to real
132+
k = _findrealmaxabs2(v)
133+
∂c_phase = -imag(∂v[k]) / real(v[k])
134+
∂c = complex(∂c_norm, ∂c_phase)
135+
end
136+
∂v .+= v .* ∂c
137+
end
138+
return ∂V
139+
end
140+
141+
function _eigen_norm_phase_rev!(∂V, A, V)
142+
@inbounds for i in axes(V, 2)
143+
v, ∂v = @views V[:, i], ∂V[:, i]
144+
∂c = dot(v, ∂v)
145+
# account for unit normalization
146+
∂v .-= real(∂c) .* v
147+
if !(eltype(V) <: Real)
148+
# account for rotation of largest element to real
149+
k = _findrealmaxabs2(v)
150+
@inbounds ∂v[k] -= im * (imag(∂c) / real(v[k]))
151+
end
152+
end
153+
return ∂V
154+
end
155+
156+
# workaround for findmax not taking a mapped function
157+
function _findrealmaxabs2(x)
158+
amax = abs2(first(x))
159+
imax = 1
160+
@inbounds for i in 2:length(x)
161+
xi = x[i]
162+
!isreal(xi) && continue
163+
a = abs2(xi)
164+
a < amax && continue
165+
amax, imax = a, i
166+
end
167+
return imax
168+
end
169+
170+
#####
171+
##### `eigvals`
172+
#####
173+
174+
function frule((_, ΔA), ::typeof(eigvals!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
175+
ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA
176+
F = eigen!(A; kwargs...)
177+
λ, V = F.values, F.vectors
178+
tmp = V \ ΔA
179+
∂λ = similar(λ)
180+
# diag(tmp * V) without computing full matrix product
181+
if eltype(∂λ) <: Real
182+
broadcast!((a, b) -> sum(real prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
183+
else
184+
broadcast!((a, b) -> sum(prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
185+
end
186+
return λ, ∂λ
187+
end
188+
189+
function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}}
190+
F = eigen(A; kwargs...)
191+
λ = F.values
192+
function eigvals_pullback(Δλ)
193+
V = F.vectors
194+
∂A = V' \ Diagonal(Δλ) * V'
195+
return NO_FIELDS, T <: Real ? real(∂A) : ∂A
196+
end
197+
eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ)
198+
return λ, eigvals_pullback
199+
end
200+
69201
#####
70202
##### `cholesky`
71203
#####

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,138 @@ end
8585
end
8686
end
8787

88+
@testset "eigendecomposition" begin
89+
@testset "eigen/eigen!" begin
90+
# NOTE: eigen!/eigen are not type-stable, so neither are their frule/rrule
91+
92+
# avoid implementing to_vec(::Eigen)
93+
f(E::Eigen) = (values=E.values, vectors=E.vectors)
94+
95+
# NOTE: for unstructured matrices, low enough n, and this specific seed, finite
96+
# differences of eigen seems to be stable enough for direct comparison.
97+
# This allows us to directly check differential of normalization/phase
98+
# convention
99+
n = 10
100+
101+
@testset "eigen!(::Matrix{$T}) frule" for T in (Float64,ComplexF64)
102+
X = randn(T, n, n)
103+
= rand_tangent(X)
104+
F = eigen!(copy(X))
105+
F_fwd, Ḟ_ad = frule((Zero(), copy(Ẋ)), eigen!, copy(X))
106+
@test F_fwd == F
107+
@test Ḟ_ad isa Composite{typeof(F)}
108+
Ḟ_fd = jvp(_fdm, f eigen! copy, (X, Ẋ))
109+
@test Ḟ_ad.values Ḟ_fd.values
110+
@test Ḟ_ad.vectors Ḟ_fd.vectors
111+
@test frule((Zero(), Zero()), eigen!, copy(X)) == (F, Zero())
112+
113+
@testset "tangents are real when outputs are" begin
114+
# hermitian matrices have real eigenvalues and, when real, real eigenvectors
115+
X = Matrix(Hermitian(randn(T, n, n)))
116+
= Matrix(Hermitian(rand_tangent(X)))
117+
_, Ḟ = frule((Zero(), Ẋ), eigen!, X)
118+
@test eltype(Ḟ.values) <: Real
119+
T <: Real && @test eltype(Ḟ.vectors) <: Real
120+
end
121+
end
122+
123+
@testset "eigen(::Matrix{$T}) rrule" for T in (Float64,ComplexF64)
124+
# NOTE: eigen is not type-stable, so neither are is its rrule
125+
X = randn(T, n, n)
126+
F = eigen(X)
127+
= rand_tangent(F.vectors)
128+
λ̄ = rand_tangent(F.values)
129+
CT = Composite{typeof(F)}
130+
F_rev, back = rrule(eigen, X)
131+
@test F_rev == F
132+
_, X̄_values_ad = @inferred back(CT(values = λ̄))
133+
@test X̄_values_ad j′vp(_fdm, x -> eigen(x).values, λ̄, X)[1]
134+
_, X̄_vectors_ad = @inferred back(CT(vectors = V̄))
135+
@test X̄_vectors_ad j′vp(_fdm, x -> eigen(x).vectors, V̄, X)[1]
136+
= CT(values = λ̄, vectors = V̄)
137+
s̄elf, X̄_ad = @inferred back(F̄)
138+
@test s̄elf === NO_FIELDS
139+
X̄_fd = j′vp(_fdm, f eigen, F̄, X)[1]
140+
@test X̄_ad X̄_fd
141+
@test @inferred(back(Zero())) === (NO_FIELDS, Zero())
142+
F̄zero = CT(values = Zero(), vectors = Zero())
143+
@test @inferred(back(F̄zero)) === (NO_FIELDS, Zero())
144+
145+
T <: Real && @testset "cotangent is real when input is" begin
146+
X = randn(T, n, n)
147+
= rand_tangent(X)
148+
149+
F = eigen(X)
150+
= rand_tangent(F.vectors)
151+
λ̄ = rand_tangent(F.values)
152+
= Composite{typeof(F)}(values = λ̄, vectors = V̄)
153+
= rrule(eigen, X)[2](F̄)[2]
154+
@test eltype(X̄) <: Real
155+
end
156+
end
157+
158+
@testset "normalization/phase functions are idempotent" for T in (Float64,ComplexF64)
159+
# this is as much a math check as a code check. because normalization when
160+
# applied repeatedly is idempotent, repeated pushforward/pullback should
161+
# leave the (co)tangent unchanged
162+
X = randn(T, n, n)
163+
= rand_tangent(X)
164+
F = eigen(X)
165+
166+
= rand_tangent(F.vectors)
167+
V̇proj = ChainRules._eigen_norm_phase_fwd!(copy(V̇), X, F.vectors)
168+
@test !isapprox(V̇, V̇proj)
169+
V̇proj2 = ChainRules._eigen_norm_phase_fwd!(copy(V̇proj), X, F.vectors)
170+
@test V̇proj2 V̇proj
171+
172+
= rand_tangent(F.vectors)
173+
V̄proj = ChainRules._eigen_norm_phase_rev!(copy(V̄), X, F.vectors)
174+
@test !isapprox(V̄, V̄proj)
175+
V̄proj2 = ChainRules._eigen_norm_phase_rev!(copy(V̄proj), X, F.vectors)
176+
@test V̄proj2 V̄proj
177+
end
178+
end
179+
180+
@testset "eigvals/eigvals!" begin
181+
# NOTE: eigvals!/eigvals are not type-stable, so neither are their frule/rrule
182+
@testset "eigvals!(::Matrix{$T}) frule" for T in (Float64,ComplexF64)
183+
n = 10
184+
X = randn(T, n, n)
185+
λ = eigvals!(copy(X))
186+
= rand_tangent(X)
187+
frule_test(eigvals!, (X, Ẋ))
188+
@test frule((Zero(), Zero()), eigvals!, copy(X)) == (λ, Zero())
189+
190+
@testset "tangents are real when outputs are" begin
191+
# hermitian matrices have real eigenvalues
192+
X = Matrix(Hermitian(randn(T, n, n)))
193+
= Matrix(Hermitian(rand_tangent(X)))
194+
_, λ̇ = frule((Zero(), Ẋ), eigvals!, X)
195+
@test eltype(λ̇) <: Real
196+
end
197+
end
198+
199+
@testset "eigvals(::Matrix{$T}) rrule" for T in (Float64,ComplexF64)
200+
n = 10
201+
X = randn(T, n, n)
202+
= rand_tangent(X)
203+
λ̄ = rand_tangent(eigvals(X))
204+
rrule_test(eigvals, λ̄, (X, X̄))
205+
back = rrule(eigvals, X)[2]
206+
@inferred back(λ̄)
207+
@test @inferred(back(Zero())) === (NO_FIELDS, Zero())
208+
209+
T <: Real && @testset "cotangent is real when input is" begin
210+
X = randn(T, n, n)
211+
λ = eigvals(X)
212+
λ̄ = rand_tangent(λ)
213+
= rrule(eigvals, X)[2](λ̄)[2]
214+
@test eltype(X̄) <: Real
215+
end
216+
end
217+
end
218+
end
219+
88220
# These tests are generally a bit tricky to write because FiniteDifferences doesn't
89221
# have fantastic support for this stuff at the minute.
90222
@testset "cholesky" begin

0 commit comments

Comments
 (0)