Skip to content

Commit 55a3923

Browse files
committed
Move some code around (NFC)
* Move the BLAS definitions into the `linalg` directory. * Move some general optimizations into a specific utilities file for linear algebra definitions. * Consolidate structured matrix operations, e.g. diagonal and symmetric, into one file.
1 parent 0e551fb commit 55a3923

File tree

10 files changed

+59
-45
lines changed

10 files changed

+59
-45
lines changed

src/ChainRules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ include("rules.jl")
1313
include("rules/base.jl")
1414
include("rules/array.jl")
1515
include("rules/broadcast.jl")
16+
include("rules/linalg/utils.jl")
17+
include("rules/linalg/blas.jl")
1618
include("rules/linalg/dense.jl")
17-
include("rules/linalg/diagonal.jl")
18-
include("rules/linalg/symmetric.jl")
19+
include("rules/linalg/structured.jl")
1920
include("rules/linalg/factorization.jl")
20-
include("rules/blas.jl")
2121
include("rules/nanmath.jl")
2222
include("rules/specialfunctions.jl")
2323

File renamed without changes.

src/rules/linalg/diagonal.jl

Lines changed: 0 additions & 2 deletions
This file was deleted.

src/rules/linalg/factorization.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,33 +59,6 @@ function svd_rev(USV::SVD, Ū::AbstractMatrix, s̄::AbstractVector, V̄::Abstra
5959
return
6060
end
6161

62-
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
63-
k = size(X, 1)
64-
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
65-
if i == j
66-
X[i,i] = zero(T)
67-
else
68-
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
69-
end
70-
end
71-
X
72-
end
73-
74-
function _eyesubx!(X::AbstractMatrix{T}) where T<:Real
75-
n, m = size(X)
76-
@inbounds for j = 1:m, i = 1:n
77-
X[i,j] = (i == j) - X[i,j]
78-
end
79-
X
80-
end
81-
82-
function _add!(X::AbstractMatrix{T}, Y::AbstractMatrix{T}) where T<:Real
83-
@inbounds for i = eachindex(X, Y)
84-
X[i] += Y[i]
85-
end
86-
X
87-
end
88-
8962
#####
9063
##### `cholesky`
9164
#####

src/rules/linalg/structured.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Structured matrices
2+
3+
#####
4+
##### `Diagonal`
5+
#####
6+
7+
rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag)
8+
9+
rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal)
10+
11+
#####
12+
##### `Symmetric`
13+
#####
14+
15+
rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)
16+
17+
_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
18+
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ

src/rules/linalg/symmetric.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/rules/linalg/utils.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Some utility functions for optimizing linear algebra operations that aren't specific
2+
# to any particular rule definition
3+
4+
# F .* (X - X'), overwrites X
5+
function _mulsubtrans!(X::AbstractMatrix{T}, F::AbstractMatrix{T}) where T<:Real
6+
k = size(X, 1)
7+
@inbounds for j = 1:k, i = 1:j # Iterate the upper triangle
8+
if i == j
9+
X[i,i] = zero(T)
10+
else
11+
X[i,j], X[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j])
12+
end
13+
end
14+
X
15+
end
16+
17+
# I - X, overwrites X
18+
function _eyesubx!(X::AbstractMatrix)
19+
n, m = size(X)
20+
@inbounds for j = 1:m, i = 1:n
21+
X[i,j] = (i == j) - X[i,j]
22+
end
23+
X
24+
end
25+
26+
# X + Y, overwrites X
27+
function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real
28+
@inbounds for i = eachindex(X, Y)
29+
X[i] += Y[i]
30+
end
31+
X
32+
end

test/rules/linalg/diagonal.jl renamed to test/rules/linalg/structured.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "diagonal" begin
1+
@testset "Structured Matrices" begin
22
@testset "Diagonal" begin
33
rng, N = MersenneTwister(123456), 3
44
rrule_test(Diagonal, randn(rng, N, N), (randn(rng, N), randn(rng, N)))
@@ -14,4 +14,8 @@
1414
rrule_test(diag, randn(rng, N), (randn(rng, N, N), Diagonal(randn(rng, N))))
1515
rrule_test(diag, randn(rng, N), (Diagonal(randn(rng, N)), Diagonal(randn(rng, N))))
1616
end
17+
@testset "Symmetric" begin
18+
rng, N = MersenneTwister(123456), 3
19+
rrule_test(Symmetric, randn(rng, N, N), (randn(rng, N, N), randn(rng, N, N)))
20+
end
1721
end

test/rules/linalg/symmetric.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ include("test_util.jl")
1717
include(joinpath("rules", "array.jl"))
1818
@testset "linalg" begin
1919
include(joinpath("rules", "linalg", "dense.jl"))
20-
include(joinpath("rules", "linalg", "diagonal.jl"))
21-
include(joinpath("rules", "linalg", "symmetric.jl"))
20+
include(joinpath("rules", "linalg", "structured.jl"))
2221
include(joinpath("rules", "linalg", "factorization.jl"))
2322
end
2423
include(joinpath("rules", "broadcast.jl"))

0 commit comments

Comments
 (0)