Skip to content

Commit eb10848

Browse files
authored
Move Symmetric/Hermitian rules and tests to own file (#322)
* Move symmetric rules to own file * Move symmetric tests to own file * Increment version number
1 parent fa4b93a commit eb10848

File tree

7 files changed

+126
-123
lines changed

7 files changed

+126
-123
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.35"
3+
version = "0.7.36"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ include("rulesets/LinearAlgebra/blas.jl")
4545
include("rulesets/LinearAlgebra/dense.jl")
4646
include("rulesets/LinearAlgebra/norm.jl")
4747
include("rulesets/LinearAlgebra/structured.jl")
48+
include("rulesets/LinearAlgebra/symmetric.jl")
4849
include("rulesets/LinearAlgebra/factorization.jl")
4950

5051
include("rulesets/Random/random.jl")

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -86,86 +86,6 @@ function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real})
8686
return D * V, times_pullback
8787
end
8888

89-
#####
90-
##### `Symmetric`/`Hermitian`
91-
#####
92-
93-
function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
94-
return T(A, uplo), T(ΔA, uplo)
95-
end
96-
97-
function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
98-
Ω = T(A, uplo)
99-
function HermOrSym_pullback(ΔΩ)
100-
return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist())
101-
end
102-
return Ω, HermOrSym_pullback
103-
end
104-
105-
function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
106-
return TM(A), TM(_symherm_forward(A, ΔA))
107-
end
108-
function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym)
109-
return Array(A), Array(_symherm_forward(A, ΔA))
110-
end
111-
112-
function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
113-
function Matrix_pullback(ΔΩ)
114-
TA = _symhermtype(A)
115-
T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)}
116-
uplo = A.uplo
117-
∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo)
118-
return NO_FIELDS, ∂A
119-
end
120-
return TM(A), Matrix_pullback
121-
end
122-
rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A)
123-
124-
# Get type (Symmetric or Hermitian) from type or matrix
125-
_symhermtype(::Type{<:Symmetric}) = Symmetric
126-
_symhermtype(::Type{<:Hermitian}) = Hermitian
127-
_symhermtype(A) = _symhermtype(typeof(A))
128-
129-
# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω
130-
function _symherm_forward(A, ΔA)
131-
TA = _symhermtype(A)
132-
return if ΔA isa TA
133-
ΔA
134-
else
135-
TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo)
136-
end
137-
end
138-
139-
# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
140-
_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo)
141-
function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo)
142-
return _symmetric_back(ΔΩ, uplo)
143-
end
144-
_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo)
145-
_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo)
146-
147-
function _symmetric_back(ΔΩ, uplo)
148-
L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ)
149-
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
150-
end
151-
_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ
152-
_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ))
153-
_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ)
154-
155-
function _hermitian_back(ΔΩ, uplo)
156-
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
157-
return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD
158-
end
159-
_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ)
160-
function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
161-
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
162-
return if istriu(ΔΩ)
163-
return Matrix(uplo == 'U' ? ∂UL : ∂UL')
164-
else
165-
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
166-
end
167-
end
168-
16989
#####
17090
##### `Adjoint`
17191
#####
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
#####
2+
##### `Symmetric`/`Hermitian`
3+
#####
4+
5+
function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
6+
return T(A, uplo), T(ΔA, uplo)
7+
end
8+
9+
function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
10+
Ω = T(A, uplo)
11+
function HermOrSym_pullback(ΔΩ)
12+
return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist())
13+
end
14+
return Ω, HermOrSym_pullback
15+
end
16+
17+
function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
18+
return TM(A), TM(_symherm_forward(A, ΔA))
19+
end
20+
function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym)
21+
return Array(A), Array(_symherm_forward(A, ΔA))
22+
end
23+
24+
function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
25+
function Matrix_pullback(ΔΩ)
26+
TA = _symhermtype(A)
27+
T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)}
28+
uplo = A.uplo
29+
∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo)
30+
return NO_FIELDS, ∂A
31+
end
32+
return TM(A), Matrix_pullback
33+
end
34+
rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A)
35+
36+
# Get type (Symmetric or Hermitian) from type or matrix
37+
_symhermtype(::Type{<:Symmetric}) = Symmetric
38+
_symhermtype(::Type{<:Hermitian}) = Hermitian
39+
_symhermtype(A) = _symhermtype(typeof(A))
40+
41+
# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω
42+
function _symherm_forward(A, ΔA)
43+
TA = _symhermtype(A)
44+
return if ΔA isa TA
45+
ΔA
46+
else
47+
TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo)
48+
end
49+
end
50+
51+
# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
52+
_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo)
53+
function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo)
54+
return _symmetric_back(ΔΩ, uplo)
55+
end
56+
_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo)
57+
_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo)
58+
59+
function _symmetric_back(ΔΩ, uplo)
60+
L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ)
61+
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
62+
end
63+
_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ
64+
_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ))
65+
_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ)
66+
67+
function _hermitian_back(ΔΩ, uplo)
68+
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
69+
return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD
70+
end
71+
_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ)
72+
function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
73+
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
74+
return if istriu(ΔΩ)
75+
return Matrix(uplo == 'U' ? ∂UL : ∂UL')
76+
else
77+
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
78+
end
79+
end

test/rulesets/LinearAlgebra/structured.jl

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -104,48 +104,6 @@
104104
end
105105
end
106106
end
107-
@testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for
108-
SymHerm in (Symmetric, Hermitian),
109-
T in (Float64, ComplexF64),
110-
uplo in (:U, :L)
111-
112-
N = 3
113-
@testset "frule" begin
114-
x = randn(T, N, N)
115-
Δx = randn(T, N, N)
116-
# can't use frule_test here because it doesn't yet ignore nothing tangents
117-
Ω = SymHerm(x, uplo)
118-
Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo)
119-
@test Ω_ad == Ω
120-
∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx))
121-
@test ∂Ω_ad ∂Ω_fd
122-
end
123-
@testset "rrule" begin
124-
x = randn(T, N, N)
125-
∂x = randn(T, N, N)
126-
ΔΩ = randn(T, N, N)
127-
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
128-
rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing))
129-
end
130-
@testset "back(::Diagonal)" begin
131-
rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing))
132-
end
133-
end
134-
end
135-
@testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
136-
SymHerm in (Symmetric, Hermitian),
137-
T in (Float64, ComplexF64),
138-
uplo in (:U, :L)
139-
140-
N = 3
141-
x = SymHerm(randn(T, N, N), uplo)
142-
Δx = randn(T, N, N)
143-
∂x = SymHerm(randn(T, N, N), uplo)
144-
ΔΩ = f(SymHerm(randn(T, N, N), uplo))
145-
frule_test(f, (x, Δx))
146-
frule_test(f, (x, SymHerm(Δx, uplo)))
147-
rrule_test(f, ΔΩ, (x, ∂x))
148-
end
149107
@testset "$f" for f in (Adjoint, adjoint, Transpose, transpose)
150108
n = 5
151109
m = 3
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
@testset "Symmetric/Hermitian rules" begin
2+
@testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for
3+
SymHerm in (Symmetric, Hermitian),
4+
T in (Float64, ComplexF64),
5+
uplo in (:U, :L)
6+
7+
N = 3
8+
@testset "frule" begin
9+
x = randn(T, N, N)
10+
Δx = randn(T, N, N)
11+
# can't use frule_test here because it doesn't yet ignore nothing tangents
12+
Ω = SymHerm(x, uplo)
13+
Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo)
14+
@test Ω_ad == Ω
15+
∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx))
16+
@test ∂Ω_ad ∂Ω_fd
17+
end
18+
@testset "rrule" begin
19+
x = randn(T, N, N)
20+
∂x = randn(T, N, N)
21+
ΔΩ = randn(T, N, N)
22+
@testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular)
23+
rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing))
24+
end
25+
@testset "back(::Diagonal)" begin
26+
rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing))
27+
end
28+
end
29+
end
30+
@testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array),
31+
SymHerm in (Symmetric, Hermitian),
32+
T in (Float64, ComplexF64),
33+
uplo in (:U, :L)
34+
35+
N = 3
36+
x = SymHerm(randn(T, N, N), uplo)
37+
Δx = randn(T, N, N)
38+
∂x = SymHerm(randn(T, N, N), uplo)
39+
ΔΩ = f(SymHerm(randn(T, N, N), uplo))
40+
frule_test(f, (x, Δx))
41+
frule_test(f, (x, SymHerm(Δx, uplo)))
42+
rrule_test(f, ΔΩ, (x, ∂x))
43+
end
44+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ println("Testing ChainRules.jl")
4343
include_test("rulesets/LinearAlgebra/dense.jl")
4444
include_test("rulesets/LinearAlgebra/norm.jl")
4545
include_test("rulesets/LinearAlgebra/structured.jl")
46+
include_test("rulesets/LinearAlgebra/symmetric.jl")
4647
include_test("rulesets/LinearAlgebra/factorization.jl")
4748
include_test("rulesets/LinearAlgebra/blas.jl")
4849
end

0 commit comments

Comments
 (0)